Home | History | Annotate | Download | only in ops
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/while_loop.h"
     17 
     18 #include "tensorflow/cc/framework/scope_internal.h"
     19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/common_runtime/shape_refiner.h"
     22 #include "tensorflow/core/graph/node_builder.h"
     23 
     24 namespace tensorflow {
     25 namespace ops {
     26 
     27 namespace {
     28 
     29 // Utility function for converting to internal C++ datatypes.
     30 OutputTensor ToOutputTensor(const Output& output) {
     31   return OutputTensor(output.node(), output.index());
     32 }
     33 
     34 // Utility function for converting to internal C++ datatypes.
     35 std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) {
     36   std::vector<OutputTensor> result(outputs.size());
     37   for (int i = 0; i < outputs.size(); ++i) {
     38     result[i] = ToOutputTensor(outputs[i]);
     39   }
     40   return result;
     41 }
     42 
     43 // Utility function for converting to internal C++ datatypes.
     44 std::vector<Node*> ToNodes(const std::vector<Output>& outputs) {
     45   std::vector<Node*> result(outputs.size());
     46   for (int i = 0; i < outputs.size(); ++i) {
     47     result[i] = outputs[i].node();
     48   }
     49   return result;
     50 }
     51 
     52 // Manually generates the name of the `loop_var_idx`-th NextIteration node of a
     53 // loop being constructed with `scope`. This is used to define the backedge
     54 // before the NextIteration node is created.
     55 string NextIterationName(const Scope& scope, int loop_var_idx) {
     56   string result;
     57   const string& prefix = scope.impl()->name();
     58   if (!prefix.empty()) strings::StrAppend(&result, prefix, "/");
     59   strings::StrAppend(&result, "NextIteration");
     60   if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx);
     61   return result;
     62 }
     63 
     64 // Creates the `loop_var_idx`-th Merge node of a loop being constructed with
     65 // `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output.
     66 Status CreateMerge(const Scope& scope, int loop_var_idx,
     67                    const Output& enter_output, Output* merge_output) {
     68   // The merge nodes accept the while loop's back edges as an input (i.e. the
     69   // not-yet-created next iteration nodes). Use the underlying NodeBuilder API
     70   // directly to create the back edge.
     71   NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index());
     72 
     73   const int next_output_index = 0;
     74   DataType dtype = enter_output.node()->output_type(0);
     75   NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx),
     76                                   next_output_index, dtype);
     77 
     78   std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input});
     79   const string unique_name = scope.GetUniqueNameForOp("Merge");
     80   NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list);
     81   scope.UpdateBuilder(&builder);
     82 
     83   Node* merge_node;
     84   TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node));
     85   TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node));
     86   *merge_output = Output(merge_node, 0);
     87   return Status::OK();
     88 }
     89 
     90 // Creates the condition subgraph defined by `cond`.
     91 Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
     92                   const std::vector<Output>& inputs, Output* output) {
     93   // The control dependency is for constants in the cond graph, and other ops
     94   // that do not depend on the loop variables. This ensures that these ops are
     95   // in the while loop frame (since they will indirectly depend on an Enter node
     96   // defining the frame) and that they are executed once per loop iteration.
     97   //
     98   // TODO(skyewm): the control dep will be added to all nodes in the cond graph.
     99   // This is at best unnecessary, and at worst may prevent different parts of
    100   // different loop iterations from executing in parallel.
    101   Scope cond_scope =
    102       scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
    103   Output raw_cond_out;
    104   TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
    105 
    106   TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(),
    107                                                         raw_cond_out.index()));
    108   if (raw_cond_out.type() != DT_BOOL) {
    109     return errors::InvalidArgument(
    110         "BuildWhileLoop: 'cond' argument must return a boolean output, got ",
    111         DataTypeString(raw_cond_out.type()));
    112   }
    113   // TODO(skyewm): check that raw_cond_out is scalar
    114 
    115   *output = LoopCond(scope, raw_cond_out).output;
    116   return Status::OK();
    117 }
    118 
    119 // Create the body subgraph defined by `body`. `outputs` must be non-null and
    120 // empty.
    121 Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
    122                   const std::vector<Output>& inputs,
    123                   std::vector<Output>* outputs) {
    124   DCHECK(outputs != nullptr);
    125   DCHECK(outputs->empty());
    126 
    127   // The control dependency is analogous to that in CreateCond().
    128   Scope body_scope =
    129       scope.NewSubScope("body").WithControlDependencies(inputs[0]);
    130   TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
    131 
    132   const size_t num_loop_vars = inputs.size();
    133   if (outputs->size() != num_loop_vars) {
    134     return errors::InvalidArgument(
    135         "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
    136         " output(s), got ", outputs->size());
    137   }
    138   for (const Output& output : *outputs) {
    139     TF_RETURN_IF_ERROR(
    140         scope.graph()->IsValidOutputTensor(output.node(), output.index()));
    141     // TODO(skyewm): check output types/shapes
    142   }
    143   return Status::OK();
    144 }
    145 
    146 }  // namespace
    147 
    148 // A while loop with a single loop variable looks like this:
    149 //
    150 // (output)
    151 //     ^    +---------------+
    152 //     |    | body subgraph +-------------+
    153 //    Exit  +---------------+             |
    154 //      ^    ^                            |
    155 //      |    |                            |
    156 //      Switch<--------+                  v
    157 //        ^            |             NextIteration
    158 //        |     +------+--------+         |
    159 //        +---->| cond subgraph |         |
    160 //        |     +---------------+         |
    161 //       Merge<---------------------------+
    162 //       ^
    163 //       |
    164 //    Enter
    165 //      ^
    166 //      |
    167 //   (input)
    168 //
    169 // If there are multiple loop variables, each of the control flow ops is
    170 // duplicated for each loop variable.
    171 // TODO(skyewm): link to public version of design doc
    172 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
    173                       const CondGraphBuilderFn& cond,
    174                       const BodyGraphBuilderFn& body, const string& frame_name,
    175                       OutputList* outputs, bool create_while_ctx,
    176                       Output* cond_output) {
    177   DCHECK(!inputs.empty());
    178   DCHECK(outputs != nullptr);
    179   DCHECK(outputs->empty());
    180 
    181   TF_RETURN_IF_ERROR(scope.status());
    182   const size_t num_loop_vars = inputs.size();
    183 
    184   std::vector<Output> enter_outputs(num_loop_vars);
    185   for (int i = 0; i < num_loop_vars; ++i) {
    186     enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name);
    187   }
    188   TF_RETURN_IF_ERROR(scope.status());
    189 
    190   std::vector<Output> merge_outputs(num_loop_vars);
    191   for (int i = 0; i < num_loop_vars; ++i) {
    192     TF_RETURN_IF_ERROR(
    193         CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i]));
    194   }
    195 
    196   Output cond_out;
    197   TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
    198   if (cond_output != nullptr) *cond_output = cond_out;
    199 
    200   std::vector<Output> switch_trues(num_loop_vars);
    201   std::vector<Output> switch_falses(num_loop_vars);
    202   for (int i = 0; i < num_loop_vars; ++i) {
    203     auto switch_i = Switch(scope, merge_outputs[i], cond_out);
    204     switch_trues[i] = switch_i.output_true;
    205     switch_falses[i] = switch_i.output_false;
    206   }
    207   TF_RETURN_IF_ERROR(scope.status());
    208 
    209   std::vector<Output> body_outputs;
    210   TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs));
    211 
    212   std::vector<Output> next_outputs(num_loop_vars);
    213   for (int i = 0; i < num_loop_vars; ++i) {
    214     next_outputs[i] = NextIteration(scope, body_outputs[i]);
    215     DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i));
    216   }
    217   TF_RETURN_IF_ERROR(scope.status());
    218 
    219   // Create the backedges from the NextIteration nodes to the Merge nodes.
    220   for (int i = 0; i < num_loop_vars; ++i) {
    221     const int merge_backedge_output_index = 1;
    222     scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(),
    223                            merge_outputs[i].node(),
    224                            merge_backedge_output_index);
    225   }
    226 
    227   outputs->resize(num_loop_vars);
    228   for (int i = 0; i < num_loop_vars; ++i) {
    229     (*outputs)[i] = internal::Exit(scope, switch_falses[i]);
    230   }
    231   TF_RETURN_IF_ERROR(scope.status());
    232 
    233   if (create_while_ctx) {
    234     WhileContext* while_ctx;
    235     TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
    236         frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
    237         ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
    238         ToOutputTensors(body_outputs), &while_ctx));
    239 
    240     // Set while_ctx for all exit nodes. We currently don't require knowing the
    241     // while_ctx for any other nodes.
    242     for (int i = 0; i < num_loop_vars; ++i) {
    243       (*outputs)[i].node()->set_while_ctx(while_ctx);
    244     }
    245   }
    246   return Status::OK();
    247 }
    248 
    249 }  // namespace ops
    250 }  // namespace tensorflow
    251