1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/data/dataset_utils.h" 17 #include "tensorflow/core/common_runtime/device.h" 18 #include "tensorflow/core/common_runtime/function.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/lib/gtl/cleanup.h" 21 #include "tensorflow/core/util/work_sharder.h" 22 23 namespace tensorflow { 24 namespace data { 25 26 Status ComputeShortCircuitIndices(OpKernelConstruction* ctx, 27 const NameAttrList& func, 28 std::vector<int>* indices) { 29 FunctionLibraryRuntime::Handle fn_handle; 30 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( 31 func.name(), AttrSlice(&func.attr()), &fn_handle)); 32 auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { 33 Status s = ctx->function_library()->ReleaseHandle(fn_handle); 34 if (!s.ok()) { 35 LOG(WARNING) << "Failed to release handle: " << s.error_message(); 36 } 37 }); 38 39 // If the function contains any stateful operations, we conservatively execute 40 // the entire function. 41 if (ctx->function_library()->IsStateful(func.name())) { 42 indices->clear(); 43 return Status::OK(); 44 } 45 46 const FunctionBody* fn_body = 47 ctx->function_library()->GetFunctionBody(fn_handle); 48 indices->resize(fn_body->ret_nodes.size()); 49 50 for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { 51 Node* ret_node = fn_body->ret_nodes[i]; 52 Node* ret_input_node; 53 TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); 54 55 while (ret_input_node->def().op() == "Identity") { 56 TF_RETURN_IF_ERROR(ret_input_node->input_node(0, &ret_input_node)); 57 } 58 59 if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { 60 TF_RETURN_IF_ERROR( 61 GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); 62 } else { 63 indices->clear(); 64 break; 65 } 66 } 67 return Status::OK(); 68 } 69 70 std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) { 71 std::map<int, int> last_use; 72 for (size_t i = 0; i < indices.size(); ++i) { 73 last_use[indices[i]] = i; 74 } 75 std::vector<bool> can_move; 76 can_move.resize(indices.size()); 77 for (size_t i = 0; i < indices.size(); ++i) { 78 can_move[i] = last_use[indices[i]] == i; 79 } 80 return can_move; 81 } 82 83 Status MakeIteratorFromInputElement( 84 IteratorContext* ctx, const std::vector<Tensor>& input_element, 85 int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func, 86 StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) { 87 std::vector<Tensor> return_values; 88 89 TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element, 90 &return_values)); 91 92 if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && 93 TensorShapeUtils::IsScalar(return_values[0].shape()))) { 94 return errors::InvalidArgument( 95 "Function must return a single scalar of dtype DT_VARIANT."); 96 } 97 98 // Retrieve the dataset that was created in `f`. 99 DatasetBase* returned_dataset; 100 TF_RETURN_IF_ERROR( 101 GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); 102 103 // Create an iterator for the dataset that was returned by `f`. 104 return returned_dataset->MakeIterator( 105 ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); 106 } 107 108 Status VerifyTypesMatch(const DataTypeVector& expected, 109 const DataTypeVector& received) { 110 if (expected.size() != received.size()) { 111 return errors::InvalidArgument( 112 "Number of components does not match: expected ", expected.size(), 113 " types but got ", received.size(), "."); 114 } 115 for (size_t i = 0; i < expected.size(); ++i) { 116 if (expected[i] != received[i]) { 117 return errors::InvalidArgument("Data type mismatch at component ", i, 118 ": expected ", DataTypeString(expected[i]), 119 " but got ", DataTypeString(received[i]), 120 "."); 121 } 122 } 123 return Status::OK(); 124 } 125 126 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, 127 const std::vector<PartialTensorShape>& received) { 128 if (expected.size() != received.size()) { 129 return errors::InvalidArgument( 130 "Number of components does not match: expected ", expected.size(), 131 " shapes but got ", received.size(), "."); 132 } 133 for (size_t i = 0; i < expected.size(); ++i) { 134 if (!expected[i].IsCompatibleWith(received[i])) { 135 return errors::InvalidArgument("Incompatible shapes at component ", i, 136 ": expected ", expected[i].DebugString(), 137 " but got ", received[i].DebugString(), 138 "."); 139 } 140 } 141 142 return Status::OK(); 143 } 144 145 namespace { 146 147 constexpr char kDelimiter[] = "@@"; 148 149 } // namespace 150 151 VariantTensorDataReader::VariantTensorDataReader( 152 const tensorflow::VariantTensorData* data) 153 : data_(data) { 154 string metadata; 155 data_->get_metadata(&metadata); 156 auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty()); 157 for (size_t i = 0; i < keys.size(); ++i) { 158 map_[keys[i]] = i; 159 } 160 } 161 162 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) { 163 return ReadScalarInternal(key, val); 164 } 165 166 Status VariantTensorDataReader::ReadScalar(StringPiece key, string* val) { 167 return ReadScalarInternal(key, val); 168 } 169 170 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) { 171 return ReadTensorInternal(key, val); 172 } 173 174 bool VariantTensorDataReader::Contains(StringPiece key) { 175 return map_.find(string(key)) != map_.end(); 176 } 177 178 template <typename T> 179 Status VariantTensorDataReader::ReadScalarInternal(StringPiece key, T* val) { 180 if (map_.find(string(key)) == map_.end()) { 181 return errors::NotFound(key); 182 } 183 *val = data_->tensors(map_[string(key)]).scalar<T>()(); 184 return Status::OK(); 185 } 186 187 Status VariantTensorDataReader::ReadTensorInternal(StringPiece key, 188 Tensor* val) { 189 if (map_.find(string(key)) == map_.end()) { 190 return errors::NotFound(key); 191 } 192 *val = data_->tensors(map_[string(key)]); 193 return Status::OK(); 194 } 195 196 Status VariantTensorDataWriter::WriteScalar(StringPiece key, const int64 val) { 197 return WriteScalarInternal(key, val); 198 } 199 200 Status VariantTensorDataWriter::WriteScalar(StringPiece key, 201 const string& val) { 202 return WriteScalarInternal(key, val); 203 } 204 205 Status VariantTensorDataWriter::WriteTensor(StringPiece key, 206 const Tensor& val) { 207 return WriteTensorInternal(key, val); 208 } 209 210 Status VariantTensorDataWriter::Flush() { 211 string metadata; 212 for (size_t i = 0; i < keys_.size(); ++i) { 213 strings::StrAppend(&metadata, kDelimiter, keys_[i]); 214 } 215 data_->set_metadata(metadata); 216 return Status::OK(); 217 } 218 219 template <typename T> 220 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece key, 221 const T& val) { 222 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({})); 223 val_t.scalar<T>()() = val; 224 return WriteTensorInternal(key, val_t); 225 } 226 227 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece key, 228 const Tensor& val) { 229 DCHECK_EQ(key.find(kDelimiter), string::npos); 230 keys_.push_back(string(key)); 231 *(data_->add_tensors()) = val; 232 return Status::OK(); 233 } 234 235 Status AddToFunctionLibrary(FunctionLibraryDefinition* base, 236 const FunctionLibraryDefinition& to_add) { 237 for (const auto& fn : to_add.ListFunctionNames()) { 238 if (auto found = base->Find(fn)) { 239 if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) { 240 return errors::InvalidArgument("Cannot add function '", fn, 241 "' because a different function with " 242 "the same signature already exists."); 243 } 244 TF_RETURN_IF_ERROR(base->RemoveFunction(fn)); 245 } 246 } 247 return base->AddLibrary(to_add); 248 } 249 250 Status AddToFunctionLibrary(FunctionLibraryDefinition* base, 251 const FunctionDefLibrary& to_add) { 252 for (const auto& fd : to_add.function()) { 253 if (auto found = base->Find(fd.signature().name())) { 254 if (!OpDefEqual(found->signature(), fd.signature())) { 255 return errors::InvalidArgument("Cannot add function '", 256 fd.signature().name(), 257 "' because a different function with " 258 "the same signature already exists."); 259 } 260 TF_RETURN_IF_ERROR(base->RemoveFunction(fd.signature().name())); 261 } 262 } 263 return base->AddLibrary(to_add); 264 } 265 266 std::function<void(std::function<void()>)> RunnerWithMaxParallelism( 267 std::function<void(std::function<void()>)> runner, int max_parallelism) { 268 return std::bind( 269 [max_parallelism]( 270 // Note: `runner` is a const reference to avoid copying it. 271 const std::function<void(std::function<void()>)>& runner, 272 std::function<void()> fn) { 273 std::function<void()> scoped_fn = std::bind( 274 [max_parallelism](const std::function<void()>& fn) { 275 ScopedPerThreadMaxParallelism scope(max_parallelism); 276 fn(); 277 }, 278 std::move(fn)); 279 runner(std::move(scoped_fn)); 280 }, 281 std::move(runner), std::placeholders::_1); 282 } 283 } // namespace data 284 } // namespace tensorflow 285