Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/data/dataset_utils.h"
     17 #include "tensorflow/core/common_runtime/device.h"
     18 #include "tensorflow/core/common_runtime/function.h"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/lib/gtl/cleanup.h"
     21 #include "tensorflow/core/util/work_sharder.h"
     22 
     23 namespace tensorflow {
     24 namespace data {
     25 
     26 Status ComputeShortCircuitIndices(OpKernelConstruction* ctx,
     27                                   const NameAttrList& func,
     28                                   std::vector<int>* indices) {
     29   FunctionLibraryRuntime::Handle fn_handle;
     30   TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
     31       func.name(), AttrSlice(&func.attr()), &fn_handle));
     32   auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
     33     Status s = ctx->function_library()->ReleaseHandle(fn_handle);
     34     if (!s.ok()) {
     35       LOG(WARNING) << "Failed to release handle: " << s.error_message();
     36     }
     37   });
     38 
     39   // If the function contains any stateful operations, we conservatively execute
     40   // the entire function.
     41   if (ctx->function_library()->IsStateful(func.name())) {
     42     indices->clear();
     43     return Status::OK();
     44   }
     45 
     46   const FunctionBody* fn_body =
     47       ctx->function_library()->GetFunctionBody(fn_handle);
     48   indices->resize(fn_body->ret_nodes.size());
     49 
     50   for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
     51     Node* ret_node = fn_body->ret_nodes[i];
     52     Node* ret_input_node;
     53     TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
     54 
     55     while (ret_input_node->def().op() == "Identity") {
     56       TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node));
     57     }
     58 
     59     if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
     60       TF_RETURN_IF_ERROR(
     61           GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
     62     } else {
     63       indices->clear();
     64       break;
     65     }
     66   }
     67   return Status::OK();
     68 }
     69 
     70 std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
     71   std::map<int, int> last_use;
     72   for (size_t i = 0; i < indices.size(); ++i) {
     73     last_use[indices[i]] = i;
     74   }
     75   std::vector<bool> can_move;
     76   can_move.resize(indices.size());
     77   for (size_t i = 0; i < indices.size(); ++i) {
     78     can_move[i] = last_use[indices[i]] == i;
     79   }
     80   return can_move;
     81 }
     82 
     83 Status MakeIteratorFromInputElement(
     84     IteratorContext* ctx, const std::vector<Tensor>& input_element,
     85     int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
     86     StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
     87   std::vector<Tensor> return_values;
     88 
     89   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
     90                                                             &return_values));
     91 
     92   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
     93         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
     94     return errors::InvalidArgument(
     95         "Function must return a single scalar of dtype DT_VARIANT.");
     96   }
     97 
     98   // Retrieve the dataset that was created in `f`.
     99   DatasetBase* returned_dataset;
    100   TF_RETURN_IF_ERROR(
    101       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
    102 
    103   // Create an iterator for the dataset that was returned by `f`.
    104   return returned_dataset->MakeIterator(
    105       ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
    106 }
    107 
    108 Status VerifyTypesMatch(const DataTypeVector& expected,
    109                         const DataTypeVector& received) {
    110   if (expected.size() != received.size()) {
    111     return errors::InvalidArgument(
    112         "Number of components does not match: expected ", expected.size(),
    113         " types but got ", received.size(), ".");
    114   }
    115   for (size_t i = 0; i < expected.size(); ++i) {
    116     if (expected[i] != received[i]) {
    117       return errors::InvalidArgument("Data type mismatch at component ", i,
    118                                      ": expected ", DataTypeString(expected[i]),
    119                                      " but got ", DataTypeString(received[i]),
    120                                      ".");
    121     }
    122   }
    123   return Status::OK();
    124 }
    125 
    126 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
    127                               const std::vector<PartialTensorShape>& received) {
    128   if (expected.size() != received.size()) {
    129     return errors::InvalidArgument(
    130         "Number of components does not match: expected ", expected.size(),
    131         " shapes but got ", received.size(), ".");
    132   }
    133   for (size_t i = 0; i < expected.size(); ++i) {
    134     if (!expected[i].IsCompatibleWith(received[i])) {
    135       return errors::InvalidArgument("Incompatible shapes at component ", i,
    136                                      ": expected ", expected[i].DebugString(),
    137                                      " but got ", received[i].DebugString(),
    138                                      ".");
    139     }
    140   }
    141 
    142   return Status::OK();
    143 }
    144 
    145 namespace {
    146 
    147 constexpr char kDelimiter[] = "@@";
    148 
    149 }  // namespace
    150 
    151 VariantTensorDataReader::VariantTensorDataReader(
    152     const tensorflow::VariantTensorData* data)
    153     : data_(data) {
    154   string metadata;
    155   data_->get_metadata(&metadata);
    156   auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
    157   for (size_t i = 0; i < keys.size(); ++i) {
    158     map_[keys[i]] = i;
    159   }
    160 }
    161 
    162 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) {
    163   return ReadScalarInternal(key, val);
    164 }
    165 
    166 Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) {
    167   return ReadScalarInternal(key, val);
    168 }
    169 
    170 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) {
    171   return ReadTensorInternal(key, val);
    172 }
    173 
    174 bool VariantTensorDataReader::Contains(StringPiece key) {
    175   return map_.find(string(key)) != map_.end();
    176 }
    177 
    178 template <typename T>
    179 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) {
    180   if (map_.find(string(key)) == map_.end()) {
    181     return errors::NotFound(key);
    182   }
    183   *val = data_->tensors(map_[string(key)]).scalar<T>()();
    184   return Status::OK();
    185 }
    186 
    187 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key,
    188                                                    Tensor* val) {
    189   if (map_.find(string(key)) == map_.end()) {
    190     return errors::NotFound(key);
    191   }
    192   *val = data_->tensors(map_[string(key)]);
    193   return Status::OK();
    194 }
    195 
    196 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) {
    197   return WriteScalarInternal(key, val);
    198 }
    199 
    200 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
    201                                             const string& val) {
    202   return WriteScalarInternal(key, val);
    203 }
    204 
    205 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
    206                                             const Tensor& val) {
    207   return WriteTensorInternal(key, val);
    208 }
    209 
    210 Status VariantTensorDataWriter::Flush() {
    211   string metadata;
    212   for (size_t i = 0; i < keys_.size(); ++i) {
    213     strings::StrAppend(&metadata, kDelimiter, keys_[i]);
    214   }
    215   data_->set_metadata(metadata);
    216   return Status::OK();
    217 }
    218 
    219 template <typename T>
    220 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key,
    221                                                     const T& val) {
    222   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
    223   val_t.scalar<T>()() = val;
    224   return WriteTensorInternal(key, val_t);
    225 }
    226 
    227 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key,
    228                                                     const Tensor& val) {
    229   DCHECK_EQ(key.find(kDelimiter), string::npos);
    230   keys_.push_back(string(key));
    231   *(data_->add_tensors()) = val;
    232   return Status::OK();
    233 }
    234 
    235 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
    236                             const FunctionLibraryDefinition& to_add) {
    237   for (const auto& fn : to_add.ListFunctionNames()) {
    238     if (auto found = base->Find(fn)) {
    239       if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) {
    240         return errors::InvalidArgument("Cannot add function '", fn,
    241                                        "' because a different function with "
    242                                        "the same signature already exists.");
    243       }
    244       TF_RETURN_IF_ERROR(base->RemoveFunction(fn));
    245     }
    246   }
    247   return base->AddLibrary(to_add);
    248 }
    249 
    250 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
    251                             const FunctionDefLibrary& to_add) {
    252   for (const auto& fd : to_add.function()) {
    253     if (auto found = base->Find(fd.signature().name())) {
    254       if (!OpDefEqual(found->signature(), fd.signature())) {
    255         return errors::InvalidArgument("Cannot add function '",
    256                                        fd.signature().name(),
    257                                        "' because a different function with "
    258                                        "the same signature already exists.");
    259       }
    260       TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name()));
    261     }
    262   }
    263   return base->AddLibrary(to_add);
    264 }
    265 
    266 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
    267     std::function<void(std::function<void()>)> runner, int max_parallelism) {
    268   return std::bind(
    269       [max_parallelism](
    270           // Note: `runner` is a const reference to avoid copying it.
    271           const std::function<void(std::function<void()>)>& runner,
    272           std::function<void()> fn) {
    273         std::function<void()> scoped_fn = std::bind(
    274             [max_parallelism](const std::function<void()>& fn) {
    275               ScopedPerThreadMaxParallelism scope(max_parallelism);
    276               fn();
    277             },
    278             std::move(fn));
    279         runner(std::move(scoped_fn));
    280       },
    281       std::move(runner), std::placeholders::_1);
    282 }
    283 }  // namespace data
    284 }  // namespace tensorflow
    285