Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for warm_starting_util."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import numpy as np
     23 import six
     24 
     25 from tensorflow.python.feature_column import feature_column_lib as fc
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import variable_scope
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import checkpoint_utils
     34 from tensorflow.python.training import saver as saver_lib
     35 from tensorflow.python.training import warm_starting_util as ws_util
     36 
     37 ones = init_ops.ones_initializer
     38 norms = init_ops.truncated_normal_initializer
     39 rand = init_ops.random_uniform_initializer
     40 zeros = init_ops.zeros_initializer
     41 
     42 
     43 class WarmStartingUtilTest(test.TestCase):
     44 
     45   def _write_vocab(self, string_values, file_name):
     46     vocab_file = os.path.join(self.get_temp_dir(), file_name)
     47     with open(vocab_file, "w") as f:
     48       f.write("\n".join(string_values))
     49     return vocab_file
     50 
     51   def _write_checkpoint(self, sess):
     52     self.evaluate(variables.global_variables_initializer())
     53     saver = saver_lib.Saver()
     54     ckpt_prefix = os.path.join(self.get_temp_dir(), "model")
     55     saver.save(sess, ckpt_prefix, global_step=0)
     56 
     57   def _create_prev_run_var(self,
     58                            var_name,
     59                            shape=None,
     60                            initializer=None,
     61                            partitioner=None):
     62     with ops.Graph().as_default() as g:
     63       with self.session(graph=g) as sess:
     64         var = variable_scope.get_variable(
     65             var_name,
     66             shape=shape,
     67             initializer=initializer,
     68             partitioner=partitioner)
     69         self._write_checkpoint(sess)
     70         if partitioner:
     71           self.assertTrue(isinstance(var, variables.PartitionedVariable))
     72           var = var._get_variable_list()
     73         return var, self.evaluate(var)
     74 
     75   def _create_prev_run_vars(self,
     76                             var_names,
     77                             shapes,
     78                             initializers):
     79     with ops.Graph().as_default() as g:
     80       with self.session(graph=g) as sess:
     81         all_vars = []
     82         for var_name, shape, initializer in zip(var_names, shapes,
     83                                                 initializers):
     84           all_vars.append(variable_scope.get_variable(
     85               var_name,
     86               shape=shape,
     87               initializer=initializer))
     88         self._write_checkpoint(sess)
     89         return [self.evaluate(var) for var in all_vars]
     90 
     91   def _create_dummy_inputs(self):
     92     return {
     93         "sc_int": array_ops.sparse_placeholder(dtypes.int32),
     94         "sc_hash": array_ops.sparse_placeholder(dtypes.string),
     95         "sc_keys": array_ops.sparse_placeholder(dtypes.string),
     96         "sc_vocab": array_ops.sparse_placeholder(dtypes.string),
     97         "real": array_ops.placeholder(dtypes.float32)
     98     }
     99 
    100   def _create_linear_model(self, feature_cols, partitioner):
    101     cols_to_vars = {}
    102     with variable_scope.variable_scope("", partitioner=partitioner):
    103       # Create the variables.
    104       fc.linear_model(
    105           features=self._create_dummy_inputs(),
    106           feature_columns=feature_cols,
    107           units=1,
    108           cols_to_vars=cols_to_vars)
    109     # Return a dictionary mapping each column to its variable.
    110     return cols_to_vars
    111 
    112   def _assert_cols_to_vars(self, cols_to_vars, cols_to_expected_values, sess):
    113     for col, expected_values in six.iteritems(cols_to_expected_values):
    114       for i, var in enumerate(cols_to_vars[col]):
    115         self.assertAllClose(expected_values[i], var.eval(sess))
    116 
    117   def testWarmStartVar(self):
    118     _, prev_val = self._create_prev_run_var(
    119         "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
    120 
    121     with ops.Graph().as_default() as g:
    122       with self.session(graph=g) as sess:
    123         fruit_weights = variable_scope.get_variable(
    124             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
    125         prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
    126         checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
    127                                               {prev_tensor_name: var})
    128         self.evaluate(variables.global_variables_initializer())
    129         self.assertAllClose(prev_val, fruit_weights.eval(sess))
    130 
    131   def testWarmStartVarPrevVarPartitioned(self):
    132     _, weights = self._create_prev_run_var(
    133         "fruit_weights",
    134         shape=[4, 1],
    135         initializer=[[0.5], [1.], [1.5], [2.]],
    136         partitioner=lambda shape, dtype: [2, 1])
    137     prev_val = np.concatenate([weights[0], weights[1]], axis=0)
    138 
    139     with ops.Graph().as_default() as g:
    140       with self.session(graph=g) as sess:
    141         fruit_weights = variable_scope.get_variable(
    142             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
    143         prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
    144         checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
    145                                               {prev_tensor_name: var})
    146         self.evaluate(variables.global_variables_initializer())
    147         self.assertAllClose(prev_val, fruit_weights.eval(sess))
    148 
    149   def testWarmStartVarCurrentVarPartitioned(self):
    150     _, prev_val = self._create_prev_run_var(
    151         "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
    152 
    153     with ops.Graph().as_default() as g:
    154       with self.session(graph=g) as sess:
    155         fruit_weights = variable_scope.get_variable(
    156             "fruit_weights",
    157             shape=[4, 1],
    158             initializer=[[0.], [0.], [0.], [0.]],
    159             partitioner=lambda shape, dtype: [2, 1])
    160         self.assertTrue(
    161             isinstance(fruit_weights, variables.PartitionedVariable))
    162         prev_tensor_name, var = ws_util._get_var_info(fruit_weights)
    163         checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
    164                                               {prev_tensor_name: var})
    165         self.evaluate(variables.global_variables_initializer())
    166         fruit_weights = fruit_weights._get_variable_list()
    167         new_val = np.concatenate(
    168             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
    169         self.assertAllClose(prev_val, new_val)
    170 
    171   def testWarmStartVarBothVarsPartitioned(self):
    172     _, weights = self._create_prev_run_var(
    173         "old_scope/fruit_weights",
    174         shape=[4, 1],
    175         initializer=[[0.5], [1.], [1.5], [2.]],
    176         partitioner=lambda shape, dtype: [2, 1])
    177     prev_val = np.concatenate([weights[0], weights[1]], axis=0)
    178     # New session and new graph.
    179     with ops.Graph().as_default() as g:
    180       with self.session(graph=g) as sess:
    181         fruit_weights = variable_scope.get_variable(
    182             "new_scope/fruit_weights",
    183             shape=[4, 1],
    184             initializer=[[0.], [0.], [0.], [0.]],
    185             partitioner=lambda shape, dtype: [2, 1])
    186         self.assertTrue(
    187             isinstance(fruit_weights, variables.PartitionedVariable))
    188         prev_tensor_name, var = ws_util._get_var_info(
    189             fruit_weights, prev_tensor_name="old_scope/fruit_weights")
    190         checkpoint_utils.init_from_checkpoint(self.get_temp_dir(),
    191                                               {prev_tensor_name: var})
    192         self.evaluate(variables.global_variables_initializer())
    193         fruit_weights = fruit_weights._get_variable_list()
    194         new_val = np.concatenate(
    195             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
    196         self.assertAllClose(prev_val, new_val)
    197 
    198   def testWarmStartVarWithVocab(self):
    199     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    200                                         "old_vocab")
    201     self._create_prev_run_var(
    202         "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
    203 
    204     # New vocab with elements in reverse order and one new element.
    205     new_vocab_path = self._write_vocab(
    206         ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
    207     # New session and new graph.
    208     with ops.Graph().as_default() as g:
    209       with self.session(graph=g) as sess:
    210         fruit_weights = variable_scope.get_variable(
    211             "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
    212         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
    213                                            self.get_temp_dir(), prev_vocab_path)
    214         self.evaluate(variables.global_variables_initializer())
    215         self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
    216                             fruit_weights.eval(sess))
    217 
    218   def testWarmStartVarWithColumnVocab(self):
    219     prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
    220     self._create_prev_run_var(
    221         "fruit_output_layer",
    222         initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
    223 
    224     # New vocab with elements in reverse order and one new element.
    225     new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
    226                                        "new_vocab")
    227     # New session and new graph.
    228     with ops.Graph().as_default() as g:
    229       with self.session(graph=g) as sess:
    230         fruit_output_layer = variable_scope.get_variable(
    231             "fruit_output_layer",
    232             initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
    233                          [0., 0., 0.]])
    234         ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
    235                                            current_vocab_size=3,
    236                                            prev_ckpt=self.get_temp_dir(),
    237                                            prev_vocab_path=prev_vocab_path,
    238                                            axis=1)
    239         self.evaluate(variables.global_variables_initializer())
    240         self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
    241                              [2.3, 2., 0.]], fruit_output_layer.eval(sess))
    242 
    243   def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
    244     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    245                                         "old_vocab")
    246     self._create_prev_run_var(
    247         "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
    248 
    249     # New vocab with elements in reverse order and one new element.
    250     new_vocab_path = self._write_vocab(
    251         ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
    252     # New session and new graph.
    253     with ops.Graph().as_default() as g:
    254       with self.session(graph=g) as sess:
    255         fruit_weights = variable_scope.get_variable(
    256             "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
    257         ws_util._warm_start_var_with_vocab(
    258             fruit_weights,
    259             new_vocab_path,
    260             5,
    261             self.get_temp_dir(),
    262             prev_vocab_path,
    263             previous_vocab_size=2)
    264         self.evaluate(variables.global_variables_initializer())
    265         # Old vocabulary limited to ['apple', 'banana'].
    266         self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
    267                             fruit_weights.eval(sess))
    268 
    269   def testWarmStartVarWithVocabPrevVarPartitioned(self):
    270     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    271                                         "old_vocab")
    272     self._create_prev_run_var(
    273         "fruit_weights",
    274         shape=[4, 1],
    275         initializer=[[0.5], [1.], [1.5], [2.]],
    276         partitioner=lambda shape, dtype: [2, 1])
    277 
    278     # New vocab with elements in reverse order and one new element.
    279     new_vocab_path = self._write_vocab(
    280         ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
    281     # New session and new graph.
    282     with ops.Graph().as_default() as g:
    283       with self.session(graph=g) as sess:
    284         fruit_weights = variable_scope.get_variable(
    285             "fruit_weights", initializer=[[0.], [0.], [0.], [0.], [0.]])
    286         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
    287                                            self.get_temp_dir(), prev_vocab_path)
    288         self.evaluate(variables.global_variables_initializer())
    289         self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
    290                             fruit_weights.eval(sess))
    291 
    292   def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
    293     prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
    294     self._create_prev_run_var(
    295         "fruit_output_layer",
    296         shape=[4, 2],
    297         initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
    298         partitioner=lambda shape, dtype: [2, 1])
    299 
    300     # New vocab with elements in reverse order and one new element.
    301     new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
    302                                        "new_vocab")
    303     # New session and new graph.
    304     with ops.Graph().as_default() as g:
    305       with self.session(graph=g) as sess:
    306         fruit_output_layer = variable_scope.get_variable(
    307             "fruit_output_layer",
    308             initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
    309                          [0., 0., 0.]])
    310         ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
    311                                            current_vocab_size=3,
    312                                            prev_ckpt=self.get_temp_dir(),
    313                                            prev_vocab_path=prev_vocab_path,
    314                                            axis=1)
    315         self.evaluate(variables.global_variables_initializer())
    316         self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
    317                              [2.3, 2., 0.]], fruit_output_layer.eval(sess))
    318 
    319   def testWarmStartVarWithVocabCurrentVarPartitioned(self):
    320     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    321                                         "old_vocab")
    322     self._create_prev_run_var(
    323         "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])
    324 
    325     # New vocab with elements in reverse order and one new element.
    326     new_vocab_path = self._write_vocab(
    327         ["orange", "guava", "banana", "apple", "raspberry"], "new_vocab")
    328     # New session and new graph.
    329     with ops.Graph().as_default() as g:
    330       with self.session(graph=g) as sess:
    331         fruit_weights = variable_scope.get_variable(
    332             "fruit_weights",
    333             shape=[6, 1],
    334             initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
    335             partitioner=lambda shape, dtype: [2, 1])
    336         ws_util._warm_start_var_with_vocab(
    337             fruit_weights,
    338             new_vocab_path,
    339             5,
    340             self.get_temp_dir(),
    341             prev_vocab_path,
    342             current_oov_buckets=1)
    343         self.evaluate(variables.global_variables_initializer())
    344         self.assertTrue(
    345             isinstance(fruit_weights, variables.PartitionedVariable))
    346         fruit_weights_vars = fruit_weights._get_variable_list()
    347         self.assertAllClose([[2.], [1.5], [1.]],
    348                             fruit_weights_vars[0].eval(sess))
    349         self.assertAllClose([[0.5], [0.], [0.]],
    350                             fruit_weights_vars[1].eval(sess))
    351 
    352   def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
    353     prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
    354     self._create_prev_run_var(
    355         "fruit_output_layer",
    356         initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
    357 
    358     # New vocab with elements in reverse order and one new element.
    359     new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
    360                                        "new_vocab")
    361     # New session and new graph.
    362     with ops.Graph().as_default() as g:
    363       with self.session(graph=g) as sess:
    364         fruit_output_layer = variable_scope.get_variable(
    365             "fruit_output_layer",
    366             shape=[4, 3],
    367             initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
    368                          [0., 0., 0.]],
    369             partitioner=lambda shape, dtype: [2, 1])
    370         ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
    371                                            current_vocab_size=3,
    372                                            prev_ckpt=self.get_temp_dir(),
    373                                            prev_vocab_path=prev_vocab_path,
    374                                            axis=1)
    375         self.evaluate(variables.global_variables_initializer())
    376         self.assertTrue(
    377             isinstance(fruit_output_layer, variables.PartitionedVariable))
    378         fruit_output_layer_vars = fruit_output_layer._get_variable_list()
    379         self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
    380                             fruit_output_layer_vars[0].eval(sess))
    381         self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
    382                             fruit_output_layer_vars[1].eval(sess))
    383 
    384   def testWarmStartVarWithVocabBothVarsPartitioned(self):
    385     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    386                                         "old_vocab")
    387     self._create_prev_run_var(
    388         "fruit_weights",
    389         shape=[4, 1],
    390         initializer=[[0.5], [1.], [1.5], [2.]],
    391         partitioner=lambda shape, dtype: [2, 1])
    392 
    393     # New vocab with elements in reverse order and two new elements.
    394     new_vocab_path = self._write_vocab(
    395         ["orange", "guava", "banana", "apple", "raspberry",
    396          "blueberry"], "new_vocab")
    397     # New session and new graph.
    398     with ops.Graph().as_default() as g:
    399       with self.session(graph=g) as sess:
    400         fruit_weights = variable_scope.get_variable(
    401             "fruit_weights",
    402             shape=[6, 1],
    403             initializer=[[0.], [0.], [0.], [0.], [0.], [0.]],
    404             partitioner=lambda shape, dtype: [2, 1])
    405         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 6,
    406                                            self.get_temp_dir(), prev_vocab_path)
    407         self.evaluate(variables.global_variables_initializer())
    408         self.assertTrue(
    409             isinstance(fruit_weights, variables.PartitionedVariable))
    410         fruit_weights_vars = fruit_weights._get_variable_list()
    411         self.assertAllClose([[2.], [1.5], [1.]],
    412                             fruit_weights_vars[0].eval(sess))
    413         self.assertAllClose([[0.5], [0.], [0.]],
    414                             fruit_weights_vars[1].eval(sess))
    415 
    416   def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
    417     prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
    418     self._create_prev_run_var(
    419         "fruit_output_layer",
    420         shape=[4, 2],
    421         initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
    422         partitioner=lambda shape, dtype: [2, 1])
    423 
    424     # New vocab with elements in reverse order and one new element.
    425     new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
    426                                        "new_vocab")
    427     # New session and new graph.
    428     with ops.Graph().as_default() as g:
    429       with self.session(graph=g) as sess:
    430         fruit_output_layer = variable_scope.get_variable(
    431             "fruit_output_layer",
    432             shape=[4, 3],
    433             initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
    434                          [0., 0., 0.]],
    435             partitioner=lambda shape, dtype: [2, 1])
    436         ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
    437                                            current_vocab_size=3,
    438                                            prev_ckpt=self.get_temp_dir(),
    439                                            prev_vocab_path=prev_vocab_path,
    440                                            axis=1)
    441         self.evaluate(variables.global_variables_initializer())
    442         self.assertTrue(
    443             isinstance(fruit_output_layer, variables.PartitionedVariable))
    444         fruit_output_layer_vars = fruit_output_layer._get_variable_list()
    445         self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
    446                             fruit_output_layer_vars[0].eval(sess))
    447         self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
    448                             fruit_output_layer_vars[1].eval(sess))
    449 
    450   def testWarmStart_ListOfVariables(self):
    451     # Save checkpoint from which to warm-start.
    452     _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
    453                                                 initializer=ones())
    454     # Verify we initialized the values correctly.
    455     self.assertAllEqual(np.ones([10, 1]), prev_int_val)
    456 
    457     # New graph, new session with warm-starting.
    458     with ops.Graph().as_default() as g:
    459       with self.session(graph=g) as sess:
    460         # Initialize with zeros.
    461         var = variable_scope.get_variable(
    462             "v1",
    463             shape=[10, 1],
    464             initializer=zeros())
    465         ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=[var])
    466         self.evaluate(variables.global_variables_initializer())
    467         # Verify weights were correctly warm-started (init overridden to ones).
    468         self.assertAllEqual(var.eval(), prev_int_val)
    469 
    470   def testWarmStart_ListOfStrings(self):
    471     # Save checkpoint from which to warm-start.
    472     _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
    473                                                 initializer=ones())
    474     # Verify we initialized the values correctly.
    475     self.assertAllEqual(np.ones([10, 1]), prev_int_val)
    476 
    477     # New graph, new session with warm-starting.
    478     with ops.Graph().as_default() as g:
    479       with self.session(graph=g) as sess:
    480         # Initialize with zeros.
    481         var = variable_scope.get_variable(
    482             "v1",
    483             shape=[10, 1],
    484             initializer=zeros())
    485         ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=["v1"])
    486         self.evaluate(variables.global_variables_initializer())
    487         # Verify weights were correctly warm-started (init overridden to ones).
    488         self.assertAllEqual(var.eval(), prev_int_val)
    489 
    490   def testWarmStart_ListOfRegexes(self):
    491     # Save checkpoint from which to warm-start.
    492     [prev_v1_val, prev_v1_momentum_val,
    493      prev_v2_val, _] = self._create_prev_run_vars(
    494          var_names=["v1", "v1/Momentum", "v2", "v2/Momentum"],
    495          shapes=[[10, 1]] * 4,
    496          initializers=[ones()] * 4)
    497 
    498     # New graph, new session with warm-starting.
    499     with ops.Graph().as_default() as g:
    500       with self.session(graph=g) as sess:
    501         # Initialize with zeros.
    502         v1 = variable_scope.get_variable(
    503             "v1",
    504             shape=[10, 1],
    505             initializer=zeros())
    506         v1_momentum = variable_scope.get_variable(
    507             "v1/Momentum",
    508             shape=[10, 1],
    509             initializer=zeros())
    510         v2 = variable_scope.get_variable(
    511             "v2",
    512             shape=[10, 1],
    513             initializer=zeros())
    514         v2_momentum = variable_scope.get_variable(
    515             "v2/Momentum",
    516             shape=[10, 1],
    517             initializer=zeros())
    518         ws_util.warm_start(self.get_temp_dir(),
    519                            # This warm-starts both v1 and v1/Momentum, but only
    520                            # v2 (and not v2/Momentum).
    521                            vars_to_warm_start=["v1", "v2[^/]"])
    522         self.evaluate(variables.global_variables_initializer())
    523         # Verify the selection of weights were correctly warm-started (init
    524         # overridden to ones).
    525         self.assertAllEqual(v1.eval(), prev_v1_val)
    526         self.assertAllEqual(v1_momentum.eval(), prev_v1_momentum_val)
    527         self.assertAllEqual(v2.eval(), prev_v2_val)
    528         self.assertAllEqual(v2_momentum.eval(), np.zeros([10, 1]))
    529 
    530   def testWarmStart_SparseColumnIntegerized(self):
    531     # Create feature column.
    532     sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
    533 
    534     # Save checkpoint from which to warm-start.
    535     _, prev_int_val = self._create_prev_run_var(
    536         "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
    537     # Verify we initialized the values correctly.
    538     self.assertAllEqual(np.ones([10, 1]), prev_int_val)
    539 
    540     partitioner = lambda shape, dtype: [1] * len(shape)
    541     # New graph, new session WITHOUT warm-starting.
    542     with ops.Graph().as_default() as g:
    543       with self.session(graph=g) as sess:
    544         cols_to_vars = self._create_linear_model([sc_int], partitioner)
    545         self.evaluate(variables.global_variables_initializer())
    546         # Without warm-starting, the weights should be initialized using default
    547         # initializer (which is init_ops.zeros_initializer).
    548         self._assert_cols_to_vars(cols_to_vars, {sc_int: [np.zeros([10, 1])]},
    549                                   sess)
    550 
    551     # New graph, new session with warm-starting.
    552     with ops.Graph().as_default() as g:
    553       with self.session(graph=g) as sess:
    554         cols_to_vars = self._create_linear_model([sc_int], partitioner)
    555         ws_util.warm_start(self.get_temp_dir(), vars_to_warm_start=".*sc_int.*")
    556         self.evaluate(variables.global_variables_initializer())
    557         # Verify weights were correctly warm-started.
    558         self._assert_cols_to_vars(cols_to_vars, {sc_int: [prev_int_val]}, sess)
    559 
    560   def testWarmStart_SparseColumnHashed(self):
    561     # Create feature column.
    562     sc_hash = fc.categorical_column_with_hash_bucket(
    563         "sc_hash", hash_bucket_size=15)
    564 
    565     # Save checkpoint from which to warm-start.
    566     _, prev_hash_val = self._create_prev_run_var(
    567         "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
    568 
    569     partitioner = lambda shape, dtype: [1] * len(shape)
    570     # New graph, new session WITHOUT warm-starting.
    571     with ops.Graph().as_default() as g:
    572       with self.session(graph=g) as sess:
    573         cols_to_vars = self._create_linear_model([sc_hash], partitioner)
    574         self.evaluate(variables.global_variables_initializer())
    575         # Without warm-starting, the weights should be initialized using default
    576         # initializer (which is init_ops.zeros_initializer).
    577         self._assert_cols_to_vars(cols_to_vars, {sc_hash: [np.zeros([15, 1])]},
    578                                   sess)
    579 
    580     # New graph, new session with warm-starting.
    581     with ops.Graph().as_default() as g:
    582       with self.session(graph=g) as sess:
    583         cols_to_vars = self._create_linear_model([sc_hash], partitioner)
    584         ws_util.warm_start(
    585             self.get_temp_dir(), vars_to_warm_start=".*sc_hash.*")
    586         self.evaluate(variables.global_variables_initializer())
    587         # Verify weights were correctly warm-started.
    588         self._assert_cols_to_vars(cols_to_vars, {sc_hash: [prev_hash_val]},
    589                                   sess)
    590 
    591   def testWarmStart_SparseColumnVocabulary(self):
    592     # Create vocab for sparse column "sc_vocab".
    593     vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    594                                    "vocab")
    595     # Create feature column.
    596     sc_vocab = fc.categorical_column_with_vocabulary_file(
    597         "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
    598 
    599     # Save checkpoint from which to warm-start.
    600     _, prev_vocab_val = self._create_prev_run_var(
    601         "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
    602 
    603     partitioner = lambda shape, dtype: [1] * len(shape)
    604     # New graph, new session WITHOUT warm-starting.
    605     with ops.Graph().as_default() as g:
    606       with self.session(graph=g) as sess:
    607         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    608         self.evaluate(variables.global_variables_initializer())
    609         # Without warm-starting, the weights should be initialized using default
    610         # initializer (which is init_ops.zeros_initializer).
    611         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
    612                                   sess)
    613 
    614     # New graph, new session with warm-starting.
    615     with ops.Graph().as_default() as g:
    616       with self.session(graph=g) as sess:
    617         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    618         # Since old vocab is not explicitly set in WarmStartSettings, the old
    619         # vocab is assumed to be same as new vocab.
    620         ws_util.warm_start(
    621             self.get_temp_dir(), vars_to_warm_start=".*sc_vocab.*")
    622         self.evaluate(variables.global_variables_initializer())
    623         # Verify weights were correctly warm-started.
    624         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
    625                                   sess)
    626 
    627   def testWarmStart_ExplicitCheckpointFile(self):
    628     # Create vocab for sparse column "sc_vocab".
    629     vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    630                                    "vocab")
    631     # Create feature column.
    632     sc_vocab = fc.categorical_column_with_vocabulary_file(
    633         "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
    634 
    635     # Save checkpoint from which to warm-start.
    636     _, prev_vocab_val = self._create_prev_run_var(
    637         "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
    638 
    639     partitioner = lambda shape, dtype: [1] * len(shape)
    640     # New graph, new session WITHOUT warm-starting.
    641     with ops.Graph().as_default() as g:
    642       with self.session(graph=g) as sess:
    643         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    644         self.evaluate(variables.global_variables_initializer())
    645         # Without warm-starting, the weights should be initialized using default
    646         # initializer (which is init_ops.zeros_initializer).
    647         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([4, 1])]},
    648                                   sess)
    649 
    650     # New graph, new session with warm-starting.
    651     with ops.Graph().as_default() as g:
    652       with self.session(graph=g) as sess:
    653         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    654         # Since old vocab is not explicitly set in WarmStartSettings, the old
    655         # vocab is assumed to be same as new vocab.
    656         ws_util.warm_start(
    657             # Explicitly provide the file prefix instead of just the dir.
    658             os.path.join(self.get_temp_dir(), "model-0"),
    659             vars_to_warm_start=".*sc_vocab.*")
    660         self.evaluate(variables.global_variables_initializer())
    661         # Verify weights were correctly warm-started.
    662         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [prev_vocab_val]},
    663                                   sess)
    664 
    665   def testWarmStart_SparseColumnVocabularyConstrainedVocabSizes(self):
    666     # Create old vocabulary, and use a size smaller than the total number of
    667     # entries.
    668     old_vocab_path = self._write_vocab(["apple", "guava", "banana"],
    669                                        "old_vocab")
    670     old_vocab_size = 2  # ['apple', 'guava']
    671 
    672     # Create new vocab for sparse column "sc_vocab".
    673     current_vocab_path = self._write_vocab(
    674         ["apple", "banana", "guava", "orange"], "current_vocab")
    675     # Create feature column.  Only use 2 of the actual entries, resulting in
    676     # ['apple', 'banana'] for the new vocabulary.
    677     sc_vocab = fc.categorical_column_with_vocabulary_file(
    678         "sc_vocab", vocabulary_file=current_vocab_path, vocabulary_size=2)
    679 
    680     # Save checkpoint from which to warm-start.
    681     self._create_prev_run_var(
    682         "linear_model/sc_vocab/weights", shape=[2, 1], initializer=ones())
    683 
    684     partitioner = lambda shape, dtype: [1] * len(shape)
    685     # New graph, new session WITHOUT warm-starting.
    686     with ops.Graph().as_default() as g:
    687       with self.session(graph=g) as sess:
    688         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    689         self.evaluate(variables.global_variables_initializer())
    690         # Without warm-starting, the weights should be initialized using default
    691         # initializer (which is init_ops.zeros_initializer).
    692         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [np.zeros([2, 1])]},
    693                                   sess)
    694 
    695     # New graph, new session with warm-starting.
    696     with ops.Graph().as_default() as g:
    697       with self.session(graph=g) as sess:
    698         cols_to_vars = self._create_linear_model([sc_vocab], partitioner)
    699         vocab_info = ws_util.VocabInfo(
    700             new_vocab=sc_vocab.vocabulary_file,
    701             new_vocab_size=sc_vocab.vocabulary_size,
    702             num_oov_buckets=sc_vocab.num_oov_buckets,
    703             old_vocab=old_vocab_path,
    704             old_vocab_size=old_vocab_size)
    705         ws_util.warm_start(
    706             ckpt_to_initialize_from=self.get_temp_dir(),
    707             vars_to_warm_start=".*sc_vocab.*",
    708             var_name_to_vocab_info={
    709                 "linear_model/sc_vocab/weights": vocab_info
    710             })
    711         self.evaluate(variables.global_variables_initializer())
    712         # Verify weights were correctly warm-started.  'banana' isn't in the
    713         # first two entries of the old vocabulary, so it's newly initialized.
    714         self._assert_cols_to_vars(cols_to_vars, {sc_vocab: [[[1], [0]]]}, sess)
    715 
    716   def testWarmStart_BucketizedColumn(self):
    717     # Create feature column.
    718     real = fc.numeric_column("real")
    719     real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
    720 
    721     # Save checkpoint from which to warm-start.
    722     _, prev_bucket_val = self._create_prev_run_var(
    723         "linear_model/real_bucketized/weights",
    724         shape=[5, 1],
    725         initializer=norms())
    726 
    727     partitioner = lambda shape, dtype: [1] * len(shape)
    728     # New graph, new session WITHOUT warm-starting.
    729     with ops.Graph().as_default() as g:
    730       with self.session(graph=g) as sess:
    731         cols_to_vars = self._create_linear_model([real_bucket], partitioner)
    732         self.evaluate(variables.global_variables_initializer())
    733         # Without warm-starting, the weights should be initialized using default
    734         # initializer (which is init_ops.zeros_initializer).
    735         self._assert_cols_to_vars(cols_to_vars,
    736                                   {real_bucket: [np.zeros([5, 1])]}, sess)
    737 
    738     # New graph, new session with warm-starting.
    739     with ops.Graph().as_default() as g:
    740       with self.session(graph=g) as sess:
    741         cols_to_vars = self._create_linear_model([real_bucket], partitioner)
    742         ws_util.warm_start(
    743             self.get_temp_dir(), vars_to_warm_start=".*real_bucketized.*")
    744         self.evaluate(variables.global_variables_initializer())
    745         # Verify weights were correctly warm-started.
    746         self._assert_cols_to_vars(cols_to_vars,
    747                                   {real_bucket: [prev_bucket_val]}, sess)
    748 
    749   def testWarmStart_MultipleCols(self):
    750     # Create vocab for sparse column "sc_vocab".
    751     vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    752                                    "vocab")
    753 
    754     # Create feature columns.
    755     sc_int = fc.categorical_column_with_identity("sc_int", num_buckets=10)
    756     sc_hash = fc.categorical_column_with_hash_bucket(
    757         "sc_hash", hash_bucket_size=15)
    758     sc_keys = fc.categorical_column_with_vocabulary_list(
    759         "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    760     sc_vocab = fc.categorical_column_with_vocabulary_file(
    761         "sc_vocab", vocabulary_file=vocab_path, vocabulary_size=4)
    762     real = fc.numeric_column("real")
    763     real_bucket = fc.bucketized_column(real, boundaries=[0., 1., 2., 3.])
    764     cross = fc.crossed_column([sc_keys, sc_vocab], hash_bucket_size=20)
    765     all_linear_cols = [sc_int, sc_hash, sc_keys, sc_vocab, real_bucket, cross]
    766 
    767     # Save checkpoint from which to warm-start.  Also create a bias variable,
    768     # so we can check that it's also warm-started.
    769     with ops.Graph().as_default() as g:
    770       with self.session(graph=g) as sess:
    771         sc_int_weights = variable_scope.get_variable(
    772             "linear_model/sc_int/weights", shape=[10, 1], initializer=ones())
    773         sc_hash_weights = variable_scope.get_variable(
    774             "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
    775         sc_keys_weights = variable_scope.get_variable(
    776             "linear_model/sc_keys/weights", shape=[4, 1], initializer=rand())
    777         sc_vocab_weights = variable_scope.get_variable(
    778             "linear_model/sc_vocab/weights", shape=[4, 1], initializer=ones())
    779         real_bucket_weights = variable_scope.get_variable(
    780             "linear_model/real_bucketized/weights",
    781             shape=[5, 1],
    782             initializer=norms())
    783         cross_weights = variable_scope.get_variable(
    784             "linear_model/sc_keys_X_sc_vocab/weights",
    785             shape=[20, 1],
    786             initializer=rand())
    787         bias = variable_scope.get_variable(
    788             "linear_model/bias_weights",
    789             shape=[1],
    790             initializer=rand())
    791         self._write_checkpoint(sess)
    792         (prev_int_val, prev_hash_val, prev_keys_val, prev_vocab_val,
    793          prev_bucket_val, prev_cross_val, prev_bias_val) = sess.run([
    794              sc_int_weights, sc_hash_weights, sc_keys_weights, sc_vocab_weights,
    795              real_bucket_weights, cross_weights, bias
    796          ])
    797 
    798     partitioner = lambda shape, dtype: [1] * len(shape)
    799     # New graph, new session WITHOUT warm-starting.
    800     with ops.Graph().as_default() as g:
    801       with self.session(graph=g) as sess:
    802         cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
    803         self.evaluate(variables.global_variables_initializer())
    804         # Without warm-starting, all weights should be initialized using default
    805         # initializer (which is init_ops.zeros_initializer).
    806         self._assert_cols_to_vars(cols_to_vars, {
    807             sc_int: [np.zeros([10, 1])],
    808             sc_hash: [np.zeros([15, 1])],
    809             sc_keys: [np.zeros([4, 1])],
    810             sc_vocab: [np.zeros([4, 1])],
    811             real_bucket: [np.zeros([5, 1])],
    812             cross: [np.zeros([20, 1])],
    813         }, sess)
    814 
    815     # New graph, new session with warm-starting.
    816     with ops.Graph().as_default() as g:
    817       with self.session(graph=g) as sess:
    818         cols_to_vars = self._create_linear_model(all_linear_cols, partitioner)
    819         vocab_info = ws_util.VocabInfo(
    820             new_vocab=sc_vocab.vocabulary_file,
    821             new_vocab_size=sc_vocab.vocabulary_size,
    822             num_oov_buckets=sc_vocab.num_oov_buckets,
    823             old_vocab=vocab_path)
    824         ws_util.warm_start(
    825             self.get_temp_dir(),
    826             var_name_to_vocab_info={
    827                 "linear_model/sc_vocab/weights": vocab_info
    828             })
    829         self.evaluate(variables.global_variables_initializer())
    830         # Verify weights were correctly warm-started.
    831         self._assert_cols_to_vars(cols_to_vars, {
    832             sc_int: [prev_int_val],
    833             sc_hash: [prev_hash_val],
    834             sc_keys: [prev_keys_val],
    835             sc_vocab: [prev_vocab_val],
    836             real_bucket: [prev_bucket_val],
    837             cross: [prev_cross_val],
    838             "bias": [prev_bias_val],
    839         }, sess)
    840 
    841   def testWarmStartMoreSettings(self):
    842     # Create old and new vocabs for sparse column "sc_vocab".
    843     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    844                                         "old_vocab")
    845     new_vocab_path = self._write_vocab(
    846         ["orange", "guava", "banana", "apple", "raspberry",
    847          "blueberry"], "new_vocab")
    848     # Create feature columns.
    849     sc_hash = fc.categorical_column_with_hash_bucket(
    850         "sc_hash", hash_bucket_size=15)
    851     sc_keys = fc.categorical_column_with_vocabulary_list(
    852         "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    853     sc_vocab = fc.categorical_column_with_vocabulary_file(
    854         "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    855     all_linear_cols = [sc_hash, sc_keys, sc_vocab]
    856 
    857     # Save checkpoint from which to warm-start.
    858     with ops.Graph().as_default() as g:
    859       with self.session(graph=g) as sess:
    860         variable_scope.get_variable(
    861             "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
    862         sc_keys_weights = variable_scope.get_variable(
    863             "some_other_name", shape=[4, 1], initializer=rand())
    864         variable_scope.get_variable(
    865             "linear_model/sc_vocab/weights",
    866             initializer=[[0.5], [1.], [2.], [3.]])
    867         self._write_checkpoint(sess)
    868         prev_keys_val = self.evaluate(sc_keys_weights)
    869 
    870     def _partitioner(shape, dtype):  # pylint:disable=unused-argument
    871       # Partition each var into 2 equal slices.
    872       partitions = [1] * len(shape)
    873       partitions[0] = min(2, shape.dims[0].value)
    874       return partitions
    875 
    876     # New graph, new session with warm-starting.
    877     with ops.Graph().as_default() as g:
    878       with self.session(graph=g) as sess:
    879         cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
    880         vocab_info = ws_util.VocabInfo(
    881             new_vocab=sc_vocab.vocabulary_file,
    882             new_vocab_size=sc_vocab.vocabulary_size,
    883             num_oov_buckets=sc_vocab.num_oov_buckets,
    884             old_vocab=prev_vocab_path)
    885         ws_util.warm_start(
    886             self.get_temp_dir(),
    887             vars_to_warm_start=".*(sc_keys|sc_vocab).*",
    888             var_name_to_vocab_info={
    889                 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
    890             },
    891             var_name_to_prev_var_name={
    892                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
    893                     "some_other_name"
    894             })
    895         self.evaluate(variables.global_variables_initializer())
    896         # Verify weights were correctly warm-started.  Var corresponding to
    897         # sc_hash should not be warm-started.  Var corresponding to sc_vocab
    898         # should be correctly warm-started after vocab remapping.
    899         self._assert_cols_to_vars(cols_to_vars, {
    900             sc_keys:
    901                 np.split(prev_keys_val, 2),
    902             sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
    903             sc_vocab: [
    904                 np.array([[3.], [2.], [1.]]),
    905                 np.array([[0.5], [0.], [0.]])
    906             ]
    907         }, sess)
    908 
    909   def testWarmStartMoreSettingsNoPartitioning(self):
    910     # Create old and new vocabs for sparse column "sc_vocab".
    911     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    912                                         "old_vocab")
    913     new_vocab_path = self._write_vocab(
    914         ["orange", "guava", "banana", "apple", "raspberry",
    915          "blueberry"], "new_vocab")
    916     # Create feature columns.
    917     sc_hash = fc.categorical_column_with_hash_bucket(
    918         "sc_hash", hash_bucket_size=15)
    919     sc_keys = fc.categorical_column_with_vocabulary_list(
    920         "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    921     sc_vocab = fc.categorical_column_with_vocabulary_file(
    922         "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    923     all_linear_cols = [sc_hash, sc_keys, sc_vocab]
    924 
    925     # Save checkpoint from which to warm-start.
    926     with ops.Graph().as_default() as g:
    927       with self.session(graph=g) as sess:
    928         variable_scope.get_variable(
    929             "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
    930         sc_keys_weights = variable_scope.get_variable(
    931             "some_other_name", shape=[4, 1], initializer=rand())
    932         variable_scope.get_variable(
    933             "linear_model/sc_vocab/weights",
    934             initializer=[[0.5], [1.], [2.], [3.]])
    935         self._write_checkpoint(sess)
    936         prev_keys_val = self.evaluate(sc_keys_weights)
    937 
    938     # New graph, new session with warm-starting.
    939     with ops.Graph().as_default() as g:
    940       with self.session(graph=g) as sess:
    941         cols_to_vars = self._create_linear_model(all_linear_cols,
    942                                                  partitioner=None)
    943         vocab_info = ws_util.VocabInfo(
    944             new_vocab=sc_vocab.vocabulary_file,
    945             new_vocab_size=sc_vocab.vocabulary_size,
    946             num_oov_buckets=sc_vocab.num_oov_buckets,
    947             old_vocab=prev_vocab_path)
    948         ws_util.warm_start(
    949             self.get_temp_dir(),
    950             vars_to_warm_start=".*(sc_keys|sc_vocab).*",
    951             var_name_to_vocab_info={
    952                 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
    953             },
    954             var_name_to_prev_var_name={
    955                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
    956                     "some_other_name"
    957             })
    958         self.evaluate(variables.global_variables_initializer())
    959         # Verify weights were correctly warm-started.  Var corresponding to
    960         # sc_hash should not be warm-started.  Var corresponding to sc_vocab
    961         # should be correctly warm-started after vocab remapping.
    962         self._assert_cols_to_vars(cols_to_vars, {
    963             sc_keys: [prev_keys_val],
    964             sc_hash: [np.zeros([15, 1])],
    965             sc_vocab: [np.array([[3.], [2.], [1.], [0.5], [0.], [0.]])]
    966         }, sess)
    967 
    968   def testWarmStartVarsToWarmstartIsNone(self):
    969     # Create old and new vocabs for sparse column "sc_vocab".
    970     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
    971                                         "old_vocab")
    972     new_vocab_path = self._write_vocab(
    973         ["orange", "guava", "banana", "apple", "raspberry",
    974          "blueberry"], "new_vocab")
    975     # Create feature columns.
    976     sc_hash = fc.categorical_column_with_hash_bucket(
    977         "sc_hash", hash_bucket_size=15)
    978     sc_keys = fc.categorical_column_with_vocabulary_list(
    979         "sc_keys", vocabulary_list=["a", "b", "c", "e"])
    980     sc_vocab = fc.categorical_column_with_vocabulary_file(
    981         "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
    982     all_linear_cols = [sc_hash, sc_keys, sc_vocab]
    983 
    984     # Save checkpoint from which to warm-start.
    985     with ops.Graph().as_default() as g:
    986       with self.session(graph=g) as sess:
    987         variable_scope.get_variable(
    988             "linear_model/sc_hash/weights", shape=[15, 1], initializer=norms())
    989         variable_scope.get_variable(
    990             "some_other_name", shape=[4, 1], initializer=rand())
    991         variable_scope.get_variable(
    992             "linear_model/sc_vocab/weights",
    993             initializer=[[0.5], [1.], [2.], [3.]])
    994         self._write_checkpoint(sess)
    995 
    996     def _partitioner(shape, dtype):  # pylint:disable=unused-argument
    997       # Partition each var into 2 equal slices.
    998       partitions = [1] * len(shape)
    999       partitions[0] = min(2, shape.dims[0].value)
   1000       return partitions
   1001 
   1002     # New graph, new session with warm-starting.
   1003     with ops.Graph().as_default() as g:
   1004       with self.session(graph=g) as sess:
   1005         cols_to_vars = self._create_linear_model(all_linear_cols, _partitioner)
   1006         vocab_info = ws_util.VocabInfo(
   1007             new_vocab=sc_vocab.vocabulary_file,
   1008             new_vocab_size=sc_vocab.vocabulary_size,
   1009             num_oov_buckets=sc_vocab.num_oov_buckets,
   1010             old_vocab=prev_vocab_path)
   1011         ws_util.warm_start(
   1012             self.get_temp_dir(),
   1013             # The special value of None here will ensure that only the variable
   1014             # specified in var_name_to_vocab_info (sc_vocab embedding) is
   1015             # warm-started.
   1016             vars_to_warm_start=None,
   1017             var_name_to_vocab_info={
   1018                 ws_util._infer_var_name(cols_to_vars[sc_vocab]): vocab_info
   1019             },
   1020             # Even though this is provided, the None value for
   1021             # vars_to_warm_start overrides the logic, and this will not be
   1022             # warm-started.
   1023             var_name_to_prev_var_name={
   1024                 ws_util._infer_var_name(cols_to_vars[sc_keys]):
   1025                     "some_other_name"
   1026             })
   1027         self.evaluate(variables.global_variables_initializer())
   1028         # Verify weights were correctly warm-started.  Var corresponding to
   1029         # sc_vocab should be correctly warm-started after vocab remapping,
   1030         # and neither of the other two should be warm-started..
   1031         self._assert_cols_to_vars(cols_to_vars, {
   1032             sc_keys: [np.zeros([2, 1]), np.zeros([2, 1])],
   1033             sc_hash: [np.zeros([8, 1]), np.zeros([7, 1])],
   1034             sc_vocab: [
   1035                 np.array([[3.], [2.], [1.]]),
   1036                 np.array([[0.5], [0.], [0.]])
   1037             ]
   1038         }, sess)
   1039 
   1040   def testWarmStartEmbeddingColumn(self):
   1041     # Create old and new vocabs for embedding column "sc_vocab".
   1042     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
   1043                                         "old_vocab")
   1044     new_vocab_path = self._write_vocab(
   1045         ["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
   1046         "new_vocab")
   1047 
   1048     # Save checkpoint from which to warm-start.
   1049     with ops.Graph().as_default() as g:
   1050       with self.session(graph=g) as sess:
   1051         variable_scope.get_variable(
   1052             "input_layer/sc_vocab_embedding/embedding_weights",
   1053             initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
   1054         self._write_checkpoint(sess)
   1055 
   1056     def _partitioner(shape, dtype):  # pylint:disable=unused-argument
   1057       # Partition each var into 2 equal slices.
   1058       partitions = [1] * len(shape)
   1059       partitions[0] = min(2, shape.dims[0].value)
   1060       return partitions
   1061 
   1062     # Create feature columns.
   1063     sc_vocab = fc.categorical_column_with_vocabulary_file(
   1064         "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
   1065     emb_vocab_column = fc.embedding_column(
   1066         categorical_column=sc_vocab,
   1067         dimension=2)
   1068     all_deep_cols = [emb_vocab_column]
   1069     # New graph, new session with warm-starting.
   1070     with ops.Graph().as_default() as g:
   1071       with self.session(graph=g) as sess:
   1072         cols_to_vars = {}
   1073         with variable_scope.variable_scope("", partitioner=_partitioner):
   1074           # Create the variables.
   1075           fc.input_layer(
   1076               features=self._create_dummy_inputs(),
   1077               feature_columns=all_deep_cols,
   1078               cols_to_vars=cols_to_vars)
   1079         vocab_info = ws_util.VocabInfo(
   1080             new_vocab=sc_vocab.vocabulary_file,
   1081             new_vocab_size=sc_vocab.vocabulary_size,
   1082             num_oov_buckets=sc_vocab.num_oov_buckets,
   1083             old_vocab=prev_vocab_path,
   1084             # Can't use constant_initializer with load_and_remap.  In practice,
   1085             # use a truncated normal initializer.
   1086             backup_initializer=init_ops.random_uniform_initializer(
   1087                 minval=0.42, maxval=0.42))
   1088         ws_util.warm_start(
   1089             self.get_temp_dir(),
   1090             var_name_to_vocab_info={
   1091                 ws_util._infer_var_name(cols_to_vars[emb_vocab_column]):
   1092                     vocab_info
   1093             })
   1094         self.evaluate(variables.global_variables_initializer())
   1095         # Verify weights were correctly warm-started. Var corresponding to
   1096         # emb_vocab_column should be correctly warm-started after vocab
   1097         # remapping. Missing values are filled in with the EmbeddingColumn's
   1098         # initializer.
   1099         self._assert_cols_to_vars(
   1100             cols_to_vars, {
   1101                 emb_vocab_column: [
   1102                     np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
   1103                     np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
   1104                 ]
   1105             }, sess)
   1106 
   1107   def testWarmStartEmbeddingColumnLinearModel(self):
   1108     # Create old and new vocabs for embedding column "sc_vocab".
   1109     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
   1110                                         "old_vocab")
   1111     new_vocab_path = self._write_vocab(
   1112         ["orange", "guava", "banana", "apple", "raspberry", "blueberry"],
   1113         "new_vocab")
   1114 
   1115     # Save checkpoint from which to warm-start.
   1116     with ops.Graph().as_default() as g:
   1117       with self.session(graph=g) as sess:
   1118         variable_scope.get_variable(
   1119             "linear_model/sc_vocab_embedding/embedding_weights",
   1120             initializer=[[0.5, 0.4], [1., 1.1], [2., 2.2], [3., 3.3]])
   1121         variable_scope.get_variable(
   1122             "linear_model/sc_vocab_embedding/weights",
   1123             initializer=[[0.69], [0.71]])
   1124         self._write_checkpoint(sess)
   1125 
   1126     def _partitioner(shape, dtype):  # pylint:disable=unused-argument
   1127       # Partition each var into 2 equal slices.
   1128       partitions = [1] * len(shape)
   1129       partitions[0] = min(2, shape.dims[0].value)
   1130       return partitions
   1131 
   1132     # Create feature columns.
   1133     sc_vocab = fc.categorical_column_with_vocabulary_file(
   1134         "sc_vocab", vocabulary_file=new_vocab_path, vocabulary_size=6)
   1135     emb_vocab = fc.embedding_column(
   1136         categorical_column=sc_vocab,
   1137         dimension=2)
   1138     all_deep_cols = [emb_vocab]
   1139     # New graph, new session with warm-starting.
   1140     with ops.Graph().as_default() as g:
   1141       with self.session(graph=g) as sess:
   1142         cols_to_vars = {}
   1143         with variable_scope.variable_scope("", partitioner=_partitioner):
   1144           # Create the variables.
   1145           fc.linear_model(
   1146               features=self._create_dummy_inputs(),
   1147               feature_columns=all_deep_cols,
   1148               cols_to_vars=cols_to_vars)
   1149 
   1150         # Construct the vocab_info for the embedding weight.
   1151         vocab_info = ws_util.VocabInfo(
   1152             new_vocab=sc_vocab.vocabulary_file,
   1153             new_vocab_size=sc_vocab.vocabulary_size,
   1154             num_oov_buckets=sc_vocab.num_oov_buckets,
   1155             old_vocab=prev_vocab_path,
   1156             # Can't use constant_initializer with load_and_remap.  In practice,
   1157             # use a truncated normal initializer.
   1158             backup_initializer=init_ops.random_uniform_initializer(
   1159                 minval=0.42, maxval=0.42))
   1160         ws_util.warm_start(
   1161             self.get_temp_dir(),
   1162             vars_to_warm_start=".*sc_vocab.*",
   1163             var_name_to_vocab_info={
   1164                 "linear_model/sc_vocab_embedding/embedding_weights": vocab_info
   1165             })
   1166         self.evaluate(variables.global_variables_initializer())
   1167         # Verify weights were correctly warm-started. Var corresponding to
   1168         # emb_vocab should be correctly warm-started after vocab remapping.
   1169         # Missing values are filled in with the EmbeddingColumn's initializer.
   1170         self._assert_cols_to_vars(
   1171             cols_to_vars,
   1172             {
   1173                 emb_vocab: [
   1174                     # linear weights part 0.
   1175                     np.array([[0.69]]),
   1176                     # linear weights part 1.
   1177                     np.array([[0.71]]),
   1178                     # embedding_weights part 0.
   1179                     np.array([[3., 3.3], [2., 2.2], [1., 1.1]]),
   1180                     # embedding_weights part 1.
   1181                     np.array([[0.5, 0.4], [0.42, 0.42], [0.42, 0.42]])
   1182                 ]
   1183             },
   1184             sess)
   1185 
   1186   def testErrorConditions(self):
   1187     x = variable_scope.get_variable(
   1188         "x",
   1189         shape=[4, 1],
   1190         initializer=ones(),
   1191         partitioner=lambda shape, dtype: [2, 1])
   1192 
   1193     # List of PartitionedVariable is invalid type when warm-starting with vocab.
   1194     self.assertRaises(TypeError, ws_util._warm_start_var_with_vocab, [x],
   1195                       "/tmp", 5, "/tmp", "/tmp")
   1196 
   1197     # Unused variable names raises ValueError.
   1198     with ops.Graph().as_default():
   1199       with self.cached_session() as sess:
   1200         x = variable_scope.get_variable(
   1201             "x",
   1202             shape=[4, 1],
   1203             initializer=ones(),
   1204             partitioner=lambda shape, dtype: [2, 1])
   1205         self._write_checkpoint(sess)
   1206 
   1207     self.assertRaises(
   1208         ValueError,
   1209         ws_util.warm_start,
   1210         self.get_temp_dir(),
   1211         var_name_to_vocab_info={"y": ws_util.VocabInfo("", 1, 0, "")})
   1212     self.assertRaises(
   1213         ValueError,
   1214         ws_util.warm_start,
   1215         self.get_temp_dir(),
   1216         var_name_to_prev_var_name={"y": "y2"})
   1217 
   1218 
   1219 if __name__ == "__main__":
   1220   test.main()
   1221