Home | History | Annotate | Download | only in framework
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/framework/dataset.h"
     16 
     17 #include "tensorflow/core/graph/graph_def_builder.h"
     18 #include "tensorflow/core/graph/node_builder.h"
     19 
     20 namespace tensorflow {
     21 
     22 namespace {
     23 
     24 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
     25 // Objects of the wrapper class own a reference on an instance of `DatasetBase`,
     26 // and the wrapper's copy constructor and destructor take care of managing the
     27 // reference count.
     28 //
     29 // NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
     30 // specification. In particular, we cannot currently serialize an arbitrary
     31 // `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
     32 // implemented.
     33 class DatasetVariantWrapper {
     34  public:
     35   DatasetVariantWrapper() : dataset_(nullptr) {}
     36 
     37   // Transfers ownership of `dataset` to `*this`.
     38   explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
     39 
     40   DatasetVariantWrapper(const DatasetVariantWrapper& other)
     41       : dataset_(other.dataset_) {
     42     if (dataset_) dataset_->Ref();
     43   }
     44 
     45   ~DatasetVariantWrapper() {
     46     if (dataset_) dataset_->Unref();
     47   }
     48 
     49   DatasetBase* get() const { return dataset_; }
     50 
     51   string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
     52   string DebugString() const {
     53     if (dataset_) {
     54       return dataset_->DebugString();
     55     } else {
     56       return "<Uninitialized DatasetVariantWrapper>";
     57     }
     58   }
     59   void Encode(VariantTensorData* data) const {
     60     LOG(ERROR) << "The Encode() method is not implemented for "
     61                   "DatasetVariantWrapper objects.";
     62   }
     63   bool Decode(const VariantTensorData& data) {
     64     LOG(ERROR) << "The Decode() method is not implemented for "
     65                   "DatasetVariantWrapper objects.";
     66     return false;
     67   }
     68 
     69  private:
     70   DatasetBase* const dataset_;  // Owns one reference.
     71 };
     72 
     73 }  // namespace
     74 
     75 Status GraphDefBuilderWrapper::AddDataset(
     76     const GraphDatasetBase* dataset,
     77     const std::vector<std::pair<size_t, Node*>>& inputs,
     78     const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
     79     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
     80     Node** output) {
     81   const string& op_type_name = dataset->op_name();
     82   std::unique_ptr<const GraphDefBuilder::Options> opts(
     83       new GraphDefBuilder::Options(b_->opts()));
     84   // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
     85   // attributes defined. It will be nice to have a consistent pattern.
     86   bool has_output_types_attr = HasAttr(op_type_name, "output_types");
     87   bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
     88   if (has_output_shapes_attr) {
     89     opts.reset(new GraphDefBuilder::Options(
     90         opts->WithAttr("output_shapes", dataset->output_shapes())));
     91   }
     92   if (has_output_types_attr) {
     93     opts.reset(new GraphDefBuilder::Options(
     94         opts->WithAttr("output_types", dataset->output_dtypes())));
     95   }
     96   for (auto attr : attrs) {
     97     opts.reset(
     98         new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
     99   }
    100   if (opts->HaveError()) {
    101     return errors::Internal("AddDataset: Failed to build Options with error ",
    102                             opts->StatusToString());
    103   }
    104   NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
    105                            opts->op_registry());
    106   {
    107     size_t total_size = inputs.size() + list_inputs.size();
    108     auto inputs_iter = inputs.begin();
    109     auto list_inputs_iter = list_inputs.begin();
    110     for (int i = 0; i < total_size; i++) {
    111       if (inputs_iter != inputs.end() && inputs_iter->first == i) {
    112         node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
    113         inputs_iter++;
    114       } else if (list_inputs_iter != list_inputs.end() &&
    115                  list_inputs_iter->first == i) {
    116         std::vector<NodeBuilder::NodeOut> nodeout_inputs;
    117         nodeout_inputs.reserve(list_inputs_iter->second.size());
    118         for (Node* n : list_inputs_iter->second) {
    119           nodeout_inputs.emplace_back(n);
    120         }
    121         node_builder.Input(nodeout_inputs);
    122         list_inputs_iter++;
    123       } else {
    124         return errors::InvalidArgument("No input found for index ", i);
    125       }
    126     }
    127   }
    128   *output = opts->FinalizeBuilder(&node_builder);
    129   if (*output == nullptr) {
    130     return errors::Internal("AddDataset: Failed to build ", op_type_name,
    131                             " op with error ", opts->StatusToString());
    132   }
    133   return Status::OK();
    134 }
    135 
    136 Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
    137                                            const string& function_name) {
    138   if (b_->HasFunction(function_name)) {
    139     LOG(INFO) << "Function with name " << function_name << "already exists in"
    140               << " the graph. It will not be added again.";
    141     return Status::OK();
    142   }
    143   TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
    144   const FunctionLibraryDefinition* flib_def =
    145       ctx->function_library()->GetFunctionLibraryDefinition();
    146   const FunctionDef* f_def = flib_def->Find(function_name);
    147   if (f_def == nullptr) {
    148     return errors::InvalidArgument("Unable to find FunctionDef for ",
    149                                    function_name, " in the registry.");
    150   }
    151   FunctionDefLibrary def;
    152   *def.add_function() = *f_def;
    153   const string gradient_func = flib_def->FindGradient(function_name);
    154   if (!gradient_func.empty()) {
    155     GradientDef* g_def = def.add_gradient();
    156     g_def->set_function_name(function_name);
    157     g_def->set_gradient_func(gradient_func);
    158   }
    159   TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
    160 
    161   // Recursively add functions in inputs of function_name.
    162   for (const NodeDef& node_def : f_def->node_def()) {
    163     const OpRegistrationData* op_reg_data = nullptr;
    164     TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
    165     if (op_reg_data->is_function_op) {
    166       TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
    167     }
    168     // Recursively add functions in attrs of this NodeDef.
    169     for (const auto& pair : node_def.attr()) {
    170       TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
    171     }
    172   }
    173 
    174   // Recursively add functions in attrs of function_name.
    175   for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
    176     TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
    177   }
    178   return Status::OK();
    179 }
    180 
    181 void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
    182                                                Node** output) {
    183   *output = ops::SourceOp(
    184       "Const",
    185       b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
    186 }
    187 
    188 bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
    189                                      const string& attr_name) const {
    190   const OpDef* op_def = nullptr;
    191   Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
    192   if (!s.ok() || op_def == nullptr) {
    193     return false;
    194   }
    195   return HasAttr(op_def, attr_name);
    196 }
    197 
    198 Status GraphDatasetBase::Serialize(OpKernelContext* ctx,
    199                                    string* serialized_graph_def,
    200                                    string* output_node) const {
    201   GraphDefBuilder b;
    202   DatasetGraphDefBuilder db(&b);
    203   Node* node = nullptr;
    204   TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
    205   *output_node = node->name();
    206   GraphDef graph_def;
    207   TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
    208   graph_def.SerializeToString(serialized_graph_def);
    209   return Status::OK();
    210 }
    211 
    212 Status GetDatasetFromVariantTensor(const Tensor& tensor,
    213                                    DatasetBase** out_dataset) {
    214   if (!(tensor.dtype() == DT_VARIANT ||
    215         TensorShapeUtils::IsScalar(tensor.shape()))) {
    216     return errors::InvalidArgument(
    217         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
    218   }
    219   const Variant& variant = tensor.scalar<Variant>()();
    220   const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
    221   if (wrapper == nullptr) {
    222     return errors::InvalidArgument("Tensor must be a Dataset object.");
    223   }
    224   *out_dataset = wrapper->get();
    225   if (*out_dataset == nullptr) {
    226     return errors::Internal("Read uninitialized Dataset variant.");
    227   }
    228   return Status::OK();
    229 }
    230 
    231 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
    232   if (!(tensor->dtype() == DT_VARIANT ||
    233         TensorShapeUtils::IsScalar(tensor->shape()))) {
    234     return errors::InvalidArgument(
    235         "Dataset tensor must be a scalar of dtype DT_VARIANT.");
    236   }
    237   tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
    238   return Status::OK();
    239 }
    240 
    241 void DatasetOpKernel::Compute(OpKernelContext* ctx) {
    242   DatasetBase* dataset = nullptr;
    243   MakeDataset(ctx, &dataset);
    244   if (ctx->status().ok()) {
    245     Tensor* output = nullptr;
    246     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
    247     OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
    248   }
    249 }
    250 
    251 void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
    252                                        DatasetBase** output) {
    253   DatasetBase* input;
    254   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
    255   MakeDataset(ctx, input, output);
    256 }
    257 
    258 void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
    259                                         DatasetBase** output) {
    260   DatasetBase* input;
    261   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
    262   DatasetBase* another_input;
    263   OP_REQUIRES_OK(ctx,
    264                  GetDatasetFromVariantTensor(ctx->input(1), &another_input));
    265   MakeDataset(ctx, input, another_input, output);
    266 }
    267 
    268 const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
    269 const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
    270     "_DATASET_GRAPH_OUTPUT_NODE";
    271 
    272 }  // namespace tensorflow
    273