Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for utilities that replicate `Estimator.model_fn` over GPUs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import re
     22 import shutil
     23 import tempfile
     24 import numpy as np
     25 import six
     26 
     27 from tensorflow.contrib.estimator.python.estimator import replicate_model_fn
     28 from tensorflow.python.estimator import estimator as estimator_lib
     29 from tensorflow.python.estimator import model_fn as model_fn_lib
     30 from tensorflow.python.estimator.canned import dnn
     31 from tensorflow.python.estimator.canned import optimizers
     32 from tensorflow.python.estimator.canned import prediction_keys
     33 from tensorflow.python.estimator.export import export
     34 from tensorflow.python.estimator.export import export_output
     35 from tensorflow.python.estimator.inputs import numpy_io
     36 from tensorflow.python.feature_column import feature_column
     37 from tensorflow.python.framework import constant_op
     38 from tensorflow.python.framework import dtypes
     39 from tensorflow.python.framework import ops as ops_lib
     40 from tensorflow.python.framework import sparse_tensor
     41 from tensorflow.python.framework import test_util
     42 from tensorflow.python.ops import array_ops
     43 from tensorflow.python.ops import control_flow_ops
     44 from tensorflow.python.ops import losses
     45 from tensorflow.python.ops import math_ops
     46 from tensorflow.python.ops import metrics as metrics_lib
     47 from tensorflow.python.ops import variable_scope
     48 from tensorflow.python.ops import variables
     49 from tensorflow.python.ops.losses import losses
     50 from tensorflow.python.platform import gfile
     51 from tensorflow.python.platform import test
     52 from tensorflow.python.saved_model import signature_constants
     53 from tensorflow.python.summary.writer import writer_cache
     54 from tensorflow.python.training import adam
     55 from tensorflow.python.training import device_setter
     56 from tensorflow.python.training import gradient_descent
     57 from tensorflow.python.training import training
     58 
     59 
     60 # TODO(isaprykin):  Parametrize all the tests on
     61 #   replicate_model_fn._VariableDistributionMode when it's supported.
     62 class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
     63 
     64   def setUp(self):
     65     self._model_dir = tempfile.mkdtemp()
     66 
     67   def test_complete_flow_with_public_version(self):
     68     return self._complete_flow_with_mode(mode=None)
     69 
     70   def test_complete_flow_with_mode_local_ps_server(self):
     71     return self._complete_flow_with_mode(
     72         replicate_model_fn._VariableDistributionMode.
     73         SHARED_LOCAL_PARAMETER_SERVER)
     74 
     75   def test_complete_flow_with_mode_round_robin(self):
     76     return self._complete_flow_with_mode(
     77         replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)
     78 
     79   def _complete_flow_with_mode(self, mode):
     80     n_classes = 3
     81     input_dimension = 2
     82     batch_size = 12
     83 
     84     data = np.linspace(
     85         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
     86     x_data = data.reshape(batch_size, input_dimension)
     87     categorical_data = np.random.random_integers(
     88         0, len(x_data), size=len(x_data))
     89     y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
     90     train_input_fn = numpy_io.numpy_input_fn(
     91         x={'x': x_data,
     92            'categories': categorical_data},
     93         y=y_data,
     94         batch_size=batch_size,
     95         num_epochs=None,
     96         shuffle=True)
     97     eval_input_fn = numpy_io.numpy_input_fn(
     98         x={'x': x_data,
     99            'categories': categorical_data},
    100         y=y_data,
    101         batch_size=batch_size,
    102         shuffle=False)
    103     predict_input_fn = numpy_io.numpy_input_fn(
    104         x={'x': x_data,
    105            'categories': categorical_data},
    106         batch_size=batch_size,
    107         shuffle=False)
    108 
    109     feature_columns = [
    110         feature_column.numeric_column('x', shape=(input_dimension,)),
    111         feature_column.embedding_column(
    112             feature_column.categorical_column_with_vocabulary_list(
    113                 'categories',
    114                 vocabulary_list=np.linspace(
    115                     0., len(x_data), len(x_data), dtype=np.int64)), 1)
    116     ]
    117 
    118     def optimizer_fn():
    119       return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
    120 
    121     estimator = dnn.DNNClassifier(
    122         hidden_units=(2, 2),
    123         # Adagrad is configured with `get_optimizer_instance`, so the function
    124         # form of `TowerOptimizer.__init__` is used.
    125         optimizer=replicate_model_fn.TowerOptimizer(optimizer_fn),
    126         feature_columns=feature_columns,
    127         n_classes=n_classes,
    128         model_dir=self._model_dir)
    129 
    130     if not mode:  # Use the public `replicate_model_fn`.
    131       model_fn = replicate_model_fn.replicate_model_fn(
    132           estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2'])
    133     else:
    134       model_fn = replicate_model_fn._replicate_model_fn_with_mode(
    135           estimator.model_fn,
    136           devices=['/gpu:0', '/gpu:1', '/gpu:2'],
    137           loss_reduction=losses.Reduction.SUM,
    138           mode=mode)
    139 
    140     estimator = estimator_lib.Estimator(
    141         model_fn=model_fn,
    142         model_dir=estimator.model_dir,
    143         config=estimator.config,
    144         params=estimator.params)
    145 
    146     num_steps = 10
    147     estimator.train(train_input_fn, steps=num_steps)
    148 
    149     scores = estimator.evaluate(eval_input_fn)
    150     self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP])
    151     self.assertIn('loss', six.iterkeys(scores))
    152 
    153     predicted_proba = np.array([
    154         x[prediction_keys.PredictionKeys.PROBABILITIES]
    155         for x in estimator.predict(predict_input_fn)
    156     ])
    157     self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
    158 
    159     feature_spec = feature_column.make_parse_example_spec(feature_columns)
    160     serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
    161         feature_spec)
    162     export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
    163                                              serving_input_receiver_fn)
    164     self.assertTrue(gfile.Exists(export_dir))
    165 
    166     # Nothing should be left in the graph so that it doesn't get serialized.
    167     self.assertFalse(ops_lib.get_default_graph().get_collection_ref(
    168         replicate_model_fn.TowerOptimizer.COLLECTION_FOR_GRAPH_STATES))
    169 
    170   def _as_label(self, data_in_float):
    171     return np.rint(data_in_float).astype(np.int64)
    172 
    173   def tearDown(self):
    174     if self._model_dir:
    175       writer_cache.FileWriterCache.clear()
    176       shutil.rmtree(self._model_dir)
    177 
    178 
    179 class ReplicateModelTest(test_util.TensorFlowTestCase):
    180 
    181   def model_fn(self, mode, features, labels, params):
    182     c = variable_scope.get_variable(
    183         'c',
    184         initializer=constant_op.constant(10, dtype=dtypes.float64),
    185         dtype=dtypes.float64)
    186 
    187     predictions = math_ops.multiply(features, c)
    188 
    189     loss = losses.absolute_difference(
    190         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    191     loss = math_ops.reduce_sum(loss)
    192 
    193     metrics = {
    194         'accuracy': metrics_lib.accuracy(labels, predictions),
    195         'auc': metrics_lib.auc(labels, predictions)
    196     }
    197 
    198     optimizer = replicate_model_fn.TowerOptimizer(
    199         gradient_descent.GradientDescentOptimizer(params['learning_rate']))
    200 
    201     return model_fn_lib.EstimatorSpec(
    202         mode=mode,
    203         loss=loss,
    204         eval_metric_ops=metrics,
    205         predictions={'probabilities': predictions},
    206         train_op=optimizer.minimize(loss))
    207 
    208   @property
    209   def params(self):
    210     params = {}
    211     params['learning_rate'] = 1.0
    212     return params
    213 
    214   def test_train(self):
    215     features = np.array([[1.0], [2.0]])
    216     labels = np.array([[1.0], [2.0]])
    217 
    218     with self.test_session() as session:
    219       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    220           self.model_fn,
    221           loss_reduction=losses.Reduction.SUM,
    222           devices=['/gpu:0', '/gpu:1'])
    223       estimator_spec = replicated_model_fn(
    224           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    225       session.run(variables.global_variables_initializer())
    226 
    227       # loss = feature * c - label
    228       total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
    229       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    230 
    231       # derivative of loss = (1*c - 1) + (2*c - 2) is 3.
    232       # new value of c = 10 - learning rate * 3 = 7.0.
    233       session.run(estimator_spec.train_op)
    234       with variable_scope.variable_scope('', reuse=True):
    235         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    236         self.assertEqual(7.0, session.run(c))
    237 
    238   def test_train_with_mean_reduction(self):
    239     features = np.array([[1.0], [2.0]])
    240     labels = np.array([[1.0], [2.0]])
    241 
    242     with self.test_session() as session:
    243       # Add another trainable variable that doesn't produce a gradient to
    244       # verify that None gradients are supported.
    245       _ = variable_scope.get_variable(
    246           'another_variable',
    247           initializer=constant_op.constant(1, dtype=dtypes.float64),
    248           dtype=dtypes.float64)
    249 
    250       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    251           self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
    252       estimator_spec = replicated_model_fn(
    253           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    254       session.run(variables.global_variables_initializer())
    255 
    256       # loss = feature * c - label
    257       total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0
    258       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    259 
    260       # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5.
    261       # It's the same computation as without mean reduction, but the
    262       # loss from every tower is scaled by 1/<number of towers>.
    263       # new value of c = 10 - learning rate * 1.5 = 8.5
    264       session.run(estimator_spec.train_op)
    265       with variable_scope.variable_scope('', reuse=True):
    266         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    267         self.assertEqual(8.5, session.run(c))
    268 
    269   def test_train_two_steps_collected_gradients_are_reset_between_steps(self):
    270     with ops_lib.Graph().as_default():
    271       features = array_ops.placeholder(dtypes.float64)
    272       labels = array_ops.placeholder(dtypes.float64)
    273 
    274       feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
    275       label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]])
    276 
    277       # loss = feature * c - label
    278       expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0),
    279                          (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5))
    280       # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5
    281       # for the second.
    282       expected_c = 10.0 - 3.0, 7.0 - 4.0
    283 
    284       with self.test_session() as session, variable_scope.variable_scope(
    285           '', reuse=variable_scope.AUTO_REUSE):
    286         replicated_model_fn = replicate_model_fn.replicate_model_fn(
    287             self.model_fn,
    288             loss_reduction=losses.Reduction.SUM,
    289             devices=['/gpu:0', '/gpu:1'])
    290         estimator_spec = replicated_model_fn(
    291             features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    292         session.run(variables.global_variables_initializer())
    293 
    294         for feature_input, label_input, loss, weight in zip(
    295             feature_inputs, label_inputs, expected_losses, expected_c):
    296           feeds = {features: feature_input, labels: label_input}
    297 
    298           self.assertEqual(loss, session.run(estimator_spec.loss, feeds))
    299 
    300           session.run(estimator_spec.train_op, feeds)
    301           c = variable_scope.get_variable('c', dtype=dtypes.float64)
    302           self.assertEqual(weight, session.run(c, feeds))
    303 
    304   def test_eval(self):
    305     features = np.array([[0.01], [0.002]])
    306     labels = np.array([[0.01], [0.02]])
    307 
    308     with self.test_session() as session:
    309       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    310           self.model_fn,
    311           loss_reduction=losses.Reduction.SUM,
    312           devices=['/gpu:0', '/gpu:1'])
    313       estimator_spec = replicated_model_fn(
    314           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
    315       session.run(variables.local_variables_initializer())
    316       session.run(variables.global_variables_initializer())
    317 
    318       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
    319       auc, b = estimator_spec.eval_metric_ops['auc']
    320 
    321       session.run([a, b])
    322       accuracy = session.run(accuracy)
    323       auc = session.run(auc)
    324 
    325       # loss[i] = features[i] * 10 - labels[i].
    326       # Accuracy is 0.0 (no match) in the first tower.
    327       # Accuracy is 1.0 (match) in the second tower, since the feature
    328       # times weight "c" happened to be equal to the label.
    329       total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
    330 
    331       self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
    332       self.assertEqual(0, auc)
    333       self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
    334 
    335   def test_eval_with_mean_reduction(self):
    336     features = np.array([[0.01], [0.002]])
    337     labels = np.array([[0.01], [0.02]])
    338 
    339     with self.test_session() as session:
    340       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    341           self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
    342       estimator_spec = replicated_model_fn(
    343           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
    344       session.run(variables.local_variables_initializer())
    345       session.run(variables.global_variables_initializer())
    346 
    347       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
    348       auc, b = estimator_spec.eval_metric_ops['auc']
    349 
    350       session.run([a, b])
    351       accuracy = session.run(accuracy)
    352       auc = session.run(auc)
    353 
    354       # loss[i] = features[i] * 10 - labels[i].
    355       # Accuracy is 0.0 (no match) in the first tower.
    356       # Accuracy is 1.0 (match) in the second tower, since the feature
    357       # times weight "c" happened to be equal to the label.
    358       total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0
    359 
    360       self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
    361       self.assertEqual(0, auc)
    362       self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
    363 
    364   def test_predict(self):
    365     features = np.array([[0.01], [0.002]])
    366     labels = np.array([[0.01], [0.02]])
    367 
    368     with self.test_session() as session:
    369       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    370           self.model_fn, devices=['/gpu:0', '/gpu:1'])
    371       estimator_spec = replicated_model_fn(
    372           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
    373       session.run(variables.global_variables_initializer())
    374 
    375       self.assertAllClose({
    376           'probabilities': np.array([[0.1], [0.02]])
    377       }, session.run(estimator_spec.predictions))
    378 
    379   def test_train_single_tower(self):
    380     features = np.array([[1.0], [2.0]])
    381     labels = np.array([[1.0], [2.0]])
    382 
    383     with self.test_session() as session:
    384       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    385           self.model_fn, devices=['/gpu:0'])
    386       estimator_spec = replicated_model_fn(
    387           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    388       session.run(variables.global_variables_initializer())
    389 
    390       # loss = feature * c - label
    391       total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
    392       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    393 
    394       # loss' of c is 3.
    395       # new value of c = 10 - learning rate * 3 = 7.0.
    396       session.run(estimator_spec.train_op)
    397       with variable_scope.variable_scope('', reuse=True):
    398         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    399         self.assertEqual(7.0, session.run(c))
    400 
    401   def test_eval_single_tower(self):
    402     features = np.array([[0.01], [0.002]])
    403     labels = np.array([[0.01], [0.02]])
    404 
    405     with self.test_session() as session:
    406       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    407           self.model_fn, devices=['/gpu:0'])
    408       estimator_spec = replicated_model_fn(
    409           features, labels, model_fn_lib.ModeKeys.EVAL, self.params)
    410       session.run(variables.local_variables_initializer())
    411       session.run(variables.global_variables_initializer())
    412 
    413       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
    414       auc, b = estimator_spec.eval_metric_ops['auc']
    415 
    416       session.run([a, b])
    417       accuracy = session.run(accuracy)
    418       auc = session.run(auc)
    419 
    420       # Accuracy is 0.0 (no match) in the first tower.
    421       # Accuracy is 1.0 (match) in the second tower, since the feature
    422       # times weight "c" happened to be equal to the label.
    423       total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
    424 
    425       self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
    426       self.assertEqual(0, auc)
    427       self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
    428 
    429   def test_predict_single_tower(self):
    430     features = np.array([[0.01], [0.002]])
    431     labels = np.array([[0.01], [0.02]])
    432 
    433     with self.test_session() as session:
    434       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    435           self.model_fn, devices=['/gpu:0'])
    436       estimator_spec = replicated_model_fn(
    437           features, labels, model_fn_lib.ModeKeys.PREDICT, self.params)
    438       session.run(variables.global_variables_initializer())
    439 
    440       self.assertAllClose({
    441           'probabilities': np.array([[0.1], [0.02]])
    442       }, session.run(estimator_spec.predictions))
    443 
    444   def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
    445     features = np.array([[1.0], [2.0], [3.0]])
    446     labels = np.array([[1.0], [2.0], [3.0]])
    447 
    448     with self.assertRaisesRegexp(
    449         ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
    450       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    451           self.model_fn, devices=['/gpu:0', '/gpu:1'])
    452       _ = replicated_model_fn(
    453           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    454 
    455   def test_unsupported_loss_reduction(self):
    456     with self.assertRaisesRegexp(ValueError,
    457                                  '.+none.+reduction.+is.+specified.+'):
    458       _ = replicate_model_fn.replicate_model_fn(self.model_fn,
    459                                                 losses.Reduction.NONE)
    460 
    461   def test_places_on_gpu_with_upper_case_spelling(self):
    462     features = np.array([[0.01], [0.002]])
    463     labels = np.array([[0.01], [0.02]])
    464 
    465     with self.test_session():
    466       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    467           self.model_fn, devices=['/GPU:0'])
    468       _ = replicated_model_fn(
    469           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    470 
    471       with variable_scope.variable_scope('', reuse=True):
    472         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    473         self.assertEqual('/device:GPU:0', c.device)
    474 
    475   def test_places_on_gpu_with_lower_case_spelling(self):
    476     features = np.array([[0.01], [0.002]])
    477     labels = np.array([[0.01], [0.02]])
    478 
    479     with self.test_session():
    480       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    481           self.model_fn, devices=['/gpu:0'])
    482       _ = replicated_model_fn(
    483           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    484 
    485       with variable_scope.variable_scope('', reuse=True):
    486         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    487         self.assertEqual('/device:GPU:0', c.device)
    488 
    489 
    490 class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
    491     test_util.TensorFlowTestCase):
    492 
    493   def model_fn(self, mode, features, labels, params):
    494     c = variable_scope.get_variable(
    495         'c',
    496         initializer=constant_op.constant(10, dtype=dtypes.float64),
    497         dtype=dtypes.float64)
    498 
    499     predictions = math_ops.multiply(features, c)
    500 
    501     loss = losses.absolute_difference(
    502         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    503     loss = math_ops.reduce_sum(loss)
    504 
    505     metrics = {
    506         'accuracy': metrics_lib.accuracy(labels, predictions),
    507         'auc': metrics_lib.auc(labels, predictions)
    508     }
    509 
    510     optimizer = gradient_descent.GradientDescentOptimizer(
    511         params['learning_rate'])
    512 
    513     return model_fn_lib.EstimatorSpec(
    514         mode=mode,
    515         loss=loss,
    516         eval_metric_ops=metrics,
    517         predictions={'probabilities': predictions},
    518         train_op=optimizer.minimize(loss))
    519 
    520   @property
    521   def params(self):
    522     params = {}
    523     params['learning_rate'] = 1.0
    524     return params
    525 
    526   def test_train_single_tower(self):
    527     features = np.array([[1.0], [2.0]])
    528     labels = np.array([[1.0], [2.0]])
    529 
    530     with self.test_session() as session:
    531       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    532           self.model_fn, devices=['/gpu:0'])
    533       estimator_spec = replicated_model_fn(
    534           features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
    535       session.run(variables.global_variables_initializer())
    536 
    537       # loss = feature * c - label
    538       total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
    539       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    540 
    541       # loss' of c is 3.
    542       # new value of c = 10 - learning rate * 3 = 7.0.
    543       session.run(estimator_spec.train_op)
    544       with variable_scope.variable_scope('', reuse=True):
    545         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    546         self.assertEqual(7.0, session.run(c))
    547 
    548 
    549 class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
    550 
    551   def model_fn(self, mode, features, labels, params):
    552     c = variable_scope.get_variable(
    553         'c',
    554         initializer=constant_op.constant(10, dtype=dtypes.float64),
    555         dtype=dtypes.float64)
    556 
    557     features = features['features']
    558     predictions = math_ops.multiply(features, c)
    559 
    560     loss = losses.absolute_difference(
    561         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    562     loss = math_ops.reduce_sum(loss)
    563 
    564     metrics = {
    565         'accuracy': metrics_lib.accuracy(labels, predictions),
    566         'auc': metrics_lib.auc(labels, predictions)
    567     }
    568 
    569     optimizer = replicate_model_fn.TowerOptimizer(
    570         gradient_descent.GradientDescentOptimizer(params['learning_rate']))
    571 
    572     return model_fn_lib.EstimatorSpec(
    573         mode=mode,
    574         loss=loss,
    575         eval_metric_ops=metrics,
    576         predictions={'probabilities': predictions},
    577         train_op=optimizer.minimize(loss))
    578 
    579   @property
    580   def params(self):
    581     params = {}
    582     params['learning_rate'] = 1.0
    583     return params
    584 
    585   def test_train_single_tower(self):
    586     features = np.array([[1.0], [2.0]])
    587     labels = np.array([[1.0], [2.0]])
    588 
    589     train_input_fn = numpy_io.numpy_input_fn(
    590         x={'features': features}, y=labels, batch_size=2, shuffle=False)
    591 
    592     with self.test_session():
    593       estimator = estimator_lib.Estimator(
    594           model_fn=self.model_fn,
    595           model_dir=tempfile.mkdtemp(),
    596           params=self.params)
    597       estimator.train(train_input_fn, steps=1)
    598 
    599       self.assertEqual(7.0, estimator.get_variable_value('c'))
    600 
    601 
    602 class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
    603 
    604   def model_fn(self, mode, features, labels, params):
    605     c = variable_scope.get_variable(
    606         'c',
    607         initializer=constant_op.constant(10, dtype=dtypes.float64),
    608         dtype=dtypes.float64)
    609 
    610     features = features['features']
    611     predictions = math_ops.multiply(features, c)
    612 
    613     loss = losses.absolute_difference(
    614         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    615     loss = math_ops.reduce_sum(loss)
    616 
    617     metrics = {
    618         'accuracy': metrics_lib.accuracy(labels, predictions),
    619         'auc': metrics_lib.auc(labels, predictions)
    620     }
    621 
    622     optimizer = gradient_descent.GradientDescentOptimizer(
    623         params['learning_rate'])
    624     optimizer = training.SyncReplicasOptimizer(
    625         optimizer, replicas_to_aggregate=1)
    626     sync_hook = optimizer.make_session_run_hook(True)
    627     optimizer = replicate_model_fn.TowerOptimizer(optimizer)
    628 
    629     return model_fn_lib.EstimatorSpec(
    630         mode=mode,
    631         loss=loss,
    632         eval_metric_ops=metrics,
    633         training_hooks=[sync_hook],
    634         predictions={'probabilities': predictions},
    635         train_op=optimizer.minimize(
    636             loss, global_step=training.get_global_step()))
    637 
    638   @property
    639   def params(self):
    640     params = {}
    641     params['learning_rate'] = 1.0
    642     return params
    643 
    644   def test_train_multiple_towers(self):
    645     features = np.array([[1.0], [2.0]])
    646     labels = np.array([[1.0], [2.0]])
    647 
    648     train_input_fn = numpy_io.numpy_input_fn(
    649         x={'features': features}, y=labels, batch_size=2, shuffle=False)
    650 
    651     model_fn = replicate_model_fn.replicate_model_fn(
    652         self.model_fn,
    653         loss_reduction=losses.Reduction.SUM,
    654         devices=['/gpu:0', '/gpu:1'])
    655 
    656     estimator = estimator_lib.Estimator(
    657         model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params)
    658     estimator.train(train_input_fn, steps=1)
    659 
    660     self.assertEqual(7.0, estimator.get_variable_value('c'))
    661 
    662 
    663 class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
    664 
    665   def model_fn(self, mode, features, labels, params):
    666     c = variable_scope.get_variable(
    667         'c',
    668         initializer=constant_op.constant(10, dtype=dtypes.float64),
    669         dtype=dtypes.float64)
    670 
    671     side_effects = variable_scope.get_variable(
    672         'side_effects',
    673         initializer=constant_op.constant(0, dtype=dtypes.float64),
    674         dtype=dtypes.float64,
    675         trainable=False)
    676 
    677     predictions = math_ops.multiply(features, c)
    678 
    679     loss = losses.absolute_difference(
    680         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    681     loss = math_ops.reduce_sum(loss)
    682 
    683     metrics = {
    684         'accuracy': metrics_lib.accuracy(labels, predictions),
    685         'auc': metrics_lib.auc(labels, predictions)
    686     }
    687 
    688     first_optimizer = replicate_model_fn.TowerOptimizer(
    689         gradient_descent.GradientDescentOptimizer(1.0))
    690     second_optimizer = replicate_model_fn.TowerOptimizer(
    691         adam.AdamOptimizer(1.0))
    692 
    693     with ops_lib.control_dependencies([side_effects.assign_add(1.0)]):
    694       first_grads_and_vars = first_optimizer.compute_gradients(loss)
    695 
    696     train_op = control_flow_ops.group(
    697         [first_optimizer.apply_gradients(first_grads_and_vars),
    698          second_optimizer.minimize(loss)])
    699 
    700     return model_fn_lib.EstimatorSpec(
    701         mode=mode,
    702         loss=loss,
    703         eval_metric_ops=metrics,
    704         predictions={'probabilities': predictions},
    705         train_op=train_op)
    706 
    707   def test_train(self):
    708     features = np.array([[1.0], [2.0]])
    709     labels = np.array([[1.0], [2.0]])
    710 
    711     with self.test_session() as session:
    712       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    713           self.model_fn,
    714           loss_reduction=losses.Reduction.SUM,
    715           devices=['/gpu:0', '/gpu:1'])
    716       estimator_spec = replicated_model_fn(features, labels,
    717                                            model_fn_lib.ModeKeys.TRAIN, {})
    718       session.run(variables.global_variables_initializer())
    719 
    720       # loss = feature * c - label
    721       total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)
    722       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    723 
    724       # loss' of c is 3.
    725       # new value of c = 10 - learning rate * 3 = 7.0.
    726       # Adam subtracts another ~1.
    727       session.run(estimator_spec.train_op)
    728       with variable_scope.variable_scope('', reuse=True):
    729         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    730         self.assertNear(6.0, session.run(c), 0.000001)
    731 
    732         side_effects = variable_scope.get_variable(
    733             'side_effects', dtype=dtypes.float64)
    734         self.assertNear(2.0, session.run(side_effects), 0.000001)
    735 
    736 
    737 class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
    738 
    739   def setUp(self):
    740     self._should_skip_optimizer = False
    741     self._towers_left_before_skipping_optimizer = -1
    742 
    743   def incorrectly_skip_optimizer_for_tower(self, tower_number):
    744     self._should_skip_optimizer = True
    745     self._towers_left_before_skipping_optimizer = tower_number
    746 
    747   def should_skip_optimizer(self):
    748     if not self._should_skip_optimizer:
    749       return False
    750     if self._towers_left_before_skipping_optimizer == 0:
    751       return True
    752     else:
    753       self._towers_left_before_skipping_optimizer -= 1
    754       return False
    755 
    756   def model_fn(self, mode, features, labels, params):
    757     c = variable_scope.get_variable(
    758         'c',
    759         initializer=constant_op.constant(10, dtype=dtypes.float64),
    760         dtype=dtypes.float64)
    761     d = variable_scope.get_variable(
    762         'd',
    763         initializer=constant_op.constant(2, dtype=dtypes.float64),
    764         dtype=dtypes.float64)
    765 
    766     predictions = math_ops.multiply(features, c)
    767 
    768     loss = losses.absolute_difference(
    769         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    770     loss = math_ops.reduce_sum(loss)
    771 
    772     another_predictions = math_ops.multiply(features, d)
    773     another_loss = losses.absolute_difference(
    774         labels=labels,
    775         predictions=another_predictions,
    776         reduction=losses.Reduction.SUM)
    777     another_loss = math_ops.reduce_sum(another_loss)
    778 
    779     total_loss = math_ops.add(loss, another_loss)
    780 
    781     metrics = {
    782         'accuracy': metrics_lib.accuracy(labels, predictions),
    783         'auc': metrics_lib.auc(labels, predictions)
    784     }
    785 
    786     train_ops = []
    787 
    788     optimizer = replicate_model_fn.TowerOptimizer(
    789         gradient_descent.GradientDescentOptimizer(1.0))
    790     train_ops.append(optimizer.minimize(loss, var_list=[c]))
    791     if not self.should_skip_optimizer():
    792       another_optimizer = replicate_model_fn.TowerOptimizer(
    793           gradient_descent.GradientDescentOptimizer(1.0))
    794       train_ops.append(another_optimizer.minimize(another_loss, var_list=[d]))
    795 
    796     train_op = control_flow_ops.group(train_ops)
    797     return model_fn_lib.EstimatorSpec(
    798         mode=mode,
    799         loss=total_loss,
    800         eval_metric_ops=metrics,
    801         predictions={'probabilities': predictions},
    802         train_op=train_op)
    803 
    804   def test_train(self):
    805     features = np.array([[1.0], [2.0]])
    806     labels = np.array([[1.0], [2.0]])
    807 
    808     with self.test_session() as session:
    809       replicated_model_fn = replicate_model_fn.replicate_model_fn(
    810           self.model_fn,
    811           loss_reduction=losses.Reduction.SUM,
    812           devices=['/gpu:0', '/gpu:1'])
    813       estimator_spec = replicated_model_fn(features, labels,
    814                                            model_fn_lib.ModeKeys.TRAIN, {})
    815       session.run(variables.global_variables_initializer())
    816 
    817       # For each tower, loss = (feature * c - label) + (feature * d - label).
    818       total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + (
    819           2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0)
    820       self.assertEqual(total_loss, session.run(estimator_spec.loss))
    821 
    822       session.run(estimator_spec.train_op)
    823 
    824       # loss' of c or loss' of d is 3.
    825       # new value of c = 10 - learning rate * 3 = 7.0.
    826       # new value of d = 2  - learning rate * 3 = -1.0.
    827       with variable_scope.variable_scope('', reuse=True):
    828         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    829         self.assertNear(7.0, session.run(c), 0.000001)
    830         d = variable_scope.get_variable('d', dtype=dtypes.float64)
    831         self.assertNear(-1.0, session.run(d), 0.000001)
    832 
    833   def test_different_optimizer_calls_within_towers(self):
    834     self.incorrectly_skip_optimizer_for_tower(1)
    835 
    836     features = np.array([[1.0], [2.0]])
    837     labels = np.array([[1.0], [2.0]])
    838 
    839     with self.test_session(), ops_lib.Graph().as_default():
    840       with self.assertRaisesRegexp(
    841           ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
    842         replicated_model_fn = replicate_model_fn.replicate_model_fn(
    843             self.model_fn, devices=['/gpu:0', '/gpu:1'])
    844         _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
    845                                 {})
    846 
    847 
    848 class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
    849 
    850   def model_fn(self, mode, features, labels, params):
    851     c = variable_scope.get_variable(
    852         'c',
    853         initializer=constant_op.constant(10, dtype=dtypes.float64),
    854         dtype=dtypes.float64)
    855 
    856     predictions = math_ops.multiply(features, c)
    857 
    858     loss = losses.absolute_difference(
    859         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    860     loss = math_ops.reduce_sum(loss)
    861 
    862     metrics = {
    863         'accuracy': metrics_lib.accuracy(labels, predictions),
    864         'auc': metrics_lib.auc(labels, predictions)
    865     }
    866 
    867     optimizer = gradient_descent.GradientDescentOptimizer(1.0)
    868     train_op = optimizer.minimize(loss)
    869 
    870     return model_fn_lib.EstimatorSpec(
    871         mode=mode,
    872         loss=loss,
    873         eval_metric_ops=metrics,
    874         predictions={'probabilities': predictions},
    875         train_op=train_op)
    876 
    877   def test_train(self):
    878     features = np.array([[1.0], [2.0]])
    879     labels = np.array([[1.0], [2.0]])
    880 
    881     with self.test_session():
    882       with self.assertRaisesRegexp(ValueError,
    883                                    'Please.+wrap.+with.+TowerOptimizer'):
    884         replicated_model_fn = replicate_model_fn.replicate_model_fn(
    885             self.model_fn, devices=['/gpu:0', '/gpu:1'])
    886         _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN,
    887                                 {})
    888 
    889 
    890 class GetLossTowersTest(test_util.TensorFlowTestCase):
    891 
    892   def model_fn(self, mode, features, labels, params):
    893     c = variable_scope.get_variable(
    894         'c',
    895         initializer=constant_op.constant(0.25, dtype=dtypes.float64),
    896         dtype=dtypes.float64)
    897 
    898     predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
    899     labels = np.array([0.1, 0.2, 0.3, labels[0]])
    900 
    901     loss = losses.absolute_difference(
    902         labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
    903 
    904     return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
    905 
    906   def test_gradients_are_computed(self):
    907     with self.test_session() as session:
    908       tower_specs = replicate_model_fn._get_loss_towers(
    909           self.model_fn,
    910           mode=None,
    911           features=[[0.6], [1.6]],
    912           labels=[[0.6], [0.6]],
    913           params=None,
    914           config=None,
    915           loss_reduction=losses.Reduction.SUM,
    916           devices=['/gpu:0', '/gpu:1'],
    917           local_ps_devices=['/gpu:0'],
    918           name_scope_pattern='test_tower_{}')
    919       session.run(variables.global_variables_initializer())
    920 
    921       self.assertEqual(len(tower_specs), 2)
    922 
    923       self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
    924       self.assertEqual('Sum:0', tower_specs[0].loss.name)
    925       self.assertEqual(1.0, session.run(tower_specs[0].loss))
    926 
    927       self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
    928       self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name)
    929       # The input batch for the second tower had a loss that is 1.0
    930       # bigger: 0.6 vs 1.6.
    931       self.assertEqual(2.0, session.run(tower_specs[1].loss))
    932 
    933       self.assertEqual(1, len(variables.global_variables()))
    934       self.assertEqual(1, len(variables.trainable_variables()))
    935 
    936       with variable_scope.variable_scope('', reuse=True):
    937         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    938         self.assertEqual(0.25, session.run(c))
    939 
    940   def test_gradients_are_computed_with_mean_reduction(self):
    941     with self.test_session() as session:
    942       tower_specs = replicate_model_fn._get_loss_towers(
    943           self.model_fn,
    944           mode=model_fn_lib.ModeKeys.EVAL,
    945           features=[[0.6], [1.6]],
    946           labels=[[0.6], [0.6]],
    947           params=None,
    948           loss_reduction=losses.Reduction.MEAN,
    949           config=None,
    950           devices=['/gpu:0', '/gpu:1'],
    951           local_ps_devices=['/gpu:0'],
    952           name_scope_pattern='test_tower_{}')
    953       session.run(variables.global_variables_initializer())
    954 
    955       self.assertEqual(len(tower_specs), 2)
    956 
    957       self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
    958       self.assertEqual('averaged_loss:0', tower_specs[0].loss.name)
    959       self.assertEqual(0.5, session.run(tower_specs[0].loss))
    960 
    961       self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
    962       self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name)
    963       # The input batch for the second tower had a loss that is 1.0
    964       # bigger: 0.6 vs 1.6.
    965       self.assertEqual(1.0, session.run(tower_specs[1].loss))
    966 
    967       self.assertEqual(1, len(variables.global_variables()))
    968       self.assertEqual(1, len(variables.trainable_variables()))
    969 
    970       with variable_scope.variable_scope('', reuse=True):
    971         c = variable_scope.get_variable('c', dtype=dtypes.float64)
    972         self.assertEqual(0.25, session.run(c))
    973 
    974   def test_variables_are_round_robined_correctly(self):
    975     """Test that creates multiple variables and tests round-robin placement."""
    976 
    977     def model_fn(mode, features, labels, params):
    978       del params
    979       for variable_name in ['a', 'b', 'c', 'd']:
    980         c = variable_scope.get_variable(
    981             variable_name,
    982             initializer=constant_op.constant(0.25, dtype=dtypes.float64),
    983             dtype=dtypes.float64)
    984 
    985       predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
    986       labels = np.array([0.1, 0.2, 0.3, labels[0]])
    987       loss = losses.absolute_difference(
    988           labels=labels,
    989           predictions=predictions,
    990           reduction=losses.Reduction.SUM)
    991       return model_fn_lib.EstimatorSpec(
    992           mode=mode, loss=math_ops.reduce_sum(loss))
    993 
    994     with self.test_session() as session:
    995       tower_specs = replicate_model_fn._get_loss_towers(
    996           model_fn,
    997           mode=None,
    998           features=[[0.6], [1.6], [2.6]],
    999           labels=[[0.6], [0.6], [2.6]],
   1000           params=None,
   1001           loss_reduction=losses.Reduction.SUM,
   1002           config=None,
   1003           devices=['/gpu:0', '/gpu:1', '/gpu:3'],
   1004           local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
   1005           name_scope_pattern='test_tower_{}')
   1006       session.run(variables.global_variables_initializer())
   1007 
   1008       self.assertEqual(len(tower_specs), 3)
   1009       self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
   1010       self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
   1011       self.assertEqual('/device:GPU:3', tower_specs[2].loss.device)
   1012 
   1013       with variable_scope.variable_scope('', reuse=True):
   1014         a = variable_scope.get_variable('a', dtype=dtypes.float64)
   1015         self.assertEqual('/device:GPU:0', a.device)
   1016         b = variable_scope.get_variable('b', dtype=dtypes.float64)
   1017         self.assertEqual('/device:GPU:1', b.device)
   1018         c = variable_scope.get_variable('c', dtype=dtypes.float64)
   1019         self.assertEqual('/device:GPU:3', c.device)
   1020         d = variable_scope.get_variable('d', dtype=dtypes.float64)
   1021         self.assertEqual('/device:GPU:0', d.device)
   1022 
   1023 
   1024 class SplitBatchTest(test_util.TensorFlowTestCase):
   1025 
   1026   def evaluate_shards(self, first_list, second_list):
   1027     evaluate_items = lambda x: x.eval()
   1028     return list(map(evaluate_items, first_list)), list(
   1029         map(evaluate_items, second_list))
   1030 
   1031   def assertSparseValuesEqual(self, a, b):
   1032     self.assertAllEqual(a.indices, b.indices)
   1033     self.assertAllEqual(a.values, b.values)
   1034     self.assertAllEqual(a.dense_shape, b.dense_shape)
   1035 
   1036   def test_simple_half_split(self):
   1037     with self.test_session():
   1038       features = [0.0, 1.0, 2.0, 3.0]
   1039       labels = [10.0, 11.0, 12.0, 13.0]
   1040       feature_shards, label_shards = replicate_model_fn._split_batch(
   1041           features, labels, 2, device='/gpu:0')
   1042 
   1043       feature_shards, label_shards = self.evaluate_shards(
   1044           feature_shards, label_shards)
   1045 
   1046       self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards)
   1047       self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
   1048 
   1049   def test_to_each_their_own(self):
   1050     with self.test_session():
   1051       features = [0.0, 1.0, 2.0, 3.0]
   1052       labels = [10.0, 11.0, 12.0, 13.0]
   1053       feature_shards, label_shards = replicate_model_fn._split_batch(
   1054           features, labels, 4, device='/gpu:0')
   1055 
   1056       feature_shards, label_shards = self.evaluate_shards(
   1057           feature_shards, label_shards)
   1058 
   1059       self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards)
   1060       self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
   1061 
   1062   def test_one_batch(self):
   1063     with self.test_session():
   1064       features = [0.0, 1.0, 2.0, 3.0]
   1065       labels = [10.0, 11.0, 12.0, 13.0]
   1066       feature_shards, label_shards = replicate_model_fn._split_batch(
   1067           features, labels, 1, device='/gpu:0')
   1068 
   1069       feature_shards, label_shards = self.evaluate_shards(
   1070           feature_shards, label_shards)
   1071 
   1072       self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards)
   1073       self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
   1074 
   1075   def test_half_split_in_dictionary(self):
   1076     with self.test_session():
   1077       features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
   1078       labels = [10.0, 11.0, 12.0, 13.0]
   1079 
   1080       feature_shards, label_shards = replicate_model_fn._split_batch(
   1081           features, labels, 2, device='/gpu:0')
   1082 
   1083       self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
   1084       self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
   1085       self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
   1086       self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
   1087       self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
   1088       self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
   1089 
   1090   def test_sparse_tensor_can_be_split_unevenly(self):
   1091     with self.test_session():
   1092       features = {
   1093           'x':
   1094               sparse_tensor.SparseTensor(
   1095                   indices=[[0, 0], [1, 2], [2, 2]],
   1096                   values=[1.0, 2.0, 3.0],
   1097                   dense_shape=[3, 4])
   1098       }
   1099       labels = np.array([[1.0], [2.0]])
   1100 
   1101       feature_shards, label_shards = replicate_model_fn._split_batch(
   1102           features, labels, 2, device='/gpu:0')
   1103 
   1104       self.assertSparseValuesEqual(
   1105           sparse_tensor.SparseTensorValue(
   1106               indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
   1107           feature_shards[0]['x'].eval())
   1108       self.assertSparseValuesEqual(
   1109           sparse_tensor.SparseTensorValue(
   1110               indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
   1111           feature_shards[1]['x'].eval())
   1112       self.assertAllEqual([[1.0]], label_shards[0].eval())
   1113       self.assertAllEqual([[2.0]], label_shards[1].eval())
   1114 
   1115   def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
   1116     with self.test_session():
   1117       features = {
   1118           'x':
   1119               sparse_tensor.SparseTensor(
   1120                   indices=[[0, 0], [1, 0], [1, 1]],
   1121                   values=[1.0, 2.0, 3.0],
   1122                   dense_shape=[3, 4])
   1123       }
   1124       labels = np.array([[1.0], [2.0]])
   1125 
   1126       feature_shards, label_shards = replicate_model_fn._split_batch(
   1127           features, labels, 2, device='/gpu:0')
   1128 
   1129       self.assertSparseValuesEqual(
   1130           sparse_tensor.SparseTensorValue(
   1131               indices=[[0, 0], [1, 0], [1, 1]],
   1132               values=[1., 2., 3.],
   1133               dense_shape=[2, 4]), feature_shards[0]['x'].eval())
   1134 
   1135       second_batch = feature_shards[1]['x'].eval()
   1136       self.assertFalse(len(second_batch.indices))
   1137       self.assertFalse(len(second_batch.values))
   1138       self.assertAllEqual([1, 4], second_batch.dense_shape)
   1139       self.assertAllEqual([[1.0]], label_shards[0].eval())
   1140       self.assertAllEqual([[2.0]], label_shards[1].eval())
   1141 
   1142   def test_one_batch_in_dictionary(self):
   1143     with self.test_session() as session:  # pylint: disable=unused-variable
   1144       features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
   1145       labels = [10.0, 11.0, 12.0, 13.0]
   1146 
   1147       feature_shards, label_shards = replicate_model_fn._split_batch(
   1148           features, labels, 1, device='/gpu:0')
   1149 
   1150       self.assertAllEqual([0.0, 1.0, 2.0, 3.0],
   1151                           feature_shards[0]['first'].eval())
   1152       self.assertAllEqual([4.0, 5.0, 6.0, 7.0],
   1153                           feature_shards[0]['second'].eval())
   1154       self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
   1155 
   1156   def test_feature_and_label_dictionaries(self):
   1157     with self.test_session() as session:  # pylint: disable=unused-variable
   1158       features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
   1159       labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
   1160 
   1161       feature_shards, label_shards = replicate_model_fn._split_batch(
   1162           features, labels, 2, device='/gpu:0')
   1163 
   1164       self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval())
   1165       self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval())
   1166       self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval())
   1167       self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval())
   1168       self.assertAllEqual([10.0], label_shards[0]['first'].eval())
   1169       self.assertAllEqual([12.0], label_shards[0]['second'].eval())
   1170       self.assertAllEqual([11], label_shards[1]['first'].eval())
   1171       self.assertAllEqual([13.0], label_shards[1]['second'].eval())
   1172 
   1173 
   1174 class TrainSpecTest(test_util.TensorFlowTestCase):
   1175 
   1176   expected_predictions = {}
   1177 
   1178   def create_estimator_spec(self, loss):
   1179     return model_fn_lib.EstimatorSpec(
   1180         mode=model_fn_lib.ModeKeys.TRAIN,
   1181         loss=loss,
   1182         train_op=loss,  # Not used; currently required.
   1183         predictions=self.expected_predictions)
   1184 
   1185   def create_constant_loss(self, loss_value):
   1186     return constant_op.constant(loss_value, dtype=dtypes.float64)
   1187 
   1188   def test_example(self):
   1189     with self.test_session() as session:
   1190       tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
   1191       tower_specs = list(map(self.create_estimator_spec, tower_losses))
   1192 
   1193       expected_train_op = tower_losses[1]
   1194 
   1195       estimator_spec = replicate_model_fn._train_spec(
   1196           tower_specs, expected_train_op, aggregation_device='/gpu:0')
   1197 
   1198       self.assertEqual(expected_train_op, estimator_spec.train_op)
   1199       self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
   1200       self.assertEqual(self.expected_predictions, estimator_spec.predictions)
   1201 
   1202 
   1203 class EvalSpecTest(test_util.TensorFlowTestCase):
   1204 
   1205   def create_estimator_spec(self, loss, metrics):
   1206     return model_fn_lib.EstimatorSpec(
   1207         mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics)
   1208 
   1209   def create_constant_loss(self, loss_value):
   1210     return constant_op.constant(loss_value, dtype=dtypes.float64)
   1211 
   1212   def create_eval_metrics(self, noise):
   1213     predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise])
   1214     labels = np.array([0.1, 0.2, 0.3, 0.6])
   1215 
   1216     metrics = {
   1217         'accuracy': metrics_lib.accuracy(labels, predictions),
   1218         'auc': metrics_lib.auc(labels, predictions)
   1219     }
   1220     return metrics
   1221 
   1222   def test_example(self):
   1223     with self.test_session() as session:
   1224       tower_losses = map(self.create_constant_loss, [2, 4, 6])
   1225       tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
   1226       tower_specs = [
   1227           self.create_estimator_spec(l, m)
   1228           for l, m in zip(tower_losses, tower_metrics)
   1229       ]
   1230       session.run(variables.local_variables_initializer())
   1231 
   1232       estimator_spec = replicate_model_fn._eval_spec(
   1233           tower_specs, aggregation_device='/device:GPU:0')
   1234 
   1235       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
   1236       auc, b = estimator_spec.eval_metric_ops['auc']
   1237 
   1238       self.assertEqual('/device:CPU:0', accuracy.device)
   1239       self.assertEqual('/device:CPU:0', auc.device)
   1240 
   1241       session.run([a, b])
   1242       accuracy, auc = session.run([accuracy, auc])
   1243 
   1244       self.assertNear((12 - 2) / 12, accuracy, 0.01)
   1245       self.assertEqual(0, auc)
   1246       self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
   1247 
   1248   def test_handles_single_tower(self):
   1249     with self.test_session() as session:
   1250       tower_losses = map(self.create_constant_loss, [5])
   1251       tower_metrics = map(self.create_eval_metrics, [0.2])
   1252       tower_specs = [
   1253           self.create_estimator_spec(l, m)
   1254           for l, m in zip(tower_losses, tower_metrics)
   1255       ]
   1256       session.run(variables.local_variables_initializer())
   1257 
   1258       estimator_spec = replicate_model_fn._eval_spec(
   1259           tower_specs, aggregation_device='/device:GPU:0')
   1260 
   1261       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
   1262       auc, b = estimator_spec.eval_metric_ops['auc']
   1263 
   1264       self.assertEqual('/device:CPU:0', accuracy.device)
   1265       self.assertEqual('/device:CPU:0', auc.device)
   1266 
   1267       session.run([a, b])
   1268       accuracy = session.run(accuracy)
   1269       auc = session.run(auc)
   1270 
   1271       self.assertNear((4 - 1) / 4, accuracy, 0.01)
   1272       self.assertEqual(0, auc)
   1273       self.assertEqual(5, session.run(estimator_spec.loss))
   1274 
   1275 
   1276 class PredictSpecTest(test_util.TensorFlowTestCase):
   1277 
   1278   def model_fn(self, mode, features, labels, params):
   1279     c = variable_scope.get_variable(
   1280         'c',
   1281         initializer=constant_op.constant(0.25, dtype=dtypes.float64),
   1282         dtype=dtypes.float64)
   1283 
   1284     predictions = math_ops.add(np.array([features[0], features[0]]), c)
   1285 
   1286     return model_fn_lib.EstimatorSpec(
   1287         mode=model_fn_lib.ModeKeys.PREDICT,
   1288         predictions={
   1289             'probabilities': predictions
   1290         })
   1291 
   1292   def test_example(self):
   1293     with self.test_session() as session:
   1294       tower_specs = replicate_model_fn._get_loss_towers(
   1295           self.model_fn,
   1296           mode=None,
   1297           features=[[0.1], [0.2]],
   1298           loss_reduction=losses.Reduction.SUM,
   1299           labels=[[], []],
   1300           params=None,
   1301           config=None,
   1302           devices=['/gpu:0', '/gpu:1'],
   1303           local_ps_devices=['/gpu:0'],
   1304       )
   1305       session.run(variables.global_variables_initializer())
   1306 
   1307       estimator_spec = replicate_model_fn._predict_spec(
   1308           tower_specs, aggregation_device='/gpu:0')
   1309 
   1310       self.assertEqual('/device:GPU:0',
   1311                        estimator_spec.predictions['probabilities'].device)
   1312       self.assertAllClose({
   1313           'probabilities': np.array([0.35, 0.35, 0.45, 0.45])
   1314       }, session.run(estimator_spec.predictions))
   1315 
   1316 
   1317 class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
   1318 
   1319   def create_metric_variable(self, initial_value, name):
   1320     return variable_scope.variable(
   1321         initial_value,
   1322         trainable=False,
   1323         collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
   1324         validate_shape=True,
   1325         name=name)
   1326 
   1327   def create_tower_metrics(self, tower_id):
   1328     with variable_scope.variable_scope('', reuse=(tower_id != 0)):
   1329       self.create_metric_variable(1.3 * (tower_id + 1), 'total')
   1330       self.create_metric_variable(2.3 * (tower_id + 1), 'count')
   1331       self.create_metric_variable(
   1332           np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
   1333 
   1334   def test_example(self):
   1335     with self.test_session() as session:
   1336       for tower_id in range(3):
   1337         self.create_tower_metrics(tower_id)
   1338 
   1339       session.run(
   1340           variables.variables_initializer(
   1341               ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
   1342 
   1343       session.run(
   1344           replicate_model_fn._reduce_metric_variables(number_of_towers=3))
   1345 
   1346       # 1st tower = 1.3, 2.3,  [3.3, 3.5, 3.7]
   1347       # 2nd tower = 2.6, 4.6,  [6.6, 7.0, 7.4]
   1348       # 3rd tower = 3.9, 6.9,  [9.9, 10.5, 11.1]
   1349       # Reduced =   7.8, 13.8, [19.8, 21.0, 22.2]
   1350       # Towers are accumulated in the first tower.
   1351       local_metrics = session.run(
   1352           ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
   1353 
   1354       self.assertNear(7.8, local_metrics[0], 0.01)
   1355       self.assertNear(13.8, local_metrics[1], 0.01)
   1356       self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
   1357       self.assertNear(0.0, local_metrics[3], 0.01)
   1358       self.assertNear(0.0, local_metrics[4], 0.01)
   1359       self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
   1360       self.assertNear(0.0, local_metrics[6], 0.01)
   1361       self.assertNear(0.0, local_metrics[7], 0.01)
   1362       self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
   1363 
   1364   def test_reduce_is_idempotent(self):
   1365     with self.test_session() as session:
   1366       for tower_id in range(3):
   1367         self.create_tower_metrics(tower_id)
   1368 
   1369       session.run(
   1370           variables.variables_initializer(
   1371               ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
   1372 
   1373       for _ in range(20):
   1374         session.run(
   1375             replicate_model_fn._reduce_metric_variables(number_of_towers=3))
   1376 
   1377       local_metrics = session.run(
   1378           ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
   1379 
   1380       self.assertNear(7.8, local_metrics[0], 0.01)
   1381       self.assertNear(13.8, local_metrics[1], 0.01)
   1382       self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
   1383       self.assertNear(0.0, local_metrics[3], 0.01)
   1384       self.assertNear(0.0, local_metrics[4], 0.01)
   1385       self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
   1386       self.assertNear(0.0, local_metrics[6], 0.01)
   1387       self.assertNear(0.0, local_metrics[7], 0.01)
   1388       self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
   1389 
   1390   def test_handles_single_tower(self):
   1391     with self.test_session() as session:
   1392       self.create_tower_metrics(0)
   1393       session.run(
   1394           variables.variables_initializer(
   1395               ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
   1396 
   1397       session.run(
   1398           replicate_model_fn._reduce_metric_variables(number_of_towers=1))
   1399 
   1400       local_metrics = session.run(
   1401           ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
   1402 
   1403       self.assertNear(1.3, local_metrics[0], 0.01)
   1404       self.assertNear(2.3, local_metrics[1], 0.01)
   1405       self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
   1406 
   1407   def test_doesnt_accept_uneven_number_of_variables(self):
   1408     with self.test_session() as session:
   1409       for tower_id in range(3):
   1410         self.create_tower_metrics(tower_id)
   1411       self.create_metric_variable(-1.0, 'oddball')
   1412 
   1413       session.run(
   1414           variables.variables_initializer(
   1415               ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
   1416 
   1417       with self.assertRaisesRegexp(
   1418           ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'):
   1419         session.run(
   1420             replicate_model_fn._reduce_metric_variables(number_of_towers=3))
   1421 
   1422 
   1423 class MergeExportOutputsTest(test_util.TensorFlowTestCase):
   1424 
   1425   def model_fn(self, mode, features, labels, params):
   1426     c = variable_scope.get_variable(
   1427         'c',
   1428         initializer=constant_op.constant(10, dtype=dtypes.float64),
   1429         dtype=dtypes.float64)
   1430 
   1431     predictions = {'probabilities': math_ops.multiply(features, c)}
   1432     loss = losses.absolute_difference(
   1433         labels=labels,
   1434         predictions=predictions['probabilities'],
   1435         reduction=losses.Reduction.SUM)
   1436 
   1437     metrics = {
   1438         'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']),
   1439         'auc': metrics_lib.auc(labels, predictions['probabilities'])
   1440     }
   1441     tensor_string_repr = str(features)
   1442     classes = constant_op.constant(
   1443         re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1),
   1444         dtype=dtypes.string)
   1445 
   1446     export_outputs = {
   1447         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
   1448             export_output.PredictOutput(predictions),
   1449         'classification_output':
   1450             export_output.ClassificationOutput(predictions['probabilities'],
   1451                                                classes),
   1452         'classification_scores':
   1453             export_output.ClassificationOutput(
   1454                 scores=predictions['probabilities']),
   1455         'classification_classes':
   1456             export_output.ClassificationOutput(classes=classes),
   1457         'regression_output':
   1458             export_output.RegressionOutput(predictions['probabilities']),
   1459     }
   1460 
   1461     return model_fn_lib.EstimatorSpec(
   1462         mode=mode,
   1463         loss=math_ops.reduce_sum(loss),
   1464         eval_metric_ops=metrics,
   1465         predictions=predictions,
   1466         export_outputs=export_outputs)
   1467 
   1468   def replicate_estimator_spec(self, session):
   1469     features = np.array([0.01, 0.002])
   1470     labels = np.array([0.01, 0.02])
   1471 
   1472     replicated_model_fn = replicate_model_fn.replicate_model_fn(
   1473         self.model_fn, devices=['/gpu:0', '/gpu:1'])
   1474     estimator_spec = replicated_model_fn(features, labels,
   1475                                          model_fn_lib.ModeKeys.PREDICT, {})
   1476     session.run(variables.global_variables_initializer())
   1477     return estimator_spec
   1478 
   1479   def test_merge_predict_output(self):
   1480     with self.test_session() as session:
   1481       estimator_spec = self.replicate_estimator_spec(session)
   1482       self.assertAllClose(
   1483           {
   1484               'probabilities': np.array([0.1, 0.02])
   1485           },
   1486           session.run(estimator_spec.export_outputs[
   1487               signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
   1488 
   1489   def test_merge_classification_output_scores_classes(self):
   1490     with self.test_session() as session:
   1491       estimator_spec = self.replicate_estimator_spec(session)
   1492       self.assertAllClose(
   1493           [0.1, 0.02],
   1494           session.run(
   1495               estimator_spec.export_outputs['classification_output'].scores))
   1496       self.assertAllEqual(
   1497           [b'split_inputs/split:0', b'split_inputs/split:1'],
   1498           session.run(
   1499               estimator_spec.export_outputs['classification_output'].classes))
   1500 
   1501   def test_merge_classification_output_scores(self):
   1502     with self.test_session() as session:
   1503       estimator_spec = self.replicate_estimator_spec(session)
   1504       self.assertAllClose(
   1505           [0.1, 0.02],
   1506           session.run(
   1507               estimator_spec.export_outputs['classification_scores'].scores))
   1508       self.assertEqual(
   1509           None, estimator_spec.export_outputs['classification_scores'].classes)
   1510 
   1511   def test_merge_classification_output_classes(self):
   1512     with self.test_session() as session:
   1513       estimator_spec = self.replicate_estimator_spec(session)
   1514       self.assertAllEqual(
   1515           [b'split_inputs/split:0', b'split_inputs/split:1'],
   1516           session.run(
   1517               estimator_spec.export_outputs['classification_classes'].classes))
   1518       self.assertEqual(
   1519           None, estimator_spec.export_outputs['classification_classes'].scores)
   1520 
   1521   def test_merge_regression_output(self):
   1522     with self.test_session() as session:
   1523       estimator_spec = self.replicate_estimator_spec(session)
   1524       self.assertAllClose(
   1525           [0.1, 0.02],
   1526           session.run(estimator_spec.export_outputs['regression_output'].value))
   1527 
   1528 
   1529 class GetLocalDevicesTest(test_util.TensorFlowTestCase):
   1530 
   1531   def test_there_is_at_least_a_cpu(self):
   1532     self.assertTrue(replicate_model_fn._get_local_devices('CPU'))
   1533 
   1534   def test_there_is_no_xpu(self):
   1535     self.assertFalse(
   1536         replicate_model_fn._get_local_devices('XPU'))  # XPU doesn't exist.
   1537 
   1538   def test_whether_there_is_a_gpu(self):
   1539     if test.is_gpu_available():
   1540       self.assertTrue(len(replicate_model_fn._get_local_devices('GPU')))
   1541 
   1542 
   1543 class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
   1544 
   1545   def test_vars_are_on_ps_but_ops_are_on_workers(self):
   1546     ps_devices = ['/device:GPU:3']
   1547     round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
   1548 
   1549     local_device_setter = replicate_model_fn._local_device_setter(
   1550         ps_devices=ps_devices,
   1551         ps_strategy=round_robin,
   1552         worker_device='/device:GPU:2')
   1553 
   1554     with ops_lib.device(local_device_setter):
   1555       a = variables.Variable(0.01)
   1556       self.assertEqual('/device:GPU:3', a.device)
   1557 
   1558       b = variables.Variable(0.02)
   1559       self.assertEqual('/device:GPU:3', b.device)
   1560 
   1561       c = variables.Variable(0.03)
   1562       self.assertEqual('/device:GPU:3', c.device)
   1563 
   1564       a_op = array_ops.concat(a, axis=0)
   1565       self.assertEqual('/device:GPU:2', a_op.device)
   1566 
   1567       b_op = array_ops.concat(b, axis=0)
   1568       self.assertEqual('/device:GPU:2', b_op.device)
   1569 
   1570   def test_round_robin_placement(self):
   1571     ps_devices = [
   1572         '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4'
   1573     ]
   1574     round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
   1575 
   1576     local_device_setter = replicate_model_fn._local_device_setter(
   1577         ps_devices=ps_devices,
   1578         ps_strategy=round_robin,
   1579         worker_device='/device:GPU:2')
   1580 
   1581     with ops_lib.device(local_device_setter):
   1582       a = variables.Variable(0.01)
   1583       self.assertEqual('/device:GPU:0', a.device)
   1584 
   1585       b = variables.Variable(0.02)
   1586       self.assertEqual('/device:GPU:1', b.device)
   1587 
   1588       c = variables.Variable(0.03)
   1589       self.assertEqual('/device:GPU:3', c.device)
   1590 
   1591       a_op = array_ops.concat(a, axis=0)
   1592       self.assertEqual('/device:GPU:2', a_op.device)
   1593 
   1594       b_op = array_ops.concat(b, axis=0)
   1595       self.assertEqual('/device:GPU:2', b_op.device)
   1596 
   1597       c = variables.Variable(0.03)
   1598       self.assertEqual('/device:GPU:4', c.device)
   1599 
   1600       d = variables.Variable(0.03)
   1601       self.assertEqual('/device:GPU:0', d.device)
   1602 
   1603       c_op = array_ops.concat(c, axis=0)
   1604       self.assertEqual('/device:GPU:2', c_op.device)
   1605 
   1606 
   1607 class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
   1608 
   1609   def test_vectors(self):
   1610     with self.test_session() as session:
   1611       total = replicate_model_fn._compute_sum_on_device(
   1612           [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
   1613 
   1614       self.assertEqual('/device:GPU:0', total.device)
   1615       self.assertEqual('test_sum', total.op.name)
   1616       self.assertEqual(10.0, session.run(total))
   1617 
   1618   def test_tensors(self):
   1619     with self.test_session() as session:
   1620       total = replicate_model_fn._compute_sum_on_device(
   1621           [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
   1622 
   1623       self.assertEqual('/device:GPU:0', total.device)
   1624       self.assertEqual('test_sum', total.op.name)
   1625       self.assertAllEqual([4.0, 6.0], session.run(total))
   1626 
   1627   def test_indexedslices(self):
   1628     with self.test_session() as session:
   1629       a = ops_lib.IndexedSlices(
   1630           constant_op.constant([1.0, 2.0]), [0, 1],
   1631           dense_shape=constant_op.constant([2]))
   1632       b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
   1633 
   1634       total = replicate_model_fn._compute_sum_on_device(
   1635           [a, b], device='/device:GPU:0')
   1636 
   1637       self.assertEqual('/device:GPU:0', total.device)
   1638       self.assertAllEqual([4.0, 6.0],
   1639                           session.run(ops_lib.convert_to_tensor(total)))
   1640 
   1641   def test_indexedslices_higher_dimensions(self):
   1642     with self.test_session() as session:
   1643       a = ops_lib.IndexedSlices(
   1644           constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
   1645           dense_shape=constant_op.constant([2, 4]))
   1646       b = ops_lib.IndexedSlices(
   1647           constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
   1648 
   1649       total = replicate_model_fn._compute_sum_on_device(
   1650           [a, b], device='/device:GPU:0')
   1651 
   1652       self.assertEqual('/device:GPU:0', total.device)
   1653       self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
   1654                           session.run(ops_lib.convert_to_tensor(total)))
   1655 
   1656   def test_indexedslices_some_dont_overlap(self):
   1657     with self.test_session() as session:
   1658       a = ops_lib.IndexedSlices(
   1659           constant_op.constant([1.0, 2.0]), [0, 3],
   1660           dense_shape=constant_op.constant([4]))
   1661       b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
   1662 
   1663       total = replicate_model_fn._compute_sum_on_device(
   1664           [a, b], device='/device:GPU:0')
   1665 
   1666       self.assertEqual('/device:GPU:0', total.device)
   1667       self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
   1668                           session.run(ops_lib.convert_to_tensor(total)))
   1669 
   1670   def test_no_name_for_indexslices(self):
   1671     a = ops_lib.IndexedSlices(
   1672         constant_op.constant([1.0, 2.0]), [0, 1],
   1673         dense_shape=constant_op.constant([2]))
   1674     b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
   1675 
   1676     with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'):
   1677       _ = replicate_model_fn._compute_sum_on_device(
   1678           [a, b], device='/device:GPU:0', name='cant_name_indexslices')
   1679 
   1680 
   1681 class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
   1682 
   1683   def test_example(self):
   1684     tensor_dicts = [
   1685         {
   1686             'a': np.array([1.0, 2.0]),
   1687             'b': np.array([11.0]),
   1688             'c': np.array([21.0]),
   1689         },
   1690         {
   1691             'a': np.array([3.0]),
   1692             'b': np.array([12.0, 13.0]),
   1693         },
   1694         {
   1695             'b': np.array([14.0]),
   1696         },
   1697     ]
   1698 
   1699     with self.test_session() as session:
   1700       self.assertAllClose({
   1701           'a': np.array([1.0, 2.0, 3.0]),
   1702           'b': np.array([11.0, 12.0, 13.0, 14.0]),
   1703           'c': np.array([21.0]),
   1704       }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts)))
   1705 
   1706 
   1707 if __name__ == '__main__':
   1708   test.main()
   1709