Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
     17 
     18 #include <functional>
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/tf2xla/shape_util.h"
     22 #include "tensorflow/compiler/tf2xla/sharding_util.h"
     23 #include "tensorflow/compiler/tf2xla/xla_context.h"
     24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     25 
     26 namespace tensorflow {
     27 
     28 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
     29                          TensorShape shape,
     30                          const xla::ComputationDataHandle& initial_value,
     31                          int64 tensor_array_size,
     32                          const std::set<string>& tensor_array_gradients)
     33     : kind_(kind),
     34       arg_num_(arg_num),
     35       name_(std::move(name)),
     36       type_(type),
     37       shape_(std::move(shape)),
     38       value_(initial_value),
     39       initial_value_(initial_value),
     40       tensor_array_size_(tensor_array_size) {
     41   CHECK(kind_ != kInvalid);
     42 
     43   for (const string& gradient : tensor_array_gradients) {
     44     tensor_array_gradients_[gradient].reset(
     45         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
     46                         /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
     47                         type_, shape_, xla::ComputationDataHandle(),
     48                         tensor_array_size_, /*tensor_array_gradients=*/{}));
     49   }
     50 }
     51 
     52 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
     53   if (type == DT_INVALID) {
     54     return errors::InvalidArgument("Attempted to set type of resource '", name_,
     55                                    "'' to an invalid type");
     56   }
     57   if (initialized() && type_ != type) {
     58     return errors::InvalidArgument("Type of resource ", name_,
     59                                    " cannot be changed after initialization: "
     60                                    "old type was ",
     61                                    DataTypeString(type_), ", new type is ",
     62                                    DataTypeString(type));
     63   }
     64   if (initialized() && shape_ != shape) {
     65     return errors::InvalidArgument("Shape of resource ", name_,
     66                                    " cannot be changed after initialization: "
     67                                    "old shape was ",
     68                                    shape_.DebugString(), ", new shape is ",
     69                                    shape.DebugString());
     70   }
     71   type_ = type;
     72   shape_ = shape;
     73   return Status::OK();
     74 }
     75 
     76 Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
     77   if (type_ == DT_INVALID) {
     78     return errors::InvalidArgument(
     79         "Resource '", name_,
     80         "' must be initialized with a valid type before use.");
     81   }
     82   value_ = value;
     83   return Status::OK();
     84 }
     85 
     86 Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
     87   if (type_ == DT_INVALID) {
     88     return errors::InvalidArgument(
     89         "Resource '", name_,
     90         "' must be initialized with a valid type before use.");
     91   }
     92   switch (kind_) {
     93     case kVariable: {
     94       value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
     95                                   shape_.dim_sizes());
     96       break;
     97     }
     98     case kTensorArray: {
     99       TensorShape ta_shape;
    100       ta_shape.AddDim(tensor_array_size_);
    101       ta_shape.AppendShape(shape_);
    102       value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
    103                                   ta_shape.dim_sizes());
    104       break;
    105     }
    106     case kStack: {
    107       TensorShape ta_shape;
    108       ta_shape.AddDim(tensor_array_size_);
    109       ta_shape.AppendShape(shape_);
    110       value_ =
    111           builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
    112                                              ta_shape.dim_sizes()),
    113                           builder->ConstantR0<int32>(0)});
    114       break;
    115     }
    116 
    117     case kInvalid:
    118     default:
    119       LOG(FATAL) << "Invalid resource type";
    120   }
    121   return Status::OK();
    122 }
    123 
    124 Status XlaResource::GetOrCreateTensorArrayGradient(
    125     const string& source, xla::ComputationBuilder* builder,
    126     XlaResource** gradient_out) {
    127   VLOG(2) << "Gradient lookup for resource: " << name_
    128           << " gradient: " << source;
    129   TF_RET_CHECK(kind_ == kTensorArray);
    130   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
    131   if (!gradient) {
    132     TensorShape ta_shape;
    133     ta_shape.AddDim(tensor_array_size_);
    134     ta_shape.AppendShape(shape_);
    135     xla::ComputationDataHandle gradient_value = builder->Broadcast(
    136         XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
    137     gradient.reset(
    138         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
    139                         /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
    140                         type_, shape_, gradient_value, tensor_array_size_,
    141                         /*tensor_array_gradients=*/{}));
    142   }
    143   *gradient_out = gradient.get();
    144   return Status::OK();
    145 }
    146 
    147 Status XlaResource::Pack(xla::ComputationDataHandle* pack,
    148                          xla::ComputationBuilder* builder) const {
    149   if (tensor_array_gradients_.empty()) {
    150     *pack = value_;
    151   } else {
    152     TF_RET_CHECK(kind_ == kTensorArray);
    153     std::vector<xla::ComputationDataHandle> elems;
    154     elems.push_back(value_);
    155     for (const auto& gradient : tensor_array_gradients_) {
    156       elems.push_back(gradient.second->value_);
    157     }
    158     *pack = builder->Tuple(elems);
    159   }
    160   return Status::OK();
    161 }
    162 
    163 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
    164                                 const xla::ComputationDataHandle& pack,
    165                                 xla::ComputationBuilder* builder) {
    166   if (gradient_sources.empty()) {
    167     if (!initialized()) {
    168       initial_value_ = pack;
    169     }
    170     value_ = pack;
    171   } else {
    172     TF_RET_CHECK(kind_ == kTensorArray);
    173     int pos = 0;
    174     auto v = builder->GetTupleElement(pack, pos++);
    175     if (!initialized()) {
    176       initial_value_ = v;
    177     }
    178     value_ = v;
    179 
    180     for (const auto& source : gradient_sources) {
    181       XlaResource* gradient;
    182       TF_RETURN_IF_ERROR(
    183           GetOrCreateTensorArrayGradient(source, builder, &gradient));
    184       auto v = builder->GetTupleElement(pack, pos++);
    185       if (!gradient->initialized()) {
    186         gradient->initial_value_ = v;
    187       }
    188       gradient->value_ = v;
    189     }
    190   }
    191   return Status::OK();
    192 }
    193 
    194 }  // namespace tensorflow
    195