Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
      4 # use this file except in compliance with the License.  You may obtain a copy of
      5 # the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
     11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
     12 # License for the specific language governing permissions and limitations under
     13 # the License.
     14 # ==============================================================================
     15 """Tests for KafkaDataset."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops
     22 from tensorflow.python.data.ops import iterator_ops
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class KafkaDatasetTest(test.TestCase):
     30 
     31   def setUp(self):
     32     # The Kafka server has to be setup before the test
     33     # and tear down after the test manually.
     34     # The docker engine has to be installed.
     35     #
     36     # To setup the Kafka server:
     37     # $ bash kafka_test.sh start kafka
     38     #
     39     # To team down the Kafka server:
     40     # $ bash kafka_test.sh stop kafka
     41     pass
     42 
     43   def testKafkaDataset(self):
     44     topics = array_ops.placeholder(dtypes.string, shape=[None])
     45     num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
     46     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
     47 
     48     repeat_dataset = kafka_dataset_ops.KafkaDataset(
     49         topics, group="test", eof=True).repeat(num_epochs)
     50     batch_dataset = repeat_dataset.batch(batch_size)
     51 
     52     iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
     53     init_op = iterator.make_initializer(repeat_dataset)
     54     init_batch_op = iterator.make_initializer(batch_dataset)
     55     get_next = iterator.get_next()
     56 
     57     with self.test_session() as sess:
     58       # Basic test: read from topic 0.
     59       sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1})
     60       for i in range(5):
     61         self.assertEqual("D" + str(i), sess.run(get_next))
     62       with self.assertRaises(errors.OutOfRangeError):
     63         sess.run(get_next)
     64 
     65       # Basic test: read from topic 1.
     66       sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1})
     67       for i in range(5):
     68         self.assertEqual("D" + str(i + 5), sess.run(get_next))
     69       with self.assertRaises(errors.OutOfRangeError):
     70         sess.run(get_next)
     71 
     72       # Basic test: read from both topics.
     73       sess.run(
     74           init_op,
     75           feed_dict={
     76               topics: ["test:0:0:4", "test:0:5:-1"],
     77               num_epochs: 1
     78           })
     79       for j in range(2):
     80         for i in range(5):
     81           self.assertEqual("D" + str(i + j * 5), sess.run(get_next))
     82       with self.assertRaises(errors.OutOfRangeError):
     83         sess.run(get_next)
     84 
     85       # Test repeated iteration through both files.
     86       sess.run(
     87           init_op,
     88           feed_dict={
     89               topics: ["test:0:0:4", "test:0:5:-1"],
     90               num_epochs: 10
     91           })
     92       for _ in range(10):
     93         for j in range(2):
     94           for i in range(5):
     95             self.assertEqual("D" + str(i + j * 5), sess.run(get_next))
     96       with self.assertRaises(errors.OutOfRangeError):
     97         sess.run(get_next)
     98 
     99       # Test batched and repeated iteration through both files.
    100       sess.run(
    101           init_batch_op,
    102           feed_dict={
    103               topics: ["test:0:0:4", "test:0:5:-1"],
    104               num_epochs: 10,
    105               batch_size: 5
    106           })
    107       for _ in range(10):
    108         self.assertAllEqual(["D" + str(i) for i in range(5)],
    109                             sess.run(get_next))
    110         self.assertAllEqual(["D" + str(i + 5) for i in range(5)],
    111                             sess.run(get_next))
    112 
    113 
    114 if __name__ == "__main__":
    115   test.main()
    116