1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 # use this file except in compliance with the License. You may obtain a copy of 5 # the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 # License for the specific language governing permissions and limitations under 13 # the License. 14 # ============================================================================== 15 """Tests for KafkaDataset.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops 22 from tensorflow.python.data.ops import iterator_ops 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import errors 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.platform import test 27 28 29 class KafkaDatasetTest(test.TestCase): 30 31 def setUp(self): 32 # The Kafka server has to be setup before the test 33 # and tear down after the test manually. 34 # The docker engine has to be installed. 35 # 36 # To setup the Kafka server: 37 # $ bash kafka_test.sh start kafka 38 # 39 # To team down the Kafka server: 40 # $ bash kafka_test.sh stop kafka 41 pass 42 43 def testKafkaDataset(self): 44 topics = array_ops.placeholder(dtypes.string, shape=[None]) 45 num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) 46 batch_size = array_ops.placeholder(dtypes.int64, shape=[]) 47 48 repeat_dataset = kafka_dataset_ops.KafkaDataset( 49 topics, group="test", eof=True).repeat(num_epochs) 50 batch_dataset = repeat_dataset.batch(batch_size) 51 52 iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) 53 init_op = iterator.make_initializer(repeat_dataset) 54 init_batch_op = iterator.make_initializer(batch_dataset) 55 get_next = iterator.get_next() 56 57 with self.test_session() as sess: 58 # Basic test: read from topic 0. 59 sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) 60 for i in range(5): 61 self.assertEqual("D" + str(i), sess.run(get_next)) 62 with self.assertRaises(errors.OutOfRangeError): 63 sess.run(get_next) 64 65 # Basic test: read from topic 1. 66 sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) 67 for i in range(5): 68 self.assertEqual("D" + str(i + 5), sess.run(get_next)) 69 with self.assertRaises(errors.OutOfRangeError): 70 sess.run(get_next) 71 72 # Basic test: read from both topics. 73 sess.run( 74 init_op, 75 feed_dict={ 76 topics: ["test:0:0:4", "test:0:5:-1"], 77 num_epochs: 1 78 }) 79 for j in range(2): 80 for i in range(5): 81 self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) 82 with self.assertRaises(errors.OutOfRangeError): 83 sess.run(get_next) 84 85 # Test repeated iteration through both files. 86 sess.run( 87 init_op, 88 feed_dict={ 89 topics: ["test:0:0:4", "test:0:5:-1"], 90 num_epochs: 10 91 }) 92 for _ in range(10): 93 for j in range(2): 94 for i in range(5): 95 self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) 96 with self.assertRaises(errors.OutOfRangeError): 97 sess.run(get_next) 98 99 # Test batched and repeated iteration through both files. 100 sess.run( 101 init_batch_op, 102 feed_dict={ 103 topics: ["test:0:0:4", "test:0:5:-1"], 104 num_epochs: 10, 105 batch_size: 5 106 }) 107 for _ in range(10): 108 self.assertAllEqual(["D" + str(i) for i in range(5)], 109 sess.run(get_next)) 110 self.assertAllEqual(["D" + str(i + 5) for i in range(5)], 111 sess.run(get_next)) 112 113 114 if __name__ == "__main__": 115 test.main() 116