Home | History | Annotate | Download | only in graph_transformations
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <algorithm>
     16 #include <iterator>
     17 #include <memory>
     18 #include <string>
     19 #include <unordered_map>
     20 #include <vector>
     22 #include "absl/strings/str_join.h"
     23 #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
     24 #include "tensorflow/contrib/lite/toco/model.h"
     25 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     26 #include "tensorflow/core/platform/logging.h"
     28 namespace toco {
     30 namespace {
     32 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
     33                       int kheight, int stride_width, int stride_height,
     34                       PaddingType padding_type, Shape* output_shape,
     35                       FixedPadding* fixed_padding) {
     36   const int input_width = input_shape.dims(2);
     37   const int input_height = input_shape.dims(1);
     38   const int batch = input_shape.dims(0);
     40   int output_height = 0;
     41   int output_width = 0;
     42   if (padding_type == PaddingType::kValid) {
     43     output_height = (input_height + stride_height - kheight) / stride_height;
     44     output_width = (input_width + stride_width - kwidth) / stride_width;
     45   } else if (padding_type == PaddingType::kSame) {
     46     output_height = (input_height + stride_height - 1) / stride_height;
     47     output_width = (input_width + stride_width - 1) / stride_width;
     48   } else {
     49     LOG(FATAL) << "Only supporting SAME or VALID padding";
     50   }
     52   fixed_padding->height = std::max(
     53       0, ((output_height - 1) * stride_height + kheight - input_height) / 2);
     54   fixed_padding->width = std::max(
     55       0, ((output_width - 1) * stride_width + kwidth - input_width) / 2);
     57   // Actually had to debug a situation where those were negative due to bad
     58   // propagation of placeholder -1 sizes in TensorFlowReshape.
     59   CHECK_GT(output_width, 0);
     60   CHECK_GT(output_height, 0);
     61   output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
     62 }
     64 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
     65                                      const Shape& input_shape_y,
     66                                      Array* output_array) {
     67   // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
     68   // It zips together the two input shapes and pads with 1 to make them the
     69   // same length. For each dimension we broadcast if either dimension is 1 and
     70   // otherwise expect them to match.
     71   int rank_x = input_shape_x.dimensions_count();
     72   int rank_y = input_shape_y.dimensions_count();
     73   int rank_out = std::max(rank_x, rank_y);
     74   std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
     75   dims_out->clear();
     76   dims_out->reserve(rank_out);
     77   for (int i = 0; i < rank_out; ++i) {
     78     int dim_x = i < (rank_out - rank_x)
     79                     ? 1
     80                     : input_shape_x.dims(i - (rank_out - rank_x));
     81     bool dim_y_is_one = i < (rank_out - rank_y);
     82     int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
     83     if (dim_x == -1 || dim_y == -1) {
     84       // One or both dimensions is unknown.
     85       QCHECK(false) << "Shapes must be specified";
     86     } else if (dim_x == 1 || dim_y == 1) {
     87       // Broadcast one dimension to the other that is 1.
     88       if (dim_x == 1 && !dim_y_is_one) {
     89         // Broadcast dim_y to dim_x (1).
     90         dims_out->push_back(dim_y);
     91       } else {
     92         // Broadcast dim_x to dim_y (1).
     93         DCHECK_EQ(dim_y, 1);
     94         dims_out->push_back(dim_x);
     95       }
     96     } else {
     97       // Expect the dimensions to match.
     98       CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
     99       dims_out->push_back(dim_x);
    100     }
    101   }
    102   CHECK(output_array->has_shape());
    103 }
    105 int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
    106   const string& weights_name = op.inputs[1];
    107   const auto& weights_shape = model.GetArray(weights_name).shape();
    108   if (op.type == OperatorType::kConv ||
    109       op.type == OperatorType::kFullyConnected) {
    110     return weights_shape.dims(0);
    111   } else if (op.type == OperatorType::kDepthwiseConv) {
    112     return weights_shape.dims(3);
    113   } else {
    114     LOG(FATAL) << "Unhandled operator type";
    115   }
    116 }
    118 bool EnsureBiasVectorShape(Model* model, Operator* op) {
    119   const string& weights_name = op->inputs[1];
    120   const auto& weights_array = model->GetArray(weights_name);
    121   // Yield until weights shape has been resolved.
    122   if (!weights_array.has_shape()) {
    123     return false;
    124   }
    126   if (op->inputs.size() < 3) {
    127     return false;
    128   }
    129   auto& bias_array = model->GetArray(op->inputs[2]);
    130   if (bias_array.has_shape()) {
    131     return true;
    132   }
    134   const int output_depth = GetOutputDepthFromWeights(*model, *op);
    135   bias_array.copy_shape(Shape({output_depth}));
    137   auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
    138   float_buffer.data.resize(output_depth, 0);
    140   return true;
    141 }
    143 void ProcessConvOperator(Model* model, ConvOperator* op) {
    144   if (!EnsureBiasVectorShape(model, op)) {
    145     return;
    146   }
    148   const auto& input_array = model->GetArray(op->inputs[0]);
    149   // Yield until input dims have been resolved.
    150   if (!input_array.has_shape()) {
    151     return;
    152   }
    153   const auto& input_shape = input_array.shape();
    154   CHECK_EQ(input_shape.dimensions_count(), 4);
    156   const auto& weights_array = model->GetArray(op->inputs[1]);
    157   // Yield until weights dims have been resolved.
    158   if (!weights_array.has_shape()) {
    159     return;
    160   }
    161   const auto& weights_shape = weights_array.shape();
    162   CHECK_EQ(weights_shape.dimensions_count(), 4);
    164   auto& output_array = model->GetArray(op->outputs[0]);
    165   const int output_depth = weights_shape.dims(0);
    166   const int kheight = weights_shape.dims(1);
    167   const int kwidth = weights_shape.dims(2);
    168   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
    169                    op->stride_height, op->padding.type,
    170                    output_array.mutable_shape(),
    171                    &op->padding.GetOrCreateFixedPadding());
    172   CHECK_EQ(output_array.shape().dimensions_count(), 4);
    174   // Set im2col array dimensions if there is one.
    175   if (op->outputs.size() == 2) {
    176     const auto& output_shape = output_array.shape();
    177     const int input_depth = weights_shape.dims(3);
    178     auto& im2col_array = model->GetArray(op->outputs[1]);
    179     im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
    180                                   output_shape.dims(2),
    181                                   input_depth * kheight * kwidth});
    182   }
    183 }
    185 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
    186   if (!EnsureBiasVectorShape(model, op)) {
    187     return;
    188   }
    190   const auto& input_array = model->GetArray(op->inputs[0]);
    191   // Yield until input dims have been resolved.
    192   if (!input_array.has_shape()) {
    193     return;
    194   }
    195   const auto& input_shape = input_array.shape();
    196   CHECK_EQ(input_shape.dimensions_count(), 4);
    198   const auto& weights_array = model->GetArray(op->inputs[1]);
    199   // Yield until weights dims have been resolved.
    200   if (!weights_array.has_shape()) {
    201     return;
    202   }
    203   const auto& weights_shape = weights_array.shape();
    204   CHECK_EQ(weights_shape.dimensions_count(), 4);
    206   const string& output_name = op->outputs[0];
    207   const int input_depth = input_shape.dims(3);
    208   const int output_depth = weights_shape.dims(3);
    209   // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
    210   // instead it has to be inferred from the weights dims. However, once we are
    211   // here, weights dims have already been converted to our own internal format,
    212   // where the multiplier is no longer readily apparent. So instead we get it
    213   // as the quotient of output and input depths. We only want to do that when
    214   // depth_multiplier had the zero value: any other value should be checked
    215   // as done by the next if() below.
    216   if (!op->depth_multiplier) {
    217     op->depth_multiplier = output_depth / input_depth;
    218   }
    219   QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
    220       << "input/output depths and depth_multiplier don't match";
    222   const int kheight = weights_shape.dims(1);
    223   const int kwidth = weights_shape.dims(2);
    224   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
    225                    op->stride_height, op->padding.type,
    226                    model->GetArray(output_name).mutable_shape(),
    227                    &op->padding.GetOrCreateFixedPadding());
    228 }
    230 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
    231   const auto& input_array = model->GetArray(op->inputs[0]);
    232   // Yield until input dims have been resolved.
    233   if (!input_array.has_shape()) {
    234     return;
    235   }
    236   const auto& input_shape = input_array.shape();
    237   CHECK_EQ(input_shape.dimensions_count(), 4);
    239   const string& output_name = op->outputs[0];
    240   const int block_size = op->block_size;
    241   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
    242   const int batch = input_shape.dims(0);
    243   const int height = input_shape.dims(1);
    244   const int width = input_shape.dims(2);
    245   const int depth = input_shape.dims(3);
    246   QCHECK_EQ(depth % (block_size * block_size), 0);
    248   model->GetArray(output_name)
    249       .copy_shape(Shape({batch, height * block_size, width * block_size,
    250                          depth / block_size / block_size}));
    251 }
    253 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
    254   const auto& input_array = model->GetArray(op->inputs[0]);
    255   // Yield until input dims have been resolved.
    256   if (!input_array.has_shape()) {
    257     return;
    258   }
    259   const auto& input_shape = input_array.shape();
    260   CHECK_EQ(input_shape.dimensions_count(), 4);
    262   const string& output_name = op->outputs[0];
    263   const int block_size = op->block_size;
    264   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
    265   const int batch = input_shape.dims(0);
    266   const int height = input_shape.dims(1);
    267   const int width = input_shape.dims(2);
    268   const int depth = input_shape.dims(3);
    269   QCHECK_EQ(width % block_size, 0);
    270   QCHECK_EQ(height % block_size, 0);
    272   model->GetArray(output_name)
    273       .copy_shape(Shape({batch, height / block_size, width / block_size,
    274                          depth * block_size * block_size}));
    275 }
    277 void ProcessFillOperator(Model* model, FillOperator* op) {
    278   CHECK_EQ(op->inputs.size(), 2);
    279   CHECK_EQ(op->outputs.size(), 1);
    280   auto& output_array = model->GetArray(op->outputs[0]);
    281   if (output_array.has_shape()) {
    282     // We have already run
    283     return;
    284   }
    286   auto& dims_array = model->GetArray(op->inputs[0]);
    287   if (!dims_array.has_shape()) {
    288     // Yield until dims shape been resolved.
    289     return;
    290   }
    291   if (!dims_array.buffer) {
    292     // Yield until the dims are constant
    293     return;
    294   }
    295   CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
    296   CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
    297       << "dims vector can be no larger than 4 values";
    299   std::vector<int32> const& dims =
    300       dims_array.GetBuffer<ArrayDataType::kInt32>().data;
    301   *(output_array.mutable_shape()->mutable_dims()) = dims;
    302 }
    304 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
    305   if (!EnsureBiasVectorShape(model, op)) {
    306     return;
    307   }
    309   const auto& input_array = model->GetArray(op->inputs[0]);
    310   // Yield until input dims have been resolved.
    311   if (!input_array.has_shape()) {
    312     return;
    313   }
    314   const auto& input_shape = input_array.shape();
    315   CHECK_GE(input_shape.dimensions_count(), 1);
    317   const auto& weights_array = model->GetArray(op->inputs[1]);
    318   // Yield until weights dims have been resolved.
    319   if (!weights_array.has_shape()) {
    320     return;
    321   }
    322   const auto& weights_shape = weights_array.shape();
    324   const int weights_output_depth = weights_shape.dims(0);
    325   CHECK_EQ(weights_shape.dimensions_count(), 2);
    327   const int input_overall_size = RequiredBufferSizeForShape(input_shape);
    328   const int matmul_repeats = input_overall_size / weights_shape.dims(1);
    329   CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
    331   auto& output_array = model->GetArray(op->outputs[0]);
    332   output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
    333 }
    335 void ProcessTensorFlowReshapeOperator(Model* model,
    336                                       TensorFlowReshapeOperator* op) {
    337   auto& output_array = model->GetArray(op->outputs[0]);
    338   if (output_array.has_shape()) {
    339     // We have already run
    340     return;
    341   }
    343   const auto& input_array = model->GetArray(op->inputs[0]);
    344   if (!input_array.has_shape()) {
    345     // Yield until input dims have been resolved.
    346     return;
    347   }
    348   const auto& input_shape = input_array.shape();
    350   auto& shape_array = model->GetArray(op->inputs[1]);
    351   if (!shape_array.has_shape()) {
    352     // Yield until target_shape shape been resolved.
    353     return;
    354   }
    355   if (!shape_array.buffer) {
    356     // Yield until the target_shape is constant
    357     return;
    358   }
    359   CHECK(shape_array.data_type == ArrayDataType::kInt32)
    360       << "Reshape dims must be int32";
    362   // shape_data is the raw array of ints describing the shape
    363   // in the TensorFlow node. We intentionally make a copy here, rather than
    364   // modify wildcards in-place below, because in some graphs, the same shape
    365   // array with a wildcard may be referenced from multiple Reshape nodes, where
    366   // the wildcard needs to resolved to distinct values.
    367   std::vector<int32> shape_data =
    368       shape_array.GetBuffer<ArrayDataType::kInt32>().data;
    369   // The Reshape shape may have a wildcard dim, encoded as -1.
    370   bool has_wildcard = false;
    371   int wildcard_index = 0;
    372   int product_non_wildcard_dims = 1;
    373   for (int i = 0; i < shape_data.size(); i++) {
    374     if (shape_data[i] == -1) {
    375       CHECK(!has_wildcard);
    376       has_wildcard = true;
    377       wildcard_index = i;
    378     } else {
    379       product_non_wildcard_dims *= shape_data[i];
    380     }
    381   }
    382   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
    383   if (has_wildcard) {
    384     CHECK_GE(input_flat_size, product_non_wildcard_dims)
    385         << "Array not large enough to fill the requested dimensions for "
    386            "Reshape op with output \""
    387         << op->outputs[0] << "\". Are your input shapes correct?";
    388     shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
    389   }
    390   auto& output_shape = *output_array.mutable_shape();
    391   *output_shape.mutable_dims() = shape_data;
    392   CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
    393       << "Input cannot be reshaped to requested dimensions for Reshape op with "
    394          "output \""
    395       << op->outputs[0] << "\". Are your input shapes correct?";
    396 }
    398 void ProcessSimpleOperator(Model* model, Operator* op) {
    399   const auto& input_array = model->GetArray(op->inputs[0]);
    400   // Yield until input dims have been resolved.
    401   if (!input_array.has_shape()) {
    402     return;
    403   }
    405   const string& output_name = op->outputs[0];
    406   auto& output_array = model->GetArray(output_name);
    407   if (output_array.has_shape()) {
    408     return;
    409   }
    411   output_array.copy_shape(input_array.shape());
    412 }
    414 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
    415   CHECK_EQ(op->inputs.size(), 2);
    416   const auto& input0_array = model->GetArray(op->inputs[0]);
    417   const auto& input1_array = model->GetArray(op->inputs[1]);
    418   // Yield until input dims have been resolved.
    419   if (!input0_array.has_shape() || !input1_array.has_shape()) {
    420     return;
    421   }
    422   const string& output_name = op->outputs[0];
    423   auto& output_array = model->GetArray(output_name);
    424   ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
    425                                   &output_array);
    426 }
    428 void ProcessAddNOperator(Model* model, Operator* op) {
    429   // Yield until all input dims have been resolved.
    430   //
    431   // TODO(myenik): Since AddN does not support broadcasting, maybe we could
    432   // actually use this to improve shape propagation by propagating the shape of
    433   // one input to all other inputs once it is resolved instead of just the
    434   // output, since all inputs must be the same size and shape for a well-formed
    435   // graph.
    436   for (const auto& input : op->inputs) {
    437     const auto& input_array = model->GetArray(input);
    438     if (!input_array.has_shape()) {
    439       return;
    440     }
    441   }
    443   // AddN does not support broadcasting, all inputs must be the same shape, so
    444   // we just take the first input shape and apply it to the output.
    445   const auto& input0_array = model->GetArray(op->inputs[0]);
    446   auto& output_array = model->GetArray(op->outputs[0]);
    447   output_array.copy_shape(input0_array.shape());
    448 }
    450 bool KeepDims(const Operator& op) {
    451   switch (op.type) {
    452     case OperatorType::kTensorFlowMin:
    453       return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
    454     case OperatorType::kTensorFlowMax:
    455       return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
    456     case OperatorType::kTensorFlowSum:
    457       return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
    458     case OperatorType::kMean:
    459       return static_cast<const MeanOperator&>(op).keep_dims;
    460     default:
    461       LOG(FATAL) << "Not a reduction operator!";
    462       return false;
    463   }
    464 }
    466 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
    467   CHECK_LE(op->inputs.size(), 2);
    468   auto& output_array = model->GetArray(op->outputs[0]);
    469   if (output_array.has_shape()) {
    470     return;
    471   }
    472   const auto& input_array = model->GetArray(op->inputs[0]);
    473   if (!input_array.has_shape()) {
    474     return;
    475   }
    476   const auto& input_shape = input_array.shape();
    477   const bool keep_dims = KeepDims(*op);
    478   if (op->inputs.size() == 2) {
    479     // There is a reduction_indices input.
    480     const auto& reduction_array = model->GetArray(op->inputs[1]);
    481     if (!reduction_array.buffer) {
    482       return;
    483     }
    484     CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
    485     const auto& reduction_array_vals =
    486         reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
    487     auto& output_dims = *output_array.mutable_shape()->mutable_dims();
    488     output_dims.clear();
    489     for (int i = 0; i < input_shape.dimensions_count(); i++) {
    490       bool is_reduction_dim = false;
    491       for (int r : reduction_array_vals) {
    492         if (i == r) {
    493           is_reduction_dim = true;
    494         }
    495       }
    496       if (!is_reduction_dim) {
    497         output_dims.push_back(input_shape.dims(i));
    498       } else if (keep_dims) {
    499         output_dims.push_back(1);
    500       }
    501     }
    502   } else {
    503     // No reduction_indices means complete reduction to a single scalar.
    504     if (keep_dims) {
    505       output_array.copy_shape(input_shape);
    506     } else {
    507       output_array.copy_shape(Shape({}));
    508     }
    509   }
    510 }
    512 void ProcessSliceOperator(Model* model, SliceOperator* op) {
    513   CHECK_EQ(op->inputs.size(), 3);
    514   CHECK_EQ(op->outputs.size(), 1);
    516   // Yield until the Slice params have been resolved.
    517   if (op->begin.empty()) return;
    519   // Yield until input dims have been resolved.
    520   const auto& input_array = model->GetArray(op->inputs[0]);
    521   if (!input_array.has_shape()) return;
    522   const Shape& input_shape = input_array.shape();
    524   auto& output_array = model->GetArray(op->outputs[0]);
    525   if (output_array.has_shape()) return;
    527   CHECK_EQ(input_shape.dims().size(), op->size.size());
    528   CHECK_EQ(op->begin.size(), op->size.size());
    530   std::vector<int> output_dims;
    531   for (int i = 0; i < op->begin.size(); ++i) {
    532     int size = op->size[i];
    533     if (size == -1) {
    534       size = input_array.shape().dims(i) - op->begin[i];
    535     }
    536     output_dims.push_back(size);
    537   }
    539   *output_array.mutable_shape()->mutable_dims() = output_dims;
    540 }
    542 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
    543   const string& input_name = op->inputs[0];
    544   const auto& input_array = model->GetArray(input_name);
    545   // Yield until input dims have been resolved.
    546   if (!input_array.has_shape()) {
    547     return;
    548   }
    549   const auto& input_shape = input_array.shape();
    550   const string& output_name = op->outputs[0];
    551   Shape* output_shape = model->GetArray(output_name).mutable_shape();
    552   ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
    553               output_shape);
    554 }
    556 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
    557   // Yield until input dims have been resolved.
    558   for (const auto& input_name : op->inputs) {
    559     auto& input_array = model->GetArray(input_name);
    560     if (!input_array.has_shape()) {
    561       return;
    562     }
    563   }
    564   auto& output_array = model->GetArray(op->outputs[0]);
    565   // Use 0 input as basis for output dimensions.
    566   const auto& first_input_array = model->GetArray(op->inputs[0]);
    567   output_array.copy_shape(first_input_array.shape());
    568   // Negative axis means the count starts at the back of the dims().
    569   int axis = op->axis;
    570   if (axis < 0) axis += first_input_array.shape().dims().size();
    571   // Determine the concat size, and enfore that all inputs have
    572   // the same dimensions count.
    573   int concat_size = 0;
    574   for (const auto& input_name : op->inputs) {
    575     auto& input_array = model->GetArray(input_name);
    576     CHECK(input_array.has_shape());
    577     if (input_array.shape().dimensions_count() == 0) {
    578       continue;
    579     }
    580     CHECK_EQ(input_array.shape().dimensions_count(),
    581              output_array.shape().dimensions_count());
    582     const std::vector<int>& input_dims = input_array.shape().dims();
    583     CHECK_LT(axis, input_dims.size());
    584     concat_size += input_dims[axis];
    585   }
    586   // Write out the concat_size on the output array shape.
    587   auto& output_shape = *output_array.mutable_shape();
    588   auto& output_dims = *output_shape.mutable_dims();
    589   CHECK_LT(axis, output_shape.dimensions_count());
    590   output_dims[axis] = concat_size;
    591 }
    593 void ProcessRangeOperator(Model* model, RangeOperator* op) {
    594   CHECK_EQ(op->inputs.size(), 3);
    595   const auto& start_array = model->GetArray(op->inputs[0]);
    596   if (!start_array.has_shape()) {
    597     // Yield until input dims have been resolved.
    598     return;
    599   }
    600   const auto& limit_array = model->GetArray(op->inputs[1]);
    601   if (!limit_array.has_shape()) {
    602     return;
    603   }
    604   const auto& delta_array = model->GetArray(op->inputs[2]);
    605   if (!delta_array.has_shape()) {
    606     return;
    607   }
    609   if (!IsConstantParameterArray(*model, op->inputs[0])) {
    610     // Yield until inputs are constant.
    611     return;
    612   }
    613   if (!IsConstantParameterArray(*model, op->inputs[1])) {
    614     return;
    615   }
    616   if (!IsConstantParameterArray(*model, op->inputs[2])) {
    617     return;
    618   }
    620   CHECK(start_array.data_type == ArrayDataType::kInt32)
    621       << "Range op inputs must be int32.";
    622   CHECK(limit_array.data_type == ArrayDataType::kInt32)
    623       << "Range op inputs must be int32.";
    624   CHECK(delta_array.data_type == ArrayDataType::kInt32)
    625       << "Range op inputs must be int32.";
    626   CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
    627       << "Range op inputs must be scalar.";
    628   CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
    629       << "Range op inputs must be scalar.";
    630   CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
    631       << "Range op inputs must be scalar.";
    632   int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
    633                     start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
    634                    delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
    636   // Only set the output shape. Contents are set by ResolveConstantRange.
    637   CHECK_EQ(op->outputs.size(), 1);
    638   auto& output_array = model->GetArray(op->outputs[0]);
    639   Shape* output_shape = output_array.mutable_shape();
    640   output_shape->ReplaceDims({size});
    641 }
    643 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
    644   CHECK_EQ(op->inputs.size(), 2);
    645   const string& input_name = op->inputs[1];
    646   const auto& input_array = model->GetArray(input_name);
    647   // Yield until input dims have been resolved.
    648   if (!input_array.has_shape()) {
    649     return;
    650   }
    651   const Shape& input_shape = input_array.shape();
    653   // Yield until axis is constant.
    654   if (!IsConstantParameterArray(*model, op->inputs[0])) {
    655     return;
    656   }
    658   const auto& axis_array = model->GetArray(op->inputs[0]);
    660   // Yield until axis dims have been resolved.
    661   if (!axis_array.has_shape()) {
    662     return;
    663   }
    665   CHECK(axis_array.data_type == ArrayDataType::kInt32)
    666       << "Axis array must be int32.";
    667   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
    668       << "Axis array must be scalar.";
    670   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
    671   if (axis < 0) {
    672     axis += input_shape.dimensions_count();
    673   }
    675   const int split_dim = input_shape.dims(axis);
    676   CHECK_EQ(split_dim % op->num_split, 0);
    677   const int split_depth = split_dim / op->num_split;
    679   Shape output_shape = input_shape;
    680   (*output_shape.mutable_dims())[axis] = split_depth;
    682   CHECK_EQ(op->outputs.size(), op->num_split);
    683   for (const auto& output : op->outputs) {
    684     model->GetArray(output).copy_shape(output_shape);
    685   }
    686 }
    688 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
    689   const string& input_name = op->inputs[0];
    690   const auto& input_array = model->GetArray(input_name);
    691   // Yield until input dims have been resolved.
    692   if (!input_array.has_shape()) {
    693     return;
    694   }
    695   const auto& input_shape = input_array.shape();
    696   CHECK_EQ(input_shape.dimensions_count(), 4);
    697   const string& output_name = op->outputs[0];
    698   const int output_depth = input_shape.dims(3);
    699   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
    700                    op->stride_width, op->stride_height, op->padding.type,
    701                    model->GetArray(output_name).mutable_shape(),
    702                    &op->padding.GetOrCreateFixedPadding());
    703 }
    705 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
    706   const string& input_name = op->inputs[0];
    707   const auto& input_array = model->GetArray(input_name);
    708   // Yield until input dims have been resolved.
    709   if (!input_array.has_shape()) {
    710     return;
    711   }
    712   const auto& input_shape = input_array.shape();
    713   CHECK_EQ(input_shape.dimensions_count(), 4);
    714   const string& output_name = op->outputs[0];
    715   const int output_depth = input_shape.dims(3);
    716   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
    717                    op->stride_width, op->stride_height, op->padding.type,
    718                    model->GetArray(output_name).mutable_shape(),
    719                    &op->padding.GetOrCreateFixedPadding());
    720 }
    722 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
    723   const string& input_name = op->inputs[0];
    724   const auto& input_array = model->GetArray(input_name);
    725   // Yield until input dims have been resolved.
    726   if (!input_array.has_shape()) {
    727     return;
    728   }
    729   const auto& input_shape = input_array.shape();
    730   if (input_shape.dimensions_count() < 4) {
    731     LOG(FATAL) << "missing dimensions for " << input_name;
    732   }
    733   const string& output_name = op->outputs[0];
    734   const int output_depth = input_shape.dims(3);
    735   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
    736                    op->stride_width, op->stride_height, op->padding.type,
    737                    model->GetArray(output_name).mutable_shape(),
    738                    &op->padding.GetOrCreateFixedPadding());
    739 }
    741 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
    742   CHECK_EQ(op->inputs.size(), 2);
    743   CHECK_EQ(op->outputs.size(), 1);
    745   if (!model->GetArray(op->inputs[0]).has_shape() ||
    746       !model->GetArray(op->inputs[1]).has_shape()) {
    747     return;
    748   }
    749   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
    751   const string& output_size_name = op->inputs[1];
    752   const auto& output_size_array = model->GetArray(output_size_name);
    753   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
    754   CHECK(output_size_array.has_shape());
    755   const auto& output_size_shape = output_size_array.shape();
    756   CHECK_EQ(output_size_shape.dimensions_count(), 1);
    757   CHECK_EQ(output_size_shape.dims(0), 2);
    758   if (!output_size_array.buffer) {
    759     return;
    760   }
    761   std::vector<int32> output_shape =
    762       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
    763   model->GetArray(op->outputs[0])
    764       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
    765                          output_shape[1], input_data_shape.dims(3)}));
    766 }
    768 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
    769   // Only required for compact LstmCell with default NUM_INPUTS of inputs.
    770   if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
    772   const auto& input_array =
    773       model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
    774   // Yield until all input dims have been resolved.
    775   if (!input_array.has_shape()) {
    776     return;
    777   }
    778   const auto& input_shape = input_array.shape();
    779   CHECK_GE(input_shape.dimensions_count(), 2);
    781   const auto& prev_activ_array =
    782       model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
    783   // Yield until all input dims have been resolved.
    784   if (!prev_activ_array.has_shape()) {
    785     return;
    786   }
    787   const auto& prev_activ_shape = prev_activ_array.shape();
    788   CHECK_GE(prev_activ_shape.dimensions_count(), 2);
    790   const auto& weights_array =
    791       model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
    792   // Yield until weights dims have been resolved.
    793   if (!weights_array.has_shape()) {
    794     return;
    795   }
    796   const auto& weights_shape = weights_array.shape();
    797   CHECK_EQ(weights_shape.dimensions_count(), 2);
    799   const auto& bias_array =
    800       model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
    801   // Yield until bias dims have been resolved.
    802   if (!bias_array.has_shape()) {
    803     return;
    804   }
    805   const auto& bias_shape = bias_array.shape();
    806   CHECK_GE(bias_shape.dimensions_count(), 1);
    808   const auto& prev_state_array =
    809       model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
    810   // Yield until all input dims have been resolved.
    811   if (!prev_state_array.has_shape()) {
    812     return;
    813   }
    814   const auto& prev_state_shape = prev_state_array.shape();
    815   CHECK_GE(prev_state_shape.dimensions_count(), 2);
    817   const int fc_output_depth = weights_shape.dims(0);
    818   CHECK_EQ(fc_output_depth, bias_shape.dims(0));
    819   CHECK_EQ(fc_output_depth % 4, 0);
    820   const int depth = fc_output_depth / 4;
    822   const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
    823   const int fc_input_depth = weights_shape.dims(1);
    824   CHECK_EQ(input_depth + depth, fc_input_depth);
    825   Shape output_shape(input_shape);
    826   (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
    828   // Set output dimensions
    829   model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
    830       .copy_shape(output_shape);
    831   model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
    832       .copy_shape(output_shape);
    834   Shape concat_temp_shape(input_shape);
    835   (*concat_temp_shape
    836         .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
    837       fc_input_depth;
    838   model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
    839       .copy_shape(concat_temp_shape);
    841   Shape activ_temp_shape(input_shape);
    842   (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
    843       fc_output_depth;
    844   model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
    845       .copy_shape(activ_temp_shape);
    846 }
    848 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
    849   const auto& input_array = model->GetArray(op->inputs[0]);
    850   // Yield until input dims have been resolved.
    851   if (!input_array.has_shape()) {
    852     return;
    853   }
    854   const auto& input_shape = input_array.shape();
    855   // This method only handles input dimensions of 4.
    856   if (input_shape.dimensions_count() != 4) {
    857     return;
    858   }
    859   const auto input_height = input_shape.dims(1);
    860   const auto input_width = input_shape.dims(2);
    862   const auto& block_shape_array = model->GetArray(op->inputs[1]);
    863   const auto& paddings_array = model->GetArray(op->inputs[2]);
    864   const auto& block_shape_array_shape = block_shape_array.shape();
    865   const auto& paddings_array_shape = paddings_array.shape();
    866   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
    867   QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
    869   // We only support two dimensions.
    870   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
    871   if (!block_shape_array.buffer) {
    872     return;
    873   }
    874   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
    875   const auto& block_shape_data =
    876       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
    877   auto block_height = block_shape_data[0];
    878   auto block_width = block_shape_data[1];
    880   QCHECK_EQ(paddings_array_shape.dims(0), 2);  // Number of block dimensions
    881   QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
    882   if (!paddings_array.buffer) {
    883     return;
    884   }
    885   QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
    886   const auto& paddings_data =
    887       paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
    888   int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
    889   int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
    890   QCHECK_EQ(height_with_paddings % block_height, 0);
    891   QCHECK_EQ(width_with_paddings % block_width, 0);
    892   int output_height = height_with_paddings / block_height;
    893   int output_width = width_with_paddings / block_width;
    895   model->GetArray(op->outputs[0])
    896       .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
    897                          output_height, output_width, input_shape.dims(3)}));
    898 }
    900 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
    901   const auto& input_array = model->GetArray(op->inputs[0]);
    902   // Yield until input dims have been resolved.
    903   if (!input_array.has_shape()) {
    904     return;
    905   }
    906   const auto& input_shape = input_array.shape();
    907   CHECK_EQ(input_shape.dimensions_count(), 4);
    908   const auto input_height = input_shape.dims(1);
    909   const auto input_width = input_shape.dims(2);
    911   const auto& block_shape_array = model->GetArray(op->inputs[1]);
    912   const auto& crops_array = model->GetArray(op->inputs[2]);
    913   const auto& block_shape_array_shape = block_shape_array.shape();
    914   const auto& crops_array_shape = crops_array.shape();
    915   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
    916   QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
    918   // We only support two dimensions.
    919   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
    920   if (!block_shape_array.buffer) {
    921     return;
    922   }
    923   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
    924   const auto& block_shape_data =
    925       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
    926   auto block_height = block_shape_data[0];
    927   auto block_width = block_shape_data[1];
    929   QCHECK_EQ(crops_array_shape.dims(0), 2);  // Number of block dimensions
    930   QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
    931   if (!crops_array.buffer) {
    932     return;
    933   }
    934   QCHECK(crops_array.data_type == ArrayDataType::kInt32);
    935   const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
    936   // We don't support crops now.
    937   QCHECK_EQ(crops_data[0], 0);
    938   QCHECK_EQ(crops_data[1], 0);
    939   QCHECK_EQ(crops_data[2], 0);
    940   QCHECK_EQ(crops_data[3], 0);
    942   QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
    944   int output_height = input_height * block_height;
    945   int output_width = input_width * block_width;
    947   model->GetArray(op->outputs[0])
    948       .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
    949                          output_height, output_width, input_shape.dims(3)}));
    950 }
    952 void ProcessGatherOperator(Model* model, GatherOperator* op) {
    953   const auto& input_array = model->GetArray(op->inputs[0]);
    954   const auto& indices_array = model->GetArray(op->inputs[1]);
    955   auto& output_array = model->GetArray(op->outputs[0]);
    957   // Bail if we already know the output shape.
    958   if (output_array.has_shape()) {
    959     return;
    960   }
    962   // Yield until input dims have been resolved.
    963   if (!input_array.has_shape() || !indices_array.has_shape()) {
    964     return;
    965   }
    967   const auto& input_shape = input_array.shape();
    968   const auto& indices_shape = indices_array.shape();
    969   QCHECK_GE(input_shape.dimensions_count(), 1);
    970   op->input_rank = input_shape.dimensions_count();
    972   // We only support 1-D indices.
    973   QCHECK_EQ(indices_shape.dimensions_count(), 1);
    975   // Copy the input dimensions to the output except for dimension 0,
    976   // where the dimension of indices_shape is used.
    977   // TODO(mgubin): if axis != 0 this is not true, change when it's supported.
    978   auto output_dims = output_array.mutable_shape()->mutable_dims();
    979   output_dims->push_back(indices_shape.dims(0));
    980   for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
    981     output_dims->push_back(input_shape.dims(dim));
    982   }
    983 }
    985 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
    986   const auto& input_values = model->GetArray(op->inputs[0]);
    987   const auto& input_k = model->GetArray(op->inputs[1]);
    988   auto& output_indexes = model->GetArray(op->outputs[0]);
    989   auto& output_values = model->GetArray(op->outputs[1]);
    991   // Bail if we already know the output shape.
    992   if (output_indexes.has_shape()) {
    993     QCHECK(output_values.has_shape());
    994     return;
    995   }
    997   // Yield until input dims have been resolved.
    998   if (!input_values.has_shape()) {
    999     return;
   1000   }
   1002   const auto& input_values_shape = input_values.shape();
   1003   auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
   1004   auto output_values_dims = output_values.mutable_shape()->mutable_dims();
   1005   for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
   1006     output_indexes_dims->push_back(input_values_shape.dims(dim));
   1007     output_values_dims->push_back(input_values_shape.dims(dim));
   1008   }
   1009   // If the value is initialized, we can specify the last dimension, otherwise
   1010   // unknown.
   1011   if (input_k.buffer) {
   1012     const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
   1013     output_indexes_dims->push_back(k_value);
   1014     output_values_dims->push_back(k_value);
   1016   } else {
   1017     output_indexes_dims->push_back(0);
   1018     output_values_dims->push_back(0);
   1019   }
   1020 }
   1022 void ProcessPadOperator(Model* model, PadOperator* op) {
   1023   CHECK_EQ(op->inputs.size(), 2);
   1024   CHECK_EQ(op->outputs.size(), 1);
   1026   const auto& input_array = model->GetArray(op->inputs[0]);
   1028   // Yield until input dims have been resolved.
   1029   if (!input_array.has_shape()) return;
   1031   if (op->left_padding.empty()) return;
   1032   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
   1034   auto& output_array = model->GetArray(op->outputs[0]);
   1035   if (output_array.has_shape()) return;
   1037   Shape output_shape = input_array.shape();
   1038   std::vector<int>& dims = *output_shape.mutable_dims();
   1039   CHECK_EQ(op->left_padding.size(), dims.size());
   1041   for (int i = 0; i < op->left_padding.size(); ++i) {
   1042     dims[i] += op->left_padding[i] + op->right_padding[i];
   1043   }
   1045   output_array.copy_shape(output_shape);
   1046 }
   1048 void ProcessRankOperator(Model* model, RankOperator* op) {
   1049   CHECK_GE(op->inputs.size(), 1);
   1050   CHECK_EQ(op->outputs.size(), 1);
   1051   auto& output_array = model->GetArray(op->outputs[0]);
   1052   if (output_array.has_shape()) {
   1053     // Shape already propagated
   1054     return;
   1055   }
   1057   const auto& input_array = model->GetArray(op->inputs[0]);
   1058   if (!input_array.has_shape()) {
   1059     // Yield until input dims have been resolved.
   1060     return;
   1061   }
   1063   // Only set the output shape. Array contents are set by
   1064   // ResolveConstantShapeOrRank.
   1065   Shape* output_shape = output_array.mutable_shape();
   1066   output_shape->ReplaceDims({});
   1067 }
   1069 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
   1070   CHECK_GE(op->inputs.size(), 1);
   1071   CHECK_EQ(op->outputs.size(), 1);
   1072   auto& output_array = model->GetArray(op->outputs[0]);
   1073   if (output_array.has_shape()) {
   1074     // Shape already propagated
   1075     return;
   1076   }
   1078   const auto& input_array = model->GetArray(op->inputs[0]);
   1079   if (!input_array.has_shape()) {
   1080     // Yield until input dims have been resolved.
   1081     return;
   1082   }
   1084   // Only set the output shape. Array contents are set by
   1085   // ResolveConstantShapeOrRank.
   1086   Shape* output_shape = output_array.mutable_shape();
   1087   output_shape->ReplaceDims({input_array.shape().dimensions_count()});
   1088 }
   1090 void ProcessStackOperator(Model* model, StackOperator* op) {
   1091   CHECK_GE(op->inputs.size(), 1);
   1092   CHECK_EQ(op->outputs.size(), 1);
   1093   auto& output_array = model->GetArray(op->outputs[0]);
   1094   if (output_array.has_shape()) {
   1095     // Shape already propagated
   1096     return;
   1097   }
   1099   std::unique_ptr<Shape> stacked_shape;
   1100   for (const auto& input : op->inputs) {
   1101     const auto& input_array = model->GetArray(input);
   1102     if (!input_array.has_shape()) {
   1103       // Yield until all input dims have been resolved.
   1104       return;
   1105     }
   1107     Shape shape = input_array.shape();
   1108     if (shape.dimensions_count() == 0) {
   1109       // Convert 0D scalars to 1D scalars of shape {1}.
   1110       shape.mutable_dims()->push_back(1);
   1111     }
   1112     if (!stacked_shape) {
   1113       stacked_shape.reset(new Shape(shape));
   1114     } else {
   1115       CHECK(*stacked_shape == shape) << "All input arrays to Stack operators "
   1116                                         "must have the same shape. Input \""
   1117                                      << input << "\" is different.";
   1118     }
   1119   }
   1121   int axis = op->axis;
   1122   if (axis < 0) {
   1123     // Handle negative axis
   1124     axis += stacked_shape->dims().size() + 1;
   1125   }
   1126   stacked_shape->mutable_dims()->insert(
   1127       stacked_shape->mutable_dims()->begin() + axis, op->inputs.size());
   1128   output_array.copy_shape(*stacked_shape);
   1129 }
   1131 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
   1132   CHECK_GE(op->inputs.size(), 1);
   1133   CHECK_EQ(op->outputs.size(), 1);
   1134   auto& output_array = model->GetArray(op->outputs[0]);
   1135   if (output_array.has_shape()) {
   1136     // Shape already propagated
   1137     return;
   1138   }
   1140   if (op->start_indices.empty() || op->stop_indices.empty() ||
   1141       op->strides.empty()) {
   1142     // ResolveStridedSliceAttributes has not run yet.
   1143     return;
   1144   }
   1146   const auto& input_array = model->GetArray(op->inputs[0]);
   1147   if (!input_array.has_shape()) {
   1148     // Yield until input dims have been resolved.
   1149     return;
   1150   }
   1152   if (op->ellipsis_mask != 0) {
   1153     // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
   1154     // log noise. However, the TensorFlow logging library does not appear to
   1155     // support this.
   1156     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
   1157                  << "\". ellipsis_mask is not supported (mask="
   1158                  << op->ellipsis_mask << ")";
   1159     return;
   1160   }
   1161   if (op->new_axis_mask != 0) {
   1162     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
   1163                  << "\". new_axis_mask is not supported (mask="
   1164                  << op->new_axis_mask << ")";
   1165     return;
   1166   }
   1168   int dim_count = input_array.shape().dimensions_count();
   1169   CHECK(op->start_indices.size() == dim_count)
   1170       << ": Incorrect number of start indices supplied to StridedSlice op with "
   1171          "output \""
   1172       << op->outputs[0] << "\". Op requires " << dim_count << " start indices";
   1173   CHECK(op->stop_indices.size() == dim_count)
   1174       << ": Incorrect number of stop indices supplied to StridedSlice op with "
   1175          "output \""
   1176       << op->outputs[0] << "\". Op requires " << dim_count << " stop indices";
   1177   CHECK(op->strides.size() == dim_count)
   1178       << ": Incorrect number of strides supplied to StridedSlice op with "
   1179          " output \""
   1180       << op->outputs[0] << "\". Op requires " << dim_count << " strides";
   1182   // Create output shape
   1183   std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
   1185   // Compute output shape
   1186   for (int i = 0; i < dim_count; ++i) {
   1187     const int mask = 1 << i;
   1188     int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
   1189     if (start < 0) {
   1190       // handle negative indices
   1191       start += input_array.shape().dims(i);
   1192     }
   1193     int stop = (op->end_mask & mask) ? input_array.shape().dims(i)
   1194                                      : op->stop_indices[i];
   1195     if (stop < 0) {
   1196       // handle negative indices
   1197       stop += input_array.shape().dims(i);
   1198     }
   1200     int dim_size = ceil((stop - start) / static_cast<float>(op->strides[i]));
   1201     dim_size = dim_size < 0 ? 0 : dim_size;
   1202     if (op->shrink_axis_mask & mask) {
   1203       CHECK_EQ(dim_size, 1) << "Output size for an axis must compute to 1 when "
   1204                                "shrinking that axis";
   1205     } else {
   1206       dims->push_back(dim_size);
   1207     }
   1208   }
   1209 }
   1211 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
   1212   CHECK_EQ(op->inputs.size(), 1);
   1213   CHECK_EQ(op->outputs.size(), 1);
   1215   const auto& input_array = model->GetArray(op->inputs[0]);
   1217   // Yield until input dims have been resolved.
   1218   if (!input_array.has_shape()) return;
   1220   auto& output_array = model->GetArray(op->outputs[0]);
   1221   if (output_array.has_shape()) return;
   1223   const std::vector<int>& input_dims = input_array.shape().dims();
   1224   std::vector<int> output_dims;
   1226   for (int i = 0; i < input_dims.size(); ++i) {
   1227     if (input_dims[i] != 1 ||
   1228         (!op->squeeze_dims.empty() &&
   1229          std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
   1230              op->squeeze_dims.end())) {
   1231       output_dims.push_back(input_dims[i]);
   1232     }
   1233   }
   1234   *output_array.mutable_shape()->mutable_dims() = output_dims;
   1235 }
   1237 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
   1238   CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
   1239   const auto& input_array = model->GetArray(op->inputs[0]);
   1240   if (!input_array.has_shape()) return;
   1242   auto& weights_feature_array = model->GetArray(op->inputs[1]);
   1243   if (!weights_feature_array.has_shape()) return;
   1245   const auto& weights_time_array = model->GetArray(op->inputs[2]);
   1246   if (!weights_time_array.has_shape()) return;
   1248   const bool has_bias = (op->inputs.size() == 4);
   1249   if (has_bias) {
   1250     const auto& bias_array = model->GetArray(op->inputs[3]);
   1251     if (!bias_array.has_shape()) return;
   1252   }
   1254   const int batch_size = input_array.shape().dims()[0];
   1255   const int num_units = weights_feature_array.shape().dims()[0];
   1256   const int memory_size = weights_time_array.shape().dims()[1];
   1258   auto& state_array = model->GetArray(op->outputs[0]);
   1259   state_array.mutable_shape()->ReplaceDims(
   1260       {batch_size, memory_size * num_units});
   1262   auto& output_array = model->GetArray(op->outputs[1]);
   1263   output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
   1264 }
   1266 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
   1267   auto& output_array = model->GetArray(op->outputs[0]);
   1268   if (output_array.has_shape()) {
   1269     // We have already run
   1270     return;
   1271   }
   1273   const auto& input_array = model->GetArray(op->inputs[0]);
   1274   if (!input_array.has_shape()) {
   1275     // Yield until input dims have been resolved.
   1276     return;
   1277   }
   1278   const auto& input_shape = input_array.shape();
   1280   auto& perm_array = model->GetArray(op->inputs[1]);
   1281   if (!perm_array.has_shape()) {
   1282     // Yield until permutation shape been resolved.
   1283     return;
   1284   }
   1285   if (!perm_array.buffer) {
   1286     // Yield until the permutation is constant
   1287     return;
   1288   }
   1289   CHECK(perm_array.data_type == ArrayDataType::kInt32)
   1290       << "Transpose permutation input must be int32";
   1292   std::vector<int32> const& perm =
   1293       perm_array.GetBuffer<ArrayDataType::kInt32>().data;
   1294   CHECK_EQ(perm.size(), input_shape.dimensions_count())
   1295       << "Transpose permutation input " << op->inputs[0]
   1296       << " must be same length as input dimensions";
   1297   std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
   1298   for (int i = 0; i < perm.size(); i++) {
   1299     int axis = perm[i];
   1300     CHECK_GE(axis, 0);
   1301     CHECK_LT(axis, input_shape.dimensions_count());
   1302     output_dims->push_back(input_shape.dims(axis));
   1303   }
   1304 }
   1306 void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
   1307   CHECK_EQ(op->inputs.size(), 2);
   1308   const auto& input_array = model->GetArray(op->inputs[0]);
   1309   // Yield until input dims have been resolved.
   1310   if (!input_array.has_shape()) {
   1311     return;
   1312   }
   1314   // The current ArgMax implementation only supports 4-dimensional inputs with
   1315   // the last dimension as the axis to perform ArgMax for.
   1316   const std::vector<int>& input_dims = input_array.shape().dims();
   1317   CHECK_EQ(input_dims.size(), 4);
   1318   std::vector<int> output_dims;
   1320   output_dims.reserve(input_dims.size() - 1);
   1321   for (int i = 0; i < input_dims.size() - 1; ++i) {
   1322     output_dims.push_back(input_dims[i]);
   1323   }
   1324   output_dims.push_back(1);
   1325   const string& output_name = op->outputs[0];
   1326   auto& output_array = model->GetArray(output_name);
   1327   if (output_array.has_shape()) {
   1328     return;
   1329   }
   1330   *output_array.mutable_shape()->mutable_dims() = output_dims;
   1331 }
   1333 }  // namespace
   1335 bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
   1336   auto it = model->operators.begin() + op_index;
   1337   auto* op = it->get();
   1338   std::unordered_map<string, std::vector<int>> old_output_dims;
   1339   for (const auto& output : op->outputs) {
   1340     if (model->GetArray(output).has_shape()) {
   1341       old_output_dims[output] = model->GetArray(output).shape().dims();
   1342     }
   1343   }
   1345   switch (op->type) {
   1346     case OperatorType::kBatchNormalization:
   1347     case OperatorType::kL2Normalization:
   1348     case OperatorType::kDequantize:
   1349     case OperatorType::kRelu:
   1350     case OperatorType::kRelu1:
   1351     case OperatorType::kRelu6:
   1352     case OperatorType::kSoftmax:
   1353     case OperatorType::kLogSoftmax:
   1354     case OperatorType::kLogistic:
   1355     case OperatorType::kTanh:
   1356     case OperatorType::kLocalResponseNormalization:
   1357     case OperatorType::kTensorFlowIdentity:
   1358     case OperatorType::kFakeQuant:
   1359     case OperatorType::kNeg:
   1360     case OperatorType::kTensorFlowRsqrt:
   1361     case OperatorType::kTensorFlowSqrt:
   1362     case OperatorType::kTensorFlowSquare:
   1363     case OperatorType::kTensorFlowAll:
   1364     case OperatorType::kTensorFlowAssert:
   1365     case OperatorType::kCast:
   1366     case OperatorType::kFloor:
   1367     case OperatorType::kExp:
   1368       ProcessSimpleOperator(model, op);
   1369       break;
   1370     case OperatorType::kGather:
   1371       ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
   1372       break;
   1373     case OperatorType::kTopK_V2:
   1374       ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
   1375       break;
   1376     case OperatorType::kAdd:
   1377     case OperatorType::kSub:
   1378     case OperatorType::kMul:
   1379     case OperatorType::kDiv:
   1380     case OperatorType::kFloorDiv:
   1381     case OperatorType::kFloorMod:
   1382     case OperatorType::kTensorFlowLess:
   1383     case OperatorType::kTensorFlowLessEqual:
   1384     case OperatorType::kTensorFlowGreater:
   1385     case OperatorType::kTensorFlowMaximum:
   1386     case OperatorType::kTensorFlowMinimum:
   1387     case OperatorType::kTensorFlowGreaterEqual:
   1388       ProcessSimpleBinaryOperator(model, op);
   1389       break;
   1390     case OperatorType::kAddN:
   1391       ProcessAddNOperator(model, op);
   1392       break;
   1393     case OperatorType::kConv:
   1394       ProcessConvOperator(model, static_cast<ConvOperator*>(op));
   1395       break;
   1396     case OperatorType::kTransposeConv:
   1397       // Unimplemented, hopefully another graph transformation will drop it or
   1398       // rewrite it.
   1399       break;
   1400     case OperatorType::kDepthwiseConv:
   1401       ProcessDepthwiseConvOperator(model,
   1402                                    static_cast<DepthwiseConvOperator*>(op));
   1403       break;
   1404     case OperatorType::kDepthToSpace:
   1405       ProcessDepthToSpaceOperator(model,
   1406                                   static_cast<DepthToSpaceOperator*>(op));
   1407       break;
   1408     case OperatorType::kSpaceToDepth:
   1409       ProcessSpaceToDepthOperator(model,
   1410                                   static_cast<SpaceToDepthOperator*>(op));
   1411       break;
   1412     case OperatorType::kFill:
   1413       ProcessFillOperator(model, static_cast<FillOperator*>(op));
   1414       break;
   1415     case OperatorType::kFullyConnected:
   1416       ProcessFullyConnectedOperator(model,
   1417                                     static_cast<FullyConnectedOperator*>(op));
   1418       break;
   1419     case OperatorType::kTensorFlowReshape:
   1420       ProcessTensorFlowReshapeOperator(
   1421           model, static_cast<TensorFlowReshapeOperator*>(op));
   1422       break;
   1423     case OperatorType::kAveragePool:
   1424       ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
   1425       break;
   1426     case OperatorType::kMaxPool:
   1427       ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
   1428       break;
   1429     case OperatorType::kL2Pool:
   1430       ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
   1431       break;
   1432     case OperatorType::kTensorFlowMin:
   1433     case OperatorType::kTensorFlowMax:
   1434     case OperatorType::kTensorFlowSum:
   1435     case OperatorType::kMean:
   1436       ProcessTensorFlowReductionOperator(model, op);
   1437       break;
   1439     case OperatorType::kSlice:
   1440       ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
   1441       break;
   1443     case OperatorType::kTensorFlowTile:
   1444       // We don't currently implement the propagation of fixed sizes through
   1445       // a TensorFlow Tile.
   1446       //
   1447       // Fortunately, we don't need to: so far, we have only dealt with Tile
   1448       // or Slice ops in subgraphs that are identified as L2Normalization.
   1449       // See IdentifyL2Normalization.
   1450       break;
   1451     case OperatorType::kTensorFlowSwitch:
   1452       // We can't know the sizes of the outputs until we have resolved the
   1453       // predicate, and once we have resolved the predicate, the whole
   1454       // Switch node will get resolved away.
   1455       // See ResolveTensorFlowSwitch.
   1456       break;
   1457     case OperatorType::kTensorFlowMerge:
   1458       // No need to bother resolving TensorFlow Merge ops: other graph
   1459       // transformations will remove them anyway.
   1460       // See ResolveTensorFlowMerge.
   1461       break;
   1462     case OperatorType::kTensorFlowSplit:
   1463       ProcessTensorFlowSplitOperator(model,
   1464                                      static_cast<TensorFlowSplitOperator*>(op));
   1465       break;
   1466     case OperatorType::kSqueeze:
   1467       ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
   1468       break;
   1469     case OperatorType::kTensorFlowConcat:
   1470     case OperatorType::kTensorFlowConcatV2:
   1471       // Unimplemented, hopefully another graph transformation will
   1472       // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
   1473       // will resolve this node to a DepthConcatenation, or else we have
   1474       // a more general non-depth concatenation that will hopefully be dropped,
   1475       // or else at the moment we will abort.
   1476       break;
   1477     case OperatorType::kExpandDims:
   1478       // Yield until ExpandDims is converted to Reshape
   1479       break;
   1480     case OperatorType::kRange:
   1481       ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
   1482       break;
   1483     case OperatorType::kRank:
   1484       ProcessRankOperator(model, static_cast<RankOperator*>(op));
   1485       break;
   1486     case OperatorType::kTensorFlowShape:
   1487       ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
   1488       break;
   1489     case OperatorType::kStack:
   1490       ProcessStackOperator(model, static_cast<StackOperator*>(op));
   1491       break;
   1492     case OperatorType::kReorderAxes:
   1493       ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
   1494       break;
   1495     case OperatorType::kConcatenation:
   1496       ProcessConcatenationOperator(model,
   1497                                    static_cast<ConcatenationOperator*>(op));
   1498       break;
   1499     case OperatorType::kResizeBilinear:
   1500       ProcessResizeBilinearOperator(model,
   1501                                     static_cast<ResizeBilinearOperator*>(op));
   1502       break;
   1503     case OperatorType::kLstmCell:
   1504       ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
   1505       break;
   1506     case OperatorType::kBatchMatMul:
   1507     case OperatorType::kTensorFlowMatMul:
   1508       // MatMul operators are converted to FullyConnected, after which their
   1509       // shapes are propagated.
   1510       break;
   1511     case OperatorType::kSpaceToBatchND:
   1512       ProcessSpaceToBatchNDOperator(model,
   1513                                     static_cast<SpaceToBatchNDOperator*>(op));
   1514       break;
   1515     case OperatorType::kBatchToSpaceND:
   1516       ProcessBatchToSpaceNDOperator(model,
   1517                                     static_cast<BatchToSpaceNDOperator*>(op));
   1518       break;
   1519     case OperatorType::kPad:
   1520       ProcessPadOperator(model, static_cast<PadOperator*>(op));
   1521       break;
   1522     case OperatorType::kStridedSlice:
   1523       ProcessStridedSliceOperator(model,
   1524                                   static_cast<StridedSliceOperator*>(op));
   1525       break;
   1526     case OperatorType::kArgMax:
   1527       ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
   1528       break;
   1529     case OperatorType::kTensorFlowUnsupported:
   1530       break;
   1531     case OperatorType::kSvdf:
   1532       ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
   1533       break;
   1534     case OperatorType::kTranspose:
   1535       ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
   1536       break;
   1537     default:
   1538       // Unimplemented, another graph transformation should drop it.
   1539       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
   1540   }
   1542   // Return true if any output dim changed, false if none changed.
   1543   // Assumption: no transformation clears an output shape, they only add shapes.
   1544   for (const auto& output : op->outputs) {
   1545     if (model->GetArray(output).has_shape() &&
   1546         (old_output_dims[output] != model->GetArray(output).shape().dims())) {
   1547       AddMessageF("Set shape of %s to [%s]", output,
   1548                   absl::StrJoin(model->GetArray(output).shape().dims(), ","));
   1549       return true;
   1550     }
   1551   }
   1552   return false;
   1553 }
   1555 }  // namespace toco