1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for WALSMatrixFactorization.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 import json 23 import numpy as np 24 25 from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils 26 from tensorflow.contrib.factorization.python.ops import wals as wals_lib 27 from tensorflow.contrib.learn.python.learn import run_config 28 from tensorflow.contrib.learn.python.learn.estimators import model_fn 29 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import sparse_tensor 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import control_flow_ops 35 from tensorflow.python.ops import embedding_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import sparse_ops 38 from tensorflow.python.ops import state_ops 39 from tensorflow.python.ops import variables 40 from tensorflow.python.platform import test 41 from tensorflow.python.training import input as input_lib 42 from tensorflow.python.training import monitored_session 43 44 45 class WALSMatrixFactorizationTest(test.TestCase): 46 INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX 47 48 def np_array_to_sparse(self, np_array): 49 """Transforms an np.array to a tf.SparseTensor.""" 50 return factorization_ops_test_utils.np_matrix_to_tf_sparse(np_array) 51 52 def calculate_loss(self): 53 """Calculates the loss of the current (trained) model.""" 54 current_rows = embedding_ops.embedding_lookup( 55 self._model.get_row_factors(), math_ops.range(self._num_rows), 56 partition_strategy='div') 57 current_cols = embedding_ops.embedding_lookup( 58 self._model.get_col_factors(), math_ops.range(self._num_cols), 59 partition_strategy='div') 60 row_wts = embedding_ops.embedding_lookup( 61 self._row_weights, math_ops.range(self._num_rows), 62 partition_strategy='div') 63 col_wts = embedding_ops.embedding_lookup( 64 self._col_weights, math_ops.range(self._num_cols), 65 partition_strategy='div') 66 sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX) 67 return factorization_ops_test_utils.calculate_loss( 68 sp_inputs, current_rows, current_cols, self._regularization_coeff, 69 self._unobserved_weight, row_wts, col_wts) 70 71 # TODO(walidk): Replace with input_reader_utils functions once open sourced. 72 def remap_sparse_tensor_rows(self, sp_x, row_ids, shape): 73 """Remaps the row ids of a tf.SparseTensor.""" 74 old_row_ids, old_col_ids = array_ops.split( 75 value=sp_x.indices, num_or_size_splits=2, axis=1) 76 new_row_ids = array_ops.gather(row_ids, old_row_ids) 77 new_indices = array_ops.concat([new_row_ids, old_col_ids], 1) 78 return sparse_tensor.SparseTensor( 79 indices=new_indices, values=sp_x.values, dense_shape=shape) 80 81 # TODO(walidk): Add an option to shuffle inputs. 82 def input_fn(self, np_matrix, batch_size, mode, 83 project_row=None, projection_weights=None, 84 remove_empty_rows_columns=False): 85 """Returns an input_fn that selects row and col batches from np_matrix. 86 87 This simple utility creates an input function from a numpy_array. The 88 following transformations are performed: 89 * The empty rows and columns in np_matrix are removed (if 90 remove_empty_rows_columns is true) 91 * np_matrix is converted to a SparseTensor. 92 * The rows of the sparse matrix (and the rows of its transpose) are batched. 93 * A features dictionary is created, which contains the row / column batches. 94 95 In TRAIN mode, one only needs to specify the np_matrix and the batch_size. 96 In INFER and EVAL modes, one must also provide project_row, a boolean which 97 specifies whether we are projecting rows or columns. 98 99 Args: 100 np_matrix: A numpy array. The input matrix to use. 101 batch_size: Integer. 102 mode: Can be one of model_fn.ModeKeys.{TRAIN, INFER, EVAL}. 103 project_row: A boolean. Used in INFER and EVAL modes. Specifies whether 104 to project rows or columns. 105 projection_weights: A float numpy array. Used in INFER mode. Specifies 106 the weights to use in the projection (the weights are optional, and 107 default to 1.). 108 remove_empty_rows_columns: A boolean. When true, this will remove empty 109 rows and columns in the np_matrix. Note that this will result in 110 modifying the indices of the input matrix. The mapping from new indices 111 to old indices is returned in the form of two numpy arrays. 112 113 Returns: 114 A tuple consisting of: 115 _fn: A callable. Calling _fn returns a features dict. 116 nz_row_ids: A numpy array of the ids of non-empty rows, such that 117 nz_row_ids[i] is the old row index corresponding to new index i. 118 nz_col_ids: A numpy array of the ids of non-empty columns, such that 119 nz_col_ids[j] is the old column index corresponding to new index j. 120 """ 121 if remove_empty_rows_columns: 122 np_matrix, nz_row_ids, nz_col_ids = ( 123 factorization_ops_test_utils.remove_empty_rows_columns(np_matrix)) 124 else: 125 nz_row_ids = np.arange(np.shape(np_matrix)[0]) 126 nz_col_ids = np.arange(np.shape(np_matrix)[1]) 127 128 def extract_features(row_batch, col_batch, shape): 129 row_ids = row_batch[0] 130 col_ids = col_batch[0] 131 rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape) 132 cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape) 133 features = { 134 wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows, 135 wals_lib.WALSMatrixFactorization.INPUT_COLS: cols, 136 } 137 return features 138 139 def _fn(): 140 num_rows = np.shape(np_matrix)[0] 141 num_cols = np.shape(np_matrix)[1] 142 row_ids = math_ops.range(num_rows, dtype=dtypes.int64) 143 col_ids = math_ops.range(num_cols, dtype=dtypes.int64) 144 sp_mat = self.np_array_to_sparse(np_matrix) 145 sp_mat_t = sparse_ops.sparse_transpose(sp_mat) 146 row_batch = input_lib.batch( 147 [row_ids, sp_mat], 148 batch_size=min(batch_size, num_rows), 149 capacity=10, 150 enqueue_many=True) 151 col_batch = input_lib.batch( 152 [col_ids, sp_mat_t], 153 batch_size=min(batch_size, num_cols), 154 capacity=10, 155 enqueue_many=True) 156 157 features = extract_features(row_batch, col_batch, sp_mat.dense_shape) 158 159 if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL: 160 self.assertTrue( 161 project_row is not None, 162 msg='project_row must be specified in INFER or EVAL mode.') 163 features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = ( 164 constant_op.constant(project_row)) 165 166 if mode == model_fn.ModeKeys.INFER and projection_weights is not None: 167 weights_batch = input_lib.batch( 168 projection_weights, 169 batch_size=batch_size, 170 capacity=10, 171 enqueue_many=True) 172 features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = ( 173 weights_batch) 174 175 labels = None 176 return features, labels 177 178 return _fn, nz_row_ids, nz_col_ids 179 180 @property 181 def input_matrix(self): 182 return self.INPUT_MATRIX 183 184 @property 185 def row_steps(self): 186 return np.ceil(self._num_rows / self.batch_size) 187 188 @property 189 def col_steps(self): 190 return np.ceil(self._num_cols / self.batch_size) 191 192 @property 193 def batch_size(self): 194 return 5 195 196 @property 197 def use_cache(self): 198 return False 199 200 @property 201 def max_sweeps(self): 202 return None 203 204 def setUp(self): 205 self._num_rows = 5 206 self._num_cols = 7 207 self._embedding_dimension = 3 208 self._unobserved_weight = 0.1 209 self._num_row_shards = 2 210 self._num_col_shards = 3 211 self._regularization_coeff = 0.01 212 self._col_init = [ 213 # Shard 0. 214 [[-0.36444709, -0.39077035, -0.32528427], 215 [1.19056475, 0.07231052, 2.11834812], 216 [0.93468881, -0.71099287, 1.91826844]], 217 # Shard 1. 218 [[1.18160152, 1.52490723, -0.50015002], 219 [1.82574749, -0.57515913, -1.32810032]], 220 # Shard 2. 221 [[-0.15515432, -0.84675711, 0.13097958], 222 [-0.9246484, 0.69117504, 1.2036494]], 223 ] 224 self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]] 225 self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]] 226 227 # Values of row and column factors after running one iteration or factor 228 # updates. 229 self._row_factors_0 = [[0.097689, -0.219293, -0.020780], 230 [0.50842, 0.64626, 0.22364], 231 [0.401159, -0.046558, -0.192854]] 232 self._row_factors_1 = [[1.20597, -0.48025, 0.35582], 233 [1.5564, 1.2528, 1.0528]] 234 self._col_factors_0 = [[2.4725, -1.2950, -1.9980], 235 [0.44625, 1.50771, 1.27118], 236 [1.39801, -2.10134, 0.73572]] 237 self._col_factors_1 = [[3.36509, -0.66595, -3.51208], 238 [0.57191, 1.59407, 1.33020]] 239 self._col_factors_2 = [[3.3459, -1.3341, -3.3008], 240 [0.57366, 1.83729, 1.26798]] 241 self._model = wals_lib.WALSMatrixFactorization( 242 self._num_rows, 243 self._num_cols, 244 self._embedding_dimension, 245 self._unobserved_weight, 246 col_init=self._col_init, 247 regularization_coeff=self._regularization_coeff, 248 num_row_shards=self._num_row_shards, 249 num_col_shards=self._num_col_shards, 250 row_weights=self._row_weights, 251 col_weights=self._col_weights, 252 max_sweeps=self.max_sweeps, 253 use_factors_weights_cache_for_training=self.use_cache, 254 use_gramian_cache_for_training=self.use_cache) 255 256 def test_fit(self): 257 # Row sweep. 258 input_fn = self.input_fn(np_matrix=self.input_matrix, 259 batch_size=self.batch_size, 260 mode=model_fn.ModeKeys.TRAIN, 261 remove_empty_rows_columns=True)[0] 262 self._model.fit(input_fn=input_fn, steps=self.row_steps) 263 row_factors = self._model.get_row_factors() 264 self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) 265 self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) 266 267 # Col sweep. 268 # Running fit a second time will resume training from the checkpoint. 269 input_fn = self.input_fn(np_matrix=self.input_matrix, 270 batch_size=self.batch_size, 271 mode=model_fn.ModeKeys.TRAIN, 272 remove_empty_rows_columns=True)[0] 273 self._model.fit(input_fn=input_fn, steps=self.col_steps) 274 col_factors = self._model.get_col_factors() 275 self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) 276 self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) 277 self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) 278 279 def test_predict(self): 280 input_fn = self.input_fn(np_matrix=self.input_matrix, 281 batch_size=self.batch_size, 282 mode=model_fn.ModeKeys.TRAIN, 283 remove_empty_rows_columns=True, 284 )[0] 285 # Project rows 1 and 4 from the input matrix. 286 proj_input_fn = self.input_fn( 287 np_matrix=self.INPUT_MATRIX[[1, 4], :], 288 batch_size=2, 289 mode=model_fn.ModeKeys.INFER, 290 project_row=True, 291 projection_weights=[[0.2, 0.5]])[0] 292 293 self._model.fit(input_fn=input_fn, steps=self.row_steps) 294 projections = self._model.get_projections(proj_input_fn) 295 projected_rows = list(itertools.islice(projections, 2)) 296 297 self.assertAllClose( 298 projected_rows, 299 [self._row_factors_0[1], self._row_factors_1[1]], 300 atol=1e-3) 301 302 # Project columns 5, 3, 1 from the input matrix. 303 proj_input_fn = self.input_fn( 304 np_matrix=self.INPUT_MATRIX[:, [5, 3, 1]], 305 batch_size=3, 306 mode=model_fn.ModeKeys.INFER, 307 project_row=False, 308 projection_weights=[[0.6, 0.4, 0.2]])[0] 309 310 self._model.fit(input_fn=input_fn, steps=self.col_steps) 311 projections = self._model.get_projections(proj_input_fn) 312 projected_cols = list(itertools.islice(projections, 3)) 313 self.assertAllClose( 314 projected_cols, 315 [self._col_factors_2[0], self._col_factors_1[0], 316 self._col_factors_0[1]], 317 atol=1e-3) 318 319 def test_eval(self): 320 # Do a row sweep then evaluate the model on row inputs. 321 # The evaluate function returns the loss of the projected rows, but since 322 # projection is idempotent, the eval loss must match the model loss. 323 input_fn = self.input_fn(np_matrix=self.input_matrix, 324 batch_size=self.batch_size, 325 mode=model_fn.ModeKeys.TRAIN, 326 remove_empty_rows_columns=True, 327 )[0] 328 self._model.fit(input_fn=input_fn, steps=self.row_steps) 329 eval_input_fn_row = self.input_fn(np_matrix=self.input_matrix, 330 batch_size=1, 331 mode=model_fn.ModeKeys.EVAL, 332 project_row=True, 333 remove_empty_rows_columns=True)[0] 334 loss = self._model.evaluate( 335 input_fn=eval_input_fn_row, steps=self._num_rows)['loss'] 336 337 with self.test_session(): 338 true_loss = self.calculate_loss() 339 340 self.assertNear( 341 loss, true_loss, err=.001, 342 msg="""After row update, eval loss = {}, does not match the true 343 loss = {}.""".format(loss, true_loss)) 344 345 # Do a col sweep then evaluate the model on col inputs. 346 self._model.fit(input_fn=input_fn, steps=self.col_steps) 347 eval_input_fn_col = self.input_fn(np_matrix=self.input_matrix, 348 batch_size=1, 349 mode=model_fn.ModeKeys.EVAL, 350 project_row=False, 351 remove_empty_rows_columns=True)[0] 352 loss = self._model.evaluate( 353 input_fn=eval_input_fn_col, steps=self._num_cols)['loss'] 354 355 with self.test_session(): 356 true_loss = self.calculate_loss() 357 358 self.assertNear( 359 loss, true_loss, err=.001, 360 msg="""After col update, eval loss = {}, does not match the true 361 loss = {}.""".format(loss, true_loss)) 362 363 364 class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest): 365 366 @property 367 def max_sweeps(self): 368 return 2 369 370 # We set the column steps to None so that we rely only on max_sweeps to stop 371 # training. 372 @property 373 def col_steps(self): 374 return None 375 376 377 class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest): 378 379 @property 380 def use_cache(self): 381 return True 382 383 384 class WALSMatrixFactorizaiontTestPaddedInput(WALSMatrixFactorizationTest): 385 PADDED_INPUT_MATRIX = np.pad( 386 WALSMatrixFactorizationTest.INPUT_MATRIX, 387 [(1, 0), (1, 0)], mode='constant') 388 389 @property 390 def input_matrix(self): 391 return self.PADDED_INPUT_MATRIX 392 393 394 class WALSMatrixFactorizationUnsupportedTest(test.TestCase): 395 396 def setUp(self): 397 pass 398 399 def testDistributedWALSUnsupported(self): 400 tf_config = { 401 'cluster': { 402 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 403 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 404 }, 405 'task': { 406 'type': run_config_lib.TaskType.WORKER, 407 'index': 1 408 } 409 } 410 with test.mock.patch.dict('os.environ', 411 {'TF_CONFIG': json.dumps(tf_config)}): 412 config = run_config.RunConfig() 413 self.assertEqual(config.num_worker_replicas, 2) 414 with self.assertRaises(ValueError): 415 self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config) 416 417 418 class SweepHookTest(test.TestCase): 419 420 def test_sweeps(self): 421 is_row_sweep_var = variables.Variable(True) 422 is_sweep_done_var = variables.Variable(False) 423 init_done = variables.Variable(False) 424 row_prep_done = variables.Variable(False) 425 col_prep_done = variables.Variable(False) 426 row_train_done = variables.Variable(False) 427 col_train_done = variables.Variable(False) 428 429 init_op = state_ops.assign(init_done, True) 430 row_prep_op = state_ops.assign(row_prep_done, True) 431 col_prep_op = state_ops.assign(col_prep_done, True) 432 row_train_op = state_ops.assign(row_train_done, True) 433 col_train_op = state_ops.assign(col_train_done, True) 434 train_op = control_flow_ops.no_op() 435 switch_op = control_flow_ops.group( 436 state_ops.assign(is_sweep_done_var, False), 437 state_ops.assign(is_row_sweep_var, 438 math_ops.logical_not(is_row_sweep_var))) 439 mark_sweep_done = state_ops.assign(is_sweep_done_var, True) 440 441 with self.test_session() as sess: 442 sweep_hook = wals_lib._SweepHook( 443 is_row_sweep_var, 444 is_sweep_done_var, 445 init_op, 446 [row_prep_op], 447 [col_prep_op], 448 row_train_op, 449 col_train_op, 450 switch_op) 451 mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) 452 sess.run([variables.global_variables_initializer()]) 453 454 # Row sweep. 455 mon_sess.run(train_op) 456 self.assertTrue(sess.run(init_done), 457 msg='init op not run by the Sweephook') 458 self.assertTrue(sess.run(row_prep_done), 459 msg='row_prep_op not run by the SweepHook') 460 self.assertTrue(sess.run(row_train_done), 461 msg='row_train_op not run by the SweepHook') 462 self.assertTrue( 463 sess.run(is_row_sweep_var), 464 msg='Row sweep is not complete but is_row_sweep_var is False.') 465 # Col sweep. 466 mon_sess.run(mark_sweep_done) 467 mon_sess.run(train_op) 468 self.assertTrue(sess.run(col_prep_done), 469 msg='col_prep_op not run by the SweepHook') 470 self.assertTrue(sess.run(col_train_done), 471 msg='col_train_op not run by the SweepHook') 472 self.assertFalse( 473 sess.run(is_row_sweep_var), 474 msg='Col sweep is not complete but is_row_sweep_var is True.') 475 # Row sweep. 476 mon_sess.run(mark_sweep_done) 477 mon_sess.run(train_op) 478 self.assertTrue( 479 sess.run(is_row_sweep_var), 480 msg='Col sweep is complete but is_row_sweep_var is False.') 481 482 483 class StopAtSweepHookTest(test.TestCase): 484 485 def test_stop(self): 486 hook = wals_lib._StopAtSweepHook(last_sweep=10) 487 completed_sweeps = variables.Variable( 488 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS) 489 train_op = state_ops.assign_add(completed_sweeps, 1) 490 hook.begin() 491 492 with self.test_session() as sess: 493 sess.run([variables.global_variables_initializer()]) 494 mon_sess = monitored_session._HookedSession(sess, [hook]) 495 mon_sess.run(train_op) 496 # completed_sweeps is 9 after running train_op. 497 self.assertFalse(mon_sess.should_stop()) 498 mon_sess.run(train_op) 499 # completed_sweeps is 10 after running train_op. 500 self.assertTrue(mon_sess.should_stop()) 501 502 503 if __name__ == '__main__': 504 test.main() 505