Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/dataset.h"
     17 
     18 #include "tensorflow/core/framework/tensor.h"
     19 
     20 #include "src-cpp/rdkafkacpp.h"
     21 
     22 namespace tensorflow {
     23 
     24 class KafkaDatasetOp : public DatasetOpKernel {
     25  public:
     26   using DatasetOpKernel::DatasetOpKernel;
     27 
     28   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
     29     const Tensor* topics_tensor;
     30     OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor));
     31     OP_REQUIRES(
     32         ctx, topics_tensor->dims() <= 1,
     33         errors::InvalidArgument("`topics` must be a scalar or a vector."));
     34 
     35     std::vector<string> topics;
     36     topics.reserve(topics_tensor->NumElements());
     37     for (int i = 0; i < topics_tensor->NumElements(); ++i) {
     38       topics.push_back(topics_tensor->flat<string>()(i));
     39     }
     40 
     41     std::string servers = "";
     42     OP_REQUIRES_OK(ctx,
     43                    ParseScalarArgument<std::string>(ctx, "servers", &servers));
     44     std::string group = "";
     45     OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "group", &group));
     46     bool eof = false;
     47     OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "eof", &eof));
     48     int64 timeout = -1;
     49     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "timeout", &timeout));
     50     OP_REQUIRES(ctx, (timeout > 0),
     51                 errors::InvalidArgument(
     52                     "Timeout value should be large than 0, got ", timeout));
     53     *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout);
     54   }
     55 
     56  private:
     57   class Dataset : public GraphDatasetBase {
     58    public:
     59     Dataset(OpKernelContext* ctx, std::vector<string> topics,
     60             const string& servers, const string& group, const bool eof,
     61             const int64 timeout)
     62         : GraphDatasetBase(ctx),
     63           topics_(std::move(topics)),
     64           servers_(servers),
     65           group_(group),
     66           eof_(eof),
     67           timeout_(timeout) {}
     68 
     69     std::unique_ptr<IteratorBase> MakeIterator(
     70         const string& prefix) const override {
     71       return std::unique_ptr<IteratorBase>(
     72           new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
     73     }
     74 
     75     const DataTypeVector& output_dtypes() const override {
     76       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
     77       return *dtypes;
     78     }
     79 
     80     const std::vector<PartialTensorShape>& output_shapes() const override {
     81       static std::vector<PartialTensorShape>* shapes =
     82           new std::vector<PartialTensorShape>({{}});
     83       return *shapes;
     84     }
     85 
     86     string DebugString() override { return "KafkaDatasetOp::Dataset"; }
     87 
     88    protected:
     89     Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
     90                               Node** output) const override {
     91       Node* topics = nullptr;
     92       TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics));
     93       Node* servers = nullptr;
     94       TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers));
     95       Node* group = nullptr;
     96       TF_RETURN_IF_ERROR(b->AddScalar(group_, &group));
     97       Node* eof = nullptr;
     98       TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof));
     99       Node* timeout = nullptr;
    100       TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout));
    101       TF_RETURN_IF_ERROR(
    102           b->AddDataset(this, {topics, servers, group, eof, timeout}, output));
    103       return Status::OK();
    104     }
    105 
    106    private:
    107     class Iterator : public DatasetIterator<Dataset> {
    108      public:
    109       explicit Iterator(const Params& params)
    110           : DatasetIterator<Dataset>(params) {}
    111 
    112       Status GetNextInternal(IteratorContext* ctx,
    113                              std::vector<Tensor>* out_tensors,
    114                              bool* end_of_sequence) override {
    115         mutex_lock l(mu_);
    116         do {
    117           // We are currently processing a topic, so try to read the next line.
    118           if (consumer_.get()) {
    119             while (true) {
    120               if (limit_ >= 0 &&
    121                   (topic_partition_->offset() >= limit_ || offset_ >= limit_)) {
    122                 // EOF current topic
    123                 break;
    124               }
    125               std::unique_ptr<RdKafka::Message> message(
    126                   consumer_->consume(dataset()->timeout_));
    127               if (message->err() == RdKafka::ERR_NO_ERROR) {
    128                 // Produce the line as output.
    129                 Tensor line_tensor(cpu_allocator(), DT_STRING, {});
    130                 line_tensor.scalar<string>()() =
    131                     std::string(static_cast<const char*>(message->payload()),
    132                                 message->len());
    133                 out_tensors->emplace_back(std::move(line_tensor));
    134                 *end_of_sequence = false;
    135                 // Sync offset
    136                 offset_ = message->offset();
    137                 return Status::OK();
    138               }
    139 
    140               if (message->err() == RdKafka::ERR__PARTITION_EOF &&
    141                   dataset()->eof_) {
    142                 // EOF current topic
    143                 break;
    144               }
    145               if (message->err() != RdKafka::ERR__TIMED_OUT) {
    146                 return errors::Internal("Failed to consume:",
    147                                         message->errstr());
    148               }
    149               message.reset(nullptr);
    150               consumer_->poll(0);
    151             }
    152 
    153             // We have reached the end of the current topic, so maybe
    154             // move on to next topic.
    155             ResetStreamsLocked();
    156             ++current_topic_index_;
    157           }
    158 
    159           // Iteration ends when there are no more topic to process.
    160           if (current_topic_index_ == dataset()->topics_.size()) {
    161             *end_of_sequence = true;
    162             return Status::OK();
    163           }
    164 
    165           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    166         } while (true);
    167       }
    168 
    169      protected:
    170       Status SaveInternal(IteratorStateWriter* writer) override {
    171         mutex_lock l(mu_);
    172         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"),
    173                                                current_topic_index_));
    174 
    175         // `consumer_` is empty if
    176         // 1. GetNext has not been called even once.
    177         // 2. All topics have been read and iterator has been exhausted.
    178         if (consumer_.get()) {
    179           TF_RETURN_IF_ERROR(
    180               writer->WriteScalar(full_name("current_pos"), offset_));
    181         }
    182         return Status::OK();
    183       }
    184 
    185       Status RestoreInternal(IteratorContext* ctx,
    186                              IteratorStateReader* reader) override {
    187         mutex_lock l(mu_);
    188         ResetStreamsLocked();
    189         int64 current_topic_index;
    190         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"),
    191                                               &current_topic_index));
    192         current_topic_index_ = size_t(current_topic_index);
    193         // The key "current_pos" is written only if the iterator was saved
    194         // with an open topic.
    195         if (reader->Contains(full_name("current_pos"))) {
    196           int64 current_pos;
    197           TF_RETURN_IF_ERROR(
    198               reader->ReadScalar(full_name("current_pos"), &current_pos));
    199 
    200           TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
    201           topic_partition_->set_offset(current_pos);
    202           if (topic_partition_->offset() != current_pos) {
    203             return errors::Internal("Failed to restore to offset ",
    204                                     current_pos);
    205           }
    206           offset_ = current_pos;
    207         }
    208         return Status::OK();
    209       }
    210 
    211      private:
    212       // Sets up Kafka streams to read from the topic at
    213       // `current_topic_index_`.
    214       Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    215         if (current_topic_index_ >= dataset()->topics_.size()) {
    216           return errors::InvalidArgument(
    217               "current_topic_index_:", current_topic_index_,
    218               " >= topics_.size():", dataset()->topics_.size());
    219         }
    220 
    221         // Actually move on to next topic.
    222         string entry = dataset()->topics_[current_topic_index_];
    223 
    224         std::vector<string> parts = str_util::Split(entry, ":");
    225         if (parts.size() < 1) {
    226           return errors::InvalidArgument("Invalid parameters: ", entry);
    227         }
    228         string topic = parts[0];
    229         int32 partition = 0;
    230         if (parts.size() > 1) {
    231           if (!strings::safe_strto32(parts[1], &partition)) {
    232             return errors::InvalidArgument("Invalid parameters: ", entry);
    233           }
    234         }
    235         int64 offset = 0;
    236         if (parts.size() > 2) {
    237           if (!strings::safe_strto64(parts[2], &offset)) {
    238             return errors::InvalidArgument("Invalid parameters: ", entry);
    239           }
    240         }
    241 
    242         topic_partition_.reset(
    243             RdKafka::TopicPartition::create(topic, partition, offset));
    244 
    245         offset_ = topic_partition_->offset();
    246         limit_ = -1;
    247         if (parts.size() > 3) {
    248           if (!strings::safe_strto64(parts[3], &limit_)) {
    249             return errors::InvalidArgument("Invalid parameters: ", entry);
    250           }
    251         }
    252 
    253         std::unique_ptr<RdKafka::Conf> conf(
    254             RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL));
    255         std::unique_ptr<RdKafka::Conf> topic_conf(
    256             RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC));
    257 
    258         std::string errstr;
    259 
    260         RdKafka::Conf::ConfResult result =
    261             conf->set("default_topic_conf", topic_conf.get(), errstr);
    262         if (result != RdKafka::Conf::CONF_OK) {
    263           return errors::Internal("Failed to set default_topic_conf:", errstr);
    264         }
    265 
    266         result = conf->set("bootstrap.servers", dataset()->servers_, errstr);
    267         if (result != RdKafka::Conf::CONF_OK) {
    268           return errors::Internal("Failed to set bootstrap.servers ",
    269                                   dataset()->servers_, ":", errstr);
    270         }
    271         result = conf->set("group.id", dataset()->group_, errstr);
    272         if (result != RdKafka::Conf::CONF_OK) {
    273           return errors::Internal("Failed to set group.id ", dataset()->group_,
    274                                   ":", errstr);
    275         }
    276 
    277         consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr));
    278         if (!consumer_.get()) {
    279           return errors::Internal("Failed to create consumer:", errstr);
    280         }
    281 
    282         std::vector<RdKafka::TopicPartition*> partitions;
    283         partitions.emplace_back(topic_partition_.get());
    284         RdKafka::ErrorCode err = consumer_->assign(partitions);
    285         if (err != RdKafka::ERR_NO_ERROR) {
    286           return errors::Internal(
    287               "Failed to assign partition [", topic_partition_->topic(), ", ",
    288               topic_partition_->partition(), ", ", topic_partition_->offset(),
    289               "]:", RdKafka::err2str(err));
    290         }
    291 
    292         return Status::OK();
    293       }
    294 
    295       // Resets all Kafka streams.
    296       void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    297         consumer_->unassign();
    298         consumer_->close();
    299         consumer_.reset(nullptr);
    300       }
    301 
    302       mutex mu_;
    303       size_t current_topic_index_ GUARDED_BY(mu_) = 0;
    304       int64 offset_ GUARDED_BY(mu_) = 0;
    305       int64 limit_ GUARDED_BY(mu_) = -1;
    306       std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_);
    307       std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_);
    308     };
    309 
    310     const std::vector<string> topics_;
    311     const std::string servers_;
    312     const std::string group_;
    313     const bool eof_;
    314     const int64 timeout_;
    315   };
    316 };
    317 
    318 REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU),
    319                         KafkaDatasetOp);
    320 
    321 }  // namespace tensorflow
    322