1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/dataset.h" 17 18 #include "tensorflow/core/framework/tensor.h" 19 20 #include "src-cpp/rdkafkacpp.h" 21 22 namespace tensorflow { 23 24 class KafkaDatasetOp : public DatasetOpKernel { 25 public: 26 using DatasetOpKernel::DatasetOpKernel; 27 28 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 29 const Tensor* topics_tensor; 30 OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); 31 OP_REQUIRES( 32 ctx, topics_tensor->dims() <= 1, 33 errors::InvalidArgument("`topics` must be a scalar or a vector.")); 34 35 std::vector<string> topics; 36 topics.reserve(topics_tensor->NumElements()); 37 for (int i = 0; i < topics_tensor->NumElements(); ++i) { 38 topics.push_back(topics_tensor->flat<string>()(i)); 39 } 40 41 std::string servers = ""; 42 OP_REQUIRES_OK(ctx, 43 ParseScalarArgument<std::string>(ctx, "servers", &servers)); 44 std::string group = ""; 45 OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group)); 46 bool eof = false; 47 OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof)); 48 int64 timeout = -1; 49 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout)); 50 OP_REQUIRES(ctx, (timeout > 0), 51 errors::InvalidArgument( 52 "Timeout value should be large than 0, got ", timeout)); 53 *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); 54 } 55 56 private: 57 class Dataset : public GraphDatasetBase { 58 public: 59 Dataset(OpKernelContext* ctx, std::vector<string> topics, 60 const string& servers, const string& group, const bool eof, 61 const int64 timeout) 62 : GraphDatasetBase(ctx), 63 topics_(std::move(topics)), 64 servers_(servers), 65 group_(group), 66 eof_(eof), 67 timeout_(timeout) {} 68 69 std::unique_ptr<IteratorBase> MakeIterator( 70 const string& prefix) const override { 71 return std::unique_ptr<IteratorBase>( 72 new Iterator({this, strings::StrCat(prefix, "::Kafka")})); 73 } 74 75 const DataTypeVector& output_dtypes() const override { 76 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 77 return *dtypes; 78 } 79 80 const std::vector<PartialTensorShape>& output_shapes() const override { 81 static std::vector<PartialTensorShape>* shapes = 82 new std::vector<PartialTensorShape>({{}}); 83 return *shapes; 84 } 85 86 string DebugString() override { return "KafkaDatasetOp::Dataset"; } 87 88 protected: 89 Status AsGraphDefInternal(DatasetGraphDefBuilder* b, 90 Node** output) const override { 91 Node* topics = nullptr; 92 TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); 93 Node* servers = nullptr; 94 TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); 95 Node* group = nullptr; 96 TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); 97 Node* eof = nullptr; 98 TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); 99 Node* timeout = nullptr; 100 TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); 101 TF_RETURN_IF_ERROR( 102 b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); 103 return Status::OK(); 104 } 105 106 private: 107 class Iterator : public DatasetIterator<Dataset> { 108 public: 109 explicit Iterator(const Params& params) 110 : DatasetIterator<Dataset>(params) {} 111 112 Status GetNextInternal(IteratorContext* ctx, 113 std::vector<Tensor>* out_tensors, 114 bool* end_of_sequence) override { 115 mutex_lock l(mu_); 116 do { 117 // We are currently processing a topic, so try to read the next line. 118 if (consumer_.get()) { 119 while (true) { 120 if (limit_ >= 0 && 121 (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { 122 // EOF current topic 123 break; 124 } 125 std::unique_ptr<RdKafka::Message> message( 126 consumer_->consume(dataset()->timeout_)); 127 if (message->err() == RdKafka::ERR_NO_ERROR) { 128 // Produce the line as output. 129 Tensor line_tensor(cpu_allocator(), DT_STRING, {}); 130 line_tensor.scalar<string>()() = 131 std::string(static_cast<const char*>(message->payload()), 132 message->len()); 133 out_tensors->emplace_back(std::move(line_tensor)); 134 *end_of_sequence = false; 135 // Sync offset 136 offset_ = message->offset(); 137 return Status::OK(); 138 } 139 140 if (message->err() == RdKafka::ERR__PARTITION_EOF && 141 dataset()->eof_) { 142 // EOF current topic 143 break; 144 } 145 if (message->err() != RdKafka::ERR__TIMED_OUT) { 146 return errors::Internal("Failed to consume:", 147 message->errstr()); 148 } 149 message.reset(nullptr); 150 consumer_->poll(0); 151 } 152 153 // We have reached the end of the current topic, so maybe 154 // move on to next topic. 155 ResetStreamsLocked(); 156 ++current_topic_index_; 157 } 158 159 // Iteration ends when there are no more topic to process. 160 if (current_topic_index_ == dataset()->topics_.size()) { 161 *end_of_sequence = true; 162 return Status::OK(); 163 } 164 165 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 166 } while (true); 167 } 168 169 protected: 170 Status SaveInternal(IteratorStateWriter* writer) override { 171 mutex_lock l(mu_); 172 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), 173 current_topic_index_)); 174 175 // `consumer_` is empty if 176 // 1. GetNext has not been called even once. 177 // 2. All topics have been read and iterator has been exhausted. 178 if (consumer_.get()) { 179 TF_RETURN_IF_ERROR( 180 writer->WriteScalar(full_name("current_pos"), offset_)); 181 } 182 return Status::OK(); 183 } 184 185 Status RestoreInternal(IteratorContext* ctx, 186 IteratorStateReader* reader) override { 187 mutex_lock l(mu_); 188 ResetStreamsLocked(); 189 int64 current_topic_index; 190 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), 191 ¤t_topic_index)); 192 current_topic_index_ = size_t(current_topic_index); 193 // The key "current_pos" is written only if the iterator was saved 194 // with an open topic. 195 if (reader->Contains(full_name("current_pos"))) { 196 int64 current_pos; 197 TF_RETURN_IF_ERROR( 198 reader->ReadScalar(full_name("current_pos"), ¤t_pos)); 199 200 TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); 201 topic_partition_->set_offset(current_pos); 202 if (topic_partition_->offset() != current_pos) { 203 return errors::Internal("Failed to restore to offset ", 204 current_pos); 205 } 206 offset_ = current_pos; 207 } 208 return Status::OK(); 209 } 210 211 private: 212 // Sets up Kafka streams to read from the topic at 213 // `current_topic_index_`. 214 Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 215 if (current_topic_index_ >= dataset()->topics_.size()) { 216 return errors::InvalidArgument( 217 "current_topic_index_:", current_topic_index_, 218 " >= topics_.size():", dataset()->topics_.size()); 219 } 220 221 // Actually move on to next topic. 222 string entry = dataset()->topics_[current_topic_index_]; 223 224 std::vector<string> parts = str_util::Split(entry, ":"); 225 if (parts.size() < 1) { 226 return errors::InvalidArgument("Invalid parameters: ", entry); 227 } 228 string topic = parts[0]; 229 int32 partition = 0; 230 if (parts.size() > 1) { 231 if (!strings::safe_strto32(parts[1], &partition)) { 232 return errors::InvalidArgument("Invalid parameters: ", entry); 233 } 234 } 235 int64 offset = 0; 236 if (parts.size() > 2) { 237 if (!strings::safe_strto64(parts[2], &offset)) { 238 return errors::InvalidArgument("Invalid parameters: ", entry); 239 } 240 } 241 242 topic_partition_.reset( 243 RdKafka::TopicPartition::create(topic, partition, offset)); 244 245 offset_ = topic_partition_->offset(); 246 limit_ = -1; 247 if (parts.size() > 3) { 248 if (!strings::safe_strto64(parts[3], &limit_)) { 249 return errors::InvalidArgument("Invalid parameters: ", entry); 250 } 251 } 252 253 std::unique_ptr<RdKafka::Conf> conf( 254 RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); 255 std::unique_ptr<RdKafka::Conf> topic_conf( 256 RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); 257 258 std::string errstr; 259 260 RdKafka::Conf::ConfResult result = 261 conf->set("default_topic_conf", topic_conf.get(), errstr); 262 if (result != RdKafka::Conf::CONF_OK) { 263 return errors::Internal("Failed to set default_topic_conf:", errstr); 264 } 265 266 result = conf->set("bootstrap.servers", dataset()->servers_, errstr); 267 if (result != RdKafka::Conf::CONF_OK) { 268 return errors::Internal("Failed to set bootstrap.servers ", 269 dataset()->servers_, ":", errstr); 270 } 271 result = conf->set("group.id", dataset()->group_, errstr); 272 if (result != RdKafka::Conf::CONF_OK) { 273 return errors::Internal("Failed to set group.id ", dataset()->group_, 274 ":", errstr); 275 } 276 277 consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); 278 if (!consumer_.get()) { 279 return errors::Internal("Failed to create consumer:", errstr); 280 } 281 282 std::vector<RdKafka::TopicPartition*> partitions; 283 partitions.emplace_back(topic_partition_.get()); 284 RdKafka::ErrorCode err = consumer_->assign(partitions); 285 if (err != RdKafka::ERR_NO_ERROR) { 286 return errors::Internal( 287 "Failed to assign partition [", topic_partition_->topic(), ", ", 288 topic_partition_->partition(), ", ", topic_partition_->offset(), 289 "]:", RdKafka::err2str(err)); 290 } 291 292 return Status::OK(); 293 } 294 295 // Resets all Kafka streams. 296 void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 297 consumer_->unassign(); 298 consumer_->close(); 299 consumer_.reset(nullptr); 300 } 301 302 mutex mu_; 303 size_t current_topic_index_ GUARDED_BY(mu_) = 0; 304 int64 offset_ GUARDED_BY(mu_) = 0; 305 int64 limit_ GUARDED_BY(mu_) = -1; 306 std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_); 307 std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_); 308 }; 309 310 const std::vector<string> topics_; 311 const std::string servers_; 312 const std::string group_; 313 const bool eof_; 314 const int64 timeout_; 315 }; 316 }; 317 318 REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), 319 KafkaDatasetOp); 320 321 } // namespace tensorflow 322