1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # =================================================================== 15 16 """Utilities for the functionalities.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import time 23 import six 24 25 from tensorflow.python.platform import tf_logging as logging 26 from tensorflow.python.training import training 27 28 def check_positive_integer(value, name): 29 """Checks whether `value` is a positive integer.""" 30 if not isinstance(value, six.integer_types): 31 raise TypeError('{} must be int, got {}'.format(name, type(value))) 32 33 if value <= 0: 34 raise ValueError('{} must be positive, got {}'.format(name, value)) 35 36 37 # TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we 38 # release a tensorflow_estimator with MultiHostDatasetInitializerHook in 39 # python/estimator/util.py. 40 class MultiHostDatasetInitializerHook(training.SessionRunHook): 41 """Creates a SessionRunHook that initializes all passed iterators.""" 42 43 def __init__(self, dataset_initializers): 44 self._initializers = dataset_initializers 45 46 def after_create_session(self, session, coord): 47 del coord 48 start = time.time() 49 session.run(self._initializers) 50 logging.info('Initialized dataset iterators in %d seconds', 51 time.time() - start) 52