1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/while_loop.h" 17 #include "tensorflow/compiler/tf2xla/lib/util.h" 18 #include "tensorflow/compiler/xla/shape_util.h" 19 #include "tensorflow/compiler/xla/status_macros.h" 20 21 namespace tensorflow { 22 23 xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop( 24 const LoopConditionFunction& condition_function, 25 const LoopBodyFunction& body_function, 26 gtl::ArraySlice<xla::ComputationDataHandle> initial_values, 27 StringPiece name, xla::ComputationBuilder* builder) { 28 int arity = initial_values.size(); 29 std::vector<xla::Shape> var_shapes; 30 var_shapes.reserve(arity); 31 for (const xla::ComputationDataHandle& input : initial_values) { 32 TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); 33 var_shapes.push_back(std::move(*shape)); 34 } 35 xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); 36 37 // Unpacks a tuple into its component parts. 38 auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, 39 xla::ComputationBuilder* builder) { 40 std::vector<xla::ComputationDataHandle> elements(arity); 41 for (int i = 0; i < arity; ++i) { 42 elements[i] = builder->GetTupleElement(tuple, i); 43 } 44 return elements; 45 }; 46 47 // Build the condition. 48 std::unique_ptr<xla::ComputationBuilder> cond_builder = 49 builder->CreateSubBuilder(strings::StrCat(name, "_condition")); 50 { 51 auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); 52 53 TF_ASSIGN_OR_RETURN( 54 auto result, 55 condition_function(unpack_tuple(parameter, arity, cond_builder.get()), 56 cond_builder.get())); 57 TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); 58 } 59 TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); 60 61 // Build the body. 62 std::unique_ptr<xla::ComputationBuilder> body_builder = 63 builder->CreateSubBuilder(strings::StrCat(name, "_body")); 64 { 65 auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); 66 67 TF_ASSIGN_OR_RETURN( 68 auto result, 69 body_function(unpack_tuple(parameter, arity, body_builder.get()), 70 body_builder.get())); 71 72 TF_RET_CHECK(result.size() == initial_values.size()); 73 body_builder->Tuple(result); 74 } 75 TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); 76 77 auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); 78 79 return unpack_tuple(outputs, arity, builder); 80 } 81 82 xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex( 83 int64 num_iterations, xla::PrimitiveType num_iterations_type, 84 const ForEachIndexBodyFunction& body_function, 85 gtl::ArraySlice<xla::ComputationDataHandle> initial_values, 86 StringPiece name, xla::ComputationBuilder* builder) { 87 auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values, 88 xla::ComputationBuilder* cond_builder) 89 -> xla::StatusOr<xla::ComputationDataHandle> { 90 return cond_builder->Lt( 91 values[0], 92 IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); 93 }; 94 auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values, 95 xla::ComputationBuilder* body_builder) 96 -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> { 97 xla::ComputationDataHandle iteration = values[0]; 98 99 std::vector<xla::ComputationDataHandle> updated_values; 100 updated_values.reserve(values.size()); 101 updated_values.push_back(body_builder->Add( 102 iteration, 103 body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); 104 105 values.remove_prefix(1); 106 TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs, 107 body_function(iteration, values, body_builder)); 108 updated_values.insert(updated_values.end(), body_outputs.begin(), 109 body_outputs.end()); 110 return updated_values; 111 }; 112 113 std::vector<xla::ComputationDataHandle> values; 114 values.reserve(initial_values.size() + 1); 115 values.push_back( 116 builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); 117 values.insert(values.end(), initial_values.begin(), initial_values.end()); 118 119 TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, 120 name, builder)); 121 values.erase(values.begin(), values.begin() + 1); 122 return values; 123 } 124 125 } // namespace tensorflow 126