      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #include "tensorflow/core/kernels/data/dataset_utils.h"
     17 #include "tensorflow/core/common_runtime/device.h"
     18 #include "tensorflow/core/common_runtime/function.h"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/lib/gtl/cleanup.h"
     21 #include "tensorflow/core/util/work_sharder.h"
     23 namespace tensorflow {
     24 namespace data {
     26 Status ComputeShortCircuitIndices(OpKernelConstruction* ctx,
     27                                   const NameAttrList& func,
     28                                   std::vector<int>* indices) {
     29   FunctionLibraryRuntime::Handle fn_handle;
     30   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
     31       func.name(), AttrSlice(&func.attr()), &fn_handle));
     32   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
     33     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
     34     if (!s.ok()) {
     35       LOG(WARNING) << "Failed to release handle: " << s.error_message();
     36     }
     37   });
     39   // If the function contains any stateful operations, we conservatively execute
     40   // the entire function.
     41   if (ctx->function_library()->IsStateful(func.name())) {
     42     indices->clear();
     43     return Status::OK();
     44   }
     46   const FunctionBody* fn_body =
     47       ctx->function_library()->GetFunctionBody(fn_handle);
     48   indices->resize(fn_body->ret_nodes.size());
     50   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
     51     Node* ret_node = fn_body->ret_nodes[i];
     52     Node* ret_input_node;
     53     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
     55     while (ret_input_node->def().op() == "Identity") {
     56       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
     57     }
     59     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
     60       TF_RETURN_IF_ERROR(
     61           GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
     62     } else {
     63       indices->clear();
     64       break;
     65     }
     66   }
     67   return Status::OK();
     68 }
     70 std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
     71   std::map<int, int> last_use;
     72   for (size_t i = 0; i < indices.size(); ++i) {
     73     last_use[indices[i]] = i;
     74   }
     75   std::vector<bool> can_move;
     76   can_move.resize(indices.size());
     77   for (size_t i = 0; i < indices.size(); ++i) {
     78     can_move[i] = last_use[indices[i]] == i;
     79   }
     80   return can_move;
     81 }
     83 Status MakeIteratorFromInputElement(
     84     IteratorContext* ctx, const std::vector<Tensor>& input_element,
     85     int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
     86     StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
     87   std::vector<Tensor> return_values;
     89   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
     90                                                             &return_values));
     92   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
     93         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
     94     return errors::InvalidArgument(
     95         "Function must return a single scalar of dtype DT_VARIANT.");
     96   }
     98   // Retrieve the dataset that was created in `f`.
     99   DatasetBase* returned_dataset;
    101       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
    103   // Create an iterator for the dataset that was returned by `f`.
    104   return returned_dataset->MakeIterator(
    105       ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
    106 }
    108 Status VerifyTypesMatch(const DataTypeVector& expected,
    109                         const DataTypeVector& received) {
    110   if (expected.size() != received.size()) {
    111     return errors::InvalidArgument(
    112         "Number of components does not match: expected ", expected.size(),
    113         " types but got ", received.size(), ".");
    114   }
    115   for (size_t i = 0; i < expected.size(); ++i) {
    116     if (expected[i] != received[i]) {
    117       return errors::InvalidArgument("Data type mismatch at component ", i,
    118                                      ": expected ", DataTypeString(expected[i]),
    119                                      " but got ", DataTypeString(received[i]),
    120                                      ".");
    121     }
    122   }
    123   return Status::OK();
    124 }
    126 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
    127                               const std::vector<PartialTensorShape>& received) {
    128   if (expected.size() != received.size()) {
    129     return errors::InvalidArgument(
    130         "Number of components does not match: expected ", expected.size(),
    131         " shapes but got ", received.size(), ".");
    132   }
    133   for (size_t i = 0; i < expected.size(); ++i) {
    134     if (!expected[i].IsCompatibleWith(received[i])) {
    135       return errors::InvalidArgument("Incompatible shapes at component ", i,
    136                                      ": expected ", expected[i].DebugString(),
    137                                      " but got ", received[i].DebugString(),
    138                                      ".");
    139     }
    140   }
    142   return Status::OK();
    143 }
    145 namespace {
    147 constexpr char kDelimiter[] = "@@";
    149 }  // namespace
    151 VariantTensorDataReader::VariantTensorDataReader(
    152     const tensorflow::VariantTensorData* data)
    153     : data_(data) {
    154   string metadata;
    155   data_->get_metadata(&metadata);
    156   auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
    157   for (size_t i = 0; i < keys.size(); ++i) {
    158     map_[keys[i]] = i;
    159   }
    160 }
    162 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) {
    163   return ReadScalarInternal(key, val);
    164 }
    166 Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) {
    167   return ReadScalarInternal(key, val);
    168 }
    170 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) {
    171   return ReadTensorInternal(key, val);
    172 }
    174 bool VariantTensorDataReader::Contains(StringPiece key) {
    175   return map_.find(string(key)) != map_.end();
    176 }
    178 template <typename T>
    179 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
    180   if (map_.find(string(key)) == map_.end()) {
    181     return errors::NotFound(key);
    182   }
    183   *val = data_->tensors(map_[string(key)]).scalar<T>()();
    184   return Status::OK();
    185 }
    187 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
    188                                                    Tensor* val) {
    189   if (map_.find(string(key)) == map_.end()) {
    190     return errors::NotFound(key);
    191   }
    192   *val = data_->tensors(map_[string(key)]);
    193   return Status::OK();
    194 }
    196 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
    197   return WriteScalarInternal(key, val);
    198 }
    200 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
    201                                             const string& val) {
    202   return WriteScalarInternal(key, val);
    203 }
    205 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
    206                                             const Tensor& val) {
    207   return WriteTensorInternal(key, val);
    208 }
    210 Status VariantTensorDataWriter::Flush() {
    211   string metadata;
    212   for (size_t i = 0; i < keys_.size(); ++i) {
    213     strings::StrAppend(&metadata, kDelimiter, keys_[i]);
    214   }
    215   data_->set_metadata(metadata);
    216   return Status::OK();
    217 }
    219 template <typename T>
    220 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
    221                                                     const T& val) {
    222   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
    223   val_t.scalar<T>()() = val;
    224   return WriteTensorInternal(key, val_t);
    225 }
    227 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
    228                                                     const Tensor& val) {
    229   DCHECK_EQ(key.find(kDelimiter), string::npos);
    230   keys_.push_back(string(key));
    231   *(data_->add_tensors()) = val;
    232   return Status::OK();
    233 }
    235 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
    236                             const FunctionLibraryDefinition& to_add) {
    237   for (const auto& fn : to_add.ListFunctionNames()) {
    238     if (auto found = base->Find(fn)) {
    239       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
    240         return errors::InvalidArgument("Cannot add function '", fn,
    241                                        "' because a different function with "
    242                                        "the same signature already exists.");
    243       }
    244       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
    245     }
    246   }
    247   return base->AddLibrary(to_add);
    248 }
    250 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
    251                             const FunctionDefLibrary& to_add) {
    252   for (const auto& fd : to_add.function()) {
    253     if (auto found = base->Find(fd.signature().name())) {
    254       if (!OpDefEqual(found->signature(), fd.signature())) {
    255         return errors::InvalidArgument("Cannot add function '",
    256                                        fd.signature().name(),
    257                                        "' because a different function with "
    258                                        "the same signature already exists.");
    259       }
    260       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
    261     }
    262   }
    263   return base->AddLibrary(to_add);
    264 }
    266 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
    267     std::function<void(std::function<void()>)> runner, int max_parallelism) {
    268   return std::bind(
    269       [max_parallelism](
    270           // Note: `runner` is a const reference to avoid copying it.
    271           const std::function<void(std::function<void()>)>& runner,
    272           std::function<void()> fn) {
    273         std::function<void()> scoped_fn = std::bind(
    274             [max_parallelism](const std::function<void()>& fn) {
    275               ScopedPerThreadMaxParallelism scope(max_parallelism);
    276               fn();
    277             },
    278             std::move(fn));
    279         runner(std::move(scoped_fn));
    280       },
    281       std::move(runner), std::placeholders::_1);
    282 }
    283 }  // namespace data
    284 }  // namespace tensorflow