Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     16 
     17 #include <functional>
     18 #include <iterator>
     19 #include <set>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <utility>
     23 
     24 #include "absl/strings/ascii.h"
     25 #include "absl/strings/str_cat.h"
     26 #include "absl/strings/str_join.h"
     27 #include "absl/strings/str_replace.h"
     28 #include "absl/strings/str_split.h"
     29 #include "tensorflow/contrib/lite/toco/dump_graphviz.h"
     30 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
     31 #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
     32 #include "tensorflow/contrib/lite/toco/toco_port.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 
     35 namespace toco {
     36 
     37 // Find the longest common prefix of two strings.
     38 absl::string_view FindLongestCommonPrefix(absl::string_view a,
     39                                           absl::string_view b) {
     40   if (a.empty() || b.empty()) return absl::string_view();
     41 
     42   const char* pa = a.data();
     43   const char* pb = b.data();
     44   size_t count = 0;
     45   const size_t limit = std::min(a.size(), b.size());
     46   while (count < limit && *pa == *pb) {
     47     ++pa;
     48     ++pb;
     49     ++count;
     50   }
     51 
     52   return absl::string_view(a.data(), count);
     53 }
     54 
     55 string LogName(const Operator& op) {
     56   const string& opname = HelpfulOperatorTypeName(op);
     57   if (op.outputs.empty()) {
     58     return toco::port::StringF("{%s operator}", opname);
     59   } else {
     60     return toco::port::StringF("{%s operator with output %s}", opname,
     61                                op.outputs[0]);
     62   }
     63 }
     64 
     65 bool IsInputArray(const Model& model, const string& name) {
     66   for (const auto& input_array : model.flags.input_arrays()) {
     67     if (input_array.name() == name) {
     68       return true;
     69     }
     70   }
     71   return false;
     72 }
     73 
     74 bool IsArrayConsumed(const Model& model, const string& name) {
     75   if (GetOpWithInput(model, name)) {
     76     return true;
     77   }
     78   for (const string& model_output : model.flags.output_arrays()) {
     79     if (model_output == name) {
     80       return true;
     81     }
     82   }
     83   for (const auto& rnn_state : model.flags.rnn_states()) {
     84     if (rnn_state.back_edge_source_array() == name) {
     85       return true;
     86     }
     87   }
     88   return false;
     89 }
     90 
     91 int CountTrueOutputs(const Model& model, const Operator& op) {
     92   int count = 0;
     93   for (const string& output : op.outputs) {
     94     if (IsArrayConsumed(model, output)) {
     95       ++count;
     96     }
     97   }
     98   return count;
     99 }
    100 
    101 int CountOpsWithInput(const Model& model, const string& array_name) {
    102   int count = 0;
    103   for (const auto& op : model.operators) {
    104     for (auto& input : op->inputs) {
    105       if (input == array_name) {
    106         count++;
    107       }
    108     }
    109   }
    110   return count;
    111 }
    112 
    113 bool DeleteArrayIfUnused(const string& array_name, Model* model) {
    114   if (IsDiscardableArray(*model, array_name) &&
    115       CountOpsWithInput(*model, array_name) == 0) {
    116     model->EraseArray(array_name);
    117     return true;
    118   }
    119   return false;
    120 }
    121 
    122 bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) {
    123   if (IsDiscardableArray(*model, array_name) &&
    124       CountOpsWithInput(*model, array_name) == 1) {
    125     model->EraseArray(array_name);
    126     return true;
    127   }
    128   return false;
    129 }
    130 
    131 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
    132     const Model& model, const string& array_name) {
    133   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    134     for (auto& output : it->get()->outputs) {
    135       if (output == array_name) {
    136         return it;
    137       }
    138     }
    139   }
    140   return model.operators.end();
    141 }
    142 
    143 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
    144     Model& model, const string& array_name) {
    145   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    146     for (auto& output : it->get()->outputs) {
    147       if (output == array_name) {
    148         return it;
    149       }
    150     }
    151   }
    152   return model.operators.end();
    153 }
    154 
    155 Operator* GetOpWithOutput(const Model& model, const string& array_name) {
    156   auto it = FindOpWithOutput(model, array_name);
    157   return it == model.operators.end() ? nullptr : it->get();
    158 }
    159 
    160 // GetFirstOpWithInput assumes that this finds the first op.
    161 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
    162     const Model& model, const string& array_name) {
    163   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    164     for (auto& input : it->get()->inputs) {
    165       if (input == array_name) {
    166         return it;
    167       }
    168     }
    169   }
    170   return model.operators.end();
    171 }
    172 
    173 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
    174     Model& model, const string& array_name) {
    175   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    176     for (auto& input : it->get()->inputs) {
    177       if (input == array_name) {
    178         return it;
    179       }
    180     }
    181   }
    182   return model.operators.end();
    183 }
    184 
    185 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
    186     const Model& model, const Operator* op) {
    187   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    188     if (it->get() == op) {
    189       return it;
    190     }
    191   }
    192   return model.operators.end();
    193 }
    194 
    195 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
    196                                                         const Operator* op) {
    197   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
    198     if (it->get() == op) {
    199       return it;
    200     }
    201   }
    202   return model.operators.end();
    203 }
    204 
    205 Operator* GetOpWithInput(const Model& model, const string& array_name) {
    206   auto it = FindOpWithInput(model, array_name);
    207   return it == model.operators.end() ? nullptr : it->get();
    208 }
    209 
    210 Operator* GetFirstOpWithInput(const Model& model, const string& array_name) {
    211   auto it = FindOpWithInput(model, array_name);
    212   return it == model.operators.end() ? nullptr : it->get();
    213 }
    214 
    215 string FormatArraysList(const Model& model, const std::vector<string>& list) {
    216   if (list.empty()) {
    217     return "[]";
    218   }
    219   string result = "";
    220   if (list.size() > 1) {
    221     result += "[ ";
    222   }
    223   for (std::size_t i = 0; i < list.size(); i++) {
    224     if (i > 0) {
    225       result += ", ";
    226     }
    227     result += list[i];
    228   }
    229   if (list.size() > 1) {
    230     result += " ]";
    231   }
    232   return result;
    233 }
    234 
    235 const char* OperatorTypeName(OperatorType type) {
    236   switch (type) {
    237 #define HANDLE_OPERATORTYPENAME_CASE(c) \
    238   case OperatorType::k##c:              \
    239     return #c;
    240     HANDLE_OPERATORTYPENAME_CASE(Add)
    241     HANDLE_OPERATORTYPENAME_CASE(AddN)
    242     HANDLE_OPERATORTYPENAME_CASE(AveragePool)
    243     HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
    244     HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
    245     HANDLE_OPERATORTYPENAME_CASE(Conv)
    246     HANDLE_OPERATORTYPENAME_CASE(Concatenation)
    247     HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
    248     HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
    249     HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
    250     HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
    251     HANDLE_OPERATORTYPENAME_CASE(Dequantize)
    252     HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
    253     HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
    254     HANDLE_OPERATORTYPENAME_CASE(Logistic)
    255     HANDLE_OPERATORTYPENAME_CASE(LstmCell)
    256     HANDLE_OPERATORTYPENAME_CASE(MaxPool)
    257     HANDLE_OPERATORTYPENAME_CASE(L2Pool)
    258     HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
    259     HANDLE_OPERATORTYPENAME_CASE(Mul)
    260     HANDLE_OPERATORTYPENAME_CASE(Relu)
    261     HANDLE_OPERATORTYPENAME_CASE(Relu1)
    262     HANDLE_OPERATORTYPENAME_CASE(Relu6)
    263     HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
    264     HANDLE_OPERATORTYPENAME_CASE(Softmax)
    265     HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
    266     HANDLE_OPERATORTYPENAME_CASE(Div)
    267     HANDLE_OPERATORTYPENAME_CASE(Tanh)
    268     HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
    269     HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
    270     HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
    271     HANDLE_OPERATORTYPENAME_CASE(Fill)
    272     HANDLE_OPERATORTYPENAME_CASE(FloorMod)
    273     HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
    274     HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
    275     HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
    276     HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
    277     HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
    278     HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
    279     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
    280     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
    281     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
    282     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
    283     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
    284     HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
    285     HANDLE_OPERATORTYPENAME_CASE(Neg)
    286     HANDLE_OPERATORTYPENAME_CASE(Pad)
    287     HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
    288     HANDLE_OPERATORTYPENAME_CASE(Stack)
    289     HANDLE_OPERATORTYPENAME_CASE(Range)
    290     HANDLE_OPERATORTYPENAME_CASE(Rank)
    291     HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
    292     HANDLE_OPERATORTYPENAME_CASE(Squeeze)
    293     HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
    294     HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
    295     HANDLE_OPERATORTYPENAME_CASE(Slice)
    296     HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
    297     HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
    298     HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
    299     HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
    300     HANDLE_OPERATORTYPENAME_CASE(Sub)
    301     HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
    302     HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
    303     HANDLE_OPERATORTYPENAME_CASE(Transpose)
    304     HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
    305     HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
    306     HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
    307     HANDLE_OPERATORTYPENAME_CASE(Cast)
    308     HANDLE_OPERATORTYPENAME_CASE(Floor)
    309     HANDLE_OPERATORTYPENAME_CASE(Gather)
    310     HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
    311     HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
    312     HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
    313     HANDLE_OPERATORTYPENAME_CASE(Mean)
    314     HANDLE_OPERATORTYPENAME_CASE(Svdf)
    315     HANDLE_OPERATORTYPENAME_CASE(ArgMax)
    316     HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
    317     HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
    318     HANDLE_OPERATORTYPENAME_CASE(Exp)
    319     default:
    320       LOG(FATAL) << "Unhandled op type";
    321 #undef HANDLE_OPERATORTYPENAME_CASE
    322   }
    323 }
    324 
    325 string HelpfulOperatorTypeName(const Operator& op) {
    326   if (op.type == OperatorType::kTensorFlowUnsupported) {
    327     return toco::port::StringF(
    328         "(Unsupported TensorFlow op: %s)",
    329         static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
    330   }
    331   return OperatorTypeName(op.type);
    332 }
    333 
    334 bool OperatorSupportsFusedActivation(OperatorType type) {
    335   switch (type) {
    336     case OperatorType::kConcatenation:
    337     case OperatorType::kGather:
    338     case OperatorType::kSlice:
    339     case OperatorType::kSqueeze:
    340     case OperatorType::kTensorFlowReshape:
    341     case OperatorType::kTensorFlowSplit:
    342       return false;
    343     default:
    344       return true;
    345   }
    346 }
    347 
    348 void LogSummary(int log_level, const Model& model) {
    349   VLOG(log_level) << "Operators summary (" << model.operators.size()
    350                   << " operators):";
    351   std::unordered_multiset<OperatorType> ops_by_type;
    352   for (const auto& op : model.operators) {
    353     ops_by_type.insert(op->type);
    354   }
    355   auto it = ops_by_type.begin();
    356   while (it != ops_by_type.end()) {
    357     int count = ops_by_type.count(*it);
    358     VLOG(log_level) << "    " << OperatorTypeName(*it) << ": " << count;
    359     std::advance(it, count);
    360   }
    361 }
    362 
    363 void LogArray(int log_level, const Model& model, const string& name) {
    364   const auto& array = model.GetArray(name);
    365   VLOG(log_level) << "Array: " << name;
    366   switch (array.data_type) {
    367     case ArrayDataType::kNone:
    368       VLOG(log_level) << "  Data type:";
    369       break;
    370     case ArrayDataType::kFloat:
    371       VLOG(log_level) << "  Data type: kFloat";
    372       break;
    373     case ArrayDataType::kInt32:
    374       VLOG(log_level) << "  Data type: kInt32";
    375       break;
    376     case ArrayDataType::kUint8:
    377       VLOG(log_level) << "  Data type: kUint8";
    378       break;
    379     case ArrayDataType::kString:
    380       VLOG(log_level) << "  Data type: kString";
    381       break;
    382     default:
    383       VLOG(log_level) << "  Data type: other (numerical value: "
    384                       << static_cast<int>(array.data_type) << ")";
    385       break;
    386   }
    387   switch (array.final_data_type) {
    388     case ArrayDataType::kNone:
    389       VLOG(log_level) << "  Final type:";
    390       break;
    391     case ArrayDataType::kFloat:
    392       VLOG(log_level) << "  Final type: kFloat";
    393       break;
    394     case ArrayDataType::kInt32:
    395       VLOG(log_level) << "  Final type: kInt32";
    396       break;
    397     case ArrayDataType::kUint8:
    398       VLOG(log_level) << "  Final type: kUint8";
    399       break;
    400     case ArrayDataType::kString:
    401       VLOG(log_level) << "  Final type: kString";
    402       break;
    403     default:
    404       VLOG(log_level) << "  Final type: other (numerical value: "
    405                       << static_cast<int>(array.data_type) << ")";
    406       break;
    407   }
    408   if (array.buffer) {
    409     VLOG(log_level) << "  Constant Buffer";
    410   }
    411   if (array.alloc) {
    412     VLOG(log_level) << "  Transient Alloc";
    413   }
    414   if (array.has_shape()) {
    415     const Shape& array_shape = array.shape();
    416     if (array_shape.dimensions_count() == 0) {
    417       VLOG(log_level) << "  (Zero dimensions)";
    418     } else {
    419       string message = "  Dims: ";
    420       bool first = true;
    421       for (const int dim : array_shape.dims()) {
    422         if (!first) {
    423           message += ", ";
    424         }
    425         first = false;
    426         toco::port::AppendF(&message, "%d", dim);
    427       }
    428       VLOG(log_level) << message;
    429     }
    430   }
    431   if (array.minmax) {
    432     VLOG(log_level) << "  MinMax: " << array.minmax->min << " .. "
    433                     << array.minmax->max;
    434   }
    435   if (array.quantization_params) {
    436     VLOG(log_level) << "  QuantizationParams: zero_point="
    437                     << static_cast<int>(array.quantization_params->zero_point)
    438                     << ", scale=" << array.quantization_params->scale;
    439   }
    440 }
    441 
    442 void DumpGraphvizVideoFrame(const Model& model) {
    443   namespace port = toco::port;
    444 
    445   const auto& dump_options = *GraphVizDumpOptions::singleton();
    446   if (!dump_options.dump_graphviz_video) {
    447     return;
    448   }
    449   CHECK(!dump_options.dump_graphviz.empty());
    450   // TODO(benoitjacob): the static data here means that this function
    451   // is stateful, not reentrant, and effectively leaks memory till exit
    452   // (since dump_hashes can only grow in size). It also means that it
    453   // really only is intended to be called for a single model during the
    454   // process' lifetime. So it's not great design at all. The overriding
    455   // design aspect here is to make the video-dumping code as unintrusive
    456   // and self-contained as possible. Eventually, we'll want to have that
    457   // cleaned-up, but that will require some form of general statefulness
    458   // in toco (some kind of 'tooling state' data structure) that does
    459   // not exist at present, and would be premature to design here just for
    460   // this new video-dumping feature.
    461   static int dump_id = 0;
    462   static std::unordered_set<std::size_t> dump_hashes;
    463   string graphviz_dump;
    464   DumpGraphviz(model, &graphviz_dump);
    465   std::size_t hash = std::hash<string>{}(graphviz_dump);
    466   if (!dump_hashes.count(hash)) {
    467     LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
    468     dump_hashes.insert(hash);
    469     CHECK(port::file::SetContents(
    470               port::file::JoinPath(
    471                   dump_options.dump_graphviz,
    472                   toco::port::StringF("toco_video_%05d.dot", dump_id)),
    473               graphviz_dump, port::file::Defaults())
    474               .ok());
    475     dump_id++;
    476   }
    477 }
    478 
    479 void LogDump(int log_level, const string& message, const Model& model) {
    480   namespace port = toco::port;
    481   const auto& dump_options = *GraphVizDumpOptions::singleton();
    482 
    483   DumpGraphvizVideoFrame(model);
    484   if (!dump_options.dump_graphviz.empty()) {
    485     string graphviz_dump;
    486 
    487     DumpGraphviz(model, &graphviz_dump);
    488     CHECK(port::file::SetContents(
    489               port::file::JoinPath(
    490                   dump_options.dump_graphviz,
    491                   absl::StrCat("toco_",
    492                                absl::StrReplaceAll(message, {{" ", "_"}}),
    493                                ".dot")),
    494               graphviz_dump, port::file::Defaults())
    495               .ok());
    496   }
    497 
    498   if (!VLOG_IS_ON(log_level)) {
    499     return;
    500   }
    501   VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
    502   LogSummary(log_level, model);
    503   std::unordered_set<string> already_printed_arrays;
    504   for (const auto& op : model.operators) {
    505     for (const auto& input : op->inputs) {
    506       if (!already_printed_arrays.count(input)) {
    507         already_printed_arrays.insert(input);
    508         LogArray(log_level, model, input);
    509       }
    510     }
    511     VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
    512     VLOG(log_level) << "  " << FormatArraysList(model, op->inputs) << " -> "
    513                     << FormatArraysList(model, op->outputs);
    514     if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
    515       VLOG(log_level) << "    (with fused activation function)";
    516     }
    517     for (const auto& output : op->outputs) {
    518       if (!already_printed_arrays.count(output)) {
    519         already_printed_arrays.insert(output);
    520         LogArray(log_level, model, output);
    521       }
    522     }
    523   }
    524   VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
    525 }
    526 
    527 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
    528 void ExtendShape(Shape* shape, int new_shape_size) {
    529   CHECK_GE(new_shape_size, shape->dimensions_count());
    530   const int size_increase = new_shape_size - shape->dimensions_count();
    531   auto* shape_dims = shape->mutable_dims();
    532   shape_dims->insert(shape_dims->begin(), size_increase, 1);
    533 }
    534 
    535 // TODO(b/62904716) Remove along with remaining uses.
    536 void UnextendShape(Shape* shape, int new_shape_size) {
    537   CHECK_LE(new_shape_size, shape->dimensions_count());
    538   const int size_reduction = shape->dimensions_count() - new_shape_size;
    539   for (int i = 0; i < size_reduction; i++) {
    540     CHECK_EQ(shape->dims(i), 1);
    541   }
    542   std::vector<int>& shape_dims = *shape->mutable_dims();
    543   shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
    544 }
    545 
    546 void CheckShapeDimensions(const Shape& shape) {
    547   for (int i = 0; i < shape.dimensions_count(); ++i) {
    548     CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
    549                                  << ". shape = " << ShapeToString(shape);
    550   }
    551 }
    552 
    553 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
    554   CheckShapeDimensions(shape0);
    555   CheckShapeDimensions(shape1);
    556 
    557   const Shape* longer = &shape0;
    558   const Shape* shorter = &shape1;
    559   if (shape1.dimensions_count() > shape0.dimensions_count()) {
    560     longer = &shape1;
    561     shorter = &shape0;
    562   }
    563 
    564   // Walk dimensions back to front until we run out of dimensions in the shorter
    565   // shape.
    566   int longer_index = longer->dimensions_count() - 1;
    567   int shorter_index = shorter->dimensions_count() - 1;
    568   while (shorter_index >= 0) {
    569     const int d_long = longer->dims(longer_index);
    570     const int d_short = shorter->dims(shorter_index);
    571     // Broadcasting fails if the dimensions are different *and* neither is 1.
    572     if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
    573       return false;
    574     }
    575     longer_index--;
    576     shorter_index--;
    577   }
    578   return true;
    579 }
    580 
    581 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
    582   CheckShapeDimensions(shape0);
    583   CheckShapeDimensions(shape1);
    584 
    585   const Shape* longer = &shape0;
    586   const Shape* shorter = &shape1;
    587   if (shape1.dimensions_count() > shape0.dimensions_count()) {
    588     longer = &shape1;
    589     shorter = &shape0;
    590   }
    591 
    592   // Walk dimensions back to front until we run out of dimensions in the shorter
    593   // shape.
    594   int longer_index = longer->dimensions_count() - 1;
    595   int shorter_index = shorter->dimensions_count() - 1;
    596   while (shorter_index >= 0) {
    597     const int d_long = longer->dims(longer_index);
    598     const int d_short = shorter->dims(shorter_index);
    599     // Extending fails if the dimensions are different.
    600     if (d_long != d_short) {
    601       return false;
    602     }
    603     longer_index--;
    604     shorter_index--;
    605   }
    606 
    607   // The remaining dimensions in the longer shape must be 1.
    608   while (longer_index >= 0) {
    609     const int d_long = longer->dims(longer_index);
    610     if (d_long != 1) {
    611       return false;
    612     }
    613     longer_index--;
    614   }
    615 
    616   return true;
    617 }
    618 
    619 int RequiredBufferSizeForShape(const Shape& shape) {
    620   int max_offset = 1;
    621   for (const auto& dim : shape.dims()) {
    622     CHECK_GE(dim, 1);
    623     max_offset *= dim;
    624   }
    625   return max_offset;
    626 }
    627 
    628 bool IsConstantParameterArray(const Model& model, const string& name) {
    629   if (!model.HasArray(name)) {
    630     return false;
    631   }
    632 
    633   return !!model.GetArray(name).buffer;
    634 }
    635 
    636 namespace {
    637 // Take an array name, which may be something like "name:3_5" and make it
    638 // acceptable as a TF node name, say "name_3_5";
    639 string SanitizeNameForTFNode(const string& array_name) {
    640   auto node_name = array_name;
    641   std::replace(node_name.begin(), node_name.end(), ':', '_');
    642   return node_name;
    643 }
    644 
    645 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
    646   for (const auto& input_array : model_flags.input_arrays()) {
    647     for (const string& output_array : model_flags.output_arrays()) {
    648       QCHECK_NE(input_array.name(), output_array)
    649           << "The array " << output_array
    650           << " is listed in both --input_arrays and --output_arrays.";
    651     }
    652   }
    653 }
    654 
    655 bool IsAsciiPrintable(const string& name) {
    656   for (char c : name) {
    657     if (!absl::ascii_isprint(c)) {
    658       return false;
    659     }
    660   }
    661   return true;
    662 }
    663 
    664 string DumpAscii(const string& name) {
    665   string result;
    666   port::AppendF(&result, "ASCII | Hex\n");
    667   port::AppendF(&result, "------+----\n");
    668   for (char c : name) {
    669     if (absl::ascii_isprint(c)) {
    670       port::AppendF(&result, "%c     | %x\n", c, c);
    671     } else {
    672       port::AppendF(&result, "      | %x   Not ASCII printable!\n", c);
    673     }
    674   }
    675   return result;
    676 }
    677 
    678 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
    679   if (model_flags.allow_nonascii_arrays()) {
    680     return;
    681   }
    682   for (const auto& input_array : model_flags.input_arrays()) {
    683     QCHECK(IsAsciiPrintable(input_array.name()))
    684         << "Non-ASCII-printable character found in --input_arrays: "
    685         << input_array.name()
    686         << ". Pass --allow_nonascii_arrays to allow that. "
    687         << "Here is a dump of the string:\n\n"
    688         << DumpAscii(input_array.name());
    689   }
    690   for (const string& output_array : model_flags.output_arrays()) {
    691     QCHECK(IsAsciiPrintable(output_array))
    692         << "Non-ASCII-printable character found in --output_arrays: "
    693         << output_array << ". Pass --allow_nonascii_arrays to allow that. "
    694         << "Here is a dump of the string:\n\n"
    695         << DumpAscii(output_array);
    696   }
    697 }
    698 
    699 void CheckNonExistentIOArrays(const Model& model) {
    700   if (model.flags.allow_nonexistent_arrays()) {
    701     return;
    702   }
    703   for (const auto& input_array : model.flags.input_arrays()) {
    704     CHECK(model.HasArray(input_array.name()))
    705         << "Input array not found: " << input_array.name();
    706   }
    707   for (const string& output_array : model.flags.output_arrays()) {
    708     CHECK(model.HasArray(output_array))
    709         << "Output array not found: " << output_array;
    710   }
    711   for (const auto& rnn_state : model.flags.rnn_states()) {
    712     if (!rnn_state.discardable()) {
    713       CHECK(model.HasArray(rnn_state.state_array()));
    714       CHECK(model.HasArray(rnn_state.back_edge_source_array()));
    715     }
    716   }
    717 }
    718 }  // namespace
    719 
    720 void CheckNoMissingArray(const Model& model) {
    721   for (const auto& op : model.operators) {
    722     for (const auto& input : op->inputs) {
    723       CHECK(model.HasArray(input) || model.optional_arrays.count(input))
    724           << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
    725     }
    726     for (const auto& output : op->outputs) {
    727       CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
    728     }
    729   }
    730   CheckNonExistentIOArrays(model);
    731 }
    732 
    733 void FixNoMissingArray(Model* model) {
    734   for (const auto& op : model->operators) {
    735     for (const auto& input : op->inputs) {
    736       if (!model->HasArray(input)) {
    737         model->GetOrCreateArray(input);
    738       }
    739     }
    740     for (const auto& output : op->outputs) {
    741       if (!model->HasArray(output)) {
    742         model->GetOrCreateArray(output);
    743       }
    744     }
    745   }
    746   if (model->flags.allow_nonexistent_arrays()) {
    747     for (const string& output_array : model->flags.output_arrays()) {
    748       model->GetOrCreateArray(output_array);
    749     }
    750     for (const auto& rnn_state : model->flags.rnn_states()) {
    751       model->GetOrCreateArray(rnn_state.state_array());
    752       model->GetOrCreateArray(rnn_state.back_edge_source_array());
    753     }
    754   }
    755 }
    756 
    757 void CheckNoOrphanedArray(const Model& model) {
    758   std::unordered_set<string> arrays_without_known_use;
    759   for (const auto& array : model.GetArrayMap()) {
    760     if (IsDiscardableArray(model, array.first)) {
    761       arrays_without_known_use.insert(array.first);
    762     }
    763   }
    764   for (const auto& op : model.operators) {
    765     for (const auto& input : op->inputs) {
    766       arrays_without_known_use.erase(input);
    767     }
    768     for (const auto& output : op->outputs) {
    769       arrays_without_known_use.erase(output);
    770     }
    771   }
    772   for (const auto& rnn_state : model.flags.rnn_states()) {
    773     arrays_without_known_use.erase(rnn_state.state_array());
    774     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
    775   }
    776   if (!arrays_without_known_use.empty()) {
    777     for (const auto& array : arrays_without_known_use) {
    778       LOG(INFO) << "Error: Orphaned array: " << array;
    779     }
    780   }
    781   CHECK(arrays_without_known_use.empty());
    782 }
    783 
    784 void FixNoOrphanedArray(Model* model) {
    785   std::unordered_set<string> arrays_without_known_use;
    786   for (const auto& array : model->GetArrayMap()) {
    787     arrays_without_known_use.insert(array.first);
    788   }
    789   for (const auto& op : model->operators) {
    790     for (const auto& input : op->inputs) {
    791       arrays_without_known_use.erase(input);
    792     }
    793     for (const auto& output : op->outputs) {
    794       arrays_without_known_use.erase(output);
    795     }
    796   }
    797   for (const auto& rnn_state : model->flags.rnn_states()) {
    798     arrays_without_known_use.erase(rnn_state.state_array());
    799     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
    800   }
    801   for (const auto& array : arrays_without_known_use) {
    802     if (IsDiscardableArray(*model, array)) {
    803       model->EraseArray(array);
    804     }
    805   }
    806 }
    807 
    808 // Apply checks to arrays individually (for-each fashion).
    809 //
    810 // Check consistency of array fields, check name.
    811 void CheckEachArray(const Model& model) {
    812   for (const auto& array_entry : model.GetArrayMap()) {
    813     const auto& array = array_entry.second;
    814     if (array->has_shape()) {
    815       for (int d : array->shape().dims()) {
    816         CHECK_GE(d, 1);
    817       }
    818     }
    819     // It's OK to have a buffer or an alloc, but not both.
    820     // (Since allocs are for transient arrays without a buffer).
    821     CHECK(!array->buffer || !array->alloc);
    822     // If there is a buffer, its type should be consistent with data_type.
    823     if (array->buffer) {
    824       CHECK(array->buffer->type == array->data_type);
    825     }
    826 
    827     // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
    828     // "name_with_both:3_8".
    829     const string& name = array_entry.first;
    830     auto colon_pos = name.find_first_of(":");
    831     if (colon_pos != string::npos) {
    832       CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
    833                string::npos)
    834           << "Array name must only have digits after colon";
    835     }
    836     CHECK_GT(colon_pos, 0)
    837         << "First character of array name must not be a colon.";
    838   }
    839 }
    840 
    841 void CheckOperatorOrdering(const Model& model) {
    842   std::unordered_set<string> arrays_behind_us;
    843   for (const auto& array_entry : model.GetArrayMap()) {
    844     if (!GetOpWithOutput(model, array_entry.first)) {
    845       arrays_behind_us.insert(array_entry.first);
    846     }
    847   }
    848   arrays_behind_us.insert(model.optional_arrays.begin(),
    849                           model.optional_arrays.end());
    850   for (const auto& op : model.operators) {
    851     for (const auto& input : op->inputs) {
    852       if (!IsConstantParameterArray(model, input)) {
    853         CHECK(arrays_behind_us.count(input));
    854       }
    855     }
    856     for (const auto& output : op->outputs) {
    857       CHECK(!arrays_behind_us.count(output));
    858       arrays_behind_us.insert(output);
    859     }
    860   }
    861   for (const string& output_array : model.flags.output_arrays()) {
    862     CHECK(arrays_behind_us.count(output_array));
    863   }
    864 }
    865 
    866 void FixOperatorOrdering(Model* model) {
    867   std::unordered_set<string> arrays_behind_us;
    868   for (const auto& array_entry : model->GetArrayMap()) {
    869     if (!GetOpWithOutput(*model, array_entry.first)) {
    870       arrays_behind_us.insert(array_entry.first);
    871     }
    872   }
    873   arrays_behind_us.insert(model->optional_arrays.begin(),
    874                           model->optional_arrays.end());
    875   std::vector<std::unique_ptr<Operator>> old_operators;
    876   std::swap(old_operators, model->operators);
    877   std::set<std::size_t> remaining;
    878   for (std::size_t i = 0; i < old_operators.size(); i++) {
    879     remaining.insert(i);
    880   }
    881   std::unordered_map<string, string> reason_why_leftover;
    882   while (true) {
    883     bool inserted_something = false;
    884     for (auto i : remaining) {
    885       bool can_insert = true;
    886       auto& op = old_operators[i];
    887       CHECK(op.get());
    888       for (const auto& input : op->inputs) {
    889         if (!IsConstantParameterArray(*model, input) &&
    890             !arrays_behind_us.count(input)) {
    891           for (const string& output : op->outputs) {
    892             reason_why_leftover[output] = input;
    893           }
    894           can_insert = false;
    895           break;
    896         }
    897       }
    898       if (can_insert) {
    899         model->operators.emplace_back(nullptr);
    900         for (const auto& output : op->outputs) {
    901           arrays_behind_us.insert(output);
    902         }
    903         std::swap(op, model->operators.back());
    904         remaining.erase(i);
    905         inserted_something = true;
    906         break;
    907       }
    908     }
    909     if (!inserted_something) {
    910       break;
    911     }
    912   }
    913   if (!remaining.empty()) {
    914     LOG(ERROR)
    915         << "No viable ordering of operators was found. "
    916         << "Here is a 'backtrace' of at least one part of the graph that is "
    917         << "problematic. It starts with the first operator that has as "
    918         << "problematic input array, and then walks back the graph to "
    919         << "the operator that produced that input array, etc., until we find "
    920         << "the root cause:";
    921     LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
    922     LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
    923     const Operator* bad_op = old_operators[*remaining.begin()].get();
    924     std::unordered_set<string> bad_inputs_already_traced;
    925     // The following while(true) loop should always end with a LOG(FATAL).
    926     while (true) {
    927       LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
    928                  << FormatArraysList(*model, bad_op->inputs) << " -> "
    929                  << FormatArraysList(*model, bad_op->outputs);
    930       bool found_bad_output = false;
    931       string bad_output;
    932       for (const string& output : bad_op->outputs) {
    933         if (reason_why_leftover.count(output)) {
    934           found_bad_output = true;
    935           bad_output = output;
    936           break;
    937         }
    938       }
    939       CHECK(found_bad_output);
    940       const string& bad_input = reason_why_leftover[bad_output];
    941       LOG(ERROR) << "The bad input here is: " << bad_input;
    942       if (bad_inputs_already_traced.count(bad_input)) {
    943         LOG(FATAL)
    944             << "Cycle found! We already encountered that "
    945             << "input array, " << bad_input << ", earlier in the "
    946             << "above trace! We expect graphs to be acyclic, even "
    947             << "RNNs. Let us know if some graph actually needs to have "
    948             << "cycles, but first, please check if it really is "
    949             << "an *inference* graph. *Training* graphs are out-of-scope "
    950             << "for toco.";
    951       }
    952       bad_inputs_already_traced.insert(bad_input);
    953       bad_op = nullptr;
    954       for (auto i : remaining) {
    955         const Operator* op = old_operators[i].get();
    956         for (const string& output : op->outputs) {
    957           if (bad_input == output) {
    958             bad_op = op;
    959             break;
    960           }
    961         }
    962         if (bad_op) {
    963           break;
    964         }
    965       }
    966       if (!bad_op) {
    967         LOG(ERROR) << "And that's the root cause: "
    968                    << "that array, " << bad_input << ", isn't produced by any "
    969                    << "operator, or provided in any other way.";
    970         LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
    971         LOG(FATAL) << "(The above was a multi-line fatal error)";
    972       }
    973       LOG(ERROR) << "And that array is the output of the following operator:";
    974     }
    975   }
    976   CHECK(remaining.empty())
    977       << "Should never get here! In case of bad graph, "
    978       << "the above code should have generated a FATAL error already!";
    979 }
    980 
    981 void CheckInvariants(const Model& model) {
    982   CheckInputArraysAreNotOutputArrays(model.flags);
    983   CheckNonAsciiIOArrays(model.flags);
    984   CheckNoMissingArray(model);
    985   CheckNoOrphanedArray(model);
    986   CheckEachArray(model);
    987   CheckOperatorOrdering(model);
    988 }
    989 
    990 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
    991                        const int count, const string& count_description) {
    992   if (model_check.count_min() >= 0) {
    993     CHECK_GE(count, model_check.count_min())
    994         << "Mismatch in " << count_description << ": count  was " << count
    995         << ", but the specified "
    996         << (model_check.count_max() > model_check.count_min() ? "minimum"
    997                                                               : "value")
    998         << " was " << model_check.count_min() << ".";
    999   }
   1000   if (model_check.count_max() > model_check.count_min()) {
   1001     CHECK_LE(count, model_check.count_max())
   1002         << "Mismatch in " << count_description << ": count  was " << count
   1003         << ", but the specified maximum was " << model_check.count_max() << ".";
   1004   }
   1005 }
   1006 
   1007 void CheckModelCounts(const Model& model) {
   1008   std::unordered_multiset<OperatorType> ops_by_type;
   1009   std::unordered_map<string, OperatorType> op_type_by_name;
   1010   if (model.flags.model_checks_size() == 0) {
   1011     return;
   1012   }
   1013 
   1014   for (const auto& op : model.operators) {
   1015     ops_by_type.insert(op->type);
   1016     op_type_by_name[OperatorTypeName(op->type)] = op->type;
   1017   }
   1018   for (const auto& model_check : model.flags.model_checks()) {
   1019     string count_type = model_check.count_type();
   1020     if (count_type == "None") {
   1021       continue;
   1022     } else if (count_type == "Arrays") {
   1023       CheckCountInRange(model_check, model.GetArrayMap().size(),
   1024                         "count of arrays");
   1025     } else if (count_type == "Total") {
   1026       CheckCountInRange(model_check, model.operators.size(),
   1027                         "count of all operator instances");
   1028     } else {
   1029       // The check type is not itself checked against the set of valid
   1030       // operators, mainly because the enum set cannot be iterated in C++.
   1031       const int found_count =
   1032           op_type_by_name.count(count_type) > 0
   1033               ? ops_by_type.count(op_type_by_name[count_type])
   1034               : 0;
   1035       CheckCountInRange(model_check, found_count,
   1036                         "count of instances of " + count_type + " operator");
   1037     }
   1038   }
   1039 }
   1040 
   1041 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
   1042                    std::vector<int>* out_dims) {
   1043   CHECK(out_dims->empty());
   1044   if (num_dims == 0) {
   1045     return;
   1046   } else if (num_dims == 1) {
   1047     CHECK_EQ(batch, 1);
   1048     *out_dims = {depth};
   1049   } else if (num_dims == 2) {
   1050     *out_dims = {batch, depth};
   1051   } else if (num_dims == 3) {
   1052     CHECK_EQ(batch, 1);
   1053     *out_dims = {height, width, depth};
   1054   } else if (num_dims == 4) {
   1055     *out_dims = {batch, height, width, depth};
   1056   } else {
   1057     LOG(FATAL) << "Should not get here: " << num_dims;
   1058   }
   1059 }
   1060 
   1061 void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
   1062   int batch = 1;
   1063   int num_dims = -1;
   1064   for (const auto& input_array : model->flags.input_arrays()) {
   1065     // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
   1066     // a better match by name.
   1067     if (input_array.name() == name || num_dims == -1) {
   1068       num_dims = input_array.shape().dims_size();
   1069       if (num_dims > 0) {
   1070         batch = input_array.shape().dims(0);
   1071       }
   1072     }
   1073   }
   1074   Array& array = model->GetOrCreateArray(name);
   1075   if (array.has_shape()) {
   1076     num_dims = array.shape().dimensions_count();
   1077   }
   1078   if (!array.has_shape() && num_dims >= 0) {
   1079     Shape* shape = array.mutable_shape();
   1080     std::vector<int> dims;
   1081     MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
   1082     *shape->mutable_dims() = dims;
   1083   }
   1084 }
   1085 
   1086 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
   1087   // Merge info about input_arrays from model_flags into model->flags
   1088   for (const auto& specified_input_array : model_flags.input_arrays()) {
   1089     toco::InputArray* dst_input_array = nullptr;
   1090     for (int i = 0; i < model->flags.input_arrays_size(); i++) {
   1091       toco::InputArray* candidate_dst_input_array =
   1092           model->flags.mutable_input_arrays(i);
   1093       if (candidate_dst_input_array->name() == specified_input_array.name()) {
   1094         // specified_input_array from model_flags maps to dst_input_array
   1095         // in model->flags
   1096         dst_input_array = candidate_dst_input_array;
   1097         break;
   1098       }
   1099     }
   1100     if (!dst_input_array) {
   1101       // Specified_input_array from model_flags is not found in model->flags.
   1102       // Match a name-less specified input array when there can be no ambiguity
   1103       // as there is only 1 input array.
   1104       if (model->flags.input_arrays_size() == 1 &&
   1105           model_flags.input_arrays_size() == 1 &&
   1106           !specified_input_array.has_name()) {
   1107         dst_input_array = model->flags.mutable_input_arrays(0);
   1108       }
   1109     }
   1110     if (!dst_input_array) {
   1111       // Still no match, so create a new input array to copy
   1112       // specified_input_array into.
   1113       dst_input_array = model->flags.add_input_arrays();
   1114       dst_input_array->set_name(specified_input_array.name());
   1115     }
   1116 
   1117 #define RESOLVE_MODEL_FLAG(field_name)                                       \
   1118   if (specified_input_array.has_##field_name()) {                            \
   1119     if (dst_input_array->has_##field_name()) {                               \
   1120       QCHECK_EQ(dst_input_array->field_name(),                               \
   1121                 specified_input_array.field_name())                          \
   1122           << "For input array '" << dst_input_array->name() << "', "         \
   1123           << "specified " #field_name " flag with value: "                   \
   1124           << specified_input_array.field_name()                              \
   1125           << " does not agree with already defined " #field_name             \
   1126              " of this model, with value: "                                  \
   1127           << specified_input_array.field_name();                             \
   1128     } else {                                                                 \
   1129       dst_input_array->set_##field_name(specified_input_array.field_name()); \
   1130     }                                                                        \
   1131   }
   1132     RESOLVE_MODEL_FLAG(std_value);
   1133     RESOLVE_MODEL_FLAG(mean_value);
   1134 #undef RESOLVE_MODEL_FLAG
   1135 
   1136     if (specified_input_array.has_shape()) {
   1137       if (dst_input_array->has_shape()) {
   1138         QCHECK_EQ(specified_input_array.shape().dims_size(),
   1139                   dst_input_array->shape().dims_size())
   1140             << "For input array '" << specified_input_array.name() << "', "
   1141             << "size of specified input shape flag with size: "
   1142             << specified_input_array.shape().dims_size()
   1143             << " does not agree with already defined input shape"
   1144                " of this model, with size: "
   1145             << dst_input_array->shape().dims_size();
   1146         // We treat the first dimension as a special case, since it is often
   1147         // a batch size and the input_shape flag is effectively overriding
   1148         // the model.
   1149         for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
   1150           QCHECK_EQ(specified_input_array.shape().dims(i),
   1151                     dst_input_array->shape().dims(i))
   1152               << "At dimension number " << i << " of input array "
   1153               << specified_input_array.name() << ", the specified shape's "
   1154               << "dimension flag with dimension: "
   1155               << specified_input_array.shape().dims(i)
   1156               << " does not agree with already defined shape"
   1157               << " of this model, with dimension: "
   1158               << dst_input_array->shape().dims(i);
   1159         }
   1160       } else {
   1161         *dst_input_array->mutable_shape() = specified_input_array.shape();
   1162       }
   1163     }
   1164 
   1165     if (specified_input_array.has_data_type()) {
   1166       QCHECK(!dst_input_array->has_data_type());
   1167       dst_input_array->set_data_type(specified_input_array.data_type());
   1168     }
   1169   }
   1170 
   1171   if (model_flags.output_arrays_size() > 0) {
   1172     model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
   1173   }
   1174 
   1175 #define RESOLVE_MODEL_FLAG(name)                                           \
   1176   if (model_flags.has_##name()) {                                          \
   1177     if (model->flags.has_##name()) {                                       \
   1178       QCHECK_EQ(model_flags.name(), model->flags.name())                   \
   1179           << "Specified " #name " flag with value: " << model_flags.name() \
   1180           << " does not agree with already defined " #name                 \
   1181              " of this model, with value: "                                \
   1182           << model->flags.name();                                          \
   1183     } else {                                                               \
   1184       model->flags.set_##name(model_flags.name());                         \
   1185     }                                                                      \
   1186   }
   1187 
   1188   RESOLVE_MODEL_FLAG(variable_batch)
   1189 
   1190 #undef RESOLVE_MODEL_FLAG
   1191 
   1192   if (!model_flags.rnn_states().empty()) {
   1193     model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
   1194   }
   1195 
   1196   if (model->flags.model_checks_size() == 0) {
   1197     model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
   1198   }
   1199 
   1200   QCHECK_GT(model->flags.output_arrays_size(), 0)
   1201       << "This model does not define output arrays, so a "
   1202          "--output_arrays flag must be given on the command-line.";
   1203 
   1204   for (const auto& input_array_proto : model->flags.input_arrays()) {
   1205     auto& input_array = model->GetOrCreateArray(input_array_proto.name());
   1206     if (input_array_proto.has_data_type()) {
   1207       const ArrayDataType specified_type =
   1208           ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
   1209       QCHECK(specified_type != ArrayDataType::kNone);
   1210       if (input_array.data_type != ArrayDataType::kNone) {
   1211         QCHECK(specified_type == input_array.data_type)
   1212             << "For input array " << input_array_proto.name()
   1213             << " the specified input data type "
   1214             << IODataType_Name(input_array_proto.data_type())
   1215             << " conflicts with the existing type.";
   1216       }
   1217       input_array.data_type = specified_type;
   1218     }
   1219 
   1220     if (input_array.data_type == ArrayDataType::kNone) {
   1221       // We start out with a float input array;
   1222       // that may get replaced by a uint8 array later, by
   1223       // MakeInitialDequantizeOp.
   1224       input_array.data_type = ArrayDataType::kFloat;
   1225     }
   1226 
   1227     // Compare/merge the model->flags describing the input_shape with
   1228     // the actual input array's shape.
   1229     if (!input_array.has_shape()) {
   1230       if (input_array_proto.has_shape()) {
   1231         auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
   1232         for (auto dim : input_array_proto.shape().dims()) {
   1233           CHECK_GE(dim, 1);
   1234           input_array_dims.push_back(dim);
   1235         }
   1236       }
   1237     } else {
   1238       if (input_array_proto.has_shape()) {
   1239         // If an input shape was specified on the flags ensure that it matches
   1240         // the actual shape in the model.
   1241         const auto& input_array_dims =
   1242             *input_array.mutable_shape()->mutable_dims();
   1243         CHECK_EQ(input_array_dims.size(),
   1244                  input_array_proto.shape().dims_size());
   1245         for (int i = 0; i < input_array_dims.size(); i++) {
   1246           CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
   1247         }
   1248       }
   1249     }
   1250 
   1251     const float mean_value = input_array_proto.mean_value();
   1252     const float std_value = input_array_proto.std_value();
   1253     MinMax input_minmax;
   1254     input_minmax.min = (0.f - mean_value) / std_value;
   1255     input_minmax.max = (255.f - mean_value) / std_value;
   1256     if (input_array.minmax) {
   1257       if (input_array_proto.has_mean_value() ||
   1258           input_array_proto.has_std_value()) {
   1259         CHECK(input_minmax == *input_array.minmax)
   1260             << input_minmax.min << ", " << input_minmax.max
   1261             << " != " << input_array.minmax->min << ", "
   1262             << input_array.minmax->max;
   1263       }
   1264     } else {
   1265       input_array.GetOrCreateMinMax() = input_minmax;
   1266     }
   1267   }
   1268   // Creation of the RNN state arrays
   1269   for (const auto& rnn_state : model->flags.rnn_states()) {
   1270     CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
   1271                                model);
   1272   }
   1273 
   1274   for (const auto& input_array : model->flags.input_arrays()) {
   1275     if (input_array.has_shape()) {
   1276       CHECK(input_array.shape().dims_size());
   1277     }
   1278   }
   1279 
   1280   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
   1281   model->flags.set_allow_nonexistent_arrays(
   1282       model_flags.allow_nonexistent_arrays());
   1283 
   1284   CHECK(!model->flags.has_arrays_extra_info());
   1285   *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
   1286 }
   1287 
   1288 void CheckIsReadyForQuantization(const Model& model) {
   1289   for (const auto& op : model.operators) {
   1290     for (const auto& input : op->inputs) {
   1291       const auto& input_array = model.GetArray(input);
   1292       if (input_array.data_type != ArrayDataType::kFloat) {
   1293         // The array is not floats, no quantization needed.
   1294         continue;
   1295       }
   1296       if (input_array.minmax) {
   1297         // The array has minmax, we're good.
   1298         continue;
   1299       }
   1300       if (input_array.buffer) {
   1301         // The array has a constant buffer, so we can
   1302         // fall back to computing the minmax from actual array entries
   1303         // (with a WARNING about possible accuracy implications).
   1304         continue;
   1305       }
   1306       LOG(FATAL)
   1307           << "Array " << input << ", which is an input to the "
   1308           << HelpfulOperatorTypeName(*op) << " operator producing the output "
   1309           << "array " << op->outputs[0] << ", is lacking min/max data, "
   1310           << "which is necessary for quantization. Either target a "
   1311           << "non-quantized output format, or change the input graph to "
   1312           << "contain min/max information, or pass --default_ranges_min= and "
   1313           << "--default_ranges_max= if you do not care about the accuracy of "
   1314           << "results.";
   1315     }
   1316   }
   1317 }
   1318 
   1319 void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
   1320                                  double default_ranges_max) {
   1321   for (const auto& op : model->operators) {
   1322     for (const auto& input : op->inputs) {
   1323       auto& input_array = model->GetArray(input);
   1324       if (!input_array.minmax && !input_array.buffer) {
   1325         auto& minmax = input_array.GetOrCreateMinMax();
   1326         minmax.min = default_ranges_min;
   1327         minmax.max = default_ranges_max;
   1328       }
   1329     }
   1330     for (const auto& output : op->outputs) {
   1331       auto& output_array = model->GetArray(output);
   1332       if (!output_array.minmax && !output_array.buffer) {
   1333         auto& minmax = output_array.GetOrCreateMinMax();
   1334         minmax.min = default_ranges_min;
   1335         minmax.max = default_ranges_max;
   1336       }
   1337     }
   1338   }
   1339 }
   1340 
   1341 int ElementSize(ArrayDataType data_type) {
   1342   switch (data_type) {
   1343     case ArrayDataType::kFloat:
   1344       return 4;
   1345     case ArrayDataType::kInt8:
   1346       return 1;
   1347     case ArrayDataType::kUint8:
   1348       return 1;
   1349     case ArrayDataType::kInt16:
   1350       return 2;
   1351     case ArrayDataType::kUint16:
   1352       return 2;
   1353     case ArrayDataType::kInt32:
   1354       return 4;
   1355     case ArrayDataType::kUint32:
   1356       return 4;
   1357     case ArrayDataType::kInt64:
   1358       return 8;
   1359     case ArrayDataType::kUint64:
   1360       return 8;
   1361 
   1362     // Usually not critical limitation because strings are only input and/or
   1363     // output.
   1364     case ArrayDataType::kString:
   1365       LOG(FATAL) << "Transient arrays with strings are not supported yet";
   1366       return 0;
   1367     default:
   1368       LOG(FATAL) << "Should not get here.";
   1369       return 0;
   1370   }
   1371 }
   1372 
   1373 void DropMinMax(Model* model, const string& array_name) {
   1374   auto& array = model->GetArray(array_name);
   1375   if (!!array.minmax) {
   1376     LOG(WARNING) << "Dropping MinMax information in array " << array_name
   1377                  << ". Expect inaccuracy in quantized inference.";
   1378     array.minmax = nullptr;
   1379   }
   1380 }
   1381 
   1382 bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
   1383   // Optional array is not transient
   1384   if (model.IsOptionalArray(array_name)) return false;
   1385   // The model's input and output arrays are externally allocated.
   1386   // They are not transient arrays.
   1387   if (IsInputArray(model, array_name)) {
   1388     return false;
   1389   }
   1390   for (const string& output_array : model.flags.output_arrays()) {
   1391     if (array_name == output_array) {
   1392       return false;
   1393     }
   1394   }
   1395   const auto& array = &model.GetArray(array_name);
   1396   // An array with a constant buffer isn't a transient array.
   1397   if (!!array->buffer) {
   1398     return false;
   1399   }
   1400   // An array without shape isn't allocatable.
   1401   if (!array->has_shape()) {
   1402     return false;
   1403   }
   1404   return true;
   1405 }
   1406 
   1407 string AvailableArrayName(const Model& model, const string& name) {
   1408   string sanitized_name = SanitizeNameForTFNode(name);
   1409   if (!model.HasArray(sanitized_name) &&
   1410       !model.IsOptionalArray(sanitized_name)) {
   1411     return sanitized_name;
   1412   }
   1413   const int kNumSuffixesToTry = 1000;
   1414   for (int i = 0; i < kNumSuffixesToTry; i++) {
   1415     const string& name_with_suffix =
   1416         toco::port::StringF("%s_%d", sanitized_name, i);
   1417     if (!model.HasArray(name_with_suffix) &&
   1418         !model.IsOptionalArray(name_with_suffix)) {
   1419       return name_with_suffix;
   1420     }
   1421   }
   1422   LOG(FATAL) << "Could not find an available array name starting with "
   1423              << sanitized_name << ". Tried " << kNumSuffixesToTry
   1424              << " suffixes, all were taken!";
   1425   return "";
   1426 }
   1427 
   1428 string ShapeToString(const Shape& shape) {
   1429   if (shape.dimensions_count() == 0) {
   1430     return "[]";
   1431   }
   1432 
   1433   return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
   1434 }
   1435 
   1436 void PrintArrayShape(Model* model, const string& name) {
   1437   if (!model->GetArray(name).has_shape()) {
   1438     LOG(INFO) << name << " has no shape";
   1439     return;
   1440   }
   1441   LOG(INFO) << name
   1442             << " has shape: " << ShapeToString(model->GetArray(name).shape());
   1443 }
   1444 
   1445 bool IsArrayFullyConnectedWeights(const Model& model, const string& name) {
   1446   bool is_fc_weights = false;
   1447   bool is_something_else = false;
   1448   for (const auto& op : model.operators) {
   1449     for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
   1450       if (op->inputs[input_index] == name) {
   1451         if (op->type == OperatorType::kFullyConnected && input_index == 1) {
   1452           is_fc_weights = true;
   1453         } else {
   1454           is_something_else = true;
   1455         }
   1456       }
   1457     }
   1458   }
   1459   CHECK(!(is_fc_weights && is_something_else));
   1460   return is_fc_weights;
   1461 }
   1462 
   1463 string CreateInt32Array(Model* model, const string& param_name,
   1464                         const std::vector<int>& value) {
   1465   auto param_array_name = AvailableArrayName(*model, param_name);
   1466   auto& param_array = model->GetOrCreateArray(param_array_name);
   1467   param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
   1468   param_array.data_type = ArrayDataType::kInt32;
   1469   auto& param_array_data =
   1470       param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
   1471   param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
   1472   for (int i = 0; i < value.size(); ++i) {
   1473     param_array_data[i] = value[i];
   1474   }
   1475   return param_array_name;
   1476 }
   1477 
   1478 bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
   1479   int64 total = 0;
   1480   for (const auto& op : model.operators) {
   1481     switch (op->type) {
   1482       case OperatorType::kFullyConnected:
   1483       case OperatorType::kConv:
   1484       case OperatorType::kDepthwiseConv: {
   1485         const auto& output_array = model.GetArray(op->outputs[0]);
   1486         const auto& weights_array = model.GetArray(op->inputs[1]);
   1487         if (!output_array.has_shape() || !weights_array.has_shape()) {
   1488           return false;
   1489         }
   1490         int cols = 1;
   1491         for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
   1492           cols *= output_array.shape().dims(i);
   1493         }
   1494         const int64 cost_per_col =
   1495             2 * RequiredBufferSizeForShape(weights_array.shape());
   1496         total += cost_per_col * cols;
   1497         if (op->inputs.size() > 2) {
   1498           // There is a bias vector. One more op per output value.
   1499           total += RequiredBufferSizeForShape(output_array.shape());
   1500         }
   1501         break;
   1502       }
   1503       case OperatorType::kAdd:
   1504       case OperatorType::kSub:
   1505       case OperatorType::kMul: {
   1506         const auto& output_array = model.GetArray(op->outputs[0]);
   1507         if (!output_array.has_shape()) {
   1508           return false;
   1509         }
   1510         total += RequiredBufferSizeForShape(output_array.shape());
   1511         break;
   1512       }
   1513       case OperatorType::kAddN: {
   1514         const auto& output_array = model.GetArray(op->outputs[0]);
   1515         if (!output_array.has_shape()) {
   1516           return false;
   1517         }
   1518         // AddN cost is roughly the same cost as N-1 Adds.
   1519         const int num_adds = op->inputs.size() - 1;
   1520         total += num_adds * RequiredBufferSizeForShape(output_array.shape());
   1521         break;
   1522       }
   1523       case OperatorType::kLogistic:
   1524       case OperatorType::kSoftmax:
   1525       case OperatorType::kLogSoftmax:
   1526       case OperatorType::kTanh: {
   1527         const auto& output_array = model.GetArray(op->outputs[0]);
   1528         if (!output_array.has_shape()) {
   1529           return false;
   1530         }
   1531         // As a very rough ballpark, the cost of evaluating a math function
   1532         // such as tanh or logistic is about 32 multiplications, and about as
   1533         // many additions/subtractions. (Just a power-of-two order-of-magnitude
   1534         // from looking at actual implementations that we use in runtime/ code).
   1535         total += 64 * RequiredBufferSizeForShape(output_array.shape());
   1536         break;
   1537       }
   1538       case OperatorType::kMaxPool: {
   1539         const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
   1540         const auto& output_array = model.GetArray(op->outputs[0]);
   1541         if (!output_array.has_shape()) {
   1542           return false;
   1543         }
   1544         total += RequiredBufferSizeForShape(output_array.shape()) *
   1545                  maxpool.kheight * maxpool.kwidth;
   1546         break;
   1547       }
   1548       case OperatorType::kAveragePool: {
   1549         const auto& avgpool =
   1550             *static_cast<const AveragePoolOperator*>(op.get());
   1551         const auto& output_array = model.GetArray(op->outputs[0]);
   1552         if (!output_array.has_shape()) {
   1553           return false;
   1554         }
   1555         total += RequiredBufferSizeForShape(output_array.shape()) *
   1556                  avgpool.kheight * avgpool.kwidth;
   1557         break;
   1558       }
   1559       case OperatorType::kL2Pool: {
   1560         const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
   1561         const auto& output_array = model.GetArray(op->outputs[0]);
   1562         if (!output_array.has_shape()) {
   1563           return false;
   1564         }
   1565         // The sum of squares requires (kheight*kwidth) multiply-adds,
   1566         // and then there is the sqrt which we ballpark at 32 ops.
   1567         const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
   1568         total +=
   1569             RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
   1570         break;
   1571       }
   1572       case OperatorType::kL2Normalization: {
   1573         const auto& output_array = model.GetArray(op->outputs[0]);
   1574         if (!output_array.has_shape()) {
   1575           return false;
   1576         }
   1577         // Computing the squared L2 norm is N multiply-adds so 2N ops,
   1578         // then the single inverse-sqrt is negligible, then we multiply each
   1579         // value by the resulting multiplier, so an extra N ops. Total 3N ops.
   1580         total += 3 * RequiredBufferSizeForShape(output_array.shape());
   1581         break;
   1582       }
   1583       default:
   1584         break;
   1585     }
   1586   }
   1587   *result = total;
   1588   return true;
   1589 }
   1590 
   1591 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
   1592                      std::vector<int>* shuffle) {
   1593   CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
   1594   shuffle->resize(4);
   1595   for (int i = 0; i < 4; i++) {
   1596     (*shuffle)[i] = i;
   1597   }
   1598   if (input_axes_order == output_axes_order) {
   1599     // nothing to do
   1600   } else if (AxesCount(input_axes_order) == 2) {
   1601     shuffle->resize(2);
   1602     (*shuffle)[0] = 1;
   1603     (*shuffle)[1] = 0;
   1604   } else if (input_axes_order == AxesOrder::kOHWI &&
   1605              output_axes_order == AxesOrder::kHWIO) {
   1606     // 3210 <- 3210
   1607     // HWIO <- OHWI
   1608     (*shuffle)[0] = 1;
   1609     (*shuffle)[1] = 2;
   1610     (*shuffle)[2] = 3;
   1611     (*shuffle)[3] = 0;
   1612   } else if (input_axes_order == AxesOrder::kHWIO &&
   1613              output_axes_order == AxesOrder::kOHWI) {
   1614     // 3210 <- 3210
   1615     // OHWI <- HWIO
   1616     (*shuffle)[0] = 3;
   1617     (*shuffle)[1] = 0;
   1618     (*shuffle)[2] = 1;
   1619     (*shuffle)[3] = 2;
   1620   } else {
   1621     LOG(FATAL) << "Bad shuffle";
   1622   }
   1623 }
   1624 
   1625 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
   1626                    std::vector<int>* extended_shuffle) {
   1627   *extended_shuffle = input_shuffle;
   1628   CHECK(newdim >= input_shuffle.size());
   1629   const int pad_size = newdim - input_shuffle.size();
   1630   extended_shuffle->resize(newdim);
   1631   for (int i = 0; i < pad_size; i++) {
   1632     (*extended_shuffle)[i] = i;
   1633   }
   1634   for (int i = pad_size; i < newdim; i++) {
   1635     (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
   1636   }
   1637 }
   1638 
   1639 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
   1640                  AxesOrder output_axes_order, Shape* output_shape) {
   1641   if (input_axes_order == AxesOrder::kHWIM &&
   1642       output_axes_order == AxesOrder::k1HWO) {
   1643     // This special case isn't just a permutation, the IM pair of dims get
   1644     // merged into the 3 dim, so we have to special-case it.
   1645     *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
   1646                            input_shape.dims(3) * input_shape.dims(2)});
   1647   } else {
   1648     std::vector<int> shuffle;
   1649     GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
   1650     std::vector<int>* output_dims = output_shape->mutable_dims();
   1651     output_dims->resize(input_shape.dimensions_count());
   1652     for (int i = 0; i < input_shape.dimensions_count(); i++) {
   1653       (*output_dims)[i] = input_shape.dims(shuffle[i]);
   1654     }
   1655   }
   1656 }
   1657 
   1658 template <typename T>
   1659 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
   1660                           AxesOrder output_axes_order,
   1661                           const Shape& output_shape, const T* input_data,
   1662                           T* output_data) {
   1663   if (input_axes_order == AxesOrder::kHWIM &&
   1664       output_axes_order == AxesOrder::k1HWO) {
   1665     // This special case isn't just a permutation, the IM pair of dims get
   1666     // merged into the O dim, so we have to special-case it. Fortunately,
   1667     // as far as array shuffling is concerned, it's just the identity
   1668     // transformation.
   1669     memcpy(output_data, input_data,
   1670            RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
   1671     return;
   1672   }
   1673   CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
   1674   const int dim = input_shape.dimensions_count();
   1675   CHECK_LE(dim, 4);
   1676   std::vector<int> shuffle;
   1677   GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
   1678   CHECK(shuffle.size() >= dim);
   1679   for (int i = 0; i < dim; i++) {
   1680     CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
   1681     CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
   1682   }
   1683   Shape extended_input_shape = input_shape;
   1684   ExtendShape(&extended_input_shape, 4);
   1685   Shape extended_output_shape = output_shape;
   1686   ExtendShape(&extended_output_shape, 4);
   1687   std::vector<int> extended_shuffle;
   1688   ExtendShuffle(shuffle, 4, &extended_shuffle);
   1689 
   1690   const std::vector<int>& extended_input_dims = extended_input_shape.dims();
   1691   const std::vector<int>& extended_output_dims = extended_output_shape.dims();
   1692 
   1693   // TODO(starka): Rework to handle different numbers of dimensions.
   1694   int input_strides[4];
   1695   input_strides[3] = 1;
   1696   input_strides[2] = extended_input_dims[3];
   1697   input_strides[1] = input_strides[2] * extended_input_dims[2];
   1698   input_strides[0] = input_strides[1] * extended_input_dims[1];
   1699   const int input_stride_0 = input_strides[extended_shuffle[3]];
   1700   const int input_stride_1 = input_strides[extended_shuffle[2]];
   1701   const int input_stride_2 = input_strides[extended_shuffle[1]];
   1702   const int input_stride_3 = input_strides[extended_shuffle[0]];
   1703 
   1704   const int output_size_0 = extended_output_dims[3];
   1705   const int output_size_1 = extended_output_dims[2];
   1706   const int output_size_2 = extended_output_dims[1];
   1707   const int output_size_3 = extended_output_dims[0];
   1708   const int output_stride_0 = 1;
   1709   const int output_stride_1 = output_size_0;
   1710   const int output_stride_2 = output_stride_1 * output_size_1;
   1711   const int output_stride_3 = output_stride_2 * output_size_2;
   1712 
   1713   for (int i3 = 0; i3 < output_size_3; i3++) {
   1714     const T* const input_ptr_3 = input_data + i3 * input_stride_3;
   1715     T* const output_ptr_3 = output_data + i3 * output_stride_3;
   1716     for (int i2 = 0; i2 < output_size_2; i2++) {
   1717       const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
   1718       T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
   1719       for (int i1 = 0; i1 < output_size_1; i1++) {
   1720         const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
   1721         T* output_ptr = output_ptr_2 + i1 * output_stride_1;
   1722         T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
   1723         while (output_ptr != output_ptr_end) {
   1724           *output_ptr = *input_ptr;
   1725           input_ptr += input_stride_0;
   1726           output_ptr += output_stride_0;
   1727         }
   1728       }
   1729     }
   1730   }
   1731 }
   1732 
   1733 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
   1734                   AxesOrder output_axes_order, const Shape& output_shape,
   1735                   const uint8* input_data, uint8* output_data) {
   1736   ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
   1737                               output_shape, input_data, output_data);
   1738 }
   1739 
   1740 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
   1741                   AxesOrder output_axes_order, const Shape& output_shape,
   1742                   const float* input_data, float* output_data) {
   1743   ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
   1744                               output_shape, input_data, output_data);
   1745 }
   1746 
   1747 int AxesCount(AxesOrder axes_order) {
   1748   switch (axes_order) {
   1749     case AxesOrder::kOneAxis:
   1750       return 1;
   1751     case AxesOrder::kRC:
   1752       return 2;
   1753     case AxesOrder::kCR:
   1754       return 2;
   1755     case AxesOrder::kHWIO:
   1756       return 4;
   1757     case AxesOrder::kOHWI:
   1758       return 4;
   1759     case AxesOrder::kHWIM:
   1760       return 4;
   1761     case AxesOrder::k1HWO:
   1762       return 4;
   1763     case AxesOrder::kNHWC:
   1764       return 4;
   1765     default:
   1766       LOG(FATAL) << "Bad AxesOrder";
   1767       return 0;
   1768   }
   1769 }
   1770 
   1771 bool IsDiscardableArray(const Model& model, const string& array_name) {
   1772   for (const auto& input_array : model.flags.input_arrays()) {
   1773     if (array_name == input_array.name()) {
   1774       return false;
   1775     }
   1776   }
   1777   for (const string& output_array : model.flags.output_arrays()) {
   1778     if (array_name == output_array) {
   1779       return false;
   1780     }
   1781   }
   1782   for (const auto& rnn_state : model.flags.rnn_states()) {
   1783     if (!rnn_state.discardable()) {
   1784       if (array_name == rnn_state.state_array()) {
   1785         return false;
   1786       }
   1787       if (array_name == rnn_state.back_edge_source_array()) {
   1788         return false;
   1789       }
   1790     }
   1791   }
   1792   return true;
   1793 }
   1794 
   1795 void CheckFinalDataTypesSatisfied(const Model& model) {
   1796   for (const auto& array_entry : model.GetArrayMap()) {
   1797     const auto& array = *array_entry.second;
   1798     if (array.final_data_type != ArrayDataType::kNone) {
   1799       CHECK(array.final_data_type == array.data_type)
   1800           << "Array \"" << array_entry.first
   1801           << "\" has mis-matching actual and final data types ("
   1802           << static_cast<int>(array.data_type) << ","
   1803           << static_cast<int>(array.final_data_type) << ").";
   1804     }
   1805   }
   1806 }
   1807 
   1808 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
   1809   switch (type) {
   1810     case FLOAT:
   1811       return ArrayDataType::kFloat;
   1812     case QUANTIZED_UINT8:
   1813       return ArrayDataType::kUint8;
   1814     case INT32:
   1815       return ArrayDataType::kInt32;
   1816     case INT64:
   1817       return ArrayDataType::kInt64;
   1818     default:
   1819       return ArrayDataType::kNone;
   1820   }
   1821 }
   1822 
   1823 void FinishBuildingRNNStates(Model* model) {
   1824   for (const auto& rnn_state : model->flags.rnn_states()) {
   1825     if (!model->HasArray(rnn_state.back_edge_source_array()) ||
   1826         !model->HasArray(rnn_state.state_array())) {
   1827       CHECK(model->HasArray(rnn_state.back_edge_source_array()));
   1828       CHECK(model->HasArray(rnn_state.state_array()));
   1829       continue;
   1830     }
   1831     const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
   1832     auto& dst_array = model->GetArray(rnn_state.state_array());
   1833     if (src_array.data_type == ArrayDataType::kNone &&
   1834         dst_array.data_type == ArrayDataType::kNone) {
   1835       dst_array.data_type = ArrayDataType::kFloat;
   1836     }
   1837   }
   1838 }
   1839 
   1840 void UseArraysExtraInfo(Model* model) {
   1841   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
   1842     QCHECK(model->HasArray(entry.name()))
   1843         << "ArraysExtraInfo refers to non-existent array name: "
   1844         << entry.name();
   1845     auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
   1846     minmax.min = entry.min();
   1847     minmax.max = entry.max();
   1848   }
   1849 }
   1850 
   1851 }  // namespace toco
   1852