1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 17 18 #include <numeric> 19 20 #include "tensorflow/compiler/tf2xla/literal_util.h" 21 #include "tensorflow/compiler/tf2xla/shape_util.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 24 #include "tensorflow/compiler/tf2xla/xla_context.h" 25 #include "tensorflow/compiler/xla/client/xla_builder.h" 26 #include "tensorflow/compiler/xla/client/xla_computation.h" 27 #include "tensorflow/compiler/xla/status_macros.h" 28 #include "tensorflow/core/common_runtime/dma_helper.h" 29 30 namespace tensorflow { 31 32 XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context) 33 : context_(context) {} 34 35 bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { 36 return context_->ValidateInputsAreSameShape(op); 37 } 38 39 XlaContext* XlaOpKernelContext::xla_context() const { 40 return &XlaContext::Get(context_); 41 } 42 43 xla::XlaBuilder* XlaOpKernelContext::builder() const { 44 return xla_context()->builder(); 45 } 46 47 XlaCompiler* XlaOpKernelContext::compiler() const { 48 return xla_context()->compiler(); 49 } 50 51 // Retrieves an XlaExpression that was allocated by a previous Op. 52 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { 53 const XlaExpression* expression = 54 reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); 55 CHECK(expression->kind() != XlaExpression::Kind::kInvalid) 56 << expression->HumanString(); 57 return expression; 58 } 59 60 // Assigns an XlaExpression to a tensor on an XLA compilation device. 61 static void AssignExpressionToTensor(Tensor* tensor, 62 const XlaExpression& value) { 63 const XlaExpression* expression = 64 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); 65 CHECK(expression->kind() == XlaExpression::Kind::kInvalid) 66 << expression->HumanString(); 67 *const_cast<XlaExpression*>(expression) = value; 68 } 69 70 const XlaExpression& XlaOpKernelContext::InputExpression(int index) { 71 return *CastExpressionFromTensor(context_->input(index)); 72 } 73 74 const XlaExpression& XlaOpKernelContext::InputExpression( 75 absl::string_view name) { 76 return *CastExpressionFromTensor(GetInputTensorByName(name)); 77 } 78 79 xla::XlaOp XlaOpKernelContext::Input(int index) { 80 return InputExpression(index).AsXlaOp(builder()); 81 } 82 83 xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { 84 return InputExpression(name).AsXlaOp(builder()); 85 } 86 87 TensorShape XlaOpKernelContext::InputShape(int index) { 88 return context_->input(index).shape(); 89 } 90 91 TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { 92 return GetInputTensorByName(name).shape(); 93 } 94 95 DataType XlaOpKernelContext::input_type(int index) const { 96 return context_->input_dtype(index); 97 } 98 99 DataType XlaOpKernelContext::InputType(absl::string_view name) { 100 return GetInputTensorByName(name).dtype(); 101 } 102 103 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { 104 xla::PrimitiveType type; 105 Status status = DataTypeToPrimitiveType(input_type(index), &type); 106 if (!status.ok()) { 107 SetStatus(status); 108 return xla::PRIMITIVE_TYPE_INVALID; 109 } 110 return type; 111 } 112 113 Status XlaOpKernelContext::ConstantInput(int index, 114 xla::Literal* constant_literal) { 115 return ConstantInputReshaped( 116 index, context_->input(index).shape().dim_sizes(), constant_literal); 117 } 118 119 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context, 120 absl::string_view name) { 121 int start, stop; 122 TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); 123 if (stop != start + 1) { 124 return errors::InvalidArgument("OpKernel used list-valued input name '", 125 name, 126 "' when single-valued input was " 127 "expected"); 128 } 129 return start; 130 } 131 132 Status XlaOpKernelContext::ConstantInput(absl::string_view name, 133 xla::Literal* constant_literal) { 134 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); 135 return ConstantInput(index, constant_literal); 136 } 137 138 Status XlaOpKernelContext::ConstantInputReshaped( 139 int index, absl::Span<const int64> new_dims, 140 xla::Literal* constant_literal) { 141 XlaExpression e = InputExpression(index); 142 xla::StatusOr<absl::optional<Tensor>> constant_or_status = 143 e.ResolveConstant(compiler()->client()); 144 if (!constant_or_status.ok()) { 145 Status status = constant_or_status.status(); 146 errors::AppendToMessage(&status, "while evaluating input ", index, " of ", 147 context_->op_kernel().type_string(), 148 " operator as a compile-time constant."); 149 return status; 150 } 151 absl::optional<Tensor> constant = constant_or_status.ValueOrDie(); 152 if (!constant.has_value()) { 153 return errors::InvalidArgument( 154 "Input ", index, " to ", context_->op_kernel().type_string(), 155 " operator must be a compile-time constant.\n" 156 "\n" 157 "XLA compilation requires that operator arguments that represent " 158 "shapes or dimensions be evaluated to concrete values at compile time. " 159 "This error means that a shape or dimension argument could not be " 160 "evaluated at compile time, usually because the value of the argument " 161 "depends on a parameter to the computation, on a variable, or on a " 162 "stateful operation such as a random number generator."); 163 } 164 165 Tensor temp(constant->dtype()); 166 if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { 167 return errors::InvalidArgument( 168 context_->op_kernel().name(), " input ", index, " has shape ", 169 constant->shape().DebugString(), 170 " but was asked to be reshaped to incompatible shape ", 171 TensorShape(new_dims).DebugString()); 172 } 173 174 TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); 175 return Status::OK(); 176 } 177 178 // Converts an int32 or int64 scalar literal to an int64. 179 static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, 180 int64* out) { 181 if (literal.shape().rank() != 0) { 182 return errors::InvalidArgument("value is not a scalar"); 183 } 184 if (literal.shape().element_type() == xla::S32) { 185 *out = literal.Get<int32>({}); 186 } else if (literal.shape().element_type() == xla::S64) { 187 *out = literal.Get<int64>({}); 188 } else { 189 return errors::InvalidArgument("value must be either int32 or int64"); 190 } 191 return Status::OK(); 192 } 193 194 // Converts an float32 or float64 scalar literal to a float64. 195 static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, 196 double* out) { 197 if (literal.shape().rank() != 0) { 198 return errors::InvalidArgument("value is not a scalar"); 199 } 200 if (literal.shape().element_type() == xla::F32) { 201 *out = literal.Get<float>({}); 202 } else if (literal.shape().element_type() == xla::F64) { 203 *out = literal.Get<double>({}); 204 } else { 205 return errors::InvalidArgument("value must be either float32 or float64"); 206 } 207 return Status::OK(); 208 } 209 210 Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { 211 xla::Literal literal; 212 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 213 return LiteralToInt64Scalar(literal, out); 214 } 215 216 Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name, 217 int64* out) { 218 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); 219 return ConstantInputAsIntScalar(index, out); 220 } 221 222 Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { 223 xla::Literal literal; 224 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 225 return LiteralToFloat64Scalar(literal, out); 226 } 227 228 // Converts an int32 or int64 1D literal to an int64 vector. 229 static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, 230 std::vector<int64>* out) { 231 if (literal.shape().rank() != 1) { 232 return errors::InvalidArgument("value is not 1D, rank: ", 233 literal.shape().rank()); 234 } 235 int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); 236 if (literal.shape().element_type() == xla::S32) { 237 for (int64 i = 0; i < size; ++i) { 238 out->push_back(literal.Get<int32>({i})); 239 } 240 } else if (literal.shape().element_type() == xla::S64) { 241 for (int64 i = 0; i < size; ++i) { 242 out->push_back(literal.Get<int64>({i})); 243 } 244 } else { 245 return errors::InvalidArgument("value must be either int32 or int64"); 246 } 247 return Status::OK(); 248 } 249 250 Status XlaOpKernelContext::ConstantInputAsIntVector(int index, 251 std::vector<int64>* out) { 252 xla::Literal literal; 253 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 254 return LiteralToInt64Vector(literal, out); 255 } 256 257 Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name, 258 std::vector<int64>* out) { 259 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); 260 return ConstantInputAsIntVector(index, out); 261 } 262 263 Status XlaOpKernelContext::ConstantInputReshapedToIntVector( 264 int index, std::vector<int64>* out) { 265 xla::Literal literal; 266 TF_RETURN_IF_ERROR(ConstantInputReshaped( 267 index, {InputShape(index).num_elements()}, &literal)); 268 return LiteralToInt64Vector(literal, out); 269 } 270 271 Status XlaOpKernelContext::ConstantInputReshapedToIntVector( 272 absl::string_view name, std::vector<int64>* out) { 273 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); 274 xla::Literal literal; 275 TF_RETURN_IF_ERROR(ConstantInputReshaped( 276 index, {InputShape(index).num_elements()}, &literal)); 277 return LiteralToInt64Vector(literal, out); 278 } 279 280 Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, 281 xla::Literal* out) { 282 xla::Literal literal; 283 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 284 switch (literal.shape().element_type()) { 285 case xla::S32: { 286 *out = xla::Literal( 287 xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); 288 auto src_data = literal.data<int32>(); 289 for (int64 i = 0; i < src_data.size(); ++i) { 290 out->data<int64>()[i] = src_data[i]; 291 } 292 return Status::OK(); 293 } 294 case xla::S64: 295 *out = std::move(literal); 296 return Status::OK(); 297 298 default: 299 return errors::InvalidArgument( 300 "Invalid argument to ConstantInputAsInt64Literal: ", 301 xla::ShapeUtil::HumanString(literal.shape())); 302 } 303 } 304 305 Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name, 306 xla::Literal* out) { 307 TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); 308 return ConstantInputAsInt64Literal(index, out); 309 } 310 311 // TODO(phawkins): validate that the dimensions form a valid shape, fail 312 // gracefully if they do not. 313 Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { 314 xla::Literal literal; 315 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 316 std::vector<int64> dims; 317 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); 318 *shape = TensorShape(dims); 319 return Status::OK(); 320 } 321 322 Status XlaOpKernelContext::ConstantInputAsPartialShape( 323 int index, PartialTensorShape* shape) { 324 xla::Literal literal; 325 TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); 326 // If `literal` is a scalar it's value must be -1. 327 if (literal.shape().rank() == 0) { 328 int64 shape_val; 329 TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); 330 if (shape_val != -1) { 331 return errors::InvalidArgument( 332 "Cannot convert value to PartialTensorShape: ", shape_val); 333 } 334 *shape = PartialTensorShape(); // Shape with unknown rank. 335 return Status::OK(); 336 } 337 std::vector<int64> dims; 338 TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); 339 *shape = PartialTensorShape(dims); 340 return Status::OK(); 341 } 342 343 Status XlaOpKernelContext::InputList(absl::string_view name, 344 std::vector<xla::XlaOp>* handles, 345 std::vector<TensorShape>* shapes) { 346 OpInputList inputs; 347 TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); 348 handles->clear(); 349 shapes->clear(); 350 for (const Tensor& input : inputs) { 351 handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); 352 shapes->push_back(input.shape()); 353 } 354 return Status::OK(); 355 } 356 357 Status XlaOpKernelContext::ConstantInputList( 358 absl::string_view name, std::vector<xla::Literal>* outputs) { 359 int start, stop; 360 TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); 361 outputs->resize(stop - start); 362 for (int i = start; i < stop; ++i) { 363 TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i])); 364 } 365 return Status::OK(); 366 } 367 368 namespace { 369 370 Status ReadVariableInputTensor(const Tensor& tensor, DataType type, 371 const XlaOpKernelContext* ctx, 372 TensorShape* shape, xla::XlaOp* value) { 373 const XlaExpression* expression = CastExpressionFromTensor(tensor); 374 XlaResource* variable = expression->resource(); 375 TF_RET_CHECK(variable != nullptr); 376 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 377 if (!variable->initialized()) { 378 return errors::FailedPrecondition("Read of uninitialized variable ", 379 variable->name()); 380 } 381 if (variable->type() != type) { 382 return errors::InvalidArgument( 383 "Type mismatch for read of variable ", variable->name(), ". Expected ", 384 DataTypeString(type), "; got ", DataTypeString(variable->type())); 385 } 386 if (shape) { 387 *shape = variable->shape(); 388 } 389 390 TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, 391 ctx->compiler()->options().shape_representation_fn( 392 variable->shape(), variable->type())); 393 xla::Shape xla_shape; 394 TF_RETURN_IF_ERROR( 395 TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); 396 if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { 397 *value = variable->value(); 398 } else { 399 *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); 400 } 401 return Status::OK(); 402 } 403 404 } // namespace 405 406 Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, 407 TensorShape* shape, 408 xla::XlaOp* value) { 409 return ReadVariableInputTensor(context_->input(index), type, this, shape, 410 value); 411 } 412 413 Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, 414 DataType type, TensorShape* shape, 415 xla::XlaOp* value) { 416 return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape, 417 value); 418 } 419 420 Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, 421 TensorShape* shape) const { 422 const Tensor& tensor = context_->input(index); 423 const XlaExpression* expression = CastExpressionFromTensor(tensor); 424 XlaResource* variable = expression->resource(); 425 TF_RET_CHECK(variable != nullptr); 426 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 427 if (!variable->initialized()) { 428 return errors::InvalidArgument("Read of uninitialized variable ", 429 variable->name()); 430 } 431 *type = variable->type(); 432 *shape = variable->shape(); 433 return Status::OK(); 434 } 435 436 void XlaOpKernelContext::SetOutputExpression(int index, 437 const XlaExpression& expression) { 438 Status status = [&] { 439 // The step's default allocator is the dummy XlaCompilationAllocator which 440 // simply allocates a metadata buffer to hold the expression to which it 441 // corresponds. 442 Tensor* output = nullptr; 443 // Provides a special behavior for DT_VARIANT: a variant is treated as 444 // DT_UINT8 scalar as the type to allow mapping for variant to more generic 445 // types. 446 if (expression.dtype() == DT_VARIANT) { 447 // tensor_data() is not supported for variant Tensor (i.e., 448 // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the 449 // XlaExpression inside the Tensor's tensor_data() does not work for 450 // variant. Instead construct a uint8 tensor and store the expression in 451 // its value. 452 // TODO(jpienaar): This should be refactored to stop masquerading 453 // XlaExpressions as Tensors. 454 output = new Tensor(); 455 TensorShape tensor_shape; 456 TF_RETURN_IF_ERROR( 457 context_->allocate_temp(DT_UINT8, tensor_shape, output)); 458 context_->set_output(index, *output); 459 } else { 460 TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); 461 TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); 462 } 463 AssignExpressionToTensor(output, expression); 464 return Status::OK(); 465 }(); 466 if (!status.ok()) { 467 SetStatus(status); 468 } 469 } 470 471 xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { 472 xla::PrimitiveType type; 473 Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); 474 if (!status.ok()) { 475 SetStatus(status); 476 return xla::PRIMITIVE_TYPE_INVALID; 477 } 478 return type; 479 } 480 481 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { 482 SetOutputExpression( 483 index, 484 XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); 485 } 486 487 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { 488 SetOutputExpression(index, XlaExpression::Constant(constant)); 489 } 490 491 void XlaOpKernelContext::SetTensorListOutput(int index, 492 const xla::XlaOp& handle) { 493 SetOutputExpression(index, XlaExpression::TensorList(handle)); 494 } 495 496 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { 497 SetOutputExpression(index, XlaExpression::Resource(resource)); 498 } 499 500 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { 501 const XlaExpression* expression = 502 CastExpressionFromTensor(context_->input(index)); 503 TF_RET_CHECK(expression->resource() != nullptr); 504 *resource = expression->resource(); 505 return Status::OK(); 506 } 507 508 namespace { 509 510 Status AssignVariableTensor(const Tensor& tensor, DataType type, 511 const XlaOpKernelContext* ctx, xla::XlaOp handle, 512 xla::XlaBuilder* builder) { 513 const XlaExpression* expression = CastExpressionFromTensor(tensor); 514 XlaResource* variable = expression->resource(); 515 TF_RET_CHECK(variable != nullptr); 516 TF_RET_CHECK(variable->kind() == XlaResource::kVariable); 517 518 auto shape_or_status = builder->GetShape(handle); 519 if (!shape_or_status.ok()) { 520 return shape_or_status.status(); 521 } 522 TensorShape shape; 523 TF_RETURN_IF_ERROR( 524 XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); 525 526 TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); 527 528 TF_ASSIGN_OR_RETURN( 529 xla::Shape representation_shape, 530 ctx->compiler()->options().shape_representation_fn(shape, type)); 531 xla::Shape xla_shape; 532 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); 533 if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { 534 handle = xla::Reshape(handle, 535 xla::AsInt64Slice(representation_shape.dimensions())); 536 } 537 variable->SetRepresentationShape(representation_shape); 538 return variable->SetValue(handle); 539 } 540 541 } // namespace 542 543 Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, 544 xla::XlaOp handle) { 545 TF_RET_CHECK(handle.valid()); 546 return AssignVariableTensor(context_->input(input_index), type, this, handle, 547 builder()); 548 } 549 550 Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, 551 xla::XlaOp handle) { 552 TF_RET_CHECK(handle.valid()); 553 return AssignVariableTensor(GetInputTensorByName(name), type, this, handle, 554 builder()); 555 } 556 557 void XlaOpKernelContext::CtxFailure(const Status& s) { 558 context_->CtxFailure(s); 559 } 560 void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) { 561 context_->CtxFailureWithWarning(s); 562 } 563 void XlaOpKernelContext::CtxFailure(const char* file, int line, 564 const Status& s) { 565 context_->CtxFailure(file, line, s); 566 } 567 void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, 568 const Status& s) { 569 context_->CtxFailureWithWarning(file, line, s); 570 } 571 572 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( 573 const DataType type) { 574 return xla_context()->GetOrCreateMax(type); 575 } 576 577 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( 578 const DataType type) { 579 return xla_context()->GetOrCreateMin(type); 580 } 581 582 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( 583 const DataType type) { 584 return xla_context()->GetOrCreateAdd(type); 585 } 586 587 const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( 588 const DataType type) { 589 return xla_context()->GetOrCreateMul(type); 590 } 591 592 const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { 593 const Tensor* tensor; 594 CHECK(context_->input(name, &tensor).ok()); 595 return *tensor; 596 } 597 598 XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} 599 600 void XlaOpKernel::Compute(OpKernelContext* context) { 601 XlaOpKernelContext xla_context(context); 602 Compile(&xla_context); 603 } 604 605 } // namespace tensorflow 606