1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_ 16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_ 17 18 #include <functional> 19 #include <map> 20 #include <regex> // NOLINT 21 #include <utility> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/memory/memory.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" 28 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" 29 30 namespace tflite { 31 namespace evaluation { 32 33 class EvaluationStage; 34 35 typedef std::function<std::unique_ptr<EvaluationStage>( 36 const EvaluationStageConfig&)> 37 FactoryFunc; 38 39 // Superclass for a single stage of an EvaluationPipeline. 40 // Provides basic functionality for construction and accessing 41 // initializers/inputs/outputs. 42 // Every subclass of EvaluationStage will define its own behavior by specifying 43 // appropriate accessor TAGs and implementing the Init, Run and Close methods. 44 class EvaluationStage { 45 public: 46 // Initializes an EvaluationStage. Returns false if initialization failed, 47 // true otherwise. 48 // Should be called only once, before any call to Run(). 49 // object_map should contain {initializer name : object pointer} mappings 50 // required for initialization. 51 // 52 // NOTE: EvaluationStage will not take ownership of any elements of 53 // object_map. 54 bool Init(absl::flat_hash_map<std::string, void*>& object_map); 55 56 // An individual run of the EvaluationStage. Returns false if there was a 57 // failure, true otherwise. 58 // Init() should be called before any calls to run(). 59 // Inputs are acquired from and outputs are written to the incoming 60 // object_map, using appropriate TAGs. 61 // 62 // NOTE: The EvaluationStage should maintain ownership of outputs it 63 // populates into object_map. Ownership of inputs must be maintained 64 // elsewhere. 65 virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map) = 0; 66 67 // Returns the latest metrics based on all Run() calls made so far. 68 virtual EvaluationStageMetrics LatestMetrics() = 0; 69 70 // The canonical way to instantiate EvaluationStages. 71 // Remember to call <classname>_ENABLE() first. 72 static std::unique_ptr<EvaluationStage> Create( 73 const EvaluationStageConfig& config) { 74 if (!config.has_specification() || 75 !config.specification().has_process_class()) { 76 LOG(ERROR) << "Process specification not present in config: " 77 << config.name(); 78 return nullptr; 79 } 80 auto& factory_ptr = 81 (*GetFactoryMapPtr())[config.specification().process_class()]; 82 if (!factory_ptr) return nullptr; 83 return factory_ptr(config); 84 } 85 86 // Used by DEFINE_REGISTRATION. 87 // This method takes ownership of factory. 88 // Should only be used via DEFINE_REGISTRATION macro. 89 static void RegisterStage(const ProcessClass& process_class, 90 FactoryFunc class_factory) { 91 (*GetFactoryMapPtr())[process_class] = std::move(class_factory); 92 } 93 94 virtual ~EvaluationStage() = default; 95 96 protected: 97 // Constructs an EvaluationStage. 98 // Each subclass constructor must invoke this constructor. 99 // 100 // NOTE: Do NOT use constructors to obtain new EvaluationStages. Use 101 // EvaluationStage::Create instead. 102 explicit EvaluationStage(const EvaluationStageConfig& config) 103 : config_(config) {} 104 105 // Class-specific initialization, to be overridden by EvaluationStage 106 // sub-classes. Gets called in EvaluationStage::Init(). 107 // 108 // NOTE: This object should not take ownership of any elements of object_map. 109 virtual bool DoInit(absl::flat_hash_map<std::string, void*>& object_map) = 0; 110 111 // The three following functions return the initializer/input/output TAGs used 112 // by an EvaluationStage. These should be mapped to meaningful names in the 113 // EvaluationStageConfig, and to required objects during calls to Init/Run. 114 // Format for TAGs: [A-Z0-9_]+ (Uppercase letters, numbers, "_") 115 // Refer docs in tflite.evaluation.EvaluationStageConfig for more information. 116 117 // Returns the expected initializer TAGs. 118 virtual std::vector<std::string> GetInitializerTags() = 0; 119 120 // Returns the expected input TAGs. 121 virtual std::vector<std::string> GetInputTags() = 0; 122 123 // Returns the expected output TAGs. 124 virtual std::vector<std::string> GetOutputTags() = 0; 125 126 // Populates a pointer to the object corresponding to provided TAG. 127 // Returns true if success, false otherwise. 128 // object_map contain a {name : object pointer} mapping, with the 129 // name being mapped to the expected TAG in the EvaluationStageConfig. 130 // NOTE: object pointer must be non-NULL. 131 template <class T> 132 bool GetObjectFromTag(const std::string& tag, 133 absl::flat_hash_map<std::string, void*>& object_map, 134 T** object_ptr) { 135 *object_ptr = nullptr; 136 // Find name corresponding to TAG. 137 auto mapping_iter = tags_to_names_map_.find(tag); 138 if (mapping_iter == tags_to_names_map_.end()) { 139 LOG(ERROR) << "Unexpected TAG to GetObjectFromTag: " << tag; 140 return false; 141 } 142 const std::string& expected_name = mapping_iter->second; 143 144 // Find object from name. 145 auto object_iter = object_map.find(expected_name); 146 if (object_iter == object_map.end()) { 147 LOG(ERROR) << "Could not find object for name: " << expected_name; 148 return false; 149 } 150 if (!object_iter->second) { 151 LOG(ERROR) << "Found null pointer for name: " << expected_name; 152 return false; 153 } 154 *object_ptr = static_cast<T*>(object_iter->second); 155 return true; 156 } 157 158 // Maps the appropriate name to a given object in object_map. The name is 159 // derived from mappings provided in the EvaluationStageConfig. 160 // Returns false if tag is invalid, true otherwise. 161 // 162 // NOTE: The EvaluationStage must maintain ownership of object for the 163 // lifetime of object_map 164 bool AssignObjectToTag(const std::string& tag, void* object_ptr, 165 absl::flat_hash_map<std::string, void*>& object_map) { 166 // Find name corresponding to TAG. 167 auto mapping_iter = tags_to_names_map_.find(tag); 168 if (mapping_iter == tags_to_names_map_.end()) { 169 LOG(ERROR) << "Unexpected TAG to AssignObjectToTag: " << tag; 170 return false; 171 } 172 const std::string& expected_name = mapping_iter->second; 173 174 object_map[expected_name] = object_ptr; 175 return true; 176 } 177 178 EvaluationStageConfig config_; 179 180 private: 181 // Verifies that all TAGs from expected_tags are present in 182 // tag_to_name_mappings, and then populates tags_to_names_map_ with the 183 // appropriate entries. Returns false in case any TAG/mapping is invalid, true 184 // otherwise. 185 // expected_tags should be a list of TAG-strings. 186 // tag_to_name_mappings should be RepeatedPtrField of strings mapping TAGs to 187 // names in the form "SOME_TAG:some_name". 188 bool ProcessExpectedTags(const std::vector<std::string>& expected_tags, 189 std::vector<std::string>& tag_to_name_mappings); 190 191 static std::map<ProcessClass, FactoryFunc>* GetFactoryMapPtr() { 192 return process_class_to_factory_map_; 193 } 194 195 // Used by factories. 196 static std::map<ProcessClass, FactoryFunc>* process_class_to_factory_map_; 197 198 // Maps expected TAGs to their names as defined by the EvaluationStageConfig. 199 absl::flat_hash_map<std::string, std::string> tags_to_names_map_; 200 201 // To ensure correct formatting in the config. 202 const std::regex kTagNameMappingPattern{"^([A-Z0-9_]+):([a-z0-9_]+)$", 203 std::regex::optimize}; 204 205 // To ensure correct formatting in TAG names. 206 const std::regex kTagPattern{"^[A-Z0-9_]+$", std::regex::optimize}; 207 }; 208 209 // Add this to headers of new EvaluationStages. 210 #define DECLARE_FACTORY(classname) void classname##_ENABLE(); 211 212 // Add this to implementation files of new EvaluationStages. 213 // Call <stage_name>_ENABLE() before using EvaluationStage::Create for the 214 // class. 215 #define DEFINE_FACTORY(classname, processclass) \ 216 void classname##_ENABLE() { \ 217 FactoryFunc classname##Factory = [](const EvaluationStageConfig& config) { \ 218 return absl::make_unique<classname>(config); \ 219 }; \ 220 EvaluationStage::RegisterStage(processclass, classname##Factory); \ 221 } 222 223 // Use this to assign a non-nullptr pointer to tag in object_map. 224 #define ASSIGN_OBJECT(tag, ptr, object_map) \ 225 if (!AssignObjectToTag(tag, ptr, object_map)) { \ 226 return false; \ 227 } 228 229 // Use this to obtain pointers to required object. 230 // Will return false if name corresponding to tag is not found, or if the 231 // pointer found is nullptr. 232 #define GET_OBJECT(tag, object_map, location) \ 233 if (!GetObjectFromTag(tag, object_map, location)) { \ 234 return false; \ 235 } 236 237 } // namespace evaluation 238 } // namespace tflite 239 240 #endif // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_ 241