Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops that need test_util."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.core.protobuf import config_pb2
     23 from tensorflow.python.client import session
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.data.ops import iterator_ops
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.framework import function
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import test_util
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import functional_ops
     34 from tensorflow.python.ops import lookup_ops
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import string_ops
     37 from tensorflow.python.platform import test
     38 
     39 
     40 class IteratorClusterTest(test.TestCase):
     41 
     42   def testRemoteIteratorWithoutRemoteCallFail(self):
     43     worker_config = config_pb2.ConfigProto()
     44     worker_config.device_count["CPU"] = 2
     45     worker, _ = test_util.create_local_cluster(
     46         1, 1, worker_config=worker_config)
     47 
     48     with ops.device("/job:worker/replica:0/task:0/cpu:1"):
     49       dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
     50       iterator_3 = dataset_3.make_one_shot_iterator()
     51       iterator_3_handle = iterator_3.string_handle()
     52 
     53     with ops.device("/job:worker/replica:0/task:0/cpu:0"):
     54       remote_it = iterator_ops.Iterator.from_string_handle(
     55           iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes)
     56       get_next_op = remote_it.get_next()
     57 
     58     with session.Session(worker[0].target) as sess:
     59       with self.assertRaises(errors.InvalidArgumentError):
     60         sess.run(get_next_op)
     61 
     62   def _testRemoteIteratorHelper(self, device0, device1, target):
     63     with ops.device(device1):
     64       dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
     65       iterator_3 = dataset_3.make_one_shot_iterator()
     66       iterator_3_handle = iterator_3.string_handle()
     67 
     68     @function.Defun(dtypes.string)
     69     def _remote_fn(h):
     70       remote_iterator = iterator_ops.Iterator.from_string_handle(
     71           h, dataset_3.output_types, dataset_3.output_shapes)
     72       return remote_iterator.get_next()
     73 
     74     with ops.device(device0):
     75       target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     76       remote_op = functional_ops.remote_call(
     77           args=[iterator_3_handle],
     78           Tout=[dtypes.int32],
     79           f=_remote_fn,
     80           target=target_placeholder)
     81 
     82     with session.Session(target) as sess:
     83       elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
     84       self.assertEqual(elem, [1])
     85       # Fails when target is cpu:0 where the resource is not located.
     86       with self.assertRaises(errors.InvalidArgumentError):
     87         sess.run(remote_op, feed_dict={target_placeholder: device0})
     88       elem = sess.run(iterator_3.get_next())
     89       self.assertEqual(elem, [2])
     90       elem = sess.run(remote_op, feed_dict={target_placeholder: device1})
     91       self.assertEqual(elem, [3])
     92       with self.assertRaises(errors.OutOfRangeError):
     93         sess.run(remote_op, feed_dict={target_placeholder: device1})
     94 
     95   def testRemoteIteratorUsingRemoteCallOp(self):
     96     worker_config = config_pb2.ConfigProto()
     97     worker_config.device_count["CPU"] = 2
     98     worker, _ = test_util.create_local_cluster(
     99         1, 1, worker_config=worker_config)
    100 
    101     self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
    102                                    "/job:worker/replica:0/task:0/cpu:1",
    103                                    worker[0].target)
    104 
    105   def testRemoteIteratorUsingRemoteCallOpCrossProcess(self):
    106     workers, _ = test_util.create_local_cluster(2, 1)
    107 
    108     self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0",
    109                                    "/job:worker/replica:0/task:1/cpu:0",
    110                                    workers[0].target)
    111 
    112   def testCaptureHashTableInSharedIterator(self):
    113     worker, _ = test_util.create_local_cluster(1, 1)
    114 
    115     # NOTE(mrry): We must use the V2 variants of `HashTable`
    116     # etc. because these produce a `tf.resource`-typed output that is
    117     # compatible with the in-graph function implementation.
    118     default_val = -1
    119     keys = constant_op.constant(["brain", "salad", "surgery"])
    120     values = constant_op.constant([0, 1, 2], dtypes.int64)
    121     table = lookup_ops.HashTable(
    122         lookup_ops.KeyValueTensorInitializer(keys, values),
    123         default_val,
    124         shared_name="shared_table")
    125 
    126     input_sentences = dataset_ops.Dataset.from_tensor_slices(
    127         ["brain brain tank salad surgery", "surgery brain"])
    128 
    129     iterator = (
    130         input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
    131             table.lookup)
    132         .make_initializable_iterator(shared_name="shared_iterator"))
    133     init_op = iterator.initializer
    134     get_next = iterator.get_next()
    135 
    136     with session.Session(worker[0].target) as sess:
    137       sess.run(table.init)
    138       sess.run(init_op)
    139       self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
    140 
    141     with session.Session(worker[0].target) as sess:
    142       self.assertAllEqual([2, 0], sess.run(get_next))
    143       with self.assertRaises(errors.OutOfRangeError):
    144         sess.run(get_next)
    145 
    146   def testImplicitDisposeParallelMapDataset(self):
    147     # Tests whether a parallel map dataset will be cleaned up correctly when
    148     # the pipeline does not run it until exhaustion.
    149     # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
    150     # RepeatDataset(None) -> PrefetchDataset(100).
    151     worker, _ = test_util.create_local_cluster(1, 1)
    152 
    153     components = (np.arange(1000),
    154                   np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
    155                   np.array(37.0) * np.arange(1000))
    156 
    157     def _map_fn(x, y, z):
    158       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
    159 
    160     dataset = (
    161         dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
    162         .repeat(None).prefetch(10000))
    163 
    164     iterator = dataset.make_initializable_iterator()
    165     init_op = iterator.initializer
    166     get_next = iterator.get_next()
    167 
    168     with session.Session(worker[0].target) as sess:
    169       sess.run(init_op)
    170       for _ in range(3):
    171         sess.run(get_next)
    172 
    173 
    174 if __name__ == "__main__":
    175   test.main()
    176