Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_
     16 #define TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_
     17 
     18 #include <algorithm>
     19 #include <cmath>
     20 #include <iostream>
     21 #include <limits>
     22 #include <memory>
     23 #include <string>
     24 #include <vector>
     25 
     26 #include "absl/strings/string_view.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #if TOCO_SUPPORT_PORTABLE_PROTOS
     29 #include "third_party/protobuf/include/google/protobuf/text_format.h"
     30 #endif  // TOCO_SUPPORT_PORTABLE_PROTOS
     31 #include "tensorflow/lite/kernels/internal/types.h"
     32 #include "tensorflow/lite/toco/model.h"
     33 #include "tensorflow/lite/toco/model_flags.pb.h"
     34 #include "tensorflow/lite/toco/runtime/types.h"
     35 #include "tensorflow/lite/toco/toco_flags.pb.h"
     36 #include "tensorflow/lite/toco/types.pb.h"
     37 #include "tensorflow/core/lib/core/errors.h"
     38 #include "tensorflow/core/lib/core/status.h"
     39 
     40 // TODO(aselle): Replace with using a container specific hash override instead.
     41 namespace std {
     42 template <>
     43 struct hash<toco::OperatorType> {
     44   size_t operator()(const toco::OperatorType& op) const {
     45     return std::hash<size_t>()(static_cast<size_t>(op));
     46   }
     47 };
     48 }  // namespace std
     49 
     50 namespace toco {
     51 
     52 constexpr int kLogLevelModelChanged = 1;
     53 constexpr int kLogLevelModelUnchanged = 2;
     54 
     55 absl::string_view FindLongestCommonPrefix(absl::string_view a,
     56                                           absl::string_view b);
     57 string LogName(const Operator& op);
     58 
     59 string ArrayDataTypeName(ArrayDataType data_type);
     60 
     61 // Returns true if the given array is specified as a model input array.
     62 bool IsInputArray(const Model& model, const string& array_name);
     63 // Returns true if the given array is specified as a model output array.
     64 bool IsOutputArray(const Model& model, const string& array_name);
     65 
     66 bool IsArrayConsumed(const Model& model, const string& name);
     67 int CountTrueOutputs(const Model& model, const Operator& op);
     68 
     69 int CountOpsWithInput(const Model& model, const string& array_name);
     70 bool DeleteArrayIfUnused(const string& array_name, Model* model);
     71 bool DeleteArrayIfUsedOnce(const string& array_name, Model* model);
     72 
     73 // Deletes the op and any of its input and output arrays if they are unused
     74 // after the op has been deleted.
     75 void DeleteOpAndArraysIfUnused(Model* model, const Operator* op);
     76 
     77 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
     78     const Model& model, const string& array_name);
     79 Operator* GetOpWithOutput(const Model& model, const string& array_name);
     80 
     81 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
     82     Model& model, const string& array_name);
     83 
     84 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
     85     const Model& model, const string& array_name);
     86 
     87 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
     88     Model& model, const string& array_name);
     89 
     90 Operator* GetOpWithInput(const Model& model, const string& array_name);
     91 Operator* GetFirstOpWithInput(const Model& model, const string& array_name);
     92 
     93 // Replaces all uses of the |old_array_name| with the |new_array_name|.
     94 void ReplaceArrayUsage(Model* model, const string& old_array_name,
     95                        const string& new_array_name);
     96 
     97 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
     98     const Model& model, const Operator* op);
     99 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
    100                                                         const Operator* op);
    101 
    102 const char* OperatorTypeName(OperatorType type);
    103 string HelpfulOperatorTypeName(const Operator& op);
    104 
    105 // Whether the operator can be fused with an activation function. Note that this
    106 // will return false by default for new operators; fusing support is opt-in.
    107 bool OperatorSupportsFusedActivation(OperatorType type);
    108 
    109 void DumpGraphvizVideoFrame(const Model& model);
    110 void LogDump(int log_level, const string& message, const Model& model);
    111 void LogSummary(int log_level, const string& message, const Model& model);
    112 
    113 // TODO(b/36075966): Clean up when dims superseded by array shape.
    114 void ExtendShape(Shape* shape, int new_shape_size);
    115 
    116 // TODO(b/36075966): Clean up when dims superseded by array shape.
    117 void UnextendShape(Shape* shape, int new_shape_size);
    118 
    119 // Checks that all dimensions of 'shape' are at least 1. Note that scalars,
    120 // lacking dimensions, satisfy this condition and are considered non-empty.
    121 bool IsNonEmpty(const Shape& shape);
    122 
    123 // Given two shapes with potentially different dimensionality and dimension
    124 // arrays d0 and d1. Without loss of generality, assume that shape0 may have
    125 // higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
    126 // "agree up to broadcasting" if:
    127 // - When walking the d0 and d1 from back to front with indices i0, i1,
    128 //   d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until
    129 //   i1 == 0 (inclusive).
    130 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
    131 
    132 // A stricter constraint than ShapesAgreeUpToBroadcasting().
    133 //
    134 // Given two shapes with potentially different dimensionality and dimension
    135 // arrays d0 and d1. Without loss of generality, assume that shape0 may have
    136 // higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1
    137 // "agree up to extending" if:
    138 // - When walking the d0 and d1 from back to front with indices i0, i1,
    139 //   d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive).
    140 // - For the remaining indices [0..i0), d0[i0] == 1.
    141 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
    142 
    143 inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
    144   return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
    145 }
    146 
    147 bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
    148 
    149 // If there is a wildcard dimension (-1), this may return a negative value.
    150 int RequiredBufferSizeForShape(const Shape& shape);
    151 
    152 bool IsConstantParameterArray(const Model& model, const string& name);
    153 
    154 // Compares two constant parameter arrays for exact equality.
    155 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array);
    156 
    157 void CheckNoMissingArray(const Model& model);
    158 void CheckInvariants(const Model& model);
    159 
    160 void CheckModelCounts(const Model& model);
    161 
    162 void FixOperatorOrdering(Model* model);
    163 void FixNoMissingArray(Model* model);
    164 void FixNoOrphanedArray(Model* model);
    165 
    166 // Fixes input/output arrays that may have issues during export or inference.
    167 void FixEdgeArrays(Model* model);
    168 
    169 // Finds and deduplicates large constant arrays in the model.
    170 // After constant propagation runs it's possible to end up with several of the
    171 // same large array (whether they be zeros or otherwise).
    172 //
    173 // |min_size| is used to adjust the minimum size in bytes of an array before
    174 // it's considered for deduping. As deduping can make the graphs more difficult
    175 // to read this helps prevent small arrays from spidering out.
    176 void DedupeConstantArrays(Model* model, size_t min_size);
    177 
    178 // Copies the contents of an array into another.
    179 // Expects that the shape and data type match.
    180 template <ArrayDataType A>
    181 void CopyArrayBuffer(const Array& source_array, Array* target_array) {
    182   int source_buffer_size = RequiredBufferSizeForShape(source_array.shape());
    183   int target_buffer_size = RequiredBufferSizeForShape(target_array->shape());
    184   CHECK_EQ(source_buffer_size, target_buffer_size)
    185       << "Buffer sizes must match in element count";
    186   CHECK(source_array.data_type == target_array->data_type)
    187       << "Data types must match";
    188   if (source_array.buffer) {
    189     const auto& source_buffer = source_array.GetBuffer<A>();
    190     auto& target_buffer = target_array->GetMutableBuffer<A>();
    191     target_buffer.data = source_buffer.data;
    192   }
    193 }
    194 
    195 // Inserts a no-op reshape operator between the source array and the target
    196 // array. This effectively just copies the data.
    197 void InsertCopyOperator(Model* model, const string& source_array_name,
    198                         const string& target_array_name);
    199 
    200 // Clones an array with all data and parameters.
    201 void CloneArray(Model* model, const string& source_array_name,
    202                 const string& target_array_name);
    203 
    204 void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
    205 
    206 template <typename T>
    207 T ConvertOperator(Operator* o, OperatorType type) {
    208   if (o != nullptr && o->type == type) {
    209     return static_cast<T>(o);
    210   }
    211 
    212   return nullptr;
    213 }
    214 
    215 void CheckIsReadyForQuantization(const Model& model);
    216 
    217 bool ReshapeIsEquivalentToTranspose(const Model& model,
    218                                     const TensorFlowReshapeOperator* op,
    219                                     bool allow_extra_unary_dims);
    220 
    221 inline int Offset(const Shape& shape, const std::vector<int>& indices) {
    222   DCHECK_EQ(shape.dimensions_count(), indices.size());
    223   const int dims_count = shape.dimensions_count();
    224   int offset = 0;
    225   for (int i = 0; i < dims_count; i++) {
    226     const int index = indices[i];
    227     DCHECK(index >= 0 && index < shape.dims(i));
    228     offset *= shape.dims(i);
    229     offset += index;
    230   }
    231   return offset;
    232 }
    233 
    234 inline std::vector<int> ReverseOffset(const Shape& shape, int index) {
    235   DCHECK_GE(index, 0);
    236   DCHECK_LT(index, RequiredBufferSizeForShape(shape));
    237   const int dims_count = shape.dimensions_count();
    238   std::vector<int> indices(dims_count);
    239   int residual = index;
    240   for (int i = dims_count - 1; i >= 0; i--) {
    241     indices[i] = residual % shape.dims(i);
    242     residual /= shape.dims(i);
    243   }
    244   return indices;
    245 }
    246 
    247 int ElementSize(ArrayDataType data_type);
    248 
    249 void DropMinMax(Model* model, const string& array_name);
    250 
    251 bool IsAllocatableTransientArray(const Model& model, const string& array_name);
    252 
    253 void CreateOrCheckRnnStateArray(const string& name, int size,
    254                                 int state_num_dims, Model* model);
    255 
    256 string AvailableArrayName(const Model& model, const string& name);
    257 
    258 // Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ].
    259 string ShapeToString(const Shape& shape);
    260 
    261 void PrintArrayShape(Model* model, const string& name);
    262 
    263 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
    264                    std::vector<int>* out_dims);
    265 
    266 // Defines a constant int32 array with the provided values formatted for use
    267 // as op parameters.
    268 string CreateInt32Array(Model* model, const string& param_name,
    269                         const std::vector<int>& value);
    270 
    271 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
    272                                 int64* result);
    273 bool EstimateArithmeticOpsCount(const Model& model, int64* result);
    274 string FormattedNumber(int64 x);
    275 
    276 int AxesCount(AxesOrder axes_order);
    277 
    278 // Returns the permutation of the dimensions based on the input axes order and
    279 // output axes order.
    280 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
    281                      std::vector<int>* shuffle);
    282 
    283 // Extend shuffle is designed to match ExtendShape, which pads the shape with
    284 // unit dimensions at the beginning.
    285 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
    286                    std::vector<int>* extended_shuffle);
    287 
    288 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
    289                  AxesOrder output_axes_order, Shape* output_shape);
    290 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
    291                   AxesOrder output_axes_order, const Shape& output_shape,
    292                   const float* input_data, float* output_data);
    293 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
    294                   AxesOrder output_axes_order, const Shape& output_shape,
    295                   const uint8* input_data, uint8* output_data);
    296 
    297 // Returns true if it may be OK for any graph transformation to ever discard
    298 // that array. The idea is that we can't ever discard arrays that are either
    299 // an input or an output of the whole graph, or that appear in RNN back-edges,
    300 // as that would undercut explicit flags that the user might pass.
    301 bool IsDiscardableArray(const Model& model, const string& array_name);
    302 
    303 void CheckFinalDataTypesSatisfied(const Model& model);
    304 
    305 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
    306 
    307 // The process of building models varies according to the import format.
    308 //
    309 // (a) In some cases, such as model-proto format, the model should be fully
    310 // specified. In these cases, no extra action should be taken by this function.
    311 // (b) In other cases, such as TF graphdef format, the desired types of RNN
    312 // arrays are not specified directly in the model, neither can they be inferred.
    313 // However, we can set the types of RNN destination arrays to float. This breaks
    314 // any cycles such as when resolution of the type of an RNN source array depends
    315 // on the type of its destination array.
    316 //
    317 // This function is applied after the main import, after resolution of flags and
    318 // after application of ArraysExtraInfo. It only defaults destination RNN arrays
    319 // to float. If the model is subsequently quantized, it is assumed that the
    320 // model contains sufficient information for that to be completed. If it is
    321 // already quantized, then case (a) should hold.
    322 void FinishBuildingRNNStates(Model* model);
    323 
    324 void UseArraysExtraInfo(Model* model, bool quantize_output);
    325 
    326 // Calculates the number of elements in tensor given a shape. Shape elements
    327 // are assumed to be of type T, while the result total is of type U. If U
    328 // doesn't have enough range to represent the sum of elements, an error is
    329 // returned.
    330 template <typename T, typename U>
    331 tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
    332   static_assert(
    333       std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(),
    334       "vector type exceed capabilities of NumElements");
    335 
    336   *num_elements = 1;
    337   for (const T& dim : shape) {
    338     if (dim < 0) {
    339       // TensorFlow's shapes sometimes include -1 to represent an "unknown"
    340       // size but TOCO isn't able to create arrays of unknown sizes and will
    341       // crash in RequiredBufferSizeForShape().
    342       return tensorflow::errors::InvalidArgument(
    343           "Tensor shape should not include negative values");
    344     }
    345     if (*num_elements != 0 &&
    346         static_cast<uint64_t>(dim) >
    347             std::numeric_limits<U>::max() / *num_elements) {
    348       *num_elements = 0;
    349       return tensorflow::errors::InvalidArgument("Tensor shape is too large");
    350     }
    351     *num_elements *= dim;
    352   }
    353   return tensorflow::Status::OK();
    354 }
    355 
    356 // A model file may have shuffled FC weights.
    357 // When that happens, we want to de-shuffle them immediately on import,
    358 // so that the rest of toco doesn't need to know about shuffled weights.
    359 void UndoWeightsShuffling(Model* model);
    360 
    361 // Copies minmax, quantization_params, and narrow_range.
    362 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst);
    363 
    364 }  // namespace toco
    365 
    366 #endif  // TENSORFLOW_LITE_TOCO_TOOLING_UTIL_H_
    367