Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for WALSMatrixFactorization."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import json
     23 import numpy as np
     24 
     25 from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils
     26 from tensorflow.contrib.factorization.python.ops import wals as wals_lib
     27 from tensorflow.contrib.learn.python.learn import run_config
     28 from tensorflow.contrib.learn.python.learn.estimators import model_fn
     29 from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import sparse_tensor
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import control_flow_ops
     35 from tensorflow.python.ops import embedding_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import sparse_ops
     38 from tensorflow.python.ops import state_ops
     39 from tensorflow.python.ops import variables
     40 from tensorflow.python.platform import test
     41 from tensorflow.python.training import input as input_lib
     42 from tensorflow.python.training import monitored_session
     43 
     44 
     45 class WALSMatrixFactorizationTest(test.TestCase):
     46   INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX
     47 
     48   def np_array_to_sparse(self, np_array):
     49     """Transforms an np.array to a tf.SparseTensor."""
     50     return factorization_ops_test_utils.np_matrix_to_tf_sparse(np_array)
     51 
     52   def calculate_loss(self):
     53     """Calculates the loss of the current (trained) model."""
     54     current_rows = embedding_ops.embedding_lookup(
     55         self._model.get_row_factors(), math_ops.range(self._num_rows),
     56         partition_strategy='div')
     57     current_cols = embedding_ops.embedding_lookup(
     58         self._model.get_col_factors(), math_ops.range(self._num_cols),
     59         partition_strategy='div')
     60     row_wts = embedding_ops.embedding_lookup(
     61         self._row_weights, math_ops.range(self._num_rows),
     62         partition_strategy='div')
     63     col_wts = embedding_ops.embedding_lookup(
     64         self._col_weights, math_ops.range(self._num_cols),
     65         partition_strategy='div')
     66     sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX)
     67     return factorization_ops_test_utils.calculate_loss(
     68         sp_inputs, current_rows, current_cols, self._regularization_coeff,
     69         self._unobserved_weight, row_wts, col_wts)
     70 
     71   # TODO(walidk): Replace with input_reader_utils functions once open sourced.
     72   def remap_sparse_tensor_rows(self, sp_x, row_ids, shape):
     73     """Remaps the row ids of a tf.SparseTensor."""
     74     old_row_ids, old_col_ids = array_ops.split(
     75         value=sp_x.indices, num_or_size_splits=2, axis=1)
     76     new_row_ids = array_ops.gather(row_ids, old_row_ids)
     77     new_indices = array_ops.concat([new_row_ids, old_col_ids], 1)
     78     return sparse_tensor.SparseTensor(
     79         indices=new_indices, values=sp_x.values, dense_shape=shape)
     80 
     81   # TODO(walidk): Add an option to shuffle inputs.
     82   def input_fn(self, np_matrix, batch_size, mode,
     83                project_row=None, projection_weights=None,
     84                remove_empty_rows_columns=False):
     85     """Returns an input_fn that selects row and col batches from np_matrix.
     86 
     87     This simple utility creates an input function from a numpy_array. The
     88     following transformations are performed:
     89     * The empty rows and columns in np_matrix are removed (if
     90       remove_empty_rows_columns is true)
     91     * np_matrix is converted to a SparseTensor.
     92     * The rows of the sparse matrix (and the rows of its transpose) are batched.
     93     * A features dictionary is created, which contains the row / column batches.
     94 
     95     In TRAIN mode, one only needs to specify the np_matrix and the batch_size.
     96     In INFER and EVAL modes, one must also provide project_row, a boolean which
     97     specifies whether we are projecting rows or columns.
     98 
     99     Args:
    100       np_matrix: A numpy array. The input matrix to use.
    101       batch_size: Integer.
    102       mode: Can be one of model_fn.ModeKeys.{TRAIN, INFER, EVAL}.
    103       project_row: A boolean. Used in INFER and EVAL modes. Specifies whether
    104         to project rows or columns.
    105       projection_weights: A float numpy array. Used in INFER mode. Specifies
    106         the weights to use in the projection (the weights are optional, and
    107         default to 1.).
    108       remove_empty_rows_columns: A boolean. When true, this will remove empty
    109         rows and columns in the np_matrix. Note that this will result in
    110         modifying the indices of the input matrix. The mapping from new indices
    111         to old indices is returned in the form of two numpy arrays.
    112 
    113     Returns:
    114       A tuple consisting of:
    115       _fn: A callable. Calling _fn returns a features dict.
    116       nz_row_ids: A numpy array of the ids of non-empty rows, such that
    117         nz_row_ids[i] is the old row index corresponding to new index i.
    118       nz_col_ids: A numpy array of the ids of non-empty columns, such that
    119         nz_col_ids[j] is the old column index corresponding to new index j.
    120     """
    121     if remove_empty_rows_columns:
    122       np_matrix, nz_row_ids, nz_col_ids = (
    123           factorization_ops_test_utils.remove_empty_rows_columns(np_matrix))
    124     else:
    125       nz_row_ids = np.arange(np.shape(np_matrix)[0])
    126       nz_col_ids = np.arange(np.shape(np_matrix)[1])
    127 
    128     def extract_features(row_batch, col_batch, shape):
    129       row_ids = row_batch[0]
    130       col_ids = col_batch[0]
    131       rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
    132       cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
    133       features = {
    134           wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
    135           wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
    136       }
    137       return features
    138 
    139     def _fn():
    140       num_rows = np.shape(np_matrix)[0]
    141       num_cols = np.shape(np_matrix)[1]
    142       row_ids = math_ops.range(num_rows, dtype=dtypes.int64)
    143       col_ids = math_ops.range(num_cols, dtype=dtypes.int64)
    144       sp_mat = self.np_array_to_sparse(np_matrix)
    145       sp_mat_t = sparse_ops.sparse_transpose(sp_mat)
    146       row_batch = input_lib.batch(
    147           [row_ids, sp_mat],
    148           batch_size=min(batch_size, num_rows),
    149           capacity=10,
    150           enqueue_many=True)
    151       col_batch = input_lib.batch(
    152           [col_ids, sp_mat_t],
    153           batch_size=min(batch_size, num_cols),
    154           capacity=10,
    155           enqueue_many=True)
    156 
    157       features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
    158 
    159       if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
    160         self.assertTrue(
    161             project_row is not None,
    162             msg='project_row must be specified in INFER or EVAL mode.')
    163         features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = (
    164             constant_op.constant(project_row))
    165 
    166       if mode == model_fn.ModeKeys.INFER and projection_weights is not None:
    167         weights_batch = input_lib.batch(
    168             projection_weights,
    169             batch_size=batch_size,
    170             capacity=10,
    171             enqueue_many=True)
    172         features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = (
    173             weights_batch)
    174 
    175       labels = None
    176       return features, labels
    177 
    178     return _fn, nz_row_ids, nz_col_ids
    179 
    180   @property
    181   def input_matrix(self):
    182     return self.INPUT_MATRIX
    183 
    184   @property
    185   def row_steps(self):
    186     return np.ceil(self._num_rows / self.batch_size)
    187 
    188   @property
    189   def col_steps(self):
    190     return np.ceil(self._num_cols / self.batch_size)
    191 
    192   @property
    193   def batch_size(self):
    194     return 5
    195 
    196   @property
    197   def use_cache(self):
    198     return False
    199 
    200   @property
    201   def max_sweeps(self):
    202     return None
    203 
    204   def setUp(self):
    205     self._num_rows = 5
    206     self._num_cols = 7
    207     self._embedding_dimension = 3
    208     self._unobserved_weight = 0.1
    209     self._num_row_shards = 2
    210     self._num_col_shards = 3
    211     self._regularization_coeff = 0.01
    212     self._col_init = [
    213         # Shard 0.
    214         [[-0.36444709, -0.39077035, -0.32528427],
    215          [1.19056475, 0.07231052, 2.11834812],
    216          [0.93468881, -0.71099287, 1.91826844]],
    217         # Shard 1.
    218         [[1.18160152, 1.52490723, -0.50015002],
    219          [1.82574749, -0.57515913, -1.32810032]],
    220         # Shard 2.
    221         [[-0.15515432, -0.84675711, 0.13097958],
    222          [-0.9246484, 0.69117504, 1.2036494]],
    223     ]
    224     self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]]
    225     self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]
    226 
    227     # Values of row and column factors after running one iteration or factor
    228     # updates.
    229     self._row_factors_0 = [[0.097689, -0.219293, -0.020780],
    230                            [0.50842, 0.64626, 0.22364],
    231                            [0.401159, -0.046558, -0.192854]]
    232     self._row_factors_1 = [[1.20597, -0.48025, 0.35582],
    233                            [1.5564, 1.2528, 1.0528]]
    234     self._col_factors_0 = [[2.4725, -1.2950, -1.9980],
    235                            [0.44625, 1.50771, 1.27118],
    236                            [1.39801, -2.10134, 0.73572]]
    237     self._col_factors_1 = [[3.36509, -0.66595, -3.51208],
    238                            [0.57191, 1.59407, 1.33020]]
    239     self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
    240                            [0.57366, 1.83729, 1.26798]]
    241     self._model = wals_lib.WALSMatrixFactorization(
    242         self._num_rows,
    243         self._num_cols,
    244         self._embedding_dimension,
    245         self._unobserved_weight,
    246         col_init=self._col_init,
    247         regularization_coeff=self._regularization_coeff,
    248         num_row_shards=self._num_row_shards,
    249         num_col_shards=self._num_col_shards,
    250         row_weights=self._row_weights,
    251         col_weights=self._col_weights,
    252         max_sweeps=self.max_sweeps,
    253         use_factors_weights_cache_for_training=self.use_cache,
    254         use_gramian_cache_for_training=self.use_cache)
    255 
    256   def test_fit(self):
    257     # Row sweep.
    258     input_fn = self.input_fn(np_matrix=self.input_matrix,
    259                              batch_size=self.batch_size,
    260                              mode=model_fn.ModeKeys.TRAIN,
    261                              remove_empty_rows_columns=True)[0]
    262     self._model.fit(input_fn=input_fn, steps=self.row_steps)
    263     row_factors = self._model.get_row_factors()
    264     self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
    265     self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)
    266 
    267     # Col sweep.
    268     # Running fit a second time will resume training from the checkpoint.
    269     input_fn = self.input_fn(np_matrix=self.input_matrix,
    270                              batch_size=self.batch_size,
    271                              mode=model_fn.ModeKeys.TRAIN,
    272                              remove_empty_rows_columns=True)[0]
    273     self._model.fit(input_fn=input_fn, steps=self.col_steps)
    274     col_factors = self._model.get_col_factors()
    275     self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
    276     self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
    277     self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
    278 
    279   def test_predict(self):
    280     input_fn = self.input_fn(np_matrix=self.input_matrix,
    281                              batch_size=self.batch_size,
    282                              mode=model_fn.ModeKeys.TRAIN,
    283                              remove_empty_rows_columns=True,
    284                             )[0]
    285     # Project rows 1 and 4 from the input matrix.
    286     proj_input_fn = self.input_fn(
    287         np_matrix=self.INPUT_MATRIX[[1, 4], :],
    288         batch_size=2,
    289         mode=model_fn.ModeKeys.INFER,
    290         project_row=True,
    291         projection_weights=[[0.2, 0.5]])[0]
    292 
    293     self._model.fit(input_fn=input_fn, steps=self.row_steps)
    294     projections = self._model.get_projections(proj_input_fn)
    295     projected_rows = list(itertools.islice(projections, 2))
    296 
    297     self.assertAllClose(
    298         projected_rows,
    299         [self._row_factors_0[1], self._row_factors_1[1]],
    300         atol=1e-3)
    301 
    302     # Project columns 5, 3, 1 from the input matrix.
    303     proj_input_fn = self.input_fn(
    304         np_matrix=self.INPUT_MATRIX[:, [5, 3, 1]],
    305         batch_size=3,
    306         mode=model_fn.ModeKeys.INFER,
    307         project_row=False,
    308         projection_weights=[[0.6, 0.4, 0.2]])[0]
    309 
    310     self._model.fit(input_fn=input_fn, steps=self.col_steps)
    311     projections = self._model.get_projections(proj_input_fn)
    312     projected_cols = list(itertools.islice(projections, 3))
    313     self.assertAllClose(
    314         projected_cols,
    315         [self._col_factors_2[0], self._col_factors_1[0],
    316          self._col_factors_0[1]],
    317         atol=1e-3)
    318 
    319   def test_eval(self):
    320     # Do a row sweep then evaluate the model on row inputs.
    321     # The evaluate function returns the loss of the projected rows, but since
    322     # projection is idempotent, the eval loss must match the model loss.
    323     input_fn = self.input_fn(np_matrix=self.input_matrix,
    324                              batch_size=self.batch_size,
    325                              mode=model_fn.ModeKeys.TRAIN,
    326                              remove_empty_rows_columns=True,
    327                             )[0]
    328     self._model.fit(input_fn=input_fn, steps=self.row_steps)
    329     eval_input_fn_row = self.input_fn(np_matrix=self.input_matrix,
    330                                       batch_size=1,
    331                                       mode=model_fn.ModeKeys.EVAL,
    332                                       project_row=True,
    333                                       remove_empty_rows_columns=True)[0]
    334     loss = self._model.evaluate(
    335         input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
    336 
    337     with self.test_session():
    338       true_loss = self.calculate_loss()
    339 
    340     self.assertNear(
    341         loss, true_loss, err=.001,
    342         msg="""After row update, eval loss = {}, does not match the true
    343         loss = {}.""".format(loss, true_loss))
    344 
    345     # Do a col sweep then evaluate the model on col inputs.
    346     self._model.fit(input_fn=input_fn, steps=self.col_steps)
    347     eval_input_fn_col = self.input_fn(np_matrix=self.input_matrix,
    348                                       batch_size=1,
    349                                       mode=model_fn.ModeKeys.EVAL,
    350                                       project_row=False,
    351                                       remove_empty_rows_columns=True)[0]
    352     loss = self._model.evaluate(
    353         input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
    354 
    355     with self.test_session():
    356       true_loss = self.calculate_loss()
    357 
    358     self.assertNear(
    359         loss, true_loss, err=.001,
    360         msg="""After col update, eval loss = {}, does not match the true
    361         loss = {}.""".format(loss, true_loss))
    362 
    363 
    364 class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest):
    365 
    366   @property
    367   def max_sweeps(self):
    368     return 2
    369 
    370   # We set the column steps to None so that we rely only on max_sweeps to stop
    371   # training.
    372   @property
    373   def col_steps(self):
    374     return None
    375 
    376 
    377 class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest):
    378 
    379   @property
    380   def use_cache(self):
    381     return True
    382 
    383 
    384 class WALSMatrixFactorizaiontTestPaddedInput(WALSMatrixFactorizationTest):
    385   PADDED_INPUT_MATRIX = np.pad(
    386       WALSMatrixFactorizationTest.INPUT_MATRIX,
    387       [(1, 0), (1, 0)], mode='constant')
    388 
    389   @property
    390   def input_matrix(self):
    391     return self.PADDED_INPUT_MATRIX
    392 
    393 
    394 class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
    395 
    396   def setUp(self):
    397     pass
    398 
    399   def testDistributedWALSUnsupported(self):
    400     tf_config = {
    401         'cluster': {
    402             run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
    403             run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
    404         },
    405         'task': {
    406             'type': run_config_lib.TaskType.WORKER,
    407             'index': 1
    408         }
    409     }
    410     with test.mock.patch.dict('os.environ',
    411                               {'TF_CONFIG': json.dumps(tf_config)}):
    412       config = run_config.RunConfig()
    413     self.assertEqual(config.num_worker_replicas, 2)
    414     with self.assertRaises(ValueError):
    415       self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config)
    416 
    417 
    418 class SweepHookTest(test.TestCase):
    419 
    420   def test_sweeps(self):
    421     is_row_sweep_var = variables.Variable(True)
    422     is_sweep_done_var = variables.Variable(False)
    423     init_done = variables.Variable(False)
    424     row_prep_done = variables.Variable(False)
    425     col_prep_done = variables.Variable(False)
    426     row_train_done = variables.Variable(False)
    427     col_train_done = variables.Variable(False)
    428 
    429     init_op = state_ops.assign(init_done, True)
    430     row_prep_op = state_ops.assign(row_prep_done, True)
    431     col_prep_op = state_ops.assign(col_prep_done, True)
    432     row_train_op = state_ops.assign(row_train_done, True)
    433     col_train_op = state_ops.assign(col_train_done, True)
    434     train_op = control_flow_ops.no_op()
    435     switch_op = control_flow_ops.group(
    436         state_ops.assign(is_sweep_done_var, False),
    437         state_ops.assign(is_row_sweep_var,
    438                          math_ops.logical_not(is_row_sweep_var)))
    439     mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
    440 
    441     with self.test_session() as sess:
    442       sweep_hook = wals_lib._SweepHook(
    443           is_row_sweep_var,
    444           is_sweep_done_var,
    445           init_op,
    446           [row_prep_op],
    447           [col_prep_op],
    448           row_train_op,
    449           col_train_op,
    450           switch_op)
    451       mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
    452       sess.run([variables.global_variables_initializer()])
    453 
    454       # Row sweep.
    455       mon_sess.run(train_op)
    456       self.assertTrue(sess.run(init_done),
    457                       msg='init op not run by the Sweephook')
    458       self.assertTrue(sess.run(row_prep_done),
    459                       msg='row_prep_op not run by the SweepHook')
    460       self.assertTrue(sess.run(row_train_done),
    461                       msg='row_train_op not run by the SweepHook')
    462       self.assertTrue(
    463           sess.run(is_row_sweep_var),
    464           msg='Row sweep is not complete but is_row_sweep_var is False.')
    465       # Col sweep.
    466       mon_sess.run(mark_sweep_done)
    467       mon_sess.run(train_op)
    468       self.assertTrue(sess.run(col_prep_done),
    469                       msg='col_prep_op not run by the SweepHook')
    470       self.assertTrue(sess.run(col_train_done),
    471                       msg='col_train_op not run by the SweepHook')
    472       self.assertFalse(
    473           sess.run(is_row_sweep_var),
    474           msg='Col sweep is not complete but is_row_sweep_var is True.')
    475       # Row sweep.
    476       mon_sess.run(mark_sweep_done)
    477       mon_sess.run(train_op)
    478       self.assertTrue(
    479           sess.run(is_row_sweep_var),
    480           msg='Col sweep is complete but is_row_sweep_var is False.')
    481 
    482 
    483 class StopAtSweepHookTest(test.TestCase):
    484 
    485   def test_stop(self):
    486     hook = wals_lib._StopAtSweepHook(last_sweep=10)
    487     completed_sweeps = variables.Variable(
    488         8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS)
    489     train_op = state_ops.assign_add(completed_sweeps, 1)
    490     hook.begin()
    491 
    492     with self.test_session() as sess:
    493       sess.run([variables.global_variables_initializer()])
    494       mon_sess = monitored_session._HookedSession(sess, [hook])
    495       mon_sess.run(train_op)
    496       # completed_sweeps is 9 after running train_op.
    497       self.assertFalse(mon_sess.should_stop())
    498       mon_sess.run(train_op)
    499       # completed_sweeps is 10 after running train_op.
    500       self.assertTrue(mon_sess.should_stop())
    501 
    502 
    503 if __name__ == '__main__':
    504   test.main()
    505