Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     17 
     18 #include <numeric>
     19 
     20 #include "tensorflow/compiler/tf2xla/literal_util.h"
     21 #include "tensorflow/compiler/tf2xla/shape_util.h"
     22 #include "tensorflow/compiler/tf2xla/xla_context.h"
     23 
     24 namespace tensorflow {
     25 
     26 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
     27     : context_(context) {}
     28 
     29 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
     30   return context_->ValidateInputsAreSameShape(op);
     31 }
     32 
     33 xla::ComputationBuilder* XlaOpKernelContext::builder() const {
     34   return XlaContext::Get(this).builder();
     35 }
     36 
     37 // Retrieves an XlaExpression that was allocated by a previous Op.
     38 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
     39   const XlaExpression* expression =
     40       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
     41   CHECK(expression->handle().handle() != 0 ||
     42         expression->resource() != nullptr);
     43   VLOG(1) << "Fetched T" << expression->handle().handle();
     44   return expression;
     45 }
     46 
     47 // Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
     48 static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
     49   const XlaExpression* expression =
     50       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
     51   CHECK_EQ(expression->handle().handle(), 0);
     52   return const_cast<XlaExpression*>(expression);
     53 }
     54 
     55 // Retrieves the ComputationDataHandle from an input Tensor to an Op. This
     56 // computation was constructed by an Op that executed previously and
     57 // created the output Tensor using CreateOutputTensorFromComputation
     58 // or CreateConstantOutputTensor.
     59 static const xla::ComputationDataHandle& GetComputationFromTensor(
     60     const Tensor& tensor) {
     61   return CastExpressionFromTensor(tensor)->handle();
     62 }
     63 
     64 const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
     65   return GetComputationFromTensor(context_->input(index));
     66 }
     67 
     68 TensorShape XlaOpKernelContext::InputShape(int index) {
     69   return context_->input(index).shape();
     70 }
     71 
     72 Status XlaOpKernelContext::ConstantInput(int index,
     73                                          xla::Literal* constant_literal) {
     74   return ConstantInputReshaped(
     75       index, context_->input(index).shape().dim_sizes(), constant_literal);
     76 }
     77 
     78 Status XlaOpKernelContext::ConstantInputReshaped(
     79     int index, gtl::ArraySlice<int64> new_dims,
     80     xla::Literal* constant_literal) {
     81   const Tensor& tensor = context_->input(index);
     82   TensorShape new_shape(new_dims);
     83   if (tensor.NumElements() != new_shape.num_elements()) {
     84     return errors::InvalidArgument(
     85         context_->op_kernel().name(), " input ", index, " has shape ",
     86         tensor.shape().DebugString(),
     87         " but was asked to be reshaped to incompatible shape ",
     88         new_shape.DebugString());
     89   }
     90   const XlaExpression* expression = CastExpressionFromTensor(tensor);
     91 
     92   // If the tensor has a known constant value, there is no need to invoke XLA.
     93   if (expression->has_constant_value()) {
     94     Tensor temp(tensor.dtype());
     95     if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
     96       // This should never happen. The constant should have a shape compatible
     97       // with the enclosing Tensor.
     98       return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
     99     }
    100     return HostTensorToLiteral(temp, constant_literal);
    101   }
    102 
    103   // Make sure we treat zero-element tensors as constant.
    104   if (new_shape.num_elements() == 0) {
    105     Tensor temp(tensor.dtype(), new_shape);
    106     return HostTensorToLiteral(temp, constant_literal);
    107   }
    108 
    109   xla::ComputationDataHandle handle = expression->handle();
    110   if (new_shape != tensor.shape()) {
    111     // Reshape the handle to the desired shape.
    112     handle = builder()->Reshape(handle, new_shape.dim_sizes());
    113   }
    114 
    115   // The XLA layout is specified minor to major, and TensorFlow's minor
    116   // dimension is the last one.
    117   std::vector<int64> layout_indices(new_shape.dims());
    118   std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
    119   xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
    120 
    121   xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
    122   if (!is_constant.ok()) {
    123     Status status = is_constant.status();
    124     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
    125                             context_->op_kernel().type_string(),
    126                             " operator as a compile-time constant.");
    127     return status;
    128   }
    129 
    130   if (!is_constant.ValueOrDie()) {
    131     return errors::InvalidArgument(
    132         "Input ", index, " to ", context_->op_kernel().type_string(),
    133         " operator must be a compile-time constant.\n"
    134         "\n"
    135         "XLA compilation requires that operator arguments that represent "
    136         "shapes or dimensions be evaluated to concrete values at compile time. "
    137         "This error means that a shape or dimension argument could not be "
    138         "evaluated at compile time, usually because the value of the argument "
    139         "depends on a parameter to the computation, on a variable, or on a "
    140         "stateful operation such as a random number generator.");
    141   }
    142 
    143   // Ask the XLA compiler to evaluate the data handle to a literal.
    144   xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
    145       builder()->ComputeConstant(handle, &layout);
    146   if (!computed.ok()) {
    147     return errors::Internal("Error evaluating ", context_->op_kernel().name(),
    148                             " input ", index,
    149                             "as a compile-time constant.\nError: ",
    150                             computed.status().error_message());
    151   }
    152   *constant_literal = std::move(*computed.ValueOrDie());
    153 
    154   return Status::OK();
    155 }
    156 
    157 // Converts an int32 or int64 scalar literal to an int64.
    158 static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
    159   if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
    160     return errors::InvalidArgument("value is not a scalar");
    161   }
    162   if (literal.shape().element_type() == xla::S32) {
    163     *out = literal.Get<int32>({});
    164   } else if (literal.shape().element_type() == xla::S64) {
    165     *out = literal.Get<int64>({});
    166   } else {
    167     return errors::InvalidArgument("value must be either int32 or int64");
    168   }
    169   return Status::OK();
    170 }
    171 
    172 // Converts an float32 or float64 scalar literal to a float64.
    173 static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) {
    174   if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
    175     return errors::InvalidArgument("value is not a scalar");
    176   }
    177   if (literal.shape().element_type() == xla::F32) {
    178     *out = literal.Get<float>({});
    179   } else if (literal.shape().element_type() == xla::F64) {
    180     *out = literal.Get<double>({});
    181   } else {
    182     return errors::InvalidArgument("value must be either float32 or float64");
    183   }
    184   return Status::OK();
    185 }
    186 
    187 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) {
    188   xla::Literal literal;
    189   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    190   return LiteralToInt64Scalar(literal, out);
    191 }
    192 
    193 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
    194   xla::Literal literal;
    195   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    196   return LiteralToFloat64Scalar(literal, out);
    197 }
    198 
    199 // Converts an int32 or int64 1D literal to an int64 vector.
    200 static Status LiteralToInt64Vector(const xla::Literal& literal,
    201                                    std::vector<int64>* out) {
    202   if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
    203     return errors::InvalidArgument("value is not 1D");
    204   }
    205   int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
    206   if (literal.shape().element_type() == xla::S32) {
    207     for (int64 i = 0; i < size; ++i) {
    208       out->push_back(literal.Get<int32>({i}));
    209     }
    210   } else if (literal.shape().element_type() == xla::S64) {
    211     for (int64 i = 0; i < size; ++i) {
    212       out->push_back(literal.Get<int64>({i}));
    213     }
    214   } else {
    215     return errors::InvalidArgument("value must be either int32 or int64");
    216   }
    217   return Status::OK();
    218 }
    219 
    220 Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
    221                                                     std::vector<int64>* out) {
    222   xla::Literal literal;
    223   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    224   return LiteralToInt64Vector(literal, out);
    225 }
    226 
    227 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
    228                                                        xla::Literal* out) {
    229   xla::Literal literal;
    230   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    231   switch (literal.shape().element_type()) {
    232     case xla::S32: {
    233       *out = xla::Literal(
    234           xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
    235       auto src_data = literal.data<int32>();
    236       for (int64 i = 0; i < src_data.size(); ++i) {
    237         out->data<int64>()[i] = src_data[i];
    238       }
    239       return Status::OK();
    240     }
    241     case xla::S64:
    242       *out = std::move(literal);
    243       return Status::OK();
    244 
    245     default:
    246       return errors::InvalidArgument(
    247           "Invalid argument to ConstantInputAsInt64Literal: ",
    248           xla::ShapeUtil::HumanString(literal.shape()));
    249   }
    250 }
    251 
    252 // TODO(phawkins): validate that the dimensions form a valid shape, fail
    253 // gracefully if they do not.
    254 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
    255   xla::Literal literal;
    256   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
    257   std::vector<int64> dims;
    258   TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
    259   *shape = TensorShape(dims);
    260   return Status::OK();
    261 }
    262 
    263 Status XlaOpKernelContext::InputList(
    264     StringPiece name, std::vector<xla::ComputationDataHandle>* handles,
    265     std::vector<TensorShape>* shapes) {
    266   OpInputList inputs;
    267   TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
    268   handles->clear();
    269   shapes->clear();
    270   for (const Tensor& input : inputs) {
    271     handles->push_back(GetComputationFromTensor(input));
    272     shapes->push_back(input.shape());
    273   }
    274   return Status::OK();
    275 }
    276 
    277 Status XlaOpKernelContext::ConstantInputList(
    278     StringPiece name, std::vector<xla::Literal>* outputs) {
    279   int start, stop;
    280   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
    281   outputs->resize(stop - start);
    282   for (int i = start; i < stop; ++i) {
    283     TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
    284   }
    285   return Status::OK();
    286 }
    287 
    288 Status XlaOpKernelContext::ReadVariableInput(
    289     int index, DataType type, TensorShape* shape,
    290     xla::ComputationDataHandle* value) {
    291   const Tensor& tensor = context_->input(index);
    292   const XlaExpression* expression = CastExpressionFromTensor(tensor);
    293   XlaResource* variable = expression->resource();
    294   TF_RET_CHECK(variable != nullptr);
    295   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    296   if (!variable->initialized()) {
    297     return errors::InvalidArgument("Read of uninitialized variable ",
    298                                    variable->name());
    299   }
    300   if (variable->type() != type) {
    301     return errors::InvalidArgument(
    302         "Type mismatch for read of variable ", variable->name(), ". Expected ",
    303         DataTypeString(type), "; got ", DataTypeString(variable->type()));
    304   }
    305   if (shape) {
    306     *shape = variable->shape();
    307   }
    308 
    309   XlaContext& xla_context = XlaContext::Get(context_);
    310   TensorShape representation_shape = xla_context.VariableRepresentationShape(
    311       variable->shape(), variable->type());
    312   if (representation_shape == variable->shape()) {
    313     *value = variable->value();
    314   } else {
    315     *value =
    316         builder()->Reshape(variable->value(), variable->shape().dim_sizes());
    317   }
    318   return Status::OK();
    319 }
    320 
    321 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
    322                                                    TensorShape* shape) const {
    323   const Tensor& tensor = context_->input(index);
    324   const XlaExpression* expression = CastExpressionFromTensor(tensor);
    325   XlaResource* variable = expression->resource();
    326   TF_RET_CHECK(variable != nullptr);
    327   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    328   if (!variable->initialized()) {
    329     return errors::InvalidArgument("Read of uninitialized variable ",
    330                                    variable->name());
    331   }
    332   *type = variable->type();
    333   *shape = variable->shape();
    334   return Status::OK();
    335 }
    336 
    337 void XlaOpKernelContext::SetOutput(int index,
    338                                    const xla::ComputationDataHandle& handle) {
    339   // Makes the host Tensor that will refer to the expression.
    340   Tensor* output = nullptr;
    341   auto shape = builder()->GetShape(handle);
    342   if (!shape.ok()) {
    343     SetStatus(shape.status());
    344     return;
    345   }
    346 
    347   // The step's default allocator is the dummy XlaCompilationAllocator which
    348   // simply allocates a metadata buffer to hold the expression to which it
    349   // corresponds.
    350   TensorShape tensor_shape;
    351   OP_REQUIRES_OK(context_,
    352                  XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape));
    353   OP_REQUIRES_OK(context_,
    354                  context_->allocate_output(index, tensor_shape, &output));
    355 
    356   // The expression is stored in the tensor's data buffer. Fill in the
    357   // fields now.
    358   XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
    359   expression->set_handle(handle);
    360 }
    361 
    362 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
    363   const TensorShape& shape = constant.shape();
    364 
    365   xla::Literal literal;
    366   OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
    367   xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal);
    368   CHECK_NE(handle.handle(), 0);
    369 
    370   // Make the Tensor that will refer to the expression.
    371   Tensor* output = nullptr;
    372   // The step's default allocator is the dummy XlaCompilationAllocator which
    373   // simply allocates a metadata buffer to hold the expression to which it
    374   // corresponds.
    375   OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
    376 
    377   // The expression is stored in the tensor's data buffer. Fill in the
    378   // fields now.
    379   XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
    380   expression->set_handle(handle);
    381   expression->set_constant_value(constant);
    382 }
    383 
    384 void XlaOpKernelContext::SetInvalidOutput(int index) {
    385   Tensor* output = nullptr;
    386   OP_REQUIRES_OK(context_,
    387                  context_->allocate_output(index, TensorShape({}), &output));
    388   XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
    389   xla::ComputationDataHandle handle;
    390   handle.set_handle(0);
    391   expression->set_handle(handle);
    392 }
    393 
    394 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
    395   Tensor* output = nullptr;
    396   // The shape of the output tensor is the shape of the resource itself
    397   // (i.e., a scalar), not the shape of the resource's value.
    398   OP_REQUIRES_OK(context_,
    399                  context_->allocate_output(index, TensorShape(), &output));
    400   XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
    401   expression->set_resource(resource);
    402 }
    403 
    404 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
    405   const XlaExpression* expression =
    406       CastExpressionFromTensor(context_->input(index));
    407   TF_RET_CHECK(expression->resource() != nullptr);
    408   *resource = expression->resource();
    409   return Status::OK();
    410 }
    411 
    412 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
    413                                           xla::ComputationDataHandle handle) {
    414   TF_RET_CHECK(handle.handle() != 0);
    415 
    416   const XlaExpression* expression =
    417       CastExpressionFromTensor(context_->input(input_index));
    418   XlaResource* variable = expression->resource();
    419   TF_RET_CHECK(variable != nullptr);
    420   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
    421 
    422   auto shape_or_status = builder()->GetShape(handle);
    423   if (!shape_or_status.ok()) {
    424     return shape_or_status.status();
    425   }
    426   TensorShape shape;
    427   TF_RETURN_IF_ERROR(
    428       XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
    429 
    430   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
    431 
    432   XlaContext& xla_context = XlaContext::Get(context_);
    433   TensorShape representation_shape =
    434       xla_context.VariableRepresentationShape(shape, type);
    435   if (shape != representation_shape) {
    436     handle = builder()->Reshape(handle, representation_shape.dim_sizes());
    437   }
    438   return variable->SetValue(handle);
    439 }
    440 
    441 XlaCompiler* XlaOpKernelContext::compiler() const {
    442   return XlaContext::Get(context_).compiler();
    443 }
    444 
    445 void XlaOpKernelContext::CtxFailure(const Status& s) {
    446   context_->CtxFailure(s);
    447 }
    448 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
    449   context_->CtxFailureWithWarning(s);
    450 }
    451 void XlaOpKernelContext::CtxFailure(const char* file, int line,
    452                                     const Status& s) {
    453   context_->CtxFailure(file, line, s);
    454 }
    455 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
    456                                                const Status& s) {
    457   context_->CtxFailureWithWarning(file, line, s);
    458 }
    459 
    460 const xla::Computation* XlaOpKernelContext::GetOrCreateMax(
    461     const DataType type) {
    462   return XlaContext::Get(context_).GetOrCreateMax(type);
    463 }
    464 
    465 const xla::Computation* XlaOpKernelContext::GetOrCreateMin(
    466     const DataType type) {
    467   return XlaContext::Get(context_).GetOrCreateMin(type);
    468 }
    469 
    470 const xla::Computation* XlaOpKernelContext::GetOrCreateAdd(
    471     const DataType type) {
    472   return XlaContext::Get(context_).GetOrCreateAdd(type);
    473 }
    474 
    475 const xla::Computation* XlaOpKernelContext::GetOrCreateMul(
    476     const DataType type) {
    477   return XlaContext::Get(context_).GetOrCreateMul(type);
    478 }
    479 
    480 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
    481 
    482 void XlaOpKernel::Compute(OpKernelContext* context) {
    483   XlaOpKernelContext xla_context(context);
    484   Compile(&xla_context);
    485 }
    486 
    487 }  // namespace tensorflow
    488