Home | History | Annotate | Download | only in evaluation
      1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
     16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
     17 
     18 #include <functional>
     19 #include <map>
     20 #include <regex>  // NOLINT
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "absl/container/flat_hash_map.h"
     25 #include "absl/memory/memory.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
     28 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
     29 
     30 namespace tflite {
     31 namespace evaluation {
     32 
     33 class EvaluationStage;
     34 
     35 typedef std::function<std::unique_ptr<EvaluationStage>(
     36     const EvaluationStageConfig&)>
     37     FactoryFunc;
     38 
     39 // Superclass for a single stage of an EvaluationPipeline.
     40 // Provides basic functionality for construction and accessing
     41 // initializers/inputs/outputs.
     42 // Every subclass of EvaluationStage will define its own behavior by specifying
     43 // appropriate accessor TAGs and implementing the Init, Run and Close methods.
     44 class EvaluationStage {
     45  public:
     46   // Initializes an EvaluationStage. Returns false if initialization failed,
     47   // true otherwise.
     48   // Should be called only once, before any call to Run().
     49   // object_map should contain {initializer name : object pointer} mappings
     50   // required for initialization.
     51   //
     52   // NOTE: EvaluationStage will not take ownership of any elements of
     53   // object_map.
     54   bool Init(absl::flat_hash_map<std::string, void*>& object_map);
     55 
     56   // An individual run of the EvaluationStage. Returns false if there was a
     57   // failure, true otherwise.
     58   // Init() should be called before any calls to run().
     59   // Inputs are acquired from and outputs are written to the incoming
     60   // object_map, using appropriate TAGs.
     61   //
     62   // NOTE: The EvaluationStage should maintain ownership of outputs it
     63   // populates into object_map. Ownership of inputs must be maintained
     64   // elsewhere.
     65   virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map) = 0;
     66 
     67   // Returns the latest metrics based on all Run() calls made so far.
     68   virtual EvaluationStageMetrics LatestMetrics() = 0;
     69 
     70   // The canonical way to instantiate EvaluationStages.
     71   // Remember to call <classname>_ENABLE() first.
     72   static std::unique_ptr<EvaluationStage> Create(
     73       const EvaluationStageConfig& config) {
     74     if (!config.has_specification() ||
     75         !config.specification().has_process_class()) {
     76       LOG(ERROR) << "Process specification not present in config: "
     77                  << config.name();
     78       return nullptr;
     79     }
     80     auto& factory_ptr =
     81         (*GetFactoryMapPtr())[config.specification().process_class()];
     82     if (!factory_ptr) return nullptr;
     83     return factory_ptr(config);
     84   }
     85 
     86   // Used by DEFINE_REGISTRATION.
     87   // This method takes ownership of factory.
     88   // Should only be used via DEFINE_REGISTRATION macro.
     89   static void RegisterStage(const ProcessClass& process_class,
     90                             FactoryFunc class_factory) {
     91     (*GetFactoryMapPtr())[process_class] = std::move(class_factory);
     92   }
     93 
     94   virtual ~EvaluationStage() = default;
     95 
     96  protected:
     97   // Constructs an EvaluationStage.
     98   // Each subclass constructor must invoke this constructor.
     99   //
    100   // NOTE: Do NOT use constructors to obtain new EvaluationStages. Use
    101   // EvaluationStage::Create instead.
    102   explicit EvaluationStage(const EvaluationStageConfig& config)
    103       : config_(config) {}
    104 
    105   // Class-specific initialization, to be overridden by EvaluationStage
    106   // sub-classes. Gets called in EvaluationStage::Init().
    107   //
    108   // NOTE: This object should not take ownership of any elements of object_map.
    109   virtual bool DoInit(absl::flat_hash_map<std::string, void*>& object_map) = 0;
    110 
    111   // The three following functions return the initializer/input/output TAGs used
    112   // by an EvaluationStage. These should be mapped to meaningful names in the
    113   // EvaluationStageConfig, and to required objects during calls to Init/Run.
    114   // Format for TAGs: [A-Z0-9_]+ (Uppercase letters, numbers, "_")
    115   // Refer docs in tflite.evaluation.EvaluationStageConfig for more information.
    116 
    117   // Returns the expected initializer TAGs.
    118   virtual std::vector<std::string> GetInitializerTags() = 0;
    119 
    120   // Returns the expected input TAGs.
    121   virtual std::vector<std::string> GetInputTags() = 0;
    122 
    123   // Returns the expected output TAGs.
    124   virtual std::vector<std::string> GetOutputTags() = 0;
    125 
    126   // Populates a pointer to the object corresponding to provided TAG.
    127   // Returns true if success, false otherwise.
    128   // object_map contain a {name : object pointer} mapping, with the
    129   // name being mapped to the expected TAG in the EvaluationStageConfig.
    130   // NOTE: object pointer must be non-NULL.
    131   template <class T>
    132   bool GetObjectFromTag(const std::string& tag,
    133                         absl::flat_hash_map<std::string, void*>& object_map,
    134                         T** object_ptr) {
    135     *object_ptr = nullptr;
    136     // Find name corresponding to TAG.
    137     auto mapping_iter = tags_to_names_map_.find(tag);
    138     if (mapping_iter == tags_to_names_map_.end()) {
    139       LOG(ERROR) << "Unexpected TAG to GetObjectFromTag: " << tag;
    140       return false;
    141     }
    142     const std::string& expected_name = mapping_iter->second;
    143 
    144     // Find object from name.
    145     auto object_iter = object_map.find(expected_name);
    146     if (object_iter == object_map.end()) {
    147       LOG(ERROR) << "Could not find object for name: " << expected_name;
    148       return false;
    149     }
    150     if (!object_iter->second) {
    151       LOG(ERROR) << "Found null pointer for name: " << expected_name;
    152       return false;
    153     }
    154     *object_ptr = static_cast<T*>(object_iter->second);
    155     return true;
    156   }
    157 
    158   // Maps the appropriate name to a given object in object_map. The name is
    159   // derived from mappings provided in the EvaluationStageConfig.
    160   // Returns false if tag is invalid, true otherwise.
    161   //
    162   // NOTE: The EvaluationStage must maintain ownership of object for the
    163   // lifetime of object_map
    164   bool AssignObjectToTag(const std::string& tag, void* object_ptr,
    165                          absl::flat_hash_map<std::string, void*>& object_map) {
    166     // Find name corresponding to TAG.
    167     auto mapping_iter = tags_to_names_map_.find(tag);
    168     if (mapping_iter == tags_to_names_map_.end()) {
    169       LOG(ERROR) << "Unexpected TAG to AssignObjectToTag: " << tag;
    170       return false;
    171     }
    172     const std::string& expected_name = mapping_iter->second;
    173 
    174     object_map[expected_name] = object_ptr;
    175     return true;
    176   }
    177 
    178   EvaluationStageConfig config_;
    179 
    180  private:
    181   // Verifies that all TAGs from expected_tags are present in
    182   // tag_to_name_mappings, and then populates tags_to_names_map_ with the
    183   // appropriate entries. Returns false in case any TAG/mapping is invalid, true
    184   // otherwise.
    185   // expected_tags should be a list of TAG-strings.
    186   // tag_to_name_mappings should be RepeatedPtrField of strings mapping TAGs to
    187   // names in the form "SOME_TAG:some_name".
    188   bool ProcessExpectedTags(const std::vector<std::string>& expected_tags,
    189                            std::vector<std::string>& tag_to_name_mappings);
    190 
    191   static std::map<ProcessClass, FactoryFunc>* GetFactoryMapPtr() {
    192     return process_class_to_factory_map_;
    193   }
    194 
    195   // Used by factories.
    196   static std::map<ProcessClass, FactoryFunc>* process_class_to_factory_map_;
    197 
    198   // Maps expected TAGs to their names as defined by the EvaluationStageConfig.
    199   absl::flat_hash_map<std::string, std::string> tags_to_names_map_;
    200 
    201   // To ensure correct formatting in the config.
    202   const std::regex kTagNameMappingPattern{"^([A-Z0-9_]+):([a-z0-9_]+)$",
    203                                           std::regex::optimize};
    204 
    205   // To ensure correct formatting in TAG names.
    206   const std::regex kTagPattern{"^[A-Z0-9_]+$", std::regex::optimize};
    207 };
    208 
    209 // Add this to headers of new EvaluationStages.
    210 #define DECLARE_FACTORY(classname) void classname##_ENABLE();
    211 
    212 // Add this to implementation files of new EvaluationStages.
    213 // Call <stage_name>_ENABLE() before using EvaluationStage::Create for the
    214 // class.
    215 #define DEFINE_FACTORY(classname, processclass)                                \
    216   void classname##_ENABLE() {                                                  \
    217     FactoryFunc classname##Factory = [](const EvaluationStageConfig& config) { \
    218       return absl::make_unique<classname>(config);                             \
    219     };                                                                         \
    220     EvaluationStage::RegisterStage(processclass, classname##Factory);          \
    221   }
    222 
    223 // Use this to assign a non-nullptr pointer to tag in object_map.
    224 #define ASSIGN_OBJECT(tag, ptr, object_map)       \
    225   if (!AssignObjectToTag(tag, ptr, object_map)) { \
    226     return false;                                 \
    227   }
    228 
    229 // Use this to obtain pointers to required object.
    230 // Will return false if name corresponding to tag is not found, or if the
    231 // pointer found is nullptr.
    232 #define GET_OBJECT(tag, object_map, location)         \
    233   if (!GetObjectFromTag(tag, object_map, location)) { \
    234     return false;                                     \
    235   }
    236 
    237 }  // namespace evaluation
    238 }  // namespace tflite
    239 
    240 #endif  // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
    241