Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_context.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/tf2xla/literal_util.h"
     23 #include "tensorflow/compiler/tf2xla/shape_util.h"
     24 #include "tensorflow/compiler/tf2xla/type_util.h"
     25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     27 #include "tensorflow/compiler/xla/client/client_library.h"
     28 #include "tensorflow/compiler/xla/client/computation_builder.h"
     29 #include "tensorflow/compiler/xla/layout_util.h"
     30 #include "tensorflow/compiler/xla/literal_util.h"
     31 #include "tensorflow/compiler/xla/statusor.h"
     32 #include "tensorflow/core/common_runtime/dma_helper.h"
     33 #include "tensorflow/core/lib/gtl/array_slice.h"
     34 #include "tensorflow/core/lib/strings/strcat.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 
     37 namespace tensorflow {
     38 
     39 const char XlaContext::kXlaContextResourceName[] = "_xla_context";
     40 
     41 // Looks up the context associated with the current step. It is stored
     42 // in a resource container managed by the device.
     43 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
     44   // When an Op kernel wants to use an XLA JIT context, the
     45   // per-step context is looked up in the resource manager. The
     46   // JIT will prepopulate the JITContext.
     47   XlaContext* context;
     48   TF_CHECK_OK(ctx->resource_manager()->Lookup(
     49       ctx->step_container()->name(), kXlaContextResourceName, &context));
     50   // The resource manager handed us a fresh reference to 'context', but retains
     51   // a reference itself so the context won't be freed. The resource manager will
     52   // outlive the JIT compilation.
     53   context->Unref();
     54   return *context;
     55 }
     56 
     57 /* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) {
     58   return Get(ctx->op_kernel_context());
     59 }
     60 
     61 void XlaContext::set_args(std::vector<XlaExpression> args) {
     62   args_ = std::move(args);
     63 }
     64 
     65 XlaContext::XlaContext(
     66     XlaCompiler* compiler, xla::ComputationBuilder* builder,
     67     bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
     68     const std::function<TensorShape(const TensorShape&, DataType)>*
     69         variable_representation_shape_fn)
     70     : compiler_(compiler),
     71       builder_(builder),
     72       allow_cpu_custom_calls_(allow_cpu_custom_calls),
     73       resolve_compile_time_constants_(resolve_compile_time_constants),
     74       variable_representation_shape_fn_(variable_representation_shape_fn) {}
     75 
     76 string XlaContext::DebugString() { return "TLA JIT context"; }
     77 
     78 // This is called by the Retval Op to associate a computed value
     79 // with a specific return value of the subgraph.
     80 void XlaContext::AddRetval(int retval_index, DataType type,
     81                            const xla::ComputationDataHandle& handle) {
     82   VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
     83   // Add the return value to the list being built up.
     84   if (retvals_.size() <= retval_index) {
     85     retvals_.resize(retval_index + 1);
     86   }
     87   retvals_[retval_index].set_handle(handle);
     88 }
     89 
     90 Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
     91                                   const xla::Literal& literal) {
     92   VLOG(1) << "Adding retval index " << retval_index
     93           << " with non-data-dependent tensor to XLA computation";
     94   if (retvals_.size() <= retval_index) {
     95     retvals_.resize(retval_index + 1);
     96   }
     97   if (resolve_compile_time_constants_) {
     98     Tensor value;
     99     TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
    100     retvals_[retval_index].set_constant_value(std::move(value));
    101   } else {
    102     retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
    103   }
    104   return Status::OK();
    105 }
    106 
    107 xla::ComputationBuilder* XlaContext::builder() { return builder_; }
    108 
    109 Status XlaContext::CreateResource(
    110     XlaResource::Kind kind, int arg_num, string name, DataType type,
    111     TensorShape shape, const xla::ComputationDataHandle& handle,
    112     int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
    113     XlaResource** resource) {
    114   resources_.emplace_back(
    115       new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
    116                       handle, tensor_array_size, tensor_array_gradients));
    117   *resource = resources_.back().get();
    118   return Status::OK();
    119 }
    120 
    121 TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
    122                                                     DataType type) const {
    123   return (*variable_representation_shape_fn_)(shape, type);
    124 }
    125 
    126 const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
    127   return LookupOrCreate(type, &max_func_, [this, type] {
    128     const string type_string = DataTypeString(type);
    129     VLOG(1) << "Building Max() for " << type_string;
    130     xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">");
    131     xla::PrimitiveType xla_type;
    132     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
    133     auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
    134     auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
    135     b.Max(x, y);
    136     return b.Build().ConsumeValueOrDie();
    137   });
    138 }
    139 
    140 const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) {
    141   return LookupOrCreate(type, &min_func_, [this, type] {
    142     const string type_string = DataTypeString(type);
    143     VLOG(1) << "Building Min() for " << type_string;
    144     xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">");
    145     xla::PrimitiveType xla_type;
    146     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
    147     auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
    148     auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
    149     b.Min(x, y);
    150     return b.Build().ConsumeValueOrDie();
    151   });
    152 }
    153 
    154 const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
    155   return LookupOrCreate(type, &add_func_, [this, type] {
    156     const string type_string = DataTypeString(type);
    157     VLOG(1) << "Building Add() for " << type_string;
    158     xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">");
    159     xla::PrimitiveType xla_type;
    160     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
    161     auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
    162     auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
    163     b.Add(x, y);
    164     return b.Build().ConsumeValueOrDie();
    165   });
    166 }
    167 
    168 const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) {
    169   return LookupOrCreate(type, &mul_func_, [this, type] {
    170     const string type_string = DataTypeString(type);
    171     VLOG(1) << "Building Mul() for " << type_string;
    172     xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">");
    173     xla::PrimitiveType xla_type;
    174     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
    175     auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
    176     auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
    177     b.Mul(x, y);
    178     return b.Build().ConsumeValueOrDie();
    179   });
    180 }
    181 
    182 const xla::Computation* XlaContext::LookupOrCreate(
    183     DataType type, ComputationMap* out,
    184     const std::function<xla::Computation()>& create) {
    185   {
    186     const auto& entry = (*out)[type];
    187     if (!entry.IsNull()) {
    188       return &entry;
    189     }
    190   }
    191   auto new_entry = create();
    192   {
    193     // Somebody else might have made one concurrently.
    194     auto& entry = (*out)[type];
    195     if (entry.IsNull()) {
    196       entry = std::move(new_entry);
    197     }
    198     return &entry;
    199   }
    200 }
    201 
    202 }  // namespace tensorflow
    203