1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 17 18 #include <numeric> 19 20 #include "tensorflow/compiler/tf2xla/literal_util.h" 21 #include "tensorflow/compiler/tf2xla/shape_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_context.h" 23 24 namespace tensorflow { 25 26 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context) 27 : context_(context) {} 28 29 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { 30 return context_->ValidateInputsAreSameShape(op); 31 } 32 33 xla::ComputationBuilder* XlaOpKernelContext::builder() const { 34 return XlaContext::Get(this).builder(); 35 } 36 37 // Retrieves an XlaExpression that was allocated by a previous Op. 38 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { 39 const XlaExpression* expression = 40 reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); 41 CHECK(expression->handle().handle() != 0 || 42 expression->resource() != nullptr); 43 VLOG(1) << "Fetched T" << expression->handle().handle(); 44 return expression; 45 } 46 47 // Retrieves an uninitialized XlaExpression from a newly-allocated tensor. 48 static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { 49 const XlaExpression* expression = 50 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); 51 CHECK_EQ(expression->handle().handle(), 0); 52 return const_cast<XlaExpression*>(expression); 53 } 54 55 // Retrieves the ComputationDataHandle from an input Tensor to an Op. This 56 // computation was constructed by an Op that executed previously and 57 // created the output Tensor using CreateOutputTensorFromComputation 58 // or CreateConstantOutputTensor. 59 static const xla::ComputationDataHandle& GetComputationFromTensor( 60 const Tensor& tensor) { 61 return CastExpressionFromTensor(tensor)->handle(); 62 } 63 64 const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { 65 return GetComputationFromTensor(context_->input(index)); 66 } 67 68 TensorShape XlaOpKernelContext::InputShape(int index) { 69 return context_->input(index).shape(); 70 } 71 72 Status XlaOpKernelContext::ConstantInput(int index, 73 xla::Literal* constant_literal) { 74 return ConstantInputReshaped( 75 index, context_->input(index).shape().dim_sizes(), constant_literal); 76 } 77 78 Status XlaOpKernelContext::ConstantInputReshaped( 79 int index, gtl::ArraySlice<int64> new_dims, 80 xla::Literal* constant_literal) { 81 const Tensor& tensor = context_->input(index); 82 TensorShape new_shape(new_dims); 83 if (tensor.NumElements() != new_shape.num_elements()) { 84 return errors::InvalidArgument( 85 context_->op_kernel().name(), " input ", index, " has shape ", 86 tensor.shape().DebugString(), 87 " but was asked to be reshaped to incompatible shape ", 88 new_shape.DebugString()); 89 } 90 const XlaExpression* expression = CastExpressionFromTensor(tensor); 91 92 // If the tensor has a known constant value, there is no need to invoke XLA. 93 if (expression->has_constant_value()) { 94 Tensor temp(tensor.dtype()); 95 if (!temp.CopyFrom(expression->constant_value(), new_shape)) { 96 // This should never happen. The constant should have a shape compatible 97 // with the enclosing Tensor. 98 return errors::Internal("Incompatible shapes in ConstantInputReshaped."); 99 } 100 return HostTensorToLiteral(temp, constant_literal); 101 } 102 103 // Make sure we treat zero-element tensors as constant. 104 if (new_shape.num_elements() == 0) { 105 Tensor temp(tensor.dtype(), new_shape); 106 return HostTensorToLiteral(temp, constant_literal); 107 } 108 109 xla::ComputationDataHandle handle = expression->handle(); 110 if (new_shape != tensor.shape()) { 111 // Reshape the handle to the desired shape. 112 handle = builder()->Reshape(handle, new_shape.dim_sizes()); 113 } 114 115 // The XLA layout is specified minor to major, and TensorFlow's minor 116 // dimension is the last one. 117 std::vector<int64> layout_indices(new_shape.dims()); 118 std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); 119 xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); 120 121 xla::StatusOr<bool> is_constant = builder()->IsConstant(handle); 122 if (!is_constant.ok()) { 123 Status status = is_constant.status(); 124 errors::AppendToMessage(&status, "while evaluating input ", index, " of ", 125 context_->op_kernel().type_string(), 126 " operator as a compile-time constant."); 127 return status; 128 } 129 130 if (!is_constant.ValueOrDie()) { 131 return errors::InvalidArgument( 132 "Input ", index, " to ", context_->op_kernel().type_string(), 133 " operator must be a compile-time constant.\n" 134 "\n" 135 "XLA compilation requires that operator arguments that represent " 136 "shapes or dimensions be evaluated to concrete values at compile time. " 137 "This error means that a shape or dimension argument could not be " 138 "evaluated at compile time, usually because the value of the argument " 139 "depends on a parameter to the computation, on a variable, or on a " 140 "stateful operation such as a random number generator."); 141 } 142 143 // Ask the XLA compiler to evaluate the data handle to a literal. 144 xla::StatusOr<std::unique_ptr<xla::Literal>> computed = 145 builder()->ComputeConstant(handle, &layout); 146 if (!computed.ok()) { 147 return errors::Internal("Error evaluating ", context_->op_kernel().name(), 148 " input ", index, 149 "as a compile-time constant.\nError: ", 150 computed.status().error_message()); 151 } 152 *constant_literal = std::move(*computed.ValueOrDie()); 153 154 return Status::OK(); 155 } 156 157 // Converts an int32 or int64 scalar literal to an int64. 158 static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { 159 if (xla::ShapeUtil::Rank(literal.shape()) != 0) { 160 return errors::InvalidArgument("value is not a scalar"); 161 } 162 if (literal.shape().element_type() == xla::S32) { 163 *out = literal.Get<int32>({}); 164 } else if (literal.shape().element_type() == xla::S64) { 165 *out = literal.Get<int64>({}); 166 } else { 167 return errors::InvalidArgument("value must be either int32 or int64"); 168 } 169 return Status::OK(); 170 } 171 172 // Converts an float32 or float64 scalar literal to a float64. 173 static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) { 174 if (xla::ShapeUtil::Rank(literal.shape()) != 0) { 175 return errors::InvalidArgument("value is not a scalar"); 176 } 177 if (literal.shape().element_type() == xla::F32) { 178 *out = literal.Get<float>({}); 179 } else if (literal.shape().element_type() == xla::F64) { 180 *out = literal.Get<double>({}); 181 } else { 182 return errors::InvalidArgument("value must be either float32 or float64"); 183 } 184 return Status::OK(); 185 } 186 187 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { 188 xla::Literal literal; 189 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 190 return LiteralToInt64Scalar(literal, out); 191 } 192 193 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { 194 xla::Literal literal; 195 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 196 return LiteralToFloat64Scalar(literal, out); 197 } 198 199 // Converts an int32 or int64 1D literal to an int64 vector. 200 static Status LiteralToInt64Vector(const xla::Literal& literal, 201 std::vector<int64>* out) { 202 if (xla::ShapeUtil::Rank(literal.shape()) != 1) { 203 return errors::InvalidArgument("value is not 1D"); 204 } 205 int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); 206 if (literal.shape().element_type() == xla::S32) { 207 for (int64 i = 0; i < size; ++i) { 208 out->push_back(literal.Get<int32>({i})); 209 } 210 } else if (literal.shape().element_type() == xla::S64) { 211 for (int64 i = 0; i < size; ++i) { 212 out->push_back(literal.Get<int64>({i})); 213 } 214 } else { 215 return errors::InvalidArgument("value must be either int32 or int64"); 216 } 217 return Status::OK(); 218 } 219 220 Status XlaOpKernelContext::ConstantInputAsIntVector(int index, 221 std::vector<int64>* out) { 222 xla::Literal literal; 223 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 224 return LiteralToInt64Vector(literal, out); 225 } 226 227 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, 228 xla::Literal* out) { 229 xla::Literal literal; 230 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 231 switch (literal.shape().element_type()) { 232 case xla::S32: { 233 *out = xla::Literal( 234 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); 235 auto src_data = literal.data<int32>(); 236 for (int64 i = 0; i < src_data.size(); ++i) { 237 out->data<int64>()[i] = src_data[i]; 238 } 239 return Status::OK(); 240 } 241 case xla::S64: 242 *out = std::move(literal); 243 return Status::OK(); 244 245 default: 246 return errors::InvalidArgument( 247 "Invalid argument to ConstantInputAsInt64Literal: ", 248 xla::ShapeUtil::HumanString(literal.shape())); 249 } 250 } 251 252 // TODO(phawkins): validate that the dimensions form a valid shape, fail 253 // gracefully if they do not. 254 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { 255 xla::Literal literal; 256 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 257 std::vector<int64> dims; 258 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); 259 *shape = TensorShape(dims); 260 return Status::OK(); 261 } 262 263 Status XlaOpKernelContext::InputList( 264 StringPiece name, std::vector<xla::ComputationDataHandle>* handles, 265 std::vector<TensorShape>* shapes) { 266 OpInputList inputs; 267 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); 268 handles->clear(); 269 shapes->clear(); 270 for (const Tensor& input : inputs) { 271 handles->push_back(GetComputationFromTensor(input)); 272 shapes->push_back(input.shape()); 273 } 274 return Status::OK(); 275 } 276 277 Status XlaOpKernelContext::ConstantInputList( 278 StringPiece name, std::vector<xla::Literal>* outputs) { 279 int start, stop; 280 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); 281 outputs->resize(stop - start); 282 for (int i = start; i < stop; ++i) { 283 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); 284 } 285 return Status::OK(); 286 } 287 288 Status XlaOpKernelContext::ReadVariableInput( 289 int index, DataType type, TensorShape* shape, 290 xla::ComputationDataHandle* value) { 291 const Tensor& tensor = context_->input(index); 292 const XlaExpression* expression = CastExpressionFromTensor(tensor); 293 XlaResource* variable = expression->resource(); 294 TF_RET_CHECK(variable != nullptr); 295 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 296 if (!variable->initialized()) { 297 return errors::InvalidArgument("Read of uninitialized variable ", 298 variable->name()); 299 } 300 if (variable->type() != type) { 301 return errors::InvalidArgument( 302 "Type mismatch for read of variable ", variable->name(), ". Expected ", 303 DataTypeString(type), "; got ", DataTypeString(variable->type())); 304 } 305 if (shape) { 306 *shape = variable->shape(); 307 } 308 309 XlaContext& xla_context = XlaContext::Get(context_); 310 TensorShape representation_shape = xla_context.VariableRepresentationShape( 311 variable->shape(), variable->type()); 312 if (representation_shape == variable->shape()) { 313 *value = variable->value(); 314 } else { 315 *value = 316 builder()->Reshape(variable->value(), variable->shape().dim_sizes()); 317 } 318 return Status::OK(); 319 } 320 321 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, 322 TensorShape* shape) const { 323 const Tensor& tensor = context_->input(index); 324 const XlaExpression* expression = CastExpressionFromTensor(tensor); 325 XlaResource* variable = expression->resource(); 326 TF_RET_CHECK(variable != nullptr); 327 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 328 if (!variable->initialized()) { 329 return errors::InvalidArgument("Read of uninitialized variable ", 330 variable->name()); 331 } 332 *type = variable->type(); 333 *shape = variable->shape(); 334 return Status::OK(); 335 } 336 337 void XlaOpKernelContext::SetOutput(int index, 338 const xla::ComputationDataHandle& handle) { 339 // Makes the host Tensor that will refer to the expression. 340 Tensor* output = nullptr; 341 auto shape = builder()->GetShape(handle); 342 if (!shape.ok()) { 343 SetStatus(shape.status()); 344 return; 345 } 346 347 // The step's default allocator is the dummy XlaCompilationAllocator which 348 // simply allocates a metadata buffer to hold the expression to which it 349 // corresponds. 350 TensorShape tensor_shape; 351 OP_REQUIRES_OK(context_, 352 XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); 353 OP_REQUIRES_OK(context_, 354 context_->allocate_output(index, tensor_shape, &output)); 355 356 // The expression is stored in the tensor's data buffer. Fill in the 357 // fields now. 358 XlaExpression* expression = CastExpressionFromUninitializedTensor(output); 359 expression->set_handle(handle); 360 } 361 362 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { 363 const TensorShape& shape = constant.shape(); 364 365 xla::Literal literal; 366 OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); 367 xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); 368 CHECK_NE(handle.handle(), 0); 369 370 // Make the Tensor that will refer to the expression. 371 Tensor* output = nullptr; 372 // The step's default allocator is the dummy XlaCompilationAllocator which 373 // simply allocates a metadata buffer to hold the expression to which it 374 // corresponds. 375 OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); 376 377 // The expression is stored in the tensor's data buffer. Fill in the 378 // fields now. 379 XlaExpression* expression = CastExpressionFromUninitializedTensor(output); 380 expression->set_handle(handle); 381 expression->set_constant_value(constant); 382 } 383 384 void XlaOpKernelContext::SetInvalidOutput(int index) { 385 Tensor* output = nullptr; 386 OP_REQUIRES_OK(context_, 387 context_->allocate_output(index, TensorShape({}), &output)); 388 XlaExpression* expression = CastExpressionFromUninitializedTensor(output); 389 xla::ComputationDataHandle handle; 390 handle.set_handle(0); 391 expression->set_handle(handle); 392 } 393 394 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { 395 Tensor* output = nullptr; 396 // The shape of the output tensor is the shape of the resource itself 397 // (i.e., a scalar), not the shape of the resource's value. 398 OP_REQUIRES_OK(context_, 399 context_->allocate_output(index, TensorShape(), &output)); 400 XlaExpression* expression = CastExpressionFromUninitializedTensor(output); 401 expression->set_resource(resource); 402 } 403 404 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { 405 const XlaExpression* expression = 406 CastExpressionFromTensor(context_->input(index)); 407 TF_RET_CHECK(expression->resource() != nullptr); 408 *resource = expression->resource(); 409 return Status::OK(); 410 } 411 412 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, 413 xla::ComputationDataHandle handle) { 414 TF_RET_CHECK(handle.handle() != 0); 415 416 const XlaExpression* expression = 417 CastExpressionFromTensor(context_->input(input_index)); 418 XlaResource* variable = expression->resource(); 419 TF_RET_CHECK(variable != nullptr); 420 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 421 422 auto shape_or_status = builder()->GetShape(handle); 423 if (!shape_or_status.ok()) { 424 return shape_or_status.status(); 425 } 426 TensorShape shape; 427 TF_RETURN_IF_ERROR( 428 XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); 429 430 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); 431 432 XlaContext& xla_context = XlaContext::Get(context_); 433 TensorShape representation_shape = 434 xla_context.VariableRepresentationShape(shape, type); 435 if (shape != representation_shape) { 436 handle = builder()->Reshape(handle, representation_shape.dim_sizes()); 437 } 438 return variable->SetValue(handle); 439 } 440 441 XlaCompiler* XlaOpKernelContext::compiler() const { 442 return XlaContext::Get(context_).compiler(); 443 } 444 445 void XlaOpKernelContext::CtxFailure(const Status& s) { 446 context_->CtxFailure(s); 447 } 448 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) { 449 context_->CtxFailureWithWarning(s); 450 } 451 void XlaOpKernelContext::CtxFailure(const char* file, int line, 452 const Status& s) { 453 context_->CtxFailure(file, line, s); 454 } 455 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, 456 const Status& s) { 457 context_->CtxFailureWithWarning(file, line, s); 458 } 459 460 const xla::Computation* XlaOpKernelContext::GetOrCreateMax( 461 const DataType type) { 462 return XlaContext::Get(context_).GetOrCreateMax(type); 463 } 464 465 const xla::Computation* XlaOpKernelContext::GetOrCreateMin( 466 const DataType type) { 467 return XlaContext::Get(context_).GetOrCreateMin(type); 468 } 469 470 const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( 471 const DataType type) { 472 return XlaContext::Get(context_).GetOrCreateAdd(type); 473 } 474 475 const xla::Computation* XlaOpKernelContext::GetOrCreateMul( 476 const DataType type) { 477 return XlaContext::Get(context_).GetOrCreateMul(type); 478 } 479 480 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} 481 482 void XlaOpKernel::Compute(OpKernelContext* context) { 483 XlaOpKernelContext xla_context(context); 484 Compile(&xla_context); 485 } 486 487 } // namespace tensorflow 488