1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/while_op.h" 17 18 #include "tensorflow/compiler/tf2xla/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/computation_builder.h" 25 #include "tensorflow/compiler/xla/literal_util.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 29 namespace tensorflow { 30 31 namespace { 32 33 // Builds XlaCompiler argument descriptions `args` from `ctx`. 34 Status MakeXlaCompilerArgumentsFromInputs( 35 XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args, 36 bool* has_uninitialized_vars, bool* has_tensor_arrays) { 37 VLOG(2) << "Num inputs " << ctx->num_inputs(); 38 args->resize(ctx->num_inputs()); 39 *has_uninitialized_vars = false; 40 *has_tensor_arrays = false; 41 for (int i = 0; i < ctx->num_inputs(); ++i) { 42 VLOG(2) << " Input " << i 43 << " type: " << DataTypeString(ctx->input_type(i)) 44 << " shape: " << ctx->InputShape(i).DebugString(); 45 XlaCompiler::Argument& arg = (*args)[i]; 46 DataType type = ctx->input_type(i); 47 // When reading a resource input, use the type and shape of the resource's 48 // current value. 49 if (type == DT_RESOURCE) { 50 XlaResource* resource; 51 TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); 52 53 arg.initialized = resource->initialized(); 54 arg.kind = XlaCompiler::Argument::kResource; 55 arg.resource_kind = resource->kind(); 56 if (arg.resource_kind == XlaResource::kTensorArray) { 57 *has_tensor_arrays = true; 58 } 59 60 arg.type = resource->type(); 61 arg.shape = resource->shape(); 62 if (!arg.initialized) { 63 *has_uninitialized_vars = true; 64 } 65 arg.tensor_array_size = resource->tensor_array_size(); 66 for (const auto& gradient : resource->tensor_array_gradients()) { 67 arg.tensor_array_gradients.insert(gradient.first); 68 } 69 arg.name = resource->name(); 70 VLOG(2) << " resource " << resource->name() 71 << " type: " << DataTypeString(arg.type) 72 << " shape: " << arg.shape.DebugString() 73 << " initialized: " << arg.initialized; 74 75 } else { 76 arg.kind = XlaCompiler::Argument::kParameter; 77 arg.type = ctx->input_type(i); 78 arg.shape = ctx->InputShape(i); 79 } 80 } 81 return Status::OK(); 82 } 83 84 } // anonymous namespace 85 86 XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 87 const NameAttrList* name_attr; 88 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr)); 89 cond_name_attr_ = *name_attr; 90 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); 91 body_name_attr_ = *name_attr; 92 } 93 94 void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { 95 VLOG(1) << "WhileOp::Compile"; 96 97 std::vector<XlaCompiler::Argument> arguments; 98 bool has_uninitialized_vars; 99 bool has_tensor_arrays; 100 OP_REQUIRES_OK( 101 ctx, MakeXlaCompilerArgumentsFromInputs( 102 ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); 103 104 xla::ComputationBuilder* builder = ctx->builder(); 105 XlaCompiler* compiler = ctx->compiler(); 106 107 VLOG(1) << "Compiling body"; 108 109 // All resource that are inputs to the loop's body must also be 110 // present as loop body outputs; the signature of the loop's input and 111 // output must match. We ensure this by asking the compiler to include the 112 // current values of all resources, even if they haven't been updated by the 113 // computation. We must also ask the compiler to keep compile-time constant 114 // outputs as part of the generated computation, for the same reason. 115 // TODO(phawkins): consider adding loop-invariant inputs to XLA's While() 116 // operator. 117 XlaCompiler::CompileOptions body_options; 118 body_options.use_tuple_arg = true; 119 body_options.return_updated_values_for_all_resources = true; 120 body_options.resolve_compile_time_constants = false; 121 body_options.is_entry_computation = false; 122 XlaCompiler::CompilationResult body; 123 OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, 124 arguments, &body)); 125 126 // We must use a static shape for parameters to an XLA compilation. However, 127 // we may not know the shape of a resource if it is first 128 // written inside the loop. Furthermore, we do not know ahead of time which 129 // gradient TensorArrays will be created by the TensorArrayGradV3 operator. 130 // 131 // Ideally we would change TensorFlow to provide static shape always, but 132 // but this is not easy to do. So if uninitialized resources or TensorArrays 133 // are used by the loop body, we compile the body function twice: 134 // 1) once with uninitialized resource inputs and no TensorArray gradient 135 // inputs. We then discard the computation but we assume resource shapes 136 // and the set of gradients read or written will reach a fixpoint after one 137 // iteration. 138 // Hence we can use the output shapes and TensorArray gradients of each 139 // resource as the "true" shapes. 140 // 2) again with the "correct" resource information determined by (1). 141 if (has_uninitialized_vars || has_tensor_arrays) { 142 VLOG(2) << "Recompiling loop body: has_uninitialized_vars: " 143 << has_uninitialized_vars 144 << " has_tensor_arrays: " << has_tensor_arrays; 145 // Initializes any uninitialized resource with zero values of the 146 // shape determined by the first compilation. 147 for (int i = 0; i < body.resource_updates.size(); ++i) { 148 const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; 149 XlaResource* resource; 150 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); 151 152 XlaCompiler::Argument& arg = arguments[update.input_index]; 153 if (!arg.initialized) { 154 VLOG(2) << "Update shape for argument " << update.input_index << " " 155 << update.shape.DebugString(); 156 arg.initialized = true; 157 158 arg.shape = update.shape; 159 OP_REQUIRES_OK(ctx, 160 resource->SetTypeAndShape(update.type, update.shape)); 161 162 OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder)); 163 } 164 165 // Add any TensorArray gradients touched by the body to the enclosing 166 // graph. 167 for (const string& grad_source : update.tensor_array_gradients_accessed) { 168 VLOG(4) << "TensorArray " << resource->name() << " accessed gradient " 169 << grad_source; 170 XlaResource* gradient; 171 OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( 172 grad_source, builder, &gradient)); 173 } 174 175 // Add all of the TensorArray gradients to the argument. For simplicity, 176 // we always pass all known gradients. 177 for (const auto& gradient : resource->tensor_array_gradients()) { 178 arg.tensor_array_gradients.insert(gradient.first); 179 } 180 } 181 // Recompile the body with the "correct" resource shapes. 182 VLOG(1) << "Recompiling body with corrected resource shapes"; 183 body = {}; 184 OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, 185 arguments, &body)); 186 } 187 188 VLOG(1) << "Compiling condition"; 189 190 XlaCompiler::CompileOptions cond_options; 191 cond_options.use_tuple_arg = true; 192 cond_options.resolve_compile_time_constants = false; 193 cond_options.is_entry_computation = false; 194 XlaCompiler::CompilationResult cond; 195 OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, 196 arguments, &cond)); 197 198 OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, 199 errors::FailedPrecondition("Expected one input shape")); 200 xla::Shape body_input_shape = body.xla_input_shapes[0]; 201 OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), 202 errors::FailedPrecondition("Expected tuple shape")); 203 OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, 204 errors::FailedPrecondition("Expected one input shape")); 205 xla::Shape cond_input_shape = cond.xla_input_shapes[0]; 206 OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), 207 errors::FailedPrecondition("Expected tuple shape")); 208 209 VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) 210 << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); 211 VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape) 212 << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape); 213 214 OP_REQUIRES(ctx, 215 xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape), 216 errors::InvalidArgument( 217 "Input shapes of loop body and condition do not match: ", 218 xla::ShapeUtil::HumanString(body_input_shape), " vs. ", 219 xla::ShapeUtil::HumanString(cond_input_shape))); 220 OP_REQUIRES( 221 ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape), 222 errors::InvalidArgument( 223 "Input and output shapes of loop body do not match: ", 224 xla::ShapeUtil::HumanString(body_input_shape), " vs. ", 225 xla::ShapeUtil::HumanString(body.xla_output_shape))); 226 227 xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape( 228 {xla::ShapeUtil::MakeShape(xla::PRED, {})}); 229 OP_REQUIRES(ctx, 230 xla::ShapeUtil::Compatible(cond.xla_output_shape, 231 expected_cond_output_shape), 232 errors::InvalidArgument( 233 "Output shape of loop condition should be (pred[]), got: ", 234 xla::ShapeUtil::HumanString(cond.xla_output_shape))); 235 236 int num_inputs = body.input_mapping.size(); 237 std::vector<xla::ComputationDataHandle> inputs(num_inputs); 238 for (int i = 0; i < num_inputs; ++i) { 239 int input_num = body.input_mapping[i]; 240 if (ctx->input_type(input_num) == DT_RESOURCE) { 241 XlaResource* resource; 242 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); 243 OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); 244 } else { 245 inputs[i] = ctx->Input(i); 246 } 247 } 248 249 xla::ComputationDataHandle init = builder->Tuple(inputs); 250 251 VLOG(1) << "Building while loop"; 252 253 // Wraps the condition in a computation that unpacks the output tuple. 254 xla::Computation cond_wrapper; 255 { 256 std::unique_ptr<xla::ComputationBuilder> cb = 257 builder->CreateSubBuilder("cond_wrapper"); 258 auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); 259 auto outputs = cb->Call(*cond.computation, {inputs}); 260 cb->GetTupleElement(outputs, 0); 261 xla::StatusOr<xla::Computation> result = cb->Build(); 262 OP_REQUIRES_OK(ctx, result.status()); 263 cond_wrapper = std::move(result.ValueOrDie()); 264 } 265 266 xla::ComputationDataHandle while_result = 267 builder->While(cond_wrapper, *body.computation, init); 268 269 // Sets non-variable outputs. 270 for (int i = 0; i < ctx->num_outputs(); ++i) { 271 if (ctx->input_type(i) != DT_RESOURCE) { 272 ctx->SetOutput(body.input_mapping[i], 273 builder->GetTupleElement(while_result, i)); 274 } 275 } 276 277 // Updates the values of any resource variables modified by the loop. 278 for (int i = 0; i < body.resource_updates.size(); ++i) { 279 const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; 280 XlaResource* resource; 281 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); 282 if (update.modified) { 283 int pos = body.outputs.size() + i; 284 OP_REQUIRES_OK(ctx, 285 resource->SetFromPack( 286 arguments[update.input_index].tensor_array_gradients, 287 builder->GetTupleElement(while_result, pos), builder)); 288 } 289 VLOG(2) << "Loop-carried variable: pos: " << update.input_index 290 << " name: " << resource->name() << " modified: " << update.modified 291 << " type: " << DataTypeString(update.type) 292 << " shape: " << update.shape.DebugString(); 293 // Copies the identity of the resource variable from input to output 294 // unchanged, even if the variable was not modified. 295 ctx->op_kernel_context()->set_output( 296 update.input_index, 297 ctx->op_kernel_context()->input(update.input_index)); 298 } 299 300 VLOG(1) << "Done building while loop"; 301 } 302 303 REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); 304 305 } // namespace tensorflow 306