Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 
     18 #include <unordered_map>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/attr_value_util.h"
     23 #include "tensorflow/core/framework/device_attributes.pb.h"
     24 #include "tensorflow/core/framework/graph.pb_text.h"
     25 #include "tensorflow/core/framework/kernel_def.pb_text.h"
     26 #include "tensorflow/core/framework/log_memory.h"
     27 #include "tensorflow/core/framework/memory_types.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 #include "tensorflow/core/framework/node_def_util.h"
     30 #include "tensorflow/core/framework/op_def_util.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/graph/graph.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/core/notification.h"
     35 #include "tensorflow/core/lib/core/stringpiece.h"
     36 #include "tensorflow/core/lib/gtl/map_util.h"
     37 #include "tensorflow/core/lib/io/path.h"
     38 #include "tensorflow/core/lib/strings/str_util.h"
     39 #include "tensorflow/core/lib/strings/strcat.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/mutex.h"
     42 #include "tensorflow/core/platform/types.h"
     43 
     44 namespace tensorflow {
     45 
     46 namespace {
     47 
     48 Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
     49                             const DataTypeSlice expected_outputs,
     50                             const DataTypeSlice inputs,
     51                             const DataTypeSlice outputs) {
     52   bool signature_mismatch = false;
     53 
     54   if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
     55   for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
     56     if (!TypesCompatible(expected_inputs[i], inputs[i])) {
     57       signature_mismatch = true;
     58     }
     59   }
     60 
     61   if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
     62   for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
     63     if (!TypesCompatible(expected_outputs[i], outputs[i])) {
     64       signature_mismatch = true;
     65     }
     66   }
     67 
     68   if (signature_mismatch) {
     69     return errors::InvalidArgument(
     70         "Signature mismatch, have: ", DataTypeSliceString(inputs), "->",
     71         DataTypeSliceString(outputs),
     72         " expected: ", DataTypeSliceString(expected_inputs), "->",
     73         DataTypeSliceString(expected_outputs));
     74   }
     75   return Status::OK();
     76 }
     77 
     78 }  // namespace
     79 
     80 // OpKernel ------------------------------------------------------------------
     81 
     82 // TODO(mrry): Convert to std::make_unique when available.
     83 OpKernel::OpKernel(OpKernelConstruction* context)
     84     : OpKernel(context,
     85                std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
     86 
     87 OpKernel::OpKernel(OpKernelConstruction* context,
     88                    std::unique_ptr<const NodeDef> node_def)
     89     : def_(std::move(node_def)),
     90       input_types_(context->input_types().begin(),
     91                    context->input_types().end()),
     92       input_memory_types_(context->input_memory_types().begin(),
     93                           context->input_memory_types().end()),
     94       output_types_(context->output_types().begin(),
     95                     context->output_types().end()),
     96       output_memory_types_(context->output_memory_types().begin(),
     97                            context->output_memory_types().end()),
     98       graph_def_version_(context->graph_def_version()),
     99       is_internal_(StringPiece(type_string()).starts_with("_")),
    100       input_name_map_(context->num_inputs()),
    101       output_name_map_(context->num_outputs()) {
    102   OP_REQUIRES_OK(context,
    103                  NameRangesForNode(*def_, *context->op_def_, &input_name_map_,
    104                                    &output_name_map_));
    105   OP_REQUIRES_OK(context, CheckOpDeprecation(*context->op_def_,
    106                                              context->graph_def_version()));
    107 
    108   // Kernels executing on GPU/SYCL tie very few resources on the CPU where the
    109   // scheduler runs: we consider them as inexpensive.
    110   expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
    111                context->device_type() != DeviceType(DEVICE_SYCL);
    112 }
    113 
    114 OpKernel::~OpKernel() {}
    115 
    116 const string& OpKernel::name() const { return def_->name(); }
    117 const string& OpKernel::type_string() const { return def_->op(); }
    118 const string& OpKernel::requested_device() const { return def_->device(); }
    119 const string& OpKernel::requested_input(int i) const { return def_->input(i); }
    120 
    121 Status OpKernel::InputRange(StringPiece input_name, int* start,
    122                             int* stop) const {
    123   const auto result = input_name_map_.find(input_name);
    124   if (result == input_name_map_.end()) {
    125     return errors::InvalidArgument("Unknown input name: ", input_name);
    126   } else {
    127     *start = result->second.first;
    128     *stop = result->second.second;
    129     return Status::OK();
    130   }
    131 }
    132 
    133 Status OpKernel::OutputRange(StringPiece output_name, int* start,
    134                              int* stop) const {
    135   const auto result = output_name_map_.find(output_name);
    136   if (result == output_name_map_.end()) {
    137     return errors::InvalidArgument("Unknown output name: ", output_name);
    138   } else {
    139     *start = result->second.first;
    140     *stop = result->second.second;
    141     return Status::OK();
    142   }
    143 }
    144 
    145 Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
    146   if (!IsLegacyVector(shape.shape())) {
    147     return errors::InvalidArgument(
    148         "shape must be a vector of {int32,int64}, got shape ",
    149         shape.shape().DebugString());
    150   }
    151   if (shape.dtype() == DataType::DT_INT32) {
    152     auto vec = shape.flat<int32>();
    153     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
    154   } else if (shape.dtype() == DataType::DT_INT64) {
    155     auto vec = shape.flat<int64>();
    156     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
    157   } else {
    158     return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
    159   }
    160 }
    161 
    162 void AsyncOpKernel::Compute(OpKernelContext* context) {
    163   Notification n;
    164   ComputeAsync(context, [&n]() { n.Notify(); });
    165   n.WaitForNotification();
    166 }
    167 
    168 // PersistentTensor ----------------------------------------------------------
    169 
    170 Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
    171   // the caller has to have a valid context
    172   CHECK(context);
    173   return &tensor_;
    174 }
    175 
    176 Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
    177   context->NotifyUseOfPersistentTensor(tensor_);
    178   return &tensor_;
    179 }
    180 
    181 // OpKernelConstruction ------------------------------------------------------
    182 
    183 OpKernelConstruction::OpKernelConstruction(
    184     DeviceType device_type, DeviceBase* device, Allocator* allocator,
    185     const NodeDef* node_def, const OpDef* op_def, FunctionLibraryRuntime* flib,
    186     const DataTypeSlice& input_types, const MemoryTypeSlice& input_memory_types,
    187     const DataTypeSlice& output_types,
    188     const MemoryTypeSlice& output_memory_types, int graph_def_version,
    189     Status* status)
    190     : device_type_(std::move(device_type)),
    191       device_(device),
    192       allocator_(allocator),
    193       def_(node_def),
    194       op_def_(op_def),
    195       flib_(flib),
    196       input_types_(input_types),
    197       input_memory_types_(input_memory_types),
    198       output_types_(output_types),
    199       output_memory_types_(output_memory_types),
    200       graph_def_version_(graph_def_version),
    201       status_(status) {}
    202 
    203 bool OpKernelConstruction::HasAttr(StringPiece attr_name) const {
    204   return HasNodeAttr(def(), attr_name);
    205 }
    206 
    207 void OpKernelConstruction::SetStatus(const Status& status) {
    208   status_->Update(status);
    209 }
    210 
    211 Status OpKernelConstruction::MatchSignature(
    212     const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
    213   return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
    214                               output_types_);
    215 }
    216 
    217 Status OpKernelConstruction::allocate_temp(DataType type,
    218                                            const TensorShape& shape,
    219                                            Tensor* out_temp) {
    220   AllocationAttributes attr;
    221   attr.allocation_will_be_logged = true;
    222   Tensor new_temp(allocator_, type, shape, attr);
    223 
    224   if (!new_temp.IsInitialized()) {
    225     return errors::ResourceExhausted(
    226         "OOM when allocating temporary tensor with shape", shape.DebugString());
    227   }
    228   if (LogMemory::IsEnabled()) {
    229     LogMemory::RecordTensorAllocation(
    230         def_->name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp);
    231   }
    232   *out_temp = new_temp;
    233   return Status::OK();
    234 }
    235 
    236 Status OpKernelConstruction::allocate_persistent(
    237     DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
    238     Tensor** out_tensor) {
    239   // for now just do the same thing as allocate_temp
    240   // TODO(misard) add specific memory tracking for persistent tensors
    241   Tensor persistent;
    242   Status s = allocate_temp(type, shape, &persistent);
    243   if (!s.ok()) {
    244     return s;
    245   }
    246   *out_persistent = PersistentTensor(persistent);
    247   Tensor* allocated = out_persistent->AccessTensor(this);
    248   if (out_tensor) {
    249     *out_tensor = allocated;
    250   }
    251   return s;
    252 }
    253 
    254 // OpKernelContext -----------------------------------------------------------
    255 
    256 OpKernelContext::OpKernelContext(Params* params)
    257     : OpKernelContext(
    258           params, static_cast<int>(params->op_kernel->output_types().size())) {}
    259 
    260 OpKernelContext::OpKernelContext(Params* params, int num_outputs)
    261     : params_(params),
    262       outputs_(num_outputs),
    263       temp_memory_allocated_(0),
    264       persistent_memory_allocated_(0) {
    265   Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
    266   params_->ensure_eigen_gpu_device();
    267   params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
    268                                          params_->op_device_context,
    269                                          eigen_gpu_allocator);
    270   if (params_->record_tensor_accesses) {
    271     referenced_tensors_.Init();
    272   }
    273 }
    274 
    275 OpKernelContext::~OpKernelContext() {
    276   for (TensorValue& value : outputs_) {
    277     if (!value.is_ref()) {
    278       delete value.tensor;
    279     }
    280   }
    281   if (params_->record_tensor_accesses) referenced_tensors_.Destroy();
    282 }
    283 
    284 Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
    285   Allocator* allocator =
    286       params_->device->GetStepAllocator(attr, resource_manager());
    287   if (track_allocations()) {
    288     mutex_lock lock(mu_);
    289     for (const auto& wrapped : wrapped_allocators_) {
    290       if (wrapped.first == allocator) {
    291         return wrapped.second;
    292       }
    293     }
    294     TrackingAllocator* wrapped_allocator =
    295         new TrackingAllocator(allocator, params_->track_allocations);
    296     wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
    297     return wrapped_allocator;
    298   } else {
    299     return allocator;
    300   }
    301 }
    302 
    303 void OpKernelContext::SetStatus(const Status& status) {
    304   status_.Update(status);
    305 }
    306 
    307 void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
    308   mutex_lock l(mu_);
    309   // Keep a reference to the underlying memory around.
    310   referenced_tensors_->Add(tensor);
    311 }
    312 
    313 Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
    314   int start, stop;
    315   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    316   if (stop != start + 1) {
    317     return errors::InvalidArgument("OpKernel used list-valued input name '",
    318                                    name,
    319                                    "' when single-valued input was "
    320                                    "expected");
    321   }
    322   if (input_is_ref(start)) {
    323     return errors::InvalidArgument("OpKernel used ref input name '", name,
    324                                    "' when non-ref input was expected");
    325   }
    326   *tensor = (*params_->inputs)[start].tensor;
    327   record_tensor_reference(**tensor);
    328   return Status::OK();
    329 }
    330 
    331 Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
    332   int start, stop;
    333   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    334   if (stop != start + 1) {
    335     return errors::InvalidArgument("OpKernel used list-valued input name '",
    336                                    name,
    337                                    "' when single-valued input was "
    338                                    "expected");
    339   }
    340   const TensorValue& value((*params_->inputs)[start]);
    341   if (value.is_ref()) {
    342     *dtype = MakeRefType(value->dtype());
    343   } else {
    344     *dtype = value->dtype();
    345   }
    346   return Status::OK();
    347 }
    348 
    349 Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
    350   int start, stop;
    351   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    352   if (stop != start + 1) {
    353     return errors::InvalidArgument("OpKernel used list-valued input name '",
    354                                    name,
    355                                    "' when single-valued input was expected");
    356   }
    357   *out_mutex = input_ref_mutex(start);
    358   return Status::OK();
    359 }
    360 
    361 const Tensor& OpKernelContext::input(int index) {
    362   DCHECK_GE(index, 0);
    363   DCHECK_LT(index, num_inputs());
    364   DCHECK(!input_is_ref(index));
    365   const Tensor& tensor = *((*params_->inputs)[index].tensor);
    366   record_tensor_reference(tensor);
    367   return tensor;
    368 }
    369 
    370 Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
    371   DCHECK_GE(index, 0);
    372   DCHECK_LT(index, num_inputs());
    373   DCHECK(input_is_ref(index));
    374   // return a copy of the Ref acquired while holding the mutex
    375   if (lock_held) {
    376     Tensor& tensor = *((*params_->inputs)[index].tensor);
    377     record_tensor_reference(tensor);
    378     return tensor;
    379   } else {
    380     mutex_lock l(*input_ref_mutex(index));
    381     Tensor& tensor = *((*params_->inputs)[index].tensor);
    382     record_tensor_reference(tensor);
    383     return tensor;
    384   }
    385 }
    386 
    387 void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
    388                                         bool lock_held) {
    389   DCHECK_GE(index, 0);
    390   DCHECK_LT(index, num_inputs());
    391   DCHECK(input_is_ref(index));
    392   // should only modify the tensor while holding the mutex
    393   if (lock_held) {
    394     *(*params_->inputs)[index].tensor = tensor;
    395   } else {
    396     mutex_lock l(*input_ref_mutex(index));
    397     *(*params_->inputs)[index].tensor = tensor;
    398   }
    399   record_tensor_reference(tensor);
    400 }
    401 
    402 void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
    403                                                       int output_index) {
    404   DCHECK_GE(input_index, 0);
    405   DCHECK_LT(input_index, num_inputs());
    406   DCHECK(input_is_ref(input_index));
    407   set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
    408                  (*params_->inputs)[input_index].tensor);
    409 }
    410 
    411 bool OpKernelContext::forward_input_to_output_with_shape(
    412     int input_index, int output_index, const TensorShape& output_shape,
    413     Tensor** output) {
    414   const auto output_attr = params_->output_attr_array == nullptr
    415                                ? AllocatorAttributes()
    416                                : output_alloc_attr(output_index);
    417   std::unique_ptr<Tensor> new_tensor = forward_input(
    418       input_index, expected_output_dtype(output_index), output_shape,
    419       output_memory_type(output_index), output_attr);
    420   if (new_tensor != nullptr) {
    421     // Transfer ownership to the output slot in OpKernelContext.
    422     outputs_[output_index] = TensorValue(new_tensor.release());
    423     *output = outputs_[output_index].tensor;
    424     return true;
    425   } else {
    426     return false;
    427   }
    428 }
    429 
    430 Status OpKernelContext::forward_input_to_output_with_shape(
    431     StringPiece input_name, StringPiece output_name,
    432     const TensorShape& output_shape, Tensor** output) {
    433   int input_index, output_index, stop;
    434   TF_RETURN_IF_ERROR(
    435       params_->op_kernel->InputRange(input_name, &input_index, &stop));
    436   if (stop != input_index + 1) {
    437     return errors::InvalidArgument("OpKernel used list-valued input name '",
    438                                    input_name,
    439                                    "' when single-valued input was "
    440                                    "expected");
    441   }
    442   TF_RETURN_IF_ERROR(
    443       params_->op_kernel->OutputRange(output_name, &output_index, &stop));
    444   if (stop != output_index + 1) {
    445     return errors::InvalidArgument("OpKernel used list-valued output name '",
    446                                    output_name,
    447                                    "' when single-valued output was "
    448                                    "expected");
    449   }
    450   if (!forward_input_to_output_with_shape(input_index, output_index,
    451                                           output_shape, output)) {
    452     return errors::FailedPrecondition("OpKernel could not forward input '",
    453                                       input_name, "' to output '", output_name);
    454   }
    455   return Status::OK();
    456 }
    457 
    458 std::unique_ptr<Tensor> OpKernelContext::forward_input(
    459     int input_index, DataType output_dtype, const TensorShape& output_shape,
    460     MemoryType output_memory_type, const AllocatorAttributes& output_attr) {
    461   DCHECK_GE(input_index, 0);
    462   DCHECK_LT(input_index, num_inputs());
    463   const TensorValue& input = (*params_->inputs)[input_index];
    464   // Check that input tensor exists, is not a ref, and has no other consumers.
    465   if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) {
    466     return nullptr;
    467   }
    468   // Check that input type matches.
    469   if (input_dtype(input_index) != output_dtype) {
    470     return nullptr;
    471   }
    472   // Check that the input and output sizes are compatible.
    473   if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
    474     return nullptr;
    475   }
    476   // Check that input and output memory types match, i.e.
    477   // that they either both live in host or both live in device memory.
    478   if (input_memory_type(input_index) != output_memory_type) {
    479     return nullptr;
    480   }
    481   // Check that output allocator attributes are not more restrictive than
    482   // input allocator attributes.
    483   const auto input_attr = params_->input_alloc_attrs == nullptr
    484                               ? AllocatorAttributes()
    485                               : input_alloc_attr(input_index);
    486   if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
    487     return nullptr;
    488   }
    489   // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
    490   // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
    491   // general cleanup of ownership in this code.
    492   std::unique_ptr<Tensor> output_tensor(new Tensor());
    493   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
    494   return output_tensor;
    495 }
    496 
    497 Status OpKernelContext::forward_input_or_allocate_temp(
    498     gtl::ArraySlice<int> candidate_input_indices, DataType type,
    499     const TensorShape& shape, const AllocatorAttributes& allocator_attr,
    500     Tensor* out_temp) {
    501   for (int input_index : candidate_input_indices) {
    502     std::unique_ptr<Tensor> new_tensor =
    503         forward_input(input_index, type, shape, DEVICE_MEMORY, allocator_attr);
    504     if (new_tensor != nullptr) {
    505       *out_temp = std::move(*new_tensor);
    506       return Status::OK();
    507     }
    508   }
    509   return allocate_temp(type, shape, out_temp, allocator_attr);
    510 }
    511 
    512 void OpKernelContext::delete_ref_input(int index, bool lock_held) {
    513   DCHECK_GE(index, 0);
    514   DCHECK_LT(index, num_inputs());
    515   DCHECK(input_is_ref(index));
    516   // should only modify the tensor while holding the mutex
    517   if (lock_held) {
    518     delete (*params_->inputs)[index].tensor;
    519   } else {
    520     mutex_lock l(*input_ref_mutex(index));
    521     delete (*params_->inputs)[index].tensor;
    522   }
    523 }
    524 
    525 Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
    526                                       bool lock_held) {
    527   int start, stop;
    528   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    529   if (stop != start + 1) {
    530     return errors::InvalidArgument("OpKernel used list-valued input name '",
    531                                    name,
    532                                    "' when single-valued input was expected");
    533   }
    534   if (!input_is_ref(start)) {
    535     return errors::InvalidArgument("OpKernel used non-ref input name '", name,
    536                                    "' when ref input was expected");
    537   }
    538   // return a copy of the Ref acquired while holding the mutex
    539   if (lock_held) {
    540     *tensor = *(*params_->inputs)[start].tensor;
    541   } else {
    542     mutex_lock l(*input_ref_mutex(start));
    543     *tensor = *(*params_->inputs)[start].tensor;
    544   }
    545   record_tensor_reference(*tensor);
    546   return Status::OK();
    547 }
    548 
    549 Status OpKernelContext::replace_ref_input(StringPiece name,
    550                                           const Tensor& tensor,
    551                                           bool lock_held) {
    552   int start, stop;
    553   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    554   if (stop != start + 1) {
    555     return errors::InvalidArgument("OpKernel used list-valued input name '",
    556                                    name,
    557                                    "' when single-valued input was expected");
    558   }
    559   if (!input_is_ref(start)) {
    560     return errors::InvalidArgument("OpKernel used immutable input name '", name,
    561                                    "' when ref input was expected");
    562   }
    563   replace_ref_input(start, tensor, lock_held);
    564   return Status::OK();
    565 }
    566 
    567 Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
    568   int start, stop;
    569   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    570   *list = OpInputList(this, start, stop);
    571   return Status::OK();
    572 }
    573 
    574 Status OpKernelContext::mutable_input_list(StringPiece name,
    575                                            OpMutableInputList* list) {
    576   int start, stop;
    577   TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
    578   *list = OpMutableInputList(this, start, stop);
    579   return Status::OK();
    580 }
    581 
    582 Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
    583   int start, stop;
    584   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    585   *list = OpOutputList(this, start, stop);
    586   return Status::OK();
    587 }
    588 
    589 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
    590                                         Tensor** output) {
    591   DCHECK_GE(index, 0);
    592   DCHECK_LT(index, num_outputs());
    593   AllocatorAttributes attr = output_alloc_attr(index);
    594   return allocate_output(index, shape, output, attr);
    595 }
    596 
    597 Status OpKernelContext::allocate_output(StringPiece name,
    598                                         const TensorShape& shape,
    599                                         Tensor** tensor) {
    600   int start, stop;
    601   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    602   if (stop != start + 1) {
    603     return errors::InvalidArgument("OpKernel used list-valued output name '",
    604                                    name,
    605                                    "' when single-valued output was "
    606                                    "expected");
    607   }
    608   return allocate_output(start, shape, tensor);
    609 }
    610 
    611 Status OpKernelContext::allocate_output(StringPiece name,
    612                                         const TensorShape& shape,
    613                                         Tensor** tensor,
    614                                         AllocatorAttributes attr) {
    615   int start, stop;
    616   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    617   if (stop != start + 1) {
    618     return errors::InvalidArgument("OpKernel used list-valued output name '",
    619                                    name,
    620                                    "' when single-valued output was "
    621                                    "expected");
    622   }
    623   return allocate_output(start, shape, tensor, attr);
    624 }
    625 
    626 Status OpKernelContext::allocate_tensor(
    627     DataType type, const TensorShape& shape, Tensor* out_tensor,
    628     AllocatorAttributes attr, const AllocationAttributes& allocation_attr) {
    629   Allocator* a = get_allocator(attr);
    630   AllocationAttributes logged_attr(allocation_attr);
    631   logged_attr.allocation_will_be_logged = true;
    632   Tensor new_tensor(a, type, shape, logged_attr);
    633 
    634   if (!new_tensor.IsInitialized()) {
    635     return errors::ResourceExhausted(
    636         "OOM when allocating tensor with shape", shape.DebugString(),
    637         " and type ", DataTypeString(type), " on ", params_->device->name(),
    638         " by allocator ", a->Name());
    639   }
    640   if (params_->log_memory) {
    641     LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
    642                                       params_->step_id, new_tensor);
    643   }
    644   record_tensor_reference(new_tensor);
    645   *out_tensor = std::move(new_tensor);
    646   return Status::OK();
    647 }
    648 
    649 Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
    650                                         Tensor** output,
    651                                         AllocatorAttributes attr) {
    652   DCHECK_GE(index, 0);
    653   DCHECK_LT(index, outputs_.size());
    654   const DataType type = params_->op_kernel->output_type(index);
    655   DCHECK(!IsRefType(type));
    656   DCHECK(mutable_output(index) == nullptr);
    657   Tensor* output_tensor = new Tensor();
    658   Status s = allocate_tensor(type, shape, output_tensor, attr);
    659   if (s.ok()) {
    660     outputs_[index] = TensorValue(output_tensor);
    661     *output = outputs_[index].tensor;
    662   }
    663   return s;
    664 }
    665 
    666 Status OpKernelContext::allocate_temp(
    667     DataType type, const TensorShape& shape, Tensor* out_temp,
    668     AllocatorAttributes allocator_attr,
    669     const AllocationAttributes& allocation_attr) {
    670   Status s =
    671       allocate_tensor(type, shape, out_temp, allocator_attr, allocation_attr);
    672   if (track_allocations() && s.ok() && out_temp->TotalBytes() > 0) {
    673     Allocator* a = get_allocator(allocator_attr);
    674     if (a->TracksAllocationSizes()) {
    675       int64 alloc_size = a->AllocatedSize(out_temp->tensor_data().data());
    676       record_temp_memory_allocation(alloc_size, *out_temp);
    677     }
    678   }
    679   return s;
    680 }
    681 
    682 Status OpKernelContext::allocate_persistent(DataType type,
    683                                             const TensorShape& shape,
    684                                             PersistentTensor* out_persistent,
    685                                             Tensor** out_tensor,
    686                                             AllocatorAttributes attr) {
    687   Tensor persistent;
    688   Status s = allocate_tensor(type, shape, &persistent, attr);
    689   if (s.ok()) {
    690     *out_persistent = PersistentTensor(persistent);
    691     if (out_tensor) {
    692       *out_tensor = out_persistent->AccessTensor(this);
    693     }
    694     if (track_allocations()) {
    695       Tensor* t = out_persistent->AccessTensor(this);
    696       Allocator* a = get_allocator(attr);
    697       if (a->TracksAllocationSizes()) {
    698         int64 alloc_size = a->AllocatedSize(t->tensor_data().data());
    699         int64 alloc_id = a->AllocationId(t->tensor_data().data());
    700         record_persistent_memory_allocation(alloc_size, alloc_id);
    701       }
    702     }
    703   }
    704   return s;
    705 }
    706 
    707 Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
    708   int start, stop;
    709   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    710   if (stop != start + 1) {
    711     return errors::InvalidArgument("OpKernel used list-valued output name '",
    712                                    name,
    713                                    "' when single-valued output was "
    714                                    "expected");
    715   }
    716   set_output(start, tensor);
    717   return Status::OK();
    718 }
    719 
    720 void OpKernelContext::set_output(int index, const Tensor& tensor) {
    721   DCHECK_GE(index, 0);
    722   DCHECK_LT(index, outputs_.size());
    723   DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
    724   DCHECK_EQ(mutable_output(index), nullptr);
    725   record_tensor_reference(tensor);
    726   outputs_[index] = TensorValue(new Tensor(tensor));
    727   if (track_allocations() && tensor.TotalBytes() > 0) {
    728     mutex_lock l(stats_mu_);
    729     if (!temp_tensor_buffer_and_size_) {
    730       return;
    731     }
    732     auto it = std::find_if(temp_tensor_buffer_and_size_->begin(),
    733                            temp_tensor_buffer_and_size_->end(),
    734                            [&tensor](const std::pair<const void*, int64>& e) {
    735                              return e.first == static_cast<const void*>(
    736                                                    tensor.tensor_data().data());
    737                            });
    738     if (it != temp_tensor_buffer_and_size_->end()) {
    739       temp_memory_allocated_ -= it->second;
    740       temp_tensor_buffer_and_size_->erase(it);
    741     }
    742   }
    743 }
    744 
    745 void OpKernelContext::set_output_ref(int index, mutex* mu,
    746                                      Tensor* tensor_for_ref) {
    747   DCHECK_GE(index, 0);
    748   DCHECK_LT(index, outputs_.size());
    749   DCHECK(IsRefType(params_->op_kernel->output_type(index)));
    750   record_tensor_reference(*tensor_for_ref);
    751   outputs_[index] = TensorValue(mu, tensor_for_ref);
    752 }
    753 
    754 Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
    755                                        Tensor* tensor_for_ref) {
    756   int start, stop;
    757   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    758   if (stop != start + 1) {
    759     return errors::InvalidArgument("OpKernel used list-valued output name '",
    760                                    name,
    761                                    "' when single-valued output was "
    762                                    "expected");
    763   }
    764   set_output_ref(start, mu, tensor_for_ref);
    765   return Status::OK();
    766 }
    767 
    768 Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
    769   int start, stop;
    770   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    771   if (stop != start + 1) {
    772     return errors::InvalidArgument("OpKernel used list-valued output name '",
    773                                    name,
    774                                    "' when single-valued output was "
    775                                    "expected");
    776   }
    777   *tensor = mutable_output(start);
    778   return Status::OK();
    779 }
    780 
    781 Status OpKernelContext::release_output(StringPiece name, TensorValue* value) {
    782   int start, stop;
    783   TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
    784   if (stop != start + 1) {
    785     return errors::InvalidArgument("OpKernel used list-valued output name '",
    786                                    name,
    787                                    "' when single-valued output was "
    788                                    "expected");
    789   }
    790   *value = release_output(start);
    791   return Status::OK();
    792 }
    793 
    794 bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
    795   const auto& inputs = *params_->inputs;
    796   for (size_t i = 1; i < inputs.size(); ++i) {
    797     if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
    798       SetStatus(errors::InvalidArgument(
    799           "Inputs to operation ", op->name(), " of type ", op->type_string(),
    800           " must have the same size and shape.  Input 0: ",
    801           inputs[0]->shape().DebugString(), " != input ", i, ": ",
    802           inputs[i]->shape().DebugString()));
    803       return false;
    804     }
    805   }
    806   return true;
    807 }
    808 
    809 Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
    810                                        const DataTypeSlice expected_outputs) {
    811   DataTypeVector inputs;
    812   for (const TensorValue& t : *params_->inputs) {
    813     inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
    814   }
    815   DataTypeVector outputs = params_->op_kernel->output_types();
    816   return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
    817                               outputs);
    818 }
    819 
    820 void OpKernelContext::record_temp_memory_allocation(int64 size,
    821                                                     const Tensor& t) {
    822   mutex_lock l(stats_mu_);
    823   temp_memory_allocated_ += size;
    824   if (!temp_tensor_buffer_and_size_) {
    825     temp_tensor_buffer_and_size_.reset(
    826         new gtl::InlinedVector<std::pair<const void*, int64>, 2>());
    827   }
    828   temp_tensor_buffer_and_size_->emplace_back(
    829       static_cast<const void*>(t.tensor_data().data()), size);
    830 }
    831 
    832 int64 OpKernelContext::temp_memory_allocated() const {
    833   mutex_lock l(stats_mu_);
    834   return temp_memory_allocated_;
    835 }
    836 
    837 void OpKernelContext::record_persistent_memory_allocation(int64 size,
    838                                                           int64 alloc_id) {
    839   mutex_lock l(stats_mu_);
    840   persistent_memory_allocated_ += size;
    841   if (alloc_id >= 0) {
    842     if (!persistent_alloc_ids_) {
    843       persistent_alloc_ids_.reset(new gtl::InlinedVector<int64, 2>());
    844     }
    845     persistent_alloc_ids_->push_back(alloc_id);
    846   }
    847 }
    848 
    849 int64 OpKernelContext::persistent_memory_allocated() const {
    850   mutex_lock l(stats_mu_);
    851   return persistent_memory_allocated_;
    852 }
    853 
    854 std::vector<int64> OpKernelContext::persistent_alloc_ids() const {
    855   mutex_lock l(stats_mu_);
    856   if (persistent_alloc_ids_) {
    857     return std::vector<int64>(persistent_alloc_ids_->begin(),
    858                               persistent_alloc_ids_->end());
    859   } else {
    860     return std::vector<int64>();
    861   }
    862 }
    863 
    864 void OpKernelContext::clear_recorded_memory() {
    865   mutex_lock l(stats_mu_);
    866   temp_memory_allocated_ = 0;
    867   persistent_memory_allocated_ = 0;
    868   if (temp_tensor_buffer_and_size_) {
    869     temp_tensor_buffer_and_size_->clear();
    870   }
    871   if (persistent_alloc_ids_) {
    872     persistent_alloc_ids_->clear();
    873   }
    874 }
    875 
    876 // OpKernel registration ------------------------------------------------------
    877 
    878 struct KernelRegistration {
    879   KernelRegistration(const KernelDef& d, StringPiece c,
    880                      kernel_factory::OpKernelRegistrar::Factory f)
    881       : def(d), kernel_class_name(c.ToString()), factory(f) {}
    882   const KernelDef def;
    883   const string kernel_class_name;
    884   const kernel_factory::OpKernelRegistrar::Factory factory;
    885 };
    886 
    887 // This maps from 'op_type' + DeviceType to the set of KernelDefs and
    888 // factory functions for instantiating the OpKernel that matches the
    889 // KernelDef.
    890 typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
    891 
    892 void* GlobalKernelRegistry() {
    893   static KernelRegistry* global_kernel_registry = new KernelRegistry;
    894   return global_kernel_registry;
    895 }
    896 
    897 static KernelRegistry* GlobalKernelRegistryTyped() {
    898   return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
    899 }
    900 
    901 static string Key(StringPiece op_type, const DeviceType& device_type,
    902                   StringPiece label) {
    903   return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
    904                          label);
    905 }
    906 
    907 namespace kernel_factory {
    908 
    909 void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
    910                                      StringPiece kernel_class_name,
    911                                      Factory factory) {
    912   // See comments in register_kernel::Name in header for info on _no_register.
    913   if (kernel_def->op() != "_no_register") {
    914     const string key =
    915         Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
    916             kernel_def->label());
    917     GlobalKernelRegistryTyped()->insert(std::make_pair(
    918         key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
    919   }
    920   delete kernel_def;
    921 }
    922 
    923 }  // namespace kernel_factory
    924 
    925 namespace {
    926 
    927 // Helper for AttrsMatch().
    928 bool InTypeList(DataType dt, const AttrValue& type_list) {
    929   for (int in_list : type_list.list().type()) {
    930     if (dt == in_list) return true;
    931   }
    932   return false;
    933 }
    934 
    935 // Returns whether the attrs satisfy the constraints in the kernel_def.  Returns
    936 // an error if attrs in kernel_def are not found, or have a mismatching type.
    937 Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
    938   *match = false;
    939   for (const auto& constraint : kernel_def.constraint()) {
    940     if (constraint.allowed_values().list().type_size() == 0) {
    941       return errors::Unimplemented(
    942           "KernelDef '", ProtoShortDebugString(kernel_def),
    943           " has constraint on attr '", constraint.name(),
    944           "' with unsupported type: ",
    945           SummarizeAttrValue(constraint.allowed_values()));
    946     }
    947 
    948     const AttrValue* found = attrs.Find(constraint.name());
    949     if (found) {
    950       if (found->type() != DT_INVALID) {
    951         if (!InTypeList(found->type(), constraint.allowed_values())) {
    952           return Status::OK();
    953         }
    954       } else {
    955         if (!AttrValueHasType(*found, "list(type)").ok()) {
    956           return errors::InvalidArgument(
    957               "KernelDef '", ProtoShortDebugString(kernel_def),
    958               "' has constraint on attr '", constraint.name(),
    959               "' that has value '", SummarizeAttrValue(*found),
    960               "' that does not have type 'type' or 'list(type)' in NodeDef "
    961               "'",
    962               attrs.SummarizeNode(), "'");
    963         }
    964 
    965         for (int t : found->list().type()) {
    966           if (!InTypeList(static_cast<DataType>(t),
    967                           constraint.allowed_values())) {
    968             return Status::OK();
    969           }
    970         }
    971       }
    972     } else {
    973       return errors::InvalidArgument(
    974           "OpKernel '", kernel_def.op(), "' has constraint on attr '",
    975           constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
    976           "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
    977     }
    978   }
    979   *match = true;
    980   return Status::OK();
    981 }
    982 
    983 static const StringPiece kKernelAttr("_kernel");
    984 
    985 // TODO(irving): Replace with const Node& version below.
    986 Status FindKernelRegistration(const DeviceType& device_type,
    987                               const NodeDef& node_def,
    988                               const KernelRegistration** reg,
    989                               bool* was_attr_mismatch) {
    990   *reg = nullptr;
    991   *was_attr_mismatch = false;
    992   // Label defaults to empty if not found in NodeDef.
    993   const string& label = GetNodeAttrString(node_def, kKernelAttr);
    994 
    995   const string key = Key(node_def.op(), device_type, label);
    996   auto regs = GlobalKernelRegistryTyped()->equal_range(key);
    997   for (auto iter = regs.first; iter != regs.second; ++iter) {
    998     // If there is a kernel registered for the op and device_type,
    999     // check that the attrs match.
   1000     bool match;
   1001     TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
   1002     if (match) {
   1003       if (*reg != nullptr) {
   1004         return errors::InvalidArgument(
   1005             "Multiple OpKernel registrations match NodeDef '",
   1006             SummarizeNodeDef(node_def), "': '",
   1007             ProtoShortDebugString((*reg)->def), "' and '",
   1008             ProtoShortDebugString(iter->second.def), "'");
   1009       }
   1010       *reg = &iter->second;
   1011     } else {
   1012       *was_attr_mismatch = true;
   1013     }
   1014   }
   1015   return Status::OK();
   1016 }
   1017 
   1018 }  // namespace
   1019 
   1020 // TODO(irving): Change const NodeDef& to const Node&
   1021 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
   1022                      const KernelDef** def, string* kernel_class_name) {
   1023   const KernelRegistration* reg = nullptr;
   1024   bool was_attr_mismatch;
   1025   TF_RETURN_IF_ERROR(
   1026       FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch));
   1027   if (reg == nullptr) {
   1028     Status s = errors::NotFound(
   1029         "No registered '", node_def.op(), "' OpKernel for ",
   1030         DeviceTypeString(device_type), " devices compatible with node ",
   1031         SummarizeNodeDef(node_def));
   1032     if (was_attr_mismatch) {
   1033       errors::AppendToMessage(
   1034           &s, " (OpKernel was found, but attributes didn't match)");
   1035     }
   1036     errors::AppendToMessage(
   1037         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
   1038     return s;
   1039   }
   1040   if (def != nullptr) *def = &reg->def;
   1041   if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
   1042   return Status::OK();
   1043 }
   1044 
   1045 Status SupportedDeviceTypesForNode(
   1046     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
   1047     DeviceTypeVector* device_types) {
   1048   // TODO(zhifengc): Changes the callers (SimplePlacer and
   1049   // DynamicPlacer) to consider the possibility that 'def' is call to
   1050   // a user-defined function and only calls this
   1051   // SupportedDeviceTypesForNode for primitive ops.
   1052   const OpRegistrationData* op_reg_data;
   1053   const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
   1054   if (s.ok()) {
   1055     for (const DeviceType& device_type : prioritized_types) {
   1056       const KernelRegistration* reg = nullptr;
   1057       bool was_attr_mismatch;
   1058       TF_RETURN_IF_ERROR(
   1059           FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
   1060       if (reg != nullptr) device_types->push_back(device_type);
   1061     }
   1062   } else {
   1063     // Assumes that all device types support this node.
   1064     for (const DeviceType& device_type : prioritized_types) {
   1065       device_types->push_back(device_type);
   1066     }
   1067   }
   1068   return Status::OK();
   1069 }
   1070 
   1071 void LogAllRegisteredKernels() {
   1072   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
   1073     const KernelDef& kernel_def(key_registration.second.def);
   1074     LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
   1075   }
   1076 }
   1077 
   1078 string KernelsRegisteredForOp(StringPiece op_name) {
   1079   string ret;
   1080   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
   1081     const KernelDef& kernel_def(key_registration.second.def);
   1082     if (kernel_def.op() == op_name) {
   1083       strings::StrAppend(&ret, "  device='", kernel_def.device_type(), "'");
   1084       if (!kernel_def.label().empty()) {
   1085         strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
   1086       }
   1087       for (int i = 0; i < kernel_def.constraint_size(); ++i) {
   1088         strings::StrAppend(
   1089             &ret, "; ", kernel_def.constraint(i).name(), " in ",
   1090             SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
   1091       }
   1092       strings::StrAppend(&ret, "\n");
   1093     }
   1094   }
   1095   if (ret.empty()) return "  <no registered kernels>\n";
   1096   return ret;
   1097 }
   1098 
   1099 std::unique_ptr<OpKernel> CreateOpKernel(
   1100     DeviceType device_type, DeviceBase* device, Allocator* allocator,
   1101     const NodeDef& node_def, int graph_def_version, Status* status) {
   1102   OpKernel* kernel = nullptr;
   1103   *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr,
   1104                            node_def, graph_def_version, &kernel);
   1105   return std::unique_ptr<OpKernel>(kernel);
   1106 }
   1107 
   1108 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
   1109                       Allocator* allocator, FunctionLibraryRuntime* flib,
   1110                       const NodeDef& node_def, int graph_def_version,
   1111                       OpKernel** kernel) {
   1112   VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
   1113 
   1114   // Look up the Op registered for this op name.
   1115   const OpDef* op_def = nullptr;
   1116   Status s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
   1117   if (!s.ok()) return s;
   1118 
   1119   // Validate node_def against OpDef.
   1120   s = ValidateNodeDef(node_def, *op_def);
   1121   if (!s.ok()) return s;
   1122 
   1123   // Look up kernel registration.
   1124   const KernelRegistration* registration;
   1125   bool was_attr_mismatch;
   1126   s = FindKernelRegistration(device_type, node_def, &registration,
   1127                              &was_attr_mismatch);
   1128   if (!s.ok()) {
   1129     errors::AppendToMessage(&s, " when instantiating ", node_def.op());
   1130     return s;
   1131   }
   1132   if (registration == nullptr) {
   1133     s.Update(errors::NotFound("No registered '", node_def.op(),
   1134                               "' OpKernel for ", DeviceTypeString(device_type),
   1135                               " devices compatible with node ",
   1136                               SummarizeNodeDef(node_def)));
   1137     if (was_attr_mismatch) {
   1138       errors::AppendToMessage(
   1139           &s, " (OpKernel was found, but attributes didn't match)");
   1140     }
   1141     errors::AppendToMessage(
   1142         &s, ".  Registered:", KernelsRegisteredForOp(node_def.op()));
   1143     return s;
   1144   }
   1145 
   1146   // Get signature from the OpDef & NodeDef
   1147   DataTypeVector inputs;
   1148   DataTypeVector outputs;
   1149   s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
   1150   if (!s.ok()) {
   1151     errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def));
   1152     return s;
   1153   }
   1154 
   1155   // We are creating a kernel for an op registered in
   1156   // OpRegistry::Global(), we consult the kernel registry to decide
   1157   // the kernel's input and output memory types.
   1158   MemoryTypeVector input_memory_types;
   1159   MemoryTypeVector output_memory_types;
   1160   TF_RETURN_IF_ERROR(MemoryTypesForNode(OpRegistry::Global(), device_type,
   1161                                         node_def, &input_memory_types,
   1162                                         &output_memory_types));
   1163 
   1164   // Everything needed for OpKernel construction.
   1165   OpKernelConstruction context(
   1166       device_type, device, allocator, &node_def, op_def, flib, inputs,
   1167       input_memory_types, outputs, output_memory_types, graph_def_version, &s);
   1168   *kernel = (*registration->factory)(&context);
   1169   if (!s.ok()) {
   1170     delete *kernel;
   1171     *kernel = nullptr;
   1172   }
   1173   return s;
   1174 }
   1175 
   1176 namespace {
   1177 
   1178 bool FindArgInOp(StringPiece arg_name,
   1179                  const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
   1180   for (const auto& arg : args) {
   1181     if (arg_name == arg.name()) {
   1182       return true;
   1183     }
   1184   }
   1185   return false;
   1186 }
   1187 
   1188 }  // namespace
   1189 
   1190 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) {
   1191   for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
   1192     const KernelDef& kernel_def(key_registration.second.def);
   1193     const OpRegistrationData* op_reg_data;
   1194     const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data);
   1195     if (!status.ok()) {
   1196       // TODO(josh11b): Make this a hard error.
   1197       LOG(ERROR) << "OpKernel ('" << ProtoShortDebugString(kernel_def)
   1198                  << "') for unknown op: " << kernel_def.op();
   1199       continue;
   1200     }
   1201     const OpDef& op_def = op_reg_data->op_def;
   1202     for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
   1203       if (!FindArgInOp(host_memory_arg, op_def.input_arg()) &&
   1204           !FindArgInOp(host_memory_arg, op_def.output_arg())) {
   1205         return errors::InvalidArgument(
   1206             "HostMemory arg '", host_memory_arg,
   1207             "' not found in OpDef: ", SummarizeOpDef(op_def));
   1208       }
   1209     }
   1210   }
   1211   return Status::OK();
   1212 }
   1213 
   1214 template <>
   1215 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
   1216   return eigen_cpu_device();
   1217 }
   1218 
   1219 template <>
   1220 const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
   1221   return eigen_gpu_device();
   1222 }
   1223 
   1224 #ifdef TENSORFLOW_USE_SYCL
   1225 template <>
   1226 const Eigen::SyclDevice& OpKernelContext::eigen_device() const {
   1227   return eigen_sycl_device();
   1228 }
   1229 #endif
   1230 
   1231 void OpKernelConstruction::CtxFailure(const Status& s) {
   1232   VLOG(1) << s;
   1233   SetStatus(s);
   1234 }
   1235 
   1236 void OpKernelConstruction::CtxFailureWithWarning(const Status& s) {
   1237   LOG(WARNING) << s;
   1238   SetStatus(s);
   1239 }
   1240 
   1241 void OpKernelConstruction::CtxFailure(const char* file, int line,
   1242                                       const Status& s) {
   1243   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
   1244           << " : " << s;
   1245   SetStatus(s);
   1246 }
   1247 
   1248 void OpKernelConstruction::CtxFailureWithWarning(const char* file, int line,
   1249                                                  const Status& s) {
   1250   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
   1251                << " : " << s;
   1252   SetStatus(s);
   1253 }
   1254 
   1255 void OpKernelContext::CtxFailure(const Status& s) {
   1256   VLOG(1) << s;
   1257   SetStatus(s);
   1258 }
   1259 
   1260 void OpKernelContext::CtxFailureWithWarning(const Status& s) {
   1261   LOG(WARNING) << s;
   1262   SetStatus(s);
   1263 }
   1264 
   1265 void OpKernelContext::CtxFailure(const char* file, int line, const Status& s) {
   1266   VLOG(1) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
   1267           << " : " << s;
   1268   SetStatus(s);
   1269 }
   1270 
   1271 void OpKernelContext::CtxFailureWithWarning(const char* file, int line,
   1272                                             const Status& s) {
   1273   LOG(WARNING) << "OP_REQUIRES failed at " << io::Basename(file) << ":" << line
   1274                << " : " << s;
   1275   SetStatus(s);
   1276 }
   1277 
   1278 }  // namespace tensorflow
   1279