1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_resource.h" 17 18 #include <functional> 19 #include <memory> 20 21 #include "tensorflow/compiler/tf2xla/shape_util.h" 22 #include "tensorflow/compiler/tf2xla/sharding_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_context.h" 24 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 25 26 namespace tensorflow { 27 28 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, 29 TensorShape shape, 30 const xla::ComputationDataHandle& initial_value, 31 int64 tensor_array_size, 32 const std::set<string>& tensor_array_gradients) 33 : kind_(kind), 34 arg_num_(arg_num), 35 name_(std::move(name)), 36 type_(type), 37 shape_(std::move(shape)), 38 value_(initial_value), 39 initial_value_(initial_value), 40 tensor_array_size_(tensor_array_size) { 41 CHECK(kind_ != kInvalid); 42 43 for (const string& gradient : tensor_array_gradients) { 44 tensor_array_gradients_[gradient].reset( 45 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, 46 /*name=*/strings::StrCat("TensorArrayGrad: ", name_), 47 type_, shape_, xla::ComputationDataHandle(), 48 tensor_array_size_, /*tensor_array_gradients=*/{})); 49 } 50 } 51 52 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { 53 if (type == DT_INVALID) { 54 return errors::InvalidArgument("Attempted to set type of resource '", name_, 55 "'' to an invalid type"); 56 } 57 if (initialized() && type_ != type) { 58 return errors::InvalidArgument("Type of resource ", name_, 59 " cannot be changed after initialization: " 60 "old type was ", 61 DataTypeString(type_), ", new type is ", 62 DataTypeString(type)); 63 } 64 if (initialized() && shape_ != shape) { 65 return errors::InvalidArgument("Shape of resource ", name_, 66 " cannot be changed after initialization: " 67 "old shape was ", 68 shape_.DebugString(), ", new shape is ", 69 shape.DebugString()); 70 } 71 type_ = type; 72 shape_ = shape; 73 return Status::OK(); 74 } 75 76 Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { 77 if (type_ == DT_INVALID) { 78 return errors::InvalidArgument( 79 "Resource '", name_, 80 "' must be initialized with a valid type before use."); 81 } 82 value_ = value; 83 return Status::OK(); 84 } 85 86 Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { 87 if (type_ == DT_INVALID) { 88 return errors::InvalidArgument( 89 "Resource '", name_, 90 "' must be initialized with a valid type before use."); 91 } 92 switch (kind_) { 93 case kVariable: { 94 value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), 95 shape_.dim_sizes()); 96 break; 97 } 98 case kTensorArray: { 99 TensorShape ta_shape; 100 ta_shape.AddDim(tensor_array_size_); 101 ta_shape.AppendShape(shape_); 102 value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), 103 ta_shape.dim_sizes()); 104 break; 105 } 106 case kStack: { 107 TensorShape ta_shape; 108 ta_shape.AddDim(tensor_array_size_); 109 ta_shape.AppendShape(shape_); 110 value_ = 111 builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), 112 ta_shape.dim_sizes()), 113 builder->ConstantR0<int32>(0)}); 114 break; 115 } 116 117 case kInvalid: 118 default: 119 LOG(FATAL) << "Invalid resource type"; 120 } 121 return Status::OK(); 122 } 123 124 Status XlaResource::GetOrCreateTensorArrayGradient( 125 const string& source, xla::ComputationBuilder* builder, 126 XlaResource** gradient_out) { 127 VLOG(2) << "Gradient lookup for resource: " << name_ 128 << " gradient: " << source; 129 TF_RET_CHECK(kind_ == kTensorArray); 130 std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source]; 131 if (!gradient) { 132 TensorShape ta_shape; 133 ta_shape.AddDim(tensor_array_size_); 134 ta_shape.AppendShape(shape_); 135 xla::ComputationDataHandle gradient_value = builder->Broadcast( 136 XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); 137 gradient.reset( 138 new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, 139 /*name=*/strings::StrCat("TensorArrayGrad: ", name_), 140 type_, shape_, gradient_value, tensor_array_size_, 141 /*tensor_array_gradients=*/{})); 142 } 143 *gradient_out = gradient.get(); 144 return Status::OK(); 145 } 146 147 Status XlaResource::Pack(xla::ComputationDataHandle* pack, 148 xla::ComputationBuilder* builder) const { 149 if (tensor_array_gradients_.empty()) { 150 *pack = value_; 151 } else { 152 TF_RET_CHECK(kind_ == kTensorArray); 153 std::vector<xla::ComputationDataHandle> elems; 154 elems.push_back(value_); 155 for (const auto& gradient : tensor_array_gradients_) { 156 elems.push_back(gradient.second->value_); 157 } 158 *pack = builder->Tuple(elems); 159 } 160 return Status::OK(); 161 } 162 163 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources, 164 const xla::ComputationDataHandle& pack, 165 xla::ComputationBuilder* builder) { 166 if (gradient_sources.empty()) { 167 if (!initialized()) { 168 initial_value_ = pack; 169 } 170 value_ = pack; 171 } else { 172 TF_RET_CHECK(kind_ == kTensorArray); 173 int pos = 0; 174 auto v = builder->GetTupleElement(pack, pos++); 175 if (!initialized()) { 176 initial_value_ = v; 177 } 178 value_ = v; 179 180 for (const auto& source : gradient_sources) { 181 XlaResource* gradient; 182 TF_RETURN_IF_ERROR( 183 GetOrCreateTensorArrayGradient(source, builder, &gradient)); 184 auto v = builder->GetTupleElement(pack, pos++); 185 if (!gradient->initialized()) { 186 gradient->initial_value_ = v; 187 } 188 gradient->value_ = v; 189 } 190 } 191 return Status::OK(); 192 } 193 194 } // namespace tensorflow 195