Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
     16 #define TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
     17 
     18 #include <algorithm>
     19 #include <cmath>
     20 #include <iostream>
     21 #include <limits>
     22 #include <memory>
     23 #include <string>
     24 #include <vector>
     25 
     26 #include "absl/strings/string_view.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #if TOCO_SUPPORT_PORTABLE_PROTOS
     29 #include "third_party/protobuf/src/google/protobuf/text_format.h"
     30 #endif  // TOCO_SUPPORT_PORTABLE_PROTOS
     31 #include "tensorflow/contrib/lite/toco/model.h"
     32 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
     33 #include "tensorflow/contrib/lite/toco/runtime/types.h"
     34 #include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
     35 #include "tensorflow/contrib/lite/toco/toco_port.h"
     36 #include "tensorflow/contrib/lite/toco/types.pb.h"
     37 
     38 // TODO(aselle): Replace with using a container specific hash override instead.
     39 namespace std {
     40 template <>
     41 struct hash<toco::OperatorType> {
     42   size_t operator()(const toco::OperatorType& op) const {
     43     return std::hash<size_t>()(static_cast<size_t>(op));
     44   }
     45 };
     46 }  // namespace std
     47 
     48 namespace toco {
     49 
     50 constexpr int kLogLevelModelChanged = 1;
     51 constexpr int kLogLevelModelUnchanged = 2;
     52 
     53 absl::string_view FindLongestCommonPrefix(absl::string_view a,
     54                                           absl::string_view b);
     55 string LogName(const Operator& op);
     56 
     57 bool IsInputArray(const Model& model, const string& name);
     58 bool IsArrayConsumed(const Model& model, const string& name);
     59 int CountTrueOutputs(const Model& model, const Operator& op);
     60 
     61 int CountOpsWithInput(const Model& model, const string& array_name);
     62 bool DeleteArrayIfUnused(const string& array_name, Model* model);
     63 bool DeleteArrayIfUsedOnce(const string& array_name, Model* model);
     64 
     65 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
     66     const Model& model, const string& array_name);
     67 Operator* GetOpWithOutput(const Model& model, const string& array_name);
     68 
     69 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
     70     Model& model, const string& array_name);
     71 
     72 Operator* GetOpWithOutput(const Model& model, const string& array_name);
     73 
     74 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
     75     const Model& model, const string& array_name);
     76 
     77 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
     78     Model& model, const string& array_name);
     79 
     80 Operator* GetOpWithInput(const Model& model, const string& array_name);
     81 Operator* GetFirstOpWithInput(const Model& model, const string& array_name);
     82 
     83 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
     84     const Model& model, const Operator* op);
     85 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
     86                                                         const Operator* op);
     87 
     88 const char* OperatorTypeName(OperatorType type);
     89 string HelpfulOperatorTypeName(const Operator& op);
     90 
     91 bool OperatorSupportsFusedActivation(OperatorType type);
     92 
     93 void DumpGraphvizVideoFrame(const Model& model);
     94 void LogDump(int log_level, const string& message, const Model& model);
     95 void LogSummary(int log_level, const string& message, const Model& model);
     96 
     97 // TODO(b/36075966): Clean up when dims superseded by array shape.
     98 void ExtendShape(Shape* shape, int new_shape_size);
     99 
    100 // TODO(b/36075966): Clean up when dims superseded by array shape.
    101 void UnextendShape(Shape* shape, int new_shape_size);
    102 
    103 // Checks (using CHECK) that all dimensions of 'shape' are at least 1.
    104 void CheckShapeDimensions(const Shape& shape);
    105 
    106 // Given two shapes with potentially different dimensionality and dimension
    107 // arrays d0 and d1. Without loss of generality, assume that shape0 may have
    108 // higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
    109 // "agree up to broadcasting" if:
    110 // - When walking the d0 and d1 from back to front with indices i0, i1,
    111 //   d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until
    112 //   i1 == 0 (inclusive).
    113 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
    114 
    115 // A stricter constraint than ShapesAgreeUpToBroadcasting().
    116 //
    117 // Given two shapes with potentially different dimensionality and dimension
    118 // arrays d0 and d1. Without loss of generality, assume that shape0 may have
    119 // higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
    120 // "agree up to extending" if:
    121 // - When walking the d0 and d1 from back to front with indices i0, i1,
    122 //   d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive).
    123 // - For the remaining indices [0..i0), d0[i0] == 1.
    124 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
    125 
    126 bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
    127 
    128 // If there is a wildcard dimension (-1), this may return a negative value.
    129 int RequiredBufferSizeForShape(const Shape& shape);
    130 
    131 bool IsConstantParameterArray(const Model& model, const string& name);
    132 
    133 void CheckNoMissingArray(const Model& model);
    134 void CheckInvariants(const Model& model);
    135 
    136 void CheckModelCounts(const Model& model);
    137 
    138 void FixOperatorOrdering(Model* model);
    139 void FixNoMissingArray(Model* model);
    140 void FixNoOrphanedArray(Model* model);
    141 
    142 void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
    143 
    144 template <ArrayDataType A>
    145 void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags,
    146                                      const MinMax& minmax,
    147                                      QuantizationParams* quantization_params) {
    148   using Integer = DataType<A>;
    149   const Integer qmin = std::numeric_limits<Integer>::min();
    150   const Integer qmax = std::numeric_limits<Integer>::max();
    151   const double qmin_double = qmin;
    152   const double qmax_double = qmax;
    153   const double rmin = minmax.min;
    154   const double rmax = minmax.max;
    155   // 0 should always be a representable value. Let's assume that the initial
    156   // min,max range contains 0.
    157   CHECK_LE(rmin, 0.);
    158   CHECK_GE(rmax, 0.);
    159   if (rmin == rmax) {
    160     // Special case where the min,max range is a point. Should be {0}.
    161     CHECK_EQ(rmin, 0.);
    162     CHECK_EQ(rmax, 0.);
    163     quantization_params->zero_point = 0;
    164     quantization_params->scale = 0.;
    165     return;
    166   }
    167 
    168   // General case.
    169   //
    170   // First determine the scale.
    171   const double scale = (rmax - rmin) / (qmax_double - qmin_double);
    172 
    173   // Zero-point computation.
    174   // First the initial floating-point computation. The zero-point can be
    175   // determined from solving an affine equation for any known pair
    176   // (real value, corresponding quantized value).
    177   // We know two such pairs: (rmin, qmin) and (rmax, qmax).
    178   // The arithmetic error on the zero point computed from either pair
    179   // will be roughly machine_epsilon * (sum of absolute values of terms)
    180   // so we want to use the variant that adds the smaller terms.
    181   const double zero_point_from_min = qmin_double - rmin / scale;
    182   const double zero_point_from_max = qmax_double - rmax / scale;
    183   const double zero_point_from_min_error =
    184       std::abs(qmin_double) + std::abs(rmin / scale);
    185   const double zero_point_from_max_error =
    186       std::abs(qmax_double) + std::abs(rmax / scale);
    187 
    188   const double zero_point_double =
    189       zero_point_from_min_error < zero_point_from_max_error
    190           ? zero_point_from_min
    191           : zero_point_from_max;
    192 
    193   // Now we need to nudge the zero point to be an integer
    194   // (our zero points are integer, and this is motivated by the requirement
    195   // to be able to represent the real value "0" exactly as a quantized value,
    196   // which is required in multiple places, for example in Im2col with SAME
    197   // padding).
    198   Integer nudged_zero_point = 0;
    199   if (zero_point_double < qmin_double) {
    200     nudged_zero_point = qmin;
    201   } else if (zero_point_double > qmax_double) {
    202     nudged_zero_point = qmax;
    203   } else {
    204     nudged_zero_point = static_cast<Integer>(std::round(zero_point_double));
    205   }
    206   // The zero point should always be in the range of quantized value,
    207   // [qmin, qmax].
    208   CHECK_GE(nudged_zero_point, qmin);
    209   CHECK_LE(nudged_zero_point, qmax);
    210 
    211   // Finally, store the result nudged quantization params.
    212   quantization_params->zero_point = nudged_zero_point;
    213   quantization_params->scale = scale;
    214 }
    215 
    216 void CheckIsReadyForQuantization(const Model& model);
    217 void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
    218                                  double default_ranges_max);
    219 
    220 inline int Offset(const Shape& shape, const std::vector<int>& indices) {
    221   DCHECK_EQ(shape.dimensions_count(), indices.size());
    222   const int dims_count = shape.dimensions_count();
    223   int offset = 0;
    224   for (int i = 0; i < dims_count; i++) {
    225     const int index = indices[i];
    226     DCHECK(index >= 0 && index < shape.dims(i));
    227     offset *= shape.dims(i);
    228     offset += index;
    229   }
    230   return offset;
    231 }
    232 
    233 inline std::vector<int> ReverseOffset(const Shape& shape, int index) {
    234   DCHECK_GE(index, 0);
    235   DCHECK_LT(index, RequiredBufferSizeForShape(shape));
    236   const int dims_count = shape.dimensions_count();
    237   std::vector<int> indices(dims_count);
    238   int residual = index;
    239   for (int i = dims_count - 1; i >= 0; i--) {
    240     indices[i] = residual % shape.dims(i);
    241     residual /= shape.dims(i);
    242   }
    243   return indices;
    244 }
    245 
    246 int ElementSize(ArrayDataType data_type);
    247 
    248 void DropMinMax(Model* model, const string& array_name);
    249 
    250 bool IsAllocatableTransientArray(const Model& model, const string& array_name);
    251 
    252 void CreateOrCheckRnnStateArray(const string& name, int size, Model* model);
    253 
    254 string AvailableArrayName(const Model& model, const string& name);
    255 
    256 // Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ].
    257 string ShapeToString(const Shape& shape);
    258 
    259 void PrintArrayShape(Model* model, const string& name);
    260 
    261 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
    262                    std::vector<int>* out_dims);
    263 
    264 // Defines a constant int32 array with the provided values formatted for use
    265 // as op parameters.
    266 string CreateInt32Array(Model* model, const string& param_name,
    267                         const std::vector<int>& value);
    268 
    269 bool EstimateArithmeticOpsCount(const Model& model, int64* result);
    270 
    271 int AxesCount(AxesOrder axes_order);
    272 
    273 // Returns the permutation of the dimensions based on the input axes order and
    274 // output axes order.
    275 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
    276                      std::vector<int>* shuffle);
    277 
    278 // Extend shuffle is designed to match ExtendShape, which pads the shape with
    279 // unit dimensions at the beginning.
    280 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
    281                    std::vector<int>* extended_shuffle);
    282 
    283 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
    284                  AxesOrder output_axes_order, Shape* output_shape);
    285 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
    286                   AxesOrder output_axes_order, const Shape& output_shape,
    287                   const float* input_data, float* output_data);
    288 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
    289                   AxesOrder output_axes_order, const Shape& output_shape,
    290                   const uint8* input_data, uint8* output_data);
    291 
    292 // Returns true if it may be OK for any graph transformation to ever discard
    293 // that array. The idea is that we can't ever discard arrays that are either
    294 // an input or an output of the whole graph, or that appear in RNN back-edges,
    295 // as that would undercut explicit flags that the user might pass.
    296 bool IsDiscardableArray(const Model& model, const string& array_name);
    297 
    298 void CheckFinalDataTypesSatisfied(const Model& model);
    299 
    300 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
    301 
    302 // The process of building models varies according to the import format.
    303 //
    304 // (a) In some cases, such as model-proto format, the model should be fully
    305 // specified. In these cases, no extra action should be taken by this function.
    306 // (b) In other cases, such as TF graphdef format, the desired types of RNN
    307 // arrays are not specified directly in the model, neither can they be inferred.
    308 // However, we can set the types of RNN destination arrays to float. This breaks
    309 // any cycles such as when resolution of the type of an RNN source array depends
    310 // on the type of its destination array.
    311 //
    312 // This function is applied after the main import, after resolution of flags and
    313 // after application of ArraysExtraInfo. It only defaults destination RNN arrays
    314 // to float. If the model is subsequently quantized, it is assumed that the
    315 // model contains sufficient information for that to be completed. If it is
    316 // already quantized, then case (a) should hold.
    317 void FinishBuildingRNNStates(Model* model);
    318 
    319 void UseArraysExtraInfo(Model* model);
    320 
    321 }  // namespace toco
    322 
    323 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
    324