1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ 16 #define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ 17 18 #include <cstddef> 19 #include <initializer_list> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "tensorflow/contrib/lite/toco/model.h" 24 #include "tensorflow/contrib/lite/toco/toco_port.h" 25 26 namespace toco { 27 28 class GraphTransformation { 29 public: 30 virtual bool Run(Model* model, std::size_t op_index) = 0; 31 virtual const char* Name() const = 0; 32 virtual ~GraphTransformation() {} 33 // Returns the list of messages that this graph transformation 34 // generated since ClearMessages() was called. 35 const std::vector<string>& Messages() const { return messages_; } 36 // Clears the list of messages; should be called after every 37 // run of this graph transformation. 38 void ClearMessages() { return messages_.clear(); } 39 // Adds a message; normally only called by the graph transformation 40 // itself during its run (this function could be protected). 41 template <typename... Args> 42 void AddMessageF(const char* format, const Args&... args) { 43 return messages_.push_back(toco::port::StringF(format, args...)); 44 } 45 46 protected: 47 GraphTransformation() {} 48 49 // List of messages generated by this graph transformation. 50 std::vector<string> messages_; 51 52 private: 53 GraphTransformation(const GraphTransformation& other) = delete; 54 GraphTransformation(const GraphTransformation&& other) = delete; 55 }; 56 57 class GraphTransformationsSet { 58 public: 59 // The choice of a container with fully-specified iteration order 60 // ensures that graph transformations are always run in the same order, 61 // which avoids having toco randomly fail or produce different results 62 // depending on the toolchain. Ideally success/results should be independent 63 // of the order in which graph transformations are run, but that's 64 // unfortunately not currently guaranteed to be the case. 65 using TransformationsContainer = 66 std::vector<std::unique_ptr<GraphTransformation>>; 67 68 GraphTransformationsSet() {} 69 GraphTransformationsSet( 70 const std::initializer_list<GraphTransformation*> transformations) { 71 for (GraphTransformation* t : transformations) { 72 Add(t); 73 } 74 } 75 void Add(GraphTransformation* transformation) { 76 const string& name = transformation->Name(); 77 CHECK(!names_.count(name)); 78 names_.insert(name); 79 transformations_.emplace_back(transformation); 80 } 81 TransformationsContainer::const_iterator begin() const { 82 return transformations_.begin(); 83 } 84 TransformationsContainer::const_iterator end() const { 85 return transformations_.end(); 86 } 87 bool empty() const { return transformations_.empty(); } 88 89 private: 90 GraphTransformationsSet(const GraphTransformationsSet& other) = delete; 91 GraphTransformationsSet(const GraphTransformationsSet&& other) = delete; 92 std::vector<std::unique_ptr<GraphTransformation>> transformations_; 93 // Names of transformations in the set. Only used to guard against dupes. 94 std::unordered_set<string> names_; 95 }; 96 97 // Run the given list of graph transformations on the model. 98 // The message is only for logging purposes. 99 // The transformations is a rvalue reference, indicating that 100 // nothing else will use these pointers. The user is supposed to 101 // construct GraphTransformation objects by using 'new', pass us 102 // the resulting raw pointers, and this RunGraphTransformations 103 // takes care of delete'ing these pointers. 104 void RunGraphTransformations(Model* model, const string& message, 105 const GraphTransformationsSet& transformations); 106 107 #define DECLARE_GRAPH_TRANSFORMATION(GTName) \ 108 class GTName : public GraphTransformation { \ 109 public: \ 110 bool Run(Model* model, std::size_t op_index) override; \ 111 const char* Name() const override { return #GTName; } \ 112 }; 113 114 // List of all graph transformations 115 DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) 116 DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) 117 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) 118 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) 119 DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) 120 DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) 121 DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) 122 DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) 123 DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) 124 DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) 125 DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) 126 DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) 127 DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) 128 DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) 129 DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) 130 DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) 131 DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) 132 DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) 133 DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) 134 DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) 135 DECLARE_GRAPH_TRANSFORMATION(Quantize) 136 DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) 137 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) 138 DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) 139 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) 140 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) 141 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) 142 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice) 143 DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) 144 DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) 145 DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) 146 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) 147 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) 148 DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) 149 DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) 150 DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) 151 DECLARE_GRAPH_TRANSFORMATION(ReorderActivationFunctions) 152 DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) 153 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) 154 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) 155 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) 156 DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) 157 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) 158 DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) 159 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) 160 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) 161 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) 162 DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) 163 DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) 164 DECLARE_GRAPH_TRANSFORMATION(UnrollBatchMatMul) 165 DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes) 166 DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes) 167 DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) 168 DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) 169 DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) 170 DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) 171 DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) 172 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) 173 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) 174 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) 175 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) 176 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) 177 DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) 178 DECLARE_GRAPH_TRANSFORMATION(Dequantize) 179 180 class ResolveReshapeAttributes : public GraphTransformation { 181 public: 182 bool Run(Model* model, std::size_t op_index) override; 183 const char* Name() const override { return "ResolveReshapeAttributes"; } 184 }; 185 186 class RemoveTrivialReshape : public GraphTransformation { 187 public: 188 bool Run(Model* model, std::size_t op_index) override; 189 const char* Name() const override { return "RemoveTrivialReshape"; } 190 bool treat_expand_dims_as_trivial() const { 191 return treat_expand_dims_as_trivial_; 192 } 193 void set_treat_expand_dims_as_trivial(bool val) { 194 treat_expand_dims_as_trivial_ = val; 195 } 196 197 private: 198 bool treat_expand_dims_as_trivial_ = false; 199 }; 200 201 #undef DECLARE_GRAPH_TRANSFORMATION 202 203 } // end namespace toco 204 205 #endif // TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ 206