Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     17 
     18 #include <numeric>
     19 
     20 #include "tensorflow/compiler/tf2xla/literal_util.h"
     21 #include "tensorflow/compiler/tf2xla/shape_util.h"
     22 #include "tensorflow/compiler/tf2xla/type_util.h"
     23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     24 #include "tensorflow/compiler/tf2xla/xla_context.h"
     25 #include "tensorflow/compiler/xla/client/xla_builder.h"
     26 #include "tensorflow/compiler/xla/client/xla_computation.h"
     27 #include "tensorflow/compiler/xla/status_macros.h"
     28 #include "tensorflow/core/common_runtime/dma_helper.h"
     29 
     30 namespace tensorflow {
     31 
     32 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
     33     : context_(context) {}
     34 
     35 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
     36   return context_->ValidateInputsAreSameShape(op);
     37 }
     38 
     39 XlaContext* XlaOpKernelContext::xla_context() const {
     40   return &XlaContext::Get(context_);
     41 }
     42 
     43 xla::XlaBuilder* XlaOpKernelContext::builder() const {
     44   return xla_context()->builder();
     45 }
     46 
     47 XlaCompiler* XlaOpKernelContext::compiler() const {
     48   return xla_context()->compiler();
     49 }
     50 
     51 // Retrieves an XlaExpression that was allocated by a previous Op.
     52 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
     53   const XlaExpression* expression =
     54       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
     55   CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
     56       << expression->HumanString();
     57   return expression;
     58 }
     59 
     60 // Assigns an XlaExpression to a tensor on an XLA compilation device.
     61 static void AssignExpressionToTensor(Tensor* tensor,
     62                                      const XlaExpression& value) {
     63   const XlaExpression* expression =
     64       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
     65   CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
     66       << expression->HumanString();
     67   *const_cast<XlaExpression*>(expression) = value;
     68 }
     69 
     70 const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
     71   return *CastExpressionFromTensor(context_->input(index));
     72 }
     73 
     74 const XlaExpression& XlaOpKernelContext::InputExpression(
     75     absl::string_view name) {
     76   return *CastExpressionFromTensor(GetInputTensorByName(name));
     77 }
     78 
     79 xla::XlaOp XlaOpKernelContext::Input(int index) {
     80   return InputExpression(index).AsXlaOp(builder());
     81 }
     82 
     83 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
     84   return InputExpression(name).AsXlaOp(builder());
     85 }
     86 
     87 TensorShape XlaOpKernelContext::InputShape(int index) {
     88   return context_->input(index).shape();
     89 }
     90 
     91 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
     92   return GetInputTensorByName(name).shape();
     93 }
     94 
     95 DataType XlaOpKernelContext::input_type(int index) const {
     96   return context_->input_dtype(index);
     97 }
     98 
     99 DataType XlaOpKernelContext::InputType(absl::string_view name) {
    100   return GetInputTensorByName(name).dtype();
    101 }
    102 
    103 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
    104   xla::PrimitiveType type;
    105   Status status = DataTypeToPrimitiveType(input_type(index), &type);
    106   if (!status.ok()) {
    107     SetStatus(status);
    108     return xla::PRIMITIVE_TYPE_INVALID;
    109   }
    110   return type;
    111 }
    112 
    113 Status XlaOpKernelContext::ConstantInput(int index,
    114                                          xla::Literal* constant_literal) {
    115   return ConstantInputReshaped(
    116       index, context_->input(index).shape().dim_sizes(), constant_literal);
    117 }
    118 
    119 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
    120                                      absl::string_view name) {
    121   int start, stop;
    122   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
    123   if (stop != start + 1) {
    124     return errors::InvalidArgument("OpKernel used list-valued input name '",
    125                                    name,
    126                                    "' when single-valued input was "
    127                                    "expected");
    128   }
    129   return start;
    130 }
    131 
    132 Status XlaOpKernelContext::ConstantInput(absl::string_view name,
    133                                          xla::Literal* constant_literal) {
    134   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
    135   return ConstantInput(index, constant_literal);
    136 }
    137 
    138 Status XlaOpKernelContext::ConstantInputReshaped(
    139     int index, absl::Span<const int64> new_dims,
    140     xla::Literal* constant_literal) {
    141   XlaExpression e = InputExpression(index);
    142   xla::StatusOr<absl::optional<Tensor>> constant_or_status =
    143       e.ResolveConstant(compiler()->client());
    144   if (!constant_or_status.ok()) {
    145     Status status = constant_or_status.status();
    146     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
    147                             context_->op_kernel().type_string(),
    148                             " operator as a compile-time constant.");
    149     return status;
    150   }
    151   absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
    152   if (!constant.has_value()) {
    153     return errors::InvalidArgument(
    154         "Input ", index, " to ", context_->op_kernel().type_string(),
    155         " operator must be a compile-time constant.\n"
    156         "\n"
    157         "XLA compilation requires that operator arguments that represent "
    158         "shapes or dimensions be evaluated to concrete values at compile time. "
    159         "This error means that a shape or dimension argument could not be "
    160         "evaluated at compile time, usually because the value of the argument "
    161         "depends on a parameter to the computation, on a variable, or on a "
    162         "stateful operation such as a random number generator.");
    163   }
    164 
    165   Tensor temp(constant->dtype());
    166   if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
    167     return errors::InvalidArgument(
    168         context_->op_kernel().name(), " input ", index, " has shape ",
    169         constant->shape().DebugString(),
    170         " but was asked to be reshaped to incompatible shape ",
    171         TensorShape(new_dims).DebugString());
    172   }
    173 
    174   TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
    175   return Status::OK();
    176 }
    177 
    178 // Converts an int32 or int64 scalar literal to an int64.
    179 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
    180                                    int64* out) {
    181   if (literal.shape().rank() != 0) {
    182     return errors::InvalidArgument("value is not a scalar");
    183   }
    184   if (literal.shape().element_type() == xla::S32) {
    185     *out = literal.Get<int32>({});
    186   } else if (literal.shape().element_type() == xla::S64) {
    187     *out = literal.Get<int64>({});
    188   } else {
    189     return errors::InvalidArgument("value must be either int32 or int64");
    190   }
    191   return Status::OK();
    192 }
    193 
    194 // Converts an float32 or float64 scalar literal to a float64.
    195 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
    196                                      double* out) {
    197   if (literal.shape().rank() != 0) {
    198     return errors::InvalidArgument("value is not a scalar");
    199   }
    200   if (literal.shape().element_type() == xla::F32) {
    201     *out = literal.Get<float>({});
    202   } else if (literal.shape().element_type() == xla::F64) {
    203     *out = literal.Get<double>({});
    204   } else {
    205     return errors::InvalidArgument("value must be either float32 or float64");
    206   }
    207   return Status::OK();
    208 }
    209 
    210 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
    211   xla::Literal literal;
    212   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    213   return LiteralToInt64Scalar(literal, out);
    214 }
    215 
    216 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
    217                                                     int64* out) {
    218   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
    219   return ConstantInputAsIntScalar(index, out);
    220 }
    221 
    222 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
    223   xla::Literal literal;
    224   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    225   return LiteralToFloat64Scalar(literal, out);
    226 }
    227 
    228 // Converts an int32 or int64 1D literal to an int64 vector.
    229 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
    230                                    std::vector<int64>* out) {
    231   if (literal.shape().rank() != 1) {
    232     return errors::InvalidArgument("value is not 1D, rank: ",
    233                                    literal.shape().rank());
    234   }
    235   int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
    236   if (literal.shape().element_type() == xla::S32) {
    237     for (int64 i = 0; i < size; ++i) {
    238       out->push_back(literal.Get<int32>({i}));
    239     }
    240   } else if (literal.shape().element_type() == xla::S64) {
    241     for (int64 i = 0; i < size; ++i) {
    242       out->push_back(literal.Get<int64>({i}));
    243     }
    244   } else {
    245     return errors::InvalidArgument("value must be either int32 or int64");
    246   }
    247   return Status::OK();
    248 }
    249 
    250 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
    251                                                     std::vector<int64>* out) {
    252   xla::Literal literal;
    253   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    254   return LiteralToInt64Vector(literal, out);
    255 }
    256 
    257 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
    258                                                     std::vector<int64>* out) {
    259   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
    260   return ConstantInputAsIntVector(index, out);
    261 }
    262 
    263 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
    264     int index, std::vector<int64>* out) {
    265   xla::Literal literal;
    266   TF_RETURN_IF_ERROR(ConstantInputReshaped(
    267       index, {InputShape(index).num_elements()}, &literal));
    268   return LiteralToInt64Vector(literal, out);
    269 }
    270 
    271 Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
    272     absl::string_view name, std::vector<int64>* out) {
    273   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
    274   xla::Literal literal;
    275   TF_RETURN_IF_ERROR(ConstantInputReshaped(
    276       index, {InputShape(index).num_elements()}, &literal));
    277   return LiteralToInt64Vector(literal, out);
    278 }
    279 
    280 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
    281                                                        xla::Literal* out) {
    282   xla::Literal literal;
    283   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    284   switch (literal.shape().element_type()) {
    285     case xla::S32: {
    286       *out = xla::Literal(
    287           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
    288       auto src_data = literal.data<int32>();
    289       for (int64 i = 0; i < src_data.size(); ++i) {
    290         out->data<int64>()[i] = src_data[i];
    291       }
    292       return Status::OK();
    293     }
    294     case xla::S64:
    295       *out = std::move(literal);
    296       return Status::OK();
    297 
    298     default:
    299       return errors::InvalidArgument(
    300           "Invalid argument to ConstantInputAsInt64Literal: ",
    301           xla::ShapeUtil::HumanString(literal.shape()));
    302   }
    303 }
    304 
    305 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
    306                                                        xla::Literal* out) {
    307   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
    308   return ConstantInputAsInt64Literal(index, out);
    309 }
    310 
    311 // TODO(phawkins): validate that the dimensions form a valid shape, fail
    312 // gracefully if they do not.
    313 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
    314   xla::Literal literal;
    315   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    316   std::vector<int64> dims;
    317   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
    318   *shape = TensorShape(dims);
    319   return Status::OK();
    320 }
    321 
    322 Status XlaOpKernelContext::ConstantInputAsPartialShape(
    323     int index, PartialTensorShape* shape) {
    324   xla::Literal literal;
    325   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    326   // If `literal` is a scalar it's value must be -1.
    327   if (literal.shape().rank() == 0) {
    328     int64 shape_val;
    329     TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
    330     if (shape_val != -1) {
    331       return errors::InvalidArgument(
    332           "Cannot convert value to PartialTensorShape: ", shape_val);
    333     }
    334     *shape = PartialTensorShape();  // Shape with unknown rank.
    335     return Status::OK();
    336   }
    337   std::vector<int64> dims;
    338   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
    339   *shape = PartialTensorShape(dims);
    340   return Status::OK();
    341 }
    342 
    343 Status XlaOpKernelContext::InputList(absl::string_view name,
    344                                      std::vector<xla::XlaOp>* handles,
    345                                      std::vector<TensorShape>* shapes) {
    346   OpInputList inputs;
    347   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
    348   handles->clear();
    349   shapes->clear();
    350   for (const Tensor& input : inputs) {
    351     handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
    352     shapes->push_back(input.shape());
    353   }
    354   return Status::OK();
    355 }
    356 
    357 Status XlaOpKernelContext::ConstantInputList(
    358     absl::string_view name, std::vector<xla::Literal>* outputs) {
    359   int start, stop;
    360   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
    361   outputs->resize(stop - start);
    362   for (int i = start; i < stop; ++i) {
    363     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
    364   }
    365   return Status::OK();
    366 }
    367 
    368 namespace {
    369 
    370 Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
    371                                const XlaOpKernelContext* ctx,
    372                                TensorShape* shape, xla::XlaOp* value) {
    373   const XlaExpression* expression = CastExpressionFromTensor(tensor);
    374   XlaResource* variable = expression->resource();
    375   TF_RET_CHECK(variable != nullptr);
    376   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    377   if (!variable->initialized()) {
    378     return errors::FailedPrecondition("Read of uninitialized variable ",
    379                                       variable->name());
    380   }
    381   if (variable->type() != type) {
    382     return errors::InvalidArgument(
    383         "Type mismatch for read of variable ", variable->name(), ". Expected ",
    384         DataTypeString(type), "; got ", DataTypeString(variable->type()));
    385   }
    386   if (shape) {
    387     *shape = variable->shape();
    388   }
    389 
    390   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
    391                       ctx->compiler()->options().shape_representation_fn(
    392                           variable->shape(), variable->type()));
    393   xla::Shape xla_shape;
    394   TF_RETURN_IF_ERROR(
    395       TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
    396   if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
    397     *value = variable->value();
    398   } else {
    399     *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
    400   }
    401   return Status::OK();
    402 }
    403 
    404 }  // namespace
    405 
    406 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
    407                                              TensorShape* shape,
    408                                              xla::XlaOp* value) {
    409   return ReadVariableInputTensor(context_->input(index), type, this, shape,
    410                                  value);
    411 }
    412 
    413 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
    414                                              DataType type, TensorShape* shape,
    415                                              xla::XlaOp* value) {
    416   return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
    417                                  value);
    418 }
    419 
    420 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
    421                                                    TensorShape* shape) const {
    422   const Tensor& tensor = context_->input(index);
    423   const XlaExpression* expression = CastExpressionFromTensor(tensor);
    424   XlaResource* variable = expression->resource();
    425   TF_RET_CHECK(variable != nullptr);
    426   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    427   if (!variable->initialized()) {
    428     return errors::InvalidArgument("Read of uninitialized variable ",
    429                                    variable->name());
    430   }
    431   *type = variable->type();
    432   *shape = variable->shape();
    433   return Status::OK();
    434 }
    435 
    436 void XlaOpKernelContext::SetOutputExpression(int index,
    437                                              const XlaExpression& expression) {
    438   Status status = [&] {
    439     // The step's default allocator is the dummy XlaCompilationAllocator which
    440     // simply allocates a metadata buffer to hold the expression to which it
    441     // corresponds.
    442     Tensor* output = nullptr;
    443     // Provides a special behavior for DT_VARIANT: a variant is treated as
    444     // DT_UINT8 scalar as the type to allow mapping for variant to more generic
    445     // types.
    446     if (expression.dtype() == DT_VARIANT) {
    447       // tensor_data() is not supported for variant Tensor (i.e.,
    448       // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
    449       // XlaExpression inside the Tensor's tensor_data() does not work for
    450       // variant. Instead construct a uint8 tensor and store the expression in
    451       // its value.
    452       // TODO(jpienaar): This should be refactored to stop masquerading
    453       // XlaExpressions as Tensors.
    454       output = new Tensor();
    455       TensorShape tensor_shape;
    456       TF_RETURN_IF_ERROR(
    457           context_->allocate_temp(DT_UINT8, tensor_shape, output));
    458       context_->set_output(index, *output);
    459     } else {
    460       TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
    461       TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
    462     }
    463     AssignExpressionToTensor(output, expression);
    464     return Status::OK();
    465   }();
    466   if (!status.ok()) {
    467     SetStatus(status);
    468   }
    469 }
    470 
    471 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
    472   xla::PrimitiveType type;
    473   Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
    474   if (!status.ok()) {
    475     SetStatus(status);
    476     return xla::PRIMITIVE_TYPE_INVALID;
    477   }
    478   return type;
    479 }
    480 
    481 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
    482   SetOutputExpression(
    483       index,
    484       XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
    485 }
    486 
    487 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
    488   SetOutputExpression(index, XlaExpression::Constant(constant));
    489 }
    490 
    491 void XlaOpKernelContext::SetTensorListOutput(int index,
    492                                              const xla::XlaOp& handle) {
    493   SetOutputExpression(index, XlaExpression::TensorList(handle));
    494 }
    495 
    496 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
    497   SetOutputExpression(index, XlaExpression::Resource(resource));
    498 }
    499 
    500 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
    501   const XlaExpression* expression =
    502       CastExpressionFromTensor(context_->input(index));
    503   TF_RET_CHECK(expression->resource() != nullptr);
    504   *resource = expression->resource();
    505   return Status::OK();
    506 }
    507 
    508 namespace {
    509 
    510 Status AssignVariableTensor(const Tensor& tensor, DataType type,
    511                             const XlaOpKernelContext* ctx, xla::XlaOp handle,
    512                             xla::XlaBuilder* builder) {
    513   const XlaExpression* expression = CastExpressionFromTensor(tensor);
    514   XlaResource* variable = expression->resource();
    515   TF_RET_CHECK(variable != nullptr);
    516   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    517 
    518   auto shape_or_status = builder->GetShape(handle);
    519   if (!shape_or_status.ok()) {
    520     return shape_or_status.status();
    521   }
    522   TensorShape shape;
    523   TF_RETURN_IF_ERROR(
    524       XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
    525 
    526   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
    527 
    528   TF_ASSIGN_OR_RETURN(
    529       xla::Shape representation_shape,
    530       ctx->compiler()->options().shape_representation_fn(shape, type));
    531   xla::Shape xla_shape;
    532   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
    533   if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
    534     handle = xla::Reshape(handle,
    535                           xla::AsInt64Slice(representation_shape.dimensions()));
    536   }
    537   variable->SetRepresentationShape(representation_shape);
    538   return variable->SetValue(handle);
    539 }
    540 
    541 }  // namespace
    542 
    543 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
    544                                           xla::XlaOp handle) {
    545   TF_RET_CHECK(handle.valid());
    546   return AssignVariableTensor(context_->input(input_index), type, this, handle,
    547                               builder());
    548 }
    549 
    550 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
    551                                           xla::XlaOp handle) {
    552   TF_RET_CHECK(handle.valid());
    553   return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
    554                               builder());
    555 }
    556 
    557 void XlaOpKernelContext::CtxFailure(const Status& s) {
    558   context_->CtxFailure(s);
    559 }
    560 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
    561   context_->CtxFailureWithWarning(s);
    562 }
    563 void XlaOpKernelContext::CtxFailure(const char* file, int line,
    564                                     const Status& s) {
    565   context_->CtxFailure(file, line, s);
    566 }
    567 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
    568                                                const Status& s) {
    569   context_->CtxFailureWithWarning(file, line, s);
    570 }
    571 
    572 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
    573     const DataType type) {
    574   return xla_context()->GetOrCreateMax(type);
    575 }
    576 
    577 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
    578     const DataType type) {
    579   return xla_context()->GetOrCreateMin(type);
    580 }
    581 
    582 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
    583     const DataType type) {
    584   return xla_context()->GetOrCreateAdd(type);
    585 }
    586 
    587 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
    588     const DataType type) {
    589   return xla_context()->GetOrCreateMul(type);
    590 }
    591 
    592 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
    593   const Tensor* tensor;
    594   CHECK(context_->input(name, &tensor).ok());
    595   return *tensor;
    596 }
    597 
    598 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
    599 
    600 void XlaOpKernel::Compute(OpKernelContext* context) {
    601   XlaOpKernelContext xla_context(context);
    602   Compile(&xla_context);
    603 }
    604 
    605 }  // namespace tensorflow
    606