1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 18 #include <unordered_map> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/core/framework/attr_value_util.h" 23 #include "tensorflow/core/framework/device_attributes.pb.h" 24 #include "tensorflow/core/framework/graph.pb_text.h" 25 #include "tensorflow/core/framework/kernel_def.pb_text.h" 26 #include "tensorflow/core/framework/log_memory.h" 27 #include "tensorflow/core/framework/memory_types.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/op_def_util.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/graph/graph.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/core/notification.h" 35 #include "tensorflow/core/lib/core/stringpiece.h" 36 #include "tensorflow/core/lib/gtl/map_util.h" 37 #include "tensorflow/core/lib/io/path.h" 38 #include "tensorflow/core/lib/strings/str_util.h" 39 #include "tensorflow/core/lib/strings/strcat.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/mutex.h" 42 #include "tensorflow/core/platform/types.h" 43 44 namespace tensorflow { 45 46 namespace { 47 48 Status MatchSignatureHelper(const DataTypeSlice expected_inputs, 49 const DataTypeSlice expected_outputs, 50 const DataTypeSlice inputs, 51 const DataTypeSlice outputs) { 52 bool signature_mismatch = false; 53 54 if (inputs.size() != expected_inputs.size()) signature_mismatch = true; 55 for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { 56 if (!TypesCompatible(expected_inputs[i], inputs[i])) { 57 signature_mismatch = true; 58 } 59 } 60 61 if (outputs.size() != expected_outputs.size()) signature_mismatch = true; 62 for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { 63 if (!TypesCompatible(expected_outputs[i], outputs[i])) { 64 signature_mismatch = true; 65 } 66 } 67 68 if (signature_mismatch) { 69 return errors::InvalidArgument( 70 "Signature mismatch, have: ", DataTypeSliceString(inputs), "->", 71 DataTypeSliceString(outputs), 72 " expected: ", DataTypeSliceString(expected_inputs), "->", 73 DataTypeSliceString(expected_outputs)); 74 } 75 return Status::OK(); 76 } 77 78 } // namespace 79 80 // OpKernel ------------------------------------------------------------------ 81 82 // TODO(mrry): Convert to std::make_unique when available. 83 OpKernel::OpKernel(OpKernelConstruction* context) 84 : OpKernel(context, 85 std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {} 86 87 OpKernel::OpKernel(OpKernelConstruction* context, 88 std::unique_ptr<const NodeDef> node_def) 89 : def_(std::move(node_def)), 90 input_types_(context->input_types().begin(), 91 context->input_types().end()), 92 input_memory_types_(context->input_memory_types().begin(), 93 context->input_memory_types().end()), 94 output_types_(context->output_types().begin(), 95 context->output_types().end()), 96 output_memory_types_(context->output_memory_types().begin(), 97 context->output_memory_types().end()), 98 graph_def_version_(context->graph_def_version()), 99 is_internal_(StringPiece(type_string()).starts_with("_")), 100 input_name_map_(context->num_inputs()), 101 output_name_map_(context->num_outputs()) { 102 OP_REQUIRES_OK(context, 103 NameRangesForNode(*def_, *context->op_def_, &input_name_map_, 104 &output_name_map_)); 105 OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_, 106 context->graph_def_version())); 107 108 // Kernels executing on GPU/SYCL tie very few resources on the CPU where the 109 // scheduler runs: we consider them as inexpensive. 110 expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && 111 context->device_type() != DeviceType(DEVICE_SYCL); 112 } 113 114 OpKernel::~OpKernel() {} 115 116 const string& OpKernel::name() const { return def_->name(); } 117 const string& OpKernel::type_string() const { return def_->op(); } 118 const string& OpKernel::requested_device() const { return def_->device(); } 119 const string& OpKernel::requested_input(int i) const { return def_->input(i); } 120 121 Status OpKernel::InputRange(StringPiece input_name, int* start, 122 int* stop) const { 123 const auto result = input_name_map_.find(input_name); 124 if (result == input_name_map_.end()) { 125 return errors::InvalidArgument("Unknown input name: ", input_name); 126 } else { 127 *start = result->second.first; 128 *stop = result->second.second; 129 return Status::OK(); 130 } 131 } 132 133 Status OpKernel::OutputRange(StringPiece output_name, int* start, 134 int* stop) const { 135 const auto result = output_name_map_.find(output_name); 136 if (result == output_name_map_.end()) { 137 return errors::InvalidArgument("Unknown output name: ", output_name); 138 } else { 139 *start = result->second.first; 140 *stop = result->second.second; 141 return Status::OK(); 142 } 143 } 144 145 Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const { 146 if (!IsLegacyVector(shape.shape())) { 147 return errors::InvalidArgument( 148 "shape must be a vector of {int32,int64}, got shape ", 149 shape.shape().DebugString()); 150 } 151 if (shape.dtype() == DataType::DT_INT32) { 152 auto vec = shape.flat<int32>(); 153 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); 154 } else if (shape.dtype() == DataType::DT_INT64) { 155 auto vec = shape.flat<int64>(); 156 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); 157 } else { 158 return errors::InvalidArgument("shape must be a vector of {int32,int64}."); 159 } 160 } 161 162 void AsyncOpKernel::Compute(OpKernelContext* context) { 163 Notification n; 164 ComputeAsync(context, [&n]() { n.Notify(); }); 165 n.WaitForNotification(); 166 } 167 168 // PersistentTensor ---------------------------------------------------------- 169 170 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { 171 // the caller has to have a valid context 172 CHECK(context); 173 return &tensor_; 174 } 175 176 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { 177 context->NotifyUseOfPersistentTensor(tensor_); 178 return &tensor_; 179 } 180 181 // OpKernelConstruction ------------------------------------------------------ 182 183 OpKernelConstruction::OpKernelConstruction( 184 DeviceType device_type, DeviceBase* device, Allocator* allocator, 185 const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib, 186 const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types, 187 const DataTypeSlice& output_types, 188 const MemoryTypeSlice& output_memory_types, int graph_def_version, 189 Status* status) 190 : device_type_(std::move(device_type)), 191 device_(device), 192 allocator_(allocator), 193 def_(node_def), 194 op_def_(op_def), 195 flib_(flib), 196 input_types_(input_types), 197 input_memory_types_(input_memory_types), 198 output_types_(output_types), 199 output_memory_types_(output_memory_types), 200 graph_def_version_(graph_def_version), 201 status_(status) {} 202 203 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { 204 return HasNodeAttr(def(), attr_name); 205 } 206 207 void OpKernelConstruction::SetStatus(const Status& status) { 208 status_->Update(status); 209 } 210 211 Status OpKernelConstruction::MatchSignature( 212 const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { 213 return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, 214 output_types_); 215 } 216 217 Status OpKernelConstruction::allocate_temp(DataType type, 218 const TensorShape& shape, 219 Tensor* out_temp) { 220 AllocationAttributes attr; 221 attr.allocation_will_be_logged = true; 222 Tensor new_temp(allocator_, type, shape, attr); 223 224 if (!new_temp.IsInitialized()) { 225 return errors::ResourceExhausted( 226 "OOM when allocating temporary tensor with shape", shape.DebugString()); 227 } 228 if (LogMemory::IsEnabled()) { 229 LogMemory::RecordTensorAllocation( 230 def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); 231 } 232 *out_temp = new_temp; 233 return Status::OK(); 234 } 235 236 Status OpKernelConstruction::allocate_persistent( 237 DataType type, const TensorShape& shape, PersistentTensor* out_persistent, 238 Tensor** out_tensor) { 239 // for now just do the same thing as allocate_temp 240 // TODO(misard) add specific memory tracking for persistent tensors 241 Tensor persistent; 242 Status s = allocate_temp(type, shape, &persistent); 243 if (!s.ok()) { 244 return s; 245 } 246 *out_persistent = PersistentTensor(persistent); 247 Tensor* allocated = out_persistent->AccessTensor(this); 248 if (out_tensor) { 249 *out_tensor = allocated; 250 } 251 return s; 252 } 253 254 // OpKernelContext ----------------------------------------------------------- 255 256 OpKernelContext::OpKernelContext(Params* params) 257 : OpKernelContext( 258 params, static_cast<int>(params->op_kernel->output_types().size())) {} 259 260 OpKernelContext::OpKernelContext(Params* params, int num_outputs) 261 : params_(params), 262 outputs_(num_outputs), 263 temp_memory_allocated_(0), 264 persistent_memory_allocated_(0) { 265 Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); 266 params_->ensure_eigen_gpu_device(); 267 params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, 268 params_->op_device_context, 269 eigen_gpu_allocator); 270 if (params_->record_tensor_accesses) { 271 referenced_tensors_.Init(); 272 } 273 } 274 275 OpKernelContext::~OpKernelContext() { 276 for (TensorValue& value : outputs_) { 277 if (!value.is_ref()) { 278 delete value.tensor; 279 } 280 } 281 if (params_->record_tensor_accesses) referenced_tensors_.Destroy(); 282 } 283 284 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) { 285 Allocator* allocator = 286 params_->device->GetStepAllocator(attr, resource_manager()); 287 if (track_allocations()) { 288 mutex_lock lock(mu_); 289 for (const auto& wrapped : wrapped_allocators_) { 290 if (wrapped.first == allocator) { 291 return wrapped.second; 292 } 293 } 294 TrackingAllocator* wrapped_allocator = 295 new TrackingAllocator(allocator, params_->track_allocations); 296 wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator)); 297 return wrapped_allocator; 298 } else { 299 return allocator; 300 } 301 } 302 303 void OpKernelContext::SetStatus(const Status& status) { 304 status_.Update(status); 305 } 306 307 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { 308 mutex_lock l(mu_); 309 // Keep a reference to the underlying memory around. 310 referenced_tensors_->Add(tensor); 311 } 312 313 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { 314 int start, stop; 315 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 316 if (stop != start + 1) { 317 return errors::InvalidArgument("OpKernel used list-valued input name '", 318 name, 319 "' when single-valued input was " 320 "expected"); 321 } 322 if (input_is_ref(start)) { 323 return errors::InvalidArgument("OpKernel used ref input name '", name, 324 "' when non-ref input was expected"); 325 } 326 *tensor = (*params_->inputs)[start].tensor; 327 record_tensor_reference(**tensor); 328 return Status::OK(); 329 } 330 331 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { 332 int start, stop; 333 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 334 if (stop != start + 1) { 335 return errors::InvalidArgument("OpKernel used list-valued input name '", 336 name, 337 "' when single-valued input was " 338 "expected"); 339 } 340 const TensorValue& value((*params_->inputs)[start]); 341 if (value.is_ref()) { 342 *dtype = MakeRefType(value->dtype()); 343 } else { 344 *dtype = value->dtype(); 345 } 346 return Status::OK(); 347 } 348 349 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { 350 int start, stop; 351 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 352 if (stop != start + 1) { 353 return errors::InvalidArgument("OpKernel used list-valued input name '", 354 name, 355 "' when single-valued input was expected"); 356 } 357 *out_mutex = input_ref_mutex(start); 358 return Status::OK(); 359 } 360 361 const Tensor& OpKernelContext::input(int index) { 362 DCHECK_GE(index, 0); 363 DCHECK_LT(index, num_inputs()); 364 DCHECK(!input_is_ref(index)); 365 const Tensor& tensor = *((*params_->inputs)[index].tensor); 366 record_tensor_reference(tensor); 367 return tensor; 368 } 369 370 Tensor OpKernelContext::mutable_input(int index, bool lock_held) { 371 DCHECK_GE(index, 0); 372 DCHECK_LT(index, num_inputs()); 373 DCHECK(input_is_ref(index)); 374 // return a copy of the Ref acquired while holding the mutex 375 if (lock_held) { 376 Tensor& tensor = *((*params_->inputs)[index].tensor); 377 record_tensor_reference(tensor); 378 return tensor; 379 } else { 380 mutex_lock l(*input_ref_mutex(index)); 381 Tensor& tensor = *((*params_->inputs)[index].tensor); 382 record_tensor_reference(tensor); 383 return tensor; 384 } 385 } 386 387 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, 388 bool lock_held) { 389 DCHECK_GE(index, 0); 390 DCHECK_LT(index, num_inputs()); 391 DCHECK(input_is_ref(index)); 392 // should only modify the tensor while holding the mutex 393 if (lock_held) { 394 *(*params_->inputs)[index].tensor = tensor; 395 } else { 396 mutex_lock l(*input_ref_mutex(index)); 397 *(*params_->inputs)[index].tensor = tensor; 398 } 399 record_tensor_reference(tensor); 400 } 401 402 void OpKernelContext::forward_ref_input_to_ref_output(int input_index, 403 int output_index) { 404 DCHECK_GE(input_index, 0); 405 DCHECK_LT(input_index, num_inputs()); 406 DCHECK(input_is_ref(input_index)); 407 set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref, 408 (*params_->inputs)[input_index].tensor); 409 } 410 411 bool OpKernelContext::forward_input_to_output_with_shape( 412 int input_index, int output_index, const TensorShape& output_shape, 413 Tensor** output) { 414 const auto output_attr = params_->output_attr_array == nullptr 415 ? AllocatorAttributes() 416 : output_alloc_attr(output_index); 417 std::unique_ptr<Tensor> new_tensor = forward_input( 418 input_index, expected_output_dtype(output_index), output_shape, 419 output_memory_type(output_index), output_attr); 420 if (new_tensor != nullptr) { 421 // Transfer ownership to the output slot in OpKernelContext. 422 outputs_[output_index] = TensorValue(new_tensor.release()); 423 *output = outputs_[output_index].tensor; 424 return true; 425 } else { 426 return false; 427 } 428 } 429 430 Status OpKernelContext::forward_input_to_output_with_shape( 431 StringPiece input_name, StringPiece output_name, 432 const TensorShape& output_shape, Tensor** output) { 433 int input_index, output_index, stop; 434 TF_RETURN_IF_ERROR( 435 params_->op_kernel->InputRange(input_name, &input_index, &stop)); 436 if (stop != input_index + 1) { 437 return errors::InvalidArgument("OpKernel used list-valued input name '", 438 input_name, 439 "' when single-valued input was " 440 "expected"); 441 } 442 TF_RETURN_IF_ERROR( 443 params_->op_kernel->OutputRange(output_name, &output_index, &stop)); 444 if (stop != output_index + 1) { 445 return errors::InvalidArgument("OpKernel used list-valued output name '", 446 output_name, 447 "' when single-valued output was " 448 "expected"); 449 } 450 if (!forward_input_to_output_with_shape(input_index, output_index, 451 output_shape, output)) { 452 return errors::FailedPrecondition("OpKernel could not forward input '", 453 input_name, "' to output '", output_name); 454 } 455 return Status::OK(); 456 } 457 458 std::unique_ptr<Tensor> OpKernelContext::forward_input( 459 int input_index, DataType output_dtype, const TensorShape& output_shape, 460 MemoryType output_memory_type, const AllocatorAttributes& output_attr) { 461 DCHECK_GE(input_index, 0); 462 DCHECK_LT(input_index, num_inputs()); 463 const TensorValue& input = (*params_->inputs)[input_index]; 464 // Check that input tensor exists, is not a ref, and has no other consumers. 465 if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) { 466 return nullptr; 467 } 468 // Check that input type matches. 469 if (input_dtype(input_index) != output_dtype) { 470 return nullptr; 471 } 472 // Check that the input and output sizes are compatible. 473 if (input.tensor->shape().num_elements() != output_shape.num_elements()) { 474 return nullptr; 475 } 476 // Check that input and output memory types match, i.e. 477 // that they either both live in host or both live in device memory. 478 if (input_memory_type(input_index) != output_memory_type) { 479 return nullptr; 480 } 481 // Check that output allocator attributes are not more restrictive than 482 // input allocator attributes. 483 const auto input_attr = params_->input_alloc_attrs == nullptr 484 ? AllocatorAttributes() 485 : input_alloc_attr(input_index); 486 if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { 487 return nullptr; 488 } 489 // TODO(rmlarsen): Use MakeUnique here. There is already a copy in 490 // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of 491 // general cleanup of ownership in this code. 492 std::unique_ptr<Tensor> output_tensor(new Tensor()); 493 CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); 494 return output_tensor; 495 } 496 497 Status OpKernelContext::forward_input_or_allocate_temp( 498 gtl::ArraySlice<int> candidate_input_indices, DataType type, 499 const TensorShape& shape, const AllocatorAttributes& allocator_attr, 500 Tensor* out_temp) { 501 for (int input_index : candidate_input_indices) { 502 std::unique_ptr<Tensor> new_tensor = 503 forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr); 504 if (new_tensor != nullptr) { 505 *out_temp = std::move(*new_tensor); 506 return Status::OK(); 507 } 508 } 509 return allocate_temp(type, shape, out_temp, allocator_attr); 510 } 511 512 void OpKernelContext::delete_ref_input(int index, bool lock_held) { 513 DCHECK_GE(index, 0); 514 DCHECK_LT(index, num_inputs()); 515 DCHECK(input_is_ref(index)); 516 // should only modify the tensor while holding the mutex 517 if (lock_held) { 518 delete (*params_->inputs)[index].tensor; 519 } else { 520 mutex_lock l(*input_ref_mutex(index)); 521 delete (*params_->inputs)[index].tensor; 522 } 523 } 524 525 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, 526 bool lock_held) { 527 int start, stop; 528 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 529 if (stop != start + 1) { 530 return errors::InvalidArgument("OpKernel used list-valued input name '", 531 name, 532 "' when single-valued input was expected"); 533 } 534 if (!input_is_ref(start)) { 535 return errors::InvalidArgument("OpKernel used non-ref input name '", name, 536 "' when ref input was expected"); 537 } 538 // return a copy of the Ref acquired while holding the mutex 539 if (lock_held) { 540 *tensor = *(*params_->inputs)[start].tensor; 541 } else { 542 mutex_lock l(*input_ref_mutex(start)); 543 *tensor = *(*params_->inputs)[start].tensor; 544 } 545 record_tensor_reference(*tensor); 546 return Status::OK(); 547 } 548 549 Status OpKernelContext::replace_ref_input(StringPiece name, 550 const Tensor& tensor, 551 bool lock_held) { 552 int start, stop; 553 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 554 if (stop != start + 1) { 555 return errors::InvalidArgument("OpKernel used list-valued input name '", 556 name, 557 "' when single-valued input was expected"); 558 } 559 if (!input_is_ref(start)) { 560 return errors::InvalidArgument("OpKernel used immutable input name '", name, 561 "' when ref input was expected"); 562 } 563 replace_ref_input(start, tensor, lock_held); 564 return Status::OK(); 565 } 566 567 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { 568 int start, stop; 569 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 570 *list = OpInputList(this, start, stop); 571 return Status::OK(); 572 } 573 574 Status OpKernelContext::mutable_input_list(StringPiece name, 575 OpMutableInputList* list) { 576 int start, stop; 577 TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); 578 *list = OpMutableInputList(this, start, stop); 579 return Status::OK(); 580 } 581 582 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { 583 int start, stop; 584 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 585 *list = OpOutputList(this, start, stop); 586 return Status::OK(); 587 } 588 589 Status OpKernelContext::allocate_output(int index, const TensorShape& shape, 590 Tensor** output) { 591 DCHECK_GE(index, 0); 592 DCHECK_LT(index, num_outputs()); 593 AllocatorAttributes attr = output_alloc_attr(index); 594 return allocate_output(index, shape, output, attr); 595 } 596 597 Status OpKernelContext::allocate_output(StringPiece name, 598 const TensorShape& shape, 599 Tensor** tensor) { 600 int start, stop; 601 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 602 if (stop != start + 1) { 603 return errors::InvalidArgument("OpKernel used list-valued output name '", 604 name, 605 "' when single-valued output was " 606 "expected"); 607 } 608 return allocate_output(start, shape, tensor); 609 } 610 611 Status OpKernelContext::allocate_output(StringPiece name, 612 const TensorShape& shape, 613 Tensor** tensor, 614 AllocatorAttributes attr) { 615 int start, stop; 616 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 617 if (stop != start + 1) { 618 return errors::InvalidArgument("OpKernel used list-valued output name '", 619 name, 620 "' when single-valued output was " 621 "expected"); 622 } 623 return allocate_output(start, shape, tensor, attr); 624 } 625 626 Status OpKernelContext::allocate_tensor( 627 DataType type, const TensorShape& shape, Tensor* out_tensor, 628 AllocatorAttributes attr, const AllocationAttributes& allocation_attr) { 629 Allocator* a = get_allocator(attr); 630 AllocationAttributes logged_attr(allocation_attr); 631 logged_attr.allocation_will_be_logged = true; 632 Tensor new_tensor(a, type, shape, logged_attr); 633 634 if (!new_tensor.IsInitialized()) { 635 return errors::ResourceExhausted( 636 "OOM when allocating tensor with shape", shape.DebugString(), 637 " and type ", DataTypeString(type), " on ", params_->device->name(), 638 " by allocator ", a->Name()); 639 } 640 if (params_->log_memory) { 641 LogMemory::RecordTensorAllocation(params_->op_kernel->name(), 642 params_->step_id, new_tensor); 643 } 644 record_tensor_reference(new_tensor); 645 *out_tensor = std::move(new_tensor); 646 return Status::OK(); 647 } 648 649 Status OpKernelContext::allocate_output(int index, const TensorShape& shape, 650 Tensor** output, 651 AllocatorAttributes attr) { 652 DCHECK_GE(index, 0); 653 DCHECK_LT(index, outputs_.size()); 654 const DataType type = params_->op_kernel->output_type(index); 655 DCHECK(!IsRefType(type)); 656 DCHECK(mutable_output(index) == nullptr); 657 Tensor* output_tensor = new Tensor(); 658 Status s = allocate_tensor(type, shape, output_tensor, attr); 659 if (s.ok()) { 660 outputs_[index] = TensorValue(output_tensor); 661 *output = outputs_[index].tensor; 662 } 663 return s; 664 } 665 666 Status OpKernelContext::allocate_temp( 667 DataType type, const TensorShape& shape, Tensor* out_temp, 668 AllocatorAttributes allocator_attr, 669 const AllocationAttributes& allocation_attr) { 670 Status s = 671 allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr); 672 if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) { 673 Allocator* a = get_allocator(allocator_attr); 674 if (a->TracksAllocationSizes()) { 675 int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data()); 676 record_temp_memory_allocation(alloc_size, *out_temp); 677 } 678 } 679 return s; 680 } 681 682 Status OpKernelContext::allocate_persistent(DataType type, 683 const TensorShape& shape, 684 PersistentTensor* out_persistent, 685 Tensor** out_tensor, 686 AllocatorAttributes attr) { 687 Tensor persistent; 688 Status s = allocate_tensor(type, shape, &persistent, attr); 689 if (s.ok()) { 690 *out_persistent = PersistentTensor(persistent); 691 if (out_tensor) { 692 *out_tensor = out_persistent->AccessTensor(this); 693 } 694 if (track_allocations()) { 695 Tensor* t = out_persistent->AccessTensor(this); 696 Allocator* a = get_allocator(attr); 697 if (a->TracksAllocationSizes()) { 698 int64 alloc_size = a->AllocatedSize(t->tensor_data().data()); 699 int64 alloc_id = a->AllocationId(t->tensor_data().data()); 700 record_persistent_memory_allocation(alloc_size, alloc_id); 701 } 702 } 703 } 704 return s; 705 } 706 707 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { 708 int start, stop; 709 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 710 if (stop != start + 1) { 711 return errors::InvalidArgument("OpKernel used list-valued output name '", 712 name, 713 "' when single-valued output was " 714 "expected"); 715 } 716 set_output(start, tensor); 717 return Status::OK(); 718 } 719 720 void OpKernelContext::set_output(int index, const Tensor& tensor) { 721 DCHECK_GE(index, 0); 722 DCHECK_LT(index, outputs_.size()); 723 DCHECK(!IsRefType(params_->op_kernel->output_type(index))); 724 DCHECK_EQ(mutable_output(index), nullptr); 725 record_tensor_reference(tensor); 726 outputs_[index] = TensorValue(new Tensor(tensor)); 727 if (track_allocations() && tensor.TotalBytes() > 0) { 728 mutex_lock l(stats_mu_); 729 if (!temp_tensor_buffer_and_size_) { 730 return; 731 } 732 auto it = std::find_if(temp_tensor_buffer_and_size_->begin(), 733 temp_tensor_buffer_and_size_->end(), 734 [&tensor](const std::pair<const void*, int64>& e) { 735 return e.first == static_cast<const void*>( 736 tensor.tensor_data().data()); 737 }); 738 if (it != temp_tensor_buffer_and_size_->end()) { 739 temp_memory_allocated_ -= it->second; 740 temp_tensor_buffer_and_size_->erase(it); 741 } 742 } 743 } 744 745 void OpKernelContext::set_output_ref(int index, mutex* mu, 746 Tensor* tensor_for_ref) { 747 DCHECK_GE(index, 0); 748 DCHECK_LT(index, outputs_.size()); 749 DCHECK(IsRefType(params_->op_kernel->output_type(index))); 750 record_tensor_reference(*tensor_for_ref); 751 outputs_[index] = TensorValue(mu, tensor_for_ref); 752 } 753 754 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, 755 Tensor* tensor_for_ref) { 756 int start, stop; 757 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 758 if (stop != start + 1) { 759 return errors::InvalidArgument("OpKernel used list-valued output name '", 760 name, 761 "' when single-valued output was " 762 "expected"); 763 } 764 set_output_ref(start, mu, tensor_for_ref); 765 return Status::OK(); 766 } 767 768 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { 769 int start, stop; 770 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 771 if (stop != start + 1) { 772 return errors::InvalidArgument("OpKernel used list-valued output name '", 773 name, 774 "' when single-valued output was " 775 "expected"); 776 } 777 *tensor = mutable_output(start); 778 return Status::OK(); 779 } 780 781 Status OpKernelContext::release_output(StringPiece name, TensorValue* value) { 782 int start, stop; 783 TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); 784 if (stop != start + 1) { 785 return errors::InvalidArgument("OpKernel used list-valued output name '", 786 name, 787 "' when single-valued output was " 788 "expected"); 789 } 790 *value = release_output(start); 791 return Status::OK(); 792 } 793 794 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { 795 const auto& inputs = *params_->inputs; 796 for (size_t i = 1; i < inputs.size(); ++i) { 797 if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { 798 SetStatus(errors::InvalidArgument( 799 "Inputs to operation ", op->name(), " of type ", op->type_string(), 800 " must have the same size and shape. Input 0: ", 801 inputs[0]->shape().DebugString(), " != input ", i, ": ", 802 inputs[i]->shape().DebugString())); 803 return false; 804 } 805 } 806 return true; 807 } 808 809 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, 810 const DataTypeSlice expected_outputs) { 811 DataTypeVector inputs; 812 for (const TensorValue& t : *params_->inputs) { 813 inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); 814 } 815 DataTypeVector outputs = params_->op_kernel->output_types(); 816 return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, 817 outputs); 818 } 819 820 void OpKernelContext::record_temp_memory_allocation(int64 size, 821 const Tensor& t) { 822 mutex_lock l(stats_mu_); 823 temp_memory_allocated_ += size; 824 if (!temp_tensor_buffer_and_size_) { 825 temp_tensor_buffer_and_size_.reset( 826 new gtl::InlinedVector<std::pair<const void*, int64>, 2>()); 827 } 828 temp_tensor_buffer_and_size_->emplace_back( 829 static_cast<const void*>(t.tensor_data().data()), size); 830 } 831 832 int64 OpKernelContext::temp_memory_allocated() const { 833 mutex_lock l(stats_mu_); 834 return temp_memory_allocated_; 835 } 836 837 void OpKernelContext::record_persistent_memory_allocation(int64 size, 838 int64 alloc_id) { 839 mutex_lock l(stats_mu_); 840 persistent_memory_allocated_ += size; 841 if (alloc_id >= 0) { 842 if (!persistent_alloc_ids_) { 843 persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>()); 844 } 845 persistent_alloc_ids_->push_back(alloc_id); 846 } 847 } 848 849 int64 OpKernelContext::persistent_memory_allocated() const { 850 mutex_lock l(stats_mu_); 851 return persistent_memory_allocated_; 852 } 853 854 std::vector<int64> OpKernelContext::persistent_alloc_ids() const { 855 mutex_lock l(stats_mu_); 856 if (persistent_alloc_ids_) { 857 return std::vector<int64>(persistent_alloc_ids_->begin(), 858 persistent_alloc_ids_->end()); 859 } else { 860 return std::vector<int64>(); 861 } 862 } 863 864 void OpKernelContext::clear_recorded_memory() { 865 mutex_lock l(stats_mu_); 866 temp_memory_allocated_ = 0; 867 persistent_memory_allocated_ = 0; 868 if (temp_tensor_buffer_and_size_) { 869 temp_tensor_buffer_and_size_->clear(); 870 } 871 if (persistent_alloc_ids_) { 872 persistent_alloc_ids_->clear(); 873 } 874 } 875 876 // OpKernel registration ------------------------------------------------------ 877 878 struct KernelRegistration { 879 KernelRegistration(const KernelDef& d, StringPiece c, 880 kernel_factory::OpKernelRegistrar::Factory f) 881 : def(d), kernel_class_name(c.ToString()), factory(f) {} 882 const KernelDef def; 883 const string kernel_class_name; 884 const kernel_factory::OpKernelRegistrar::Factory factory; 885 }; 886 887 // This maps from 'op_type' + DeviceType to the set of KernelDefs and 888 // factory functions for instantiating the OpKernel that matches the 889 // KernelDef. 890 typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry; 891 892 void* GlobalKernelRegistry() { 893 static KernelRegistry* global_kernel_registry = new KernelRegistry; 894 return global_kernel_registry; 895 } 896 897 static KernelRegistry* GlobalKernelRegistryTyped() { 898 return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); 899 } 900 901 static string Key(StringPiece op_type, const DeviceType& device_type, 902 StringPiece label) { 903 return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", 904 label); 905 } 906 907 namespace kernel_factory { 908 909 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, 910 StringPiece kernel_class_name, 911 Factory factory) { 912 // See comments in register_kernel::Name in header for info on _no_register. 913 if (kernel_def->op() != "_no_register") { 914 const string key = 915 Key(kernel_def->op(), DeviceType(kernel_def->device_type()), 916 kernel_def->label()); 917 GlobalKernelRegistryTyped()->insert(std::make_pair( 918 key, KernelRegistration(*kernel_def, kernel_class_name, factory))); 919 } 920 delete kernel_def; 921 } 922 923 } // namespace kernel_factory 924 925 namespace { 926 927 // Helper for AttrsMatch(). 928 bool InTypeList(DataType dt, const AttrValue& type_list) { 929 for (int in_list : type_list.list().type()) { 930 if (dt == in_list) return true; 931 } 932 return false; 933 } 934 935 // Returns whether the attrs satisfy the constraints in the kernel_def. Returns 936 // an error if attrs in kernel_def are not found, or have a mismatching type. 937 Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) { 938 *match = false; 939 for (const auto& constraint : kernel_def.constraint()) { 940 if (constraint.allowed_values().list().type_size() == 0) { 941 return errors::Unimplemented( 942 "KernelDef '", ProtoShortDebugString(kernel_def), 943 " has constraint on attr '", constraint.name(), 944 "' with unsupported type: ", 945 SummarizeAttrValue(constraint.allowed_values())); 946 } 947 948 const AttrValue* found = attrs.Find(constraint.name()); 949 if (found) { 950 if (found->type() != DT_INVALID) { 951 if (!InTypeList(found->type(), constraint.allowed_values())) { 952 return Status::OK(); 953 } 954 } else { 955 if (!AttrValueHasType(*found, "list(type)").ok()) { 956 return errors::InvalidArgument( 957 "KernelDef '", ProtoShortDebugString(kernel_def), 958 "' has constraint on attr '", constraint.name(), 959 "' that has value '", SummarizeAttrValue(*found), 960 "' that does not have type 'type' or 'list(type)' in NodeDef " 961 "'", 962 attrs.SummarizeNode(), "'"); 963 } 964 965 for (int t : found->list().type()) { 966 if (!InTypeList(static_cast<DataType>(t), 967 constraint.allowed_values())) { 968 return Status::OK(); 969 } 970 } 971 } 972 } else { 973 return errors::InvalidArgument( 974 "OpKernel '", kernel_def.op(), "' has constraint on attr '", 975 constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(), 976 "', KernelDef: '", ProtoShortDebugString(kernel_def), "'"); 977 } 978 } 979 *match = true; 980 return Status::OK(); 981 } 982 983 static const StringPiece kKernelAttr("_kernel"); 984 985 // TODO(irving): Replace with const Node& version below. 986 Status FindKernelRegistration(const DeviceType& device_type, 987 const NodeDef& node_def, 988 const KernelRegistration** reg, 989 bool* was_attr_mismatch) { 990 *reg = nullptr; 991 *was_attr_mismatch = false; 992 // Label defaults to empty if not found in NodeDef. 993 const string& label = GetNodeAttrString(node_def, kKernelAttr); 994 995 const string key = Key(node_def.op(), device_type, label); 996 auto regs = GlobalKernelRegistryTyped()->equal_range(key); 997 for (auto iter = regs.first; iter != regs.second; ++iter) { 998 // If there is a kernel registered for the op and device_type, 999 // check that the attrs match. 1000 bool match; 1001 TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); 1002 if (match) { 1003 if (*reg != nullptr) { 1004 return errors::InvalidArgument( 1005 "Multiple OpKernel registrations match NodeDef '", 1006 SummarizeNodeDef(node_def), "': '", 1007 ProtoShortDebugString((*reg)->def), "' and '", 1008 ProtoShortDebugString(iter->second.def), "'"); 1009 } 1010 *reg = &iter->second; 1011 } else { 1012 *was_attr_mismatch = true; 1013 } 1014 } 1015 return Status::OK(); 1016 } 1017 1018 } // namespace 1019 1020 // TODO(irving): Change const NodeDef& to const Node& 1021 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, 1022 const KernelDef** def, string* kernel_class_name) { 1023 const KernelRegistration* reg = nullptr; 1024 bool was_attr_mismatch; 1025 TF_RETURN_IF_ERROR( 1026 FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch)); 1027 if (reg == nullptr) { 1028 Status s = errors::NotFound( 1029 "No registered '", node_def.op(), "' OpKernel for ", 1030 DeviceTypeString(device_type), " devices compatible with node ", 1031 SummarizeNodeDef(node_def)); 1032 if (was_attr_mismatch) { 1033 errors::AppendToMessage( 1034 &s, " (OpKernel was found, but attributes didn't match)"); 1035 } 1036 errors::AppendToMessage( 1037 &s, ". Registered:", KernelsRegisteredForOp(node_def.op())); 1038 return s; 1039 } 1040 if (def != nullptr) *def = ®->def; 1041 if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; 1042 return Status::OK(); 1043 } 1044 1045 Status SupportedDeviceTypesForNode( 1046 const std::vector<DeviceType>& prioritized_types, const NodeDef& def, 1047 DeviceTypeVector* device_types) { 1048 // TODO(zhifengc): Changes the callers (SimplePlacer and 1049 // DynamicPlacer) to consider the possibility that 'def' is call to 1050 // a user-defined function and only calls this 1051 // SupportedDeviceTypesForNode for primitive ops. 1052 const OpRegistrationData* op_reg_data; 1053 const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data); 1054 if (s.ok()) { 1055 for (const DeviceType& device_type : prioritized_types) { 1056 const KernelRegistration* reg = nullptr; 1057 bool was_attr_mismatch; 1058 TF_RETURN_IF_ERROR( 1059 FindKernelRegistration(device_type, def, ®, &was_attr_mismatch)); 1060 if (reg != nullptr) device_types->push_back(device_type); 1061 } 1062 } else { 1063 // Assumes that all device types support this node. 1064 for (const DeviceType& device_type : prioritized_types) { 1065 device_types->push_back(device_type); 1066 } 1067 } 1068 return Status::OK(); 1069 } 1070 1071 void LogAllRegisteredKernels() { 1072 for (const auto& key_registration : *GlobalKernelRegistryTyped()) { 1073 const KernelDef& kernel_def(key_registration.second.def); 1074 LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')"; 1075 } 1076 } 1077 1078 string KernelsRegisteredForOp(StringPiece op_name) { 1079 string ret; 1080 for (const auto& key_registration : *GlobalKernelRegistryTyped()) { 1081 const KernelDef& kernel_def(key_registration.second.def); 1082 if (kernel_def.op() == op_name) { 1083 strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'"); 1084 if (!kernel_def.label().empty()) { 1085 strings::StrAppend(&ret, "; label='", kernel_def.label(), "'"); 1086 } 1087 for (int i = 0; i < kernel_def.constraint_size(); ++i) { 1088 strings::StrAppend( 1089 &ret, "; ", kernel_def.constraint(i).name(), " in ", 1090 SummarizeAttrValue(kernel_def.constraint(i).allowed_values())); 1091 } 1092 strings::StrAppend(&ret, "\n"); 1093 } 1094 } 1095 if (ret.empty()) return " <no registered kernels>\n"; 1096 return ret; 1097 } 1098 1099 std::unique_ptr<OpKernel> CreateOpKernel( 1100 DeviceType device_type, DeviceBase* device, Allocator* allocator, 1101 const NodeDef& node_def, int graph_def_version, Status* status) { 1102 OpKernel* kernel = nullptr; 1103 *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, 1104 node_def, graph_def_version, &kernel); 1105 return std::unique_ptr<OpKernel>(kernel); 1106 } 1107 1108 Status CreateOpKernel(DeviceType device_type, DeviceBase* device, 1109 Allocator* allocator, FunctionLibraryRuntime* flib, 1110 const NodeDef& node_def, int graph_def_version, 1111 OpKernel** kernel) { 1112 VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); 1113 1114 // Look up the Op registered for this op name. 1115 const OpDef* op_def = nullptr; 1116 Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def); 1117 if (!s.ok()) return s; 1118 1119 // Validate node_def against OpDef. 1120 s = ValidateNodeDef(node_def, *op_def); 1121 if (!s.ok()) return s; 1122 1123 // Look up kernel registration. 1124 const KernelRegistration* registration; 1125 bool was_attr_mismatch; 1126 s = FindKernelRegistration(device_type, node_def, ®istration, 1127 &was_attr_mismatch); 1128 if (!s.ok()) { 1129 errors::AppendToMessage(&s, " when instantiating ", node_def.op()); 1130 return s; 1131 } 1132 if (registration == nullptr) { 1133 s.Update(errors::NotFound("No registered '", node_def.op(), 1134 "' OpKernel for ", DeviceTypeString(device_type), 1135 " devices compatible with node ", 1136 SummarizeNodeDef(node_def))); 1137 if (was_attr_mismatch) { 1138 errors::AppendToMessage( 1139 &s, " (OpKernel was found, but attributes didn't match)"); 1140 } 1141 errors::AppendToMessage( 1142 &s, ". Registered:", KernelsRegisteredForOp(node_def.op())); 1143 return s; 1144 } 1145 1146 // Get signature from the OpDef & NodeDef 1147 DataTypeVector inputs; 1148 DataTypeVector outputs; 1149 s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); 1150 if (!s.ok()) { 1151 errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def)); 1152 return s; 1153 } 1154 1155 // We are creating a kernel for an op registered in 1156 // OpRegistry::Global(), we consult the kernel registry to decide 1157 // the kernel's input and output memory types. 1158 MemoryTypeVector input_memory_types; 1159 MemoryTypeVector output_memory_types; 1160 TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type, 1161 node_def, &input_memory_types, 1162 &output_memory_types)); 1163 1164 // Everything needed for OpKernel construction. 1165 OpKernelConstruction context( 1166 device_type, device, allocator, &node_def, op_def, flib, inputs, 1167 input_memory_types, outputs, output_memory_types, graph_def_version, &s); 1168 *kernel = (*registration->factory)(&context); 1169 if (!s.ok()) { 1170 delete *kernel; 1171 *kernel = nullptr; 1172 } 1173 return s; 1174 } 1175 1176 namespace { 1177 1178 bool FindArgInOp(StringPiece arg_name, 1179 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { 1180 for (const auto& arg : args) { 1181 if (arg_name == arg.name()) { 1182 return true; 1183 } 1184 } 1185 return false; 1186 } 1187 1188 } // namespace 1189 1190 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { 1191 for (const auto& key_registration : *GlobalKernelRegistryTyped()) { 1192 const KernelDef& kernel_def(key_registration.second.def); 1193 const OpRegistrationData* op_reg_data; 1194 const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); 1195 if (!status.ok()) { 1196 // TODO(josh11b): Make this a hard error. 1197 LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def) 1198 << "') for unknown op: " << kernel_def.op(); 1199 continue; 1200 } 1201 const OpDef& op_def = op_reg_data->op_def; 1202 for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { 1203 if (!FindArgInOp(host_memory_arg, op_def.input_arg()) && 1204 !FindArgInOp(host_memory_arg, op_def.output_arg())) { 1205 return errors::InvalidArgument( 1206 "HostMemory arg '", host_memory_arg, 1207 "' not found in OpDef: ", SummarizeOpDef(op_def)); 1208 } 1209 } 1210 } 1211 return Status::OK(); 1212 } 1213 1214 template <> 1215 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { 1216 return eigen_cpu_device(); 1217 } 1218 1219 template <> 1220 const Eigen::GpuDevice& OpKernelContext::eigen_device() const { 1221 return eigen_gpu_device(); 1222 } 1223 1224 #ifdef TENSORFLOW_USE_SYCL 1225 template <> 1226 const Eigen::SyclDevice& OpKernelContext::eigen_device() const { 1227 return eigen_sycl_device(); 1228 } 1229 #endif 1230 1231 void OpKernelConstruction::CtxFailure(const Status& s) { 1232 VLOG(1) << s; 1233 SetStatus(s); 1234 } 1235 1236 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) { 1237 LOG(WARNING) << s; 1238 SetStatus(s); 1239 } 1240 1241 void OpKernelConstruction::CtxFailure(const char* file, int line, 1242 const Status& s) { 1243 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line 1244 << " : " << s; 1245 SetStatus(s); 1246 } 1247 1248 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line, 1249 const Status& s) { 1250 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line 1251 << " : " << s; 1252 SetStatus(s); 1253 } 1254 1255 void OpKernelContext::CtxFailure(const Status& s) { 1256 VLOG(1) << s; 1257 SetStatus(s); 1258 } 1259 1260 void OpKernelContext::CtxFailureWithWarning(const Status& s) { 1261 LOG(WARNING) << s; 1262 SetStatus(s); 1263 } 1264 1265 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) { 1266 VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line 1267 << " : " << s; 1268 SetStatus(s); 1269 } 1270 1271 void OpKernelContext::CtxFailureWithWarning(const char* file, int line, 1272 const Status& s) { 1273 LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line 1274 << " : " << s; 1275 SetStatus(s); 1276 } 1277 1278 } // namespace tensorflow 1279