Home | History | Annotate | Download | only in python
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests that show that DistributionStrategy works with canned Estimator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from absl.testing import parameterized
     22 import numpy as np
     23 
     24 from tensorflow.contrib.distribute.python import combinations
     25 from tensorflow.python import keras
     26 from tensorflow.python.distribute import distribution_strategy_context as ds_context
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.keras.optimizer_v2 import adam
     32 from tensorflow.python.keras.optimizer_v2 import gradient_descent
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import variable_scope
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 
     38 
     39 def get_model():
     40   x = keras.layers.Input(shape=(3,), name='input')
     41   y = keras.layers.Dense(4, name='dense')(x)
     42   model = keras.Model(x, y)
     43   return model
     44 
     45 
     46 class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
     47 
     48   @combinations.generate(
     49       combinations.combine(
     50           distribution=[
     51               combinations.core_mirrored_strategy_with_gpu_and_cpu,
     52               combinations.core_mirrored_strategy_with_two_gpus,
     53               combinations.parameter_server_strategy_with_two_gpus,
     54           ],
     55           mode=['graph', 'eager']))
     56   def testKerasOptimizerWithUnequalInput(self, distribution):
     57     with distribution.scope():
     58       var = variables.Variable(
     59           2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM)
     60       optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2)
     61       all_vars = []
     62 
     63       def model_fn():
     64 
     65         def loss_fn():
     66           replica_id = _replica_id()
     67           return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var
     68 
     69         train_op = optimizer.minimize(loss_fn, var_list=[var])
     70 
     71         return train_op, optimizer
     72 
     73       def train_fn():
     74         train_op, optimizer = distribution.extended.call_for_each_replica(
     75             model_fn)
     76         if not all_vars:
     77           all_vars.append(var)
     78           all_vars.append(optimizer.get_slot(var, 'm'))
     79           all_vars.append(optimizer.get_slot(var, 'v'))
     80         return distribution.group(train_op)
     81 
     82       if not context.executing_eagerly():
     83         with self.cached_session() as sess:
     84           train_fn = sess.make_callable(train_fn())
     85       self.evaluate(variables.global_variables_initializer())
     86 
     87       # first step.
     88       train_fn()
     89       # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1)
     90       #        = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8
     91       self.assertAllClose(1.99, self.evaluate(all_vars[0]))
     92       # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2
     93       self.assertAllClose(1.2, self.evaluate(all_vars[1]))
     94       # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25
     95       self.assertAllClose(1.8, self.evaluate(all_vars[2]))
     96 
     97       # second step.
     98       train_fn()
     99       # var(1) = var(0) - lr * 2 = 1.98
    100       self.assertAllClose(1.98, self.evaluate(all_vars[0]))
    101       # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5
    102       self.assertAllClose(1.44, self.evaluate(all_vars[1]))
    103       # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
    104       self.assertAllClose(2.16, self.evaluate(all_vars[2]))
    105 
    106   @combinations.generate(
    107       combinations.combine(
    108           distribution=[
    109               combinations.core_mirrored_strategy_with_gpu_and_cpu,
    110               combinations.parameter_server_strategy_with_two_gpus,
    111           ],
    112           mode=['graph', 'eager']))
    113   def testOptimizerWithKerasModelAndNumpyArrays(self, distribution):
    114 
    115     with self.cached_session():
    116       with distribution.scope():
    117         model = get_model()
    118         optimizer = gradient_descent.SGD(0.001)
    119         loss = 'mse'
    120         metrics = ['mae']
    121         model.compile(optimizer, loss, metrics=metrics)
    122 
    123       inputs = np.zeros((64, 3), dtype=np.float32)
    124       targets = np.zeros((64, 4), dtype=np.float32)
    125 
    126       model.fit(
    127           inputs,
    128           targets,
    129           epochs=1,
    130           batch_size=2,
    131           verbose=0,
    132           validation_data=(inputs, targets))
    133       model.evaluate(inputs, targets)
    134       model.predict(inputs)
    135 
    136 
    137 def _replica_id():
    138   replica_id = ds_context.get_replica_context().replica_id_in_sync_group
    139   if not isinstance(replica_id, ops.Tensor):
    140     replica_id = constant_op.constant(replica_id)
    141   return replica_id
    142 
    143 
    144 if __name__ == '__main__':
    145   test.main()
    146