1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/ops/while_loop.h" 17 18 #include "tensorflow/cc/framework/scope_internal.h" 19 #include "tensorflow/cc/ops/control_flow_ops_internal.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/core/common_runtime/shape_refiner.h" 22 #include "tensorflow/core/graph/node_builder.h" 23 24 namespace tensorflow { 25 namespace ops { 26 27 namespace { 28 29 // Utility function for converting to internal C++ datatypes. 30 OutputTensor ToOutputTensor(const Output& output) { 31 return OutputTensor(output.node(), output.index()); 32 } 33 34 // Utility function for converting to internal C++ datatypes. 35 std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) { 36 std::vector<OutputTensor> result(outputs.size()); 37 for (int i = 0; i < outputs.size(); ++i) { 38 result[i] = ToOutputTensor(outputs[i]); 39 } 40 return result; 41 } 42 43 // Utility function for converting to internal C++ datatypes. 44 std::vector<Node*> ToNodes(const std::vector<Output>& outputs) { 45 std::vector<Node*> result(outputs.size()); 46 for (int i = 0; i < outputs.size(); ++i) { 47 result[i] = outputs[i].node(); 48 } 49 return result; 50 } 51 52 // Manually generates the name of the `loop_var_idx`-th NextIteration node of a 53 // loop being constructed with `scope`. This is used to define the backedge 54 // before the NextIteration node is created. 55 string NextIterationName(const Scope& scope, int loop_var_idx) { 56 string result; 57 const string& prefix = scope.impl()->name(); 58 if (!prefix.empty()) strings::StrAppend(&result, prefix, "/"); 59 strings::StrAppend(&result, "NextIteration"); 60 if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx); 61 return result; 62 } 63 64 // Creates the `loop_var_idx`-th Merge node of a loop being constructed with 65 // `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output. 66 Status CreateMerge(const Scope& scope, int loop_var_idx, 67 const Output& enter_output, Output* merge_output) { 68 // The merge nodes accept the while loop's back edges as an input (i.e. the 69 // not-yet-created next iteration nodes). Use the underlying NodeBuilder API 70 // directly to create the back edge. 71 NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index()); 72 73 const int next_output_index = 0; 74 DataType dtype = enter_output.node()->output_type(0); 75 NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx), 76 next_output_index, dtype); 77 78 std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input}); 79 const string unique_name = scope.GetUniqueNameForOp("Merge"); 80 NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list); 81 scope.UpdateBuilder(&builder); 82 83 Node* merge_node; 84 TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node)); 85 TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node)); 86 *merge_output = Output(merge_node, 0); 87 return Status::OK(); 88 } 89 90 // Creates the condition subgraph defined by `cond`. 91 Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, 92 const std::vector<Output>& inputs, Output* output) { 93 // The control dependency is for constants in the cond graph, and other ops 94 // that do not depend on the loop variables. This ensures that these ops are 95 // in the while loop frame (since they will indirectly depend on an Enter node 96 // defining the frame) and that they are executed once per loop iteration. 97 // 98 // TODO(skyewm): the control dep will be added to all nodes in the cond graph. 99 // This is at best unnecessary, and at worst may prevent different parts of 100 // different loop iterations from executing in parallel. 101 Scope cond_scope = 102 scope.NewSubScope("cond").WithControlDependencies(inputs[0]); 103 Output raw_cond_out; 104 TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); 105 106 TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(), 107 raw_cond_out.index())); 108 if (raw_cond_out.type() != DT_BOOL) { 109 return errors::InvalidArgument( 110 "BuildWhileLoop: 'cond' argument must return a boolean output, got ", 111 DataTypeString(raw_cond_out.type())); 112 } 113 // TODO(skyewm): check that raw_cond_out is scalar 114 115 *output = LoopCond(scope, raw_cond_out).output; 116 return Status::OK(); 117 } 118 119 // Create the body subgraph defined by `body`. `outputs` must be non-null and 120 // empty. 121 Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, 122 const std::vector<Output>& inputs, 123 std::vector<Output>* outputs) { 124 DCHECK(outputs != nullptr); 125 DCHECK(outputs->empty()); 126 127 // The control dependency is analogous to that in CreateCond(). 128 Scope body_scope = 129 scope.NewSubScope("body").WithControlDependencies(inputs[0]); 130 TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); 131 132 const size_t num_loop_vars = inputs.size(); 133 if (outputs->size() != num_loop_vars) { 134 return errors::InvalidArgument( 135 "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars, 136 " output(s), got ", outputs->size()); 137 } 138 for (const Output& output : *outputs) { 139 TF_RETURN_IF_ERROR( 140 scope.graph()->IsValidOutputTensor(output.node(), output.index())); 141 // TODO(skyewm): check output types/shapes 142 } 143 return Status::OK(); 144 } 145 146 } // namespace 147 148 // A while loop with a single loop variable looks like this: 149 // 150 // (output) 151 // ^ +---------------+ 152 // | | body subgraph +-------------+ 153 // Exit +---------------+ | 154 // ^ ^ | 155 // | | | 156 // Switch<--------+ v 157 // ^ | NextIteration 158 // | +------+--------+ | 159 // +---->| cond subgraph | | 160 // | +---------------+ | 161 // Merge<---------------------------+ 162 // ^ 163 // | 164 // Enter 165 // ^ 166 // | 167 // (input) 168 // 169 // If there are multiple loop variables, each of the control flow ops is 170 // duplicated for each loop variable. 171 // TODO(skyewm): link to public version of design doc 172 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, 173 const CondGraphBuilderFn& cond, 174 const BodyGraphBuilderFn& body, const string& frame_name, 175 OutputList* outputs, bool create_while_ctx, 176 Output* cond_output) { 177 DCHECK(!inputs.empty()); 178 DCHECK(outputs != nullptr); 179 DCHECK(outputs->empty()); 180 181 TF_RETURN_IF_ERROR(scope.status()); 182 const size_t num_loop_vars = inputs.size(); 183 184 std::vector<Output> enter_outputs(num_loop_vars); 185 for (int i = 0; i < num_loop_vars; ++i) { 186 enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name); 187 } 188 TF_RETURN_IF_ERROR(scope.status()); 189 190 std::vector<Output> merge_outputs(num_loop_vars); 191 for (int i = 0; i < num_loop_vars; ++i) { 192 TF_RETURN_IF_ERROR( 193 CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i])); 194 } 195 196 Output cond_out; 197 TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); 198 if (cond_output != nullptr) *cond_output = cond_out; 199 200 std::vector<Output> switch_trues(num_loop_vars); 201 std::vector<Output> switch_falses(num_loop_vars); 202 for (int i = 0; i < num_loop_vars; ++i) { 203 auto switch_i = Switch(scope, merge_outputs[i], cond_out); 204 switch_trues[i] = switch_i.output_true; 205 switch_falses[i] = switch_i.output_false; 206 } 207 TF_RETURN_IF_ERROR(scope.status()); 208 209 std::vector<Output> body_outputs; 210 TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs)); 211 212 std::vector<Output> next_outputs(num_loop_vars); 213 for (int i = 0; i < num_loop_vars; ++i) { 214 next_outputs[i] = NextIteration(scope, body_outputs[i]); 215 DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i)); 216 } 217 TF_RETURN_IF_ERROR(scope.status()); 218 219 // Create the backedges from the NextIteration nodes to the Merge nodes. 220 for (int i = 0; i < num_loop_vars; ++i) { 221 const int merge_backedge_output_index = 1; 222 scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(), 223 merge_outputs[i].node(), 224 merge_backedge_output_index); 225 } 226 227 outputs->resize(num_loop_vars); 228 for (int i = 0; i < num_loop_vars; ++i) { 229 (*outputs)[i] = internal::Exit(scope, switch_falses[i]); 230 } 231 TF_RETURN_IF_ERROR(scope.status()); 232 233 if (create_while_ctx) { 234 WhileContext* while_ctx; 235 TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext( 236 frame_name, ToNodes(enter_outputs), ToNodes(*outputs), 237 ToOutputTensor(cond_out), ToOutputTensors(switch_trues), 238 ToOutputTensors(body_outputs), &while_ctx)); 239 240 // Set while_ctx for all exit nodes. We currently don't require knowing the 241 // while_ctx for any other nodes. 242 for (int i = 0; i < num_loop_vars; ++i) { 243 (*outputs)[i].node()->set_while_ctx(while_ctx); 244 } 245 } 246 return Status::OK(); 247 } 248 249 } // namespace ops 250 } // namespace tensorflow 251