Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Classes to maintain a static registry of whole-graph optimization
     17 // passes to be applied by the Session when it initializes a graph.
     18 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
     19 #define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
     20 
     21 #include <functional>
     22 #include <map>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/common_runtime/device_set.h"
     26 #include "tensorflow/core/framework/function.h"
     27 #include "tensorflow/core/graph/costmodel.h"
     28 #include "tensorflow/core/graph/graph.h"
     29 
     30 namespace tensorflow {
     31 struct SessionOptions;
     32 
     33 // All the parameters used by an optimization pass are packaged in
     34 // this struct. They should be enough for the optimization pass to use
     35 // as a key into a state dictionary if it wants to keep state across
     36 // calls.
     37 struct GraphOptimizationPassOptions {
     38   string session_handle;
     39   const SessionOptions* session_options = nullptr;
     40   const CostModel* cost_model = nullptr;
     41 
     42   FunctionLibraryDefinition* flib_def = nullptr;  // Not owned.
     43   // The DeviceSet contains all the devices known to the system and is
     44   // filled in for optimizations run by the session master, i.e.,
     45   // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is
     46   // nullptr for POST_PARTITIONING optimizations which are run at the
     47   // workers.
     48   const DeviceSet* device_set = nullptr;  // Not owned.
     49 
     50   // The graph to optimize, for optimization passes that run before
     51   // partitioning. Null for post-partitioning passes.
     52   // An optimization pass may replace *graph with a new graph object.
     53   std::unique_ptr<Graph>* graph = nullptr;
     54 
     55   // Graphs for each partition, if running post-partitioning. Optimization
     56   // passes may alter the graphs, but must not add or remove partitions.
     57   // Null for pre-partitioning passes.
     58   std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs =
     59       nullptr;
     60 };
     61 
     62 // Optimization passes are implemented by inheriting from
     63 // GraphOptimizationPass.
     64 class GraphOptimizationPass {
     65  public:
     66   virtual ~GraphOptimizationPass() {}
     67   virtual Status Run(const GraphOptimizationPassOptions& options) = 0;
     68 };
     69 
     70 // The key is a 'phase' number. Phases are executed in increasing
     71 // order. Within each phase the order of passes is undefined.
     72 typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>>
     73     GraphOptimizationPasses;
     74 
     75 // A global OptimizationPassRegistry is used to hold all passes.
     76 class OptimizationPassRegistry {
     77  public:
     78   // Groups of passes are run at different points in initialization.
     79   enum Grouping {
     80     PRE_PLACEMENT,          // after cost model assignment, before placement.
     81     POST_PLACEMENT,         // after placement.
     82     POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
     83     POST_PARTITIONING,      // after partitioning
     84   };
     85 
     86   // Add an optimization pass to the registry.
     87   void Register(Grouping grouping, int phase,
     88                 std::unique_ptr<GraphOptimizationPass> pass);
     89 
     90   // Run all passes in grouping, ordered by phase, with the same
     91   // options.
     92   Status RunGrouping(Grouping grouping,
     93                      const GraphOptimizationPassOptions& options);
     94 
     95   // Returns the global registry of optimization passes.
     96   static OptimizationPassRegistry* Global();
     97 
     98  private:
     99   std::map<Grouping, GraphOptimizationPasses> groups_;
    100 };
    101 
    102 namespace optimization_registration {
    103 
    104 class OptimizationPassRegistration {
    105  public:
    106   OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
    107                                int phase,
    108                                std::unique_ptr<GraphOptimizationPass> pass) {
    109     OptimizationPassRegistry::Global()->Register(grouping, phase,
    110                                                  std::move(pass));
    111   }
    112 };
    113 
    114 }  // namespace optimization_registration
    115 
    116 #define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
    117   REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)
    118 
    119 #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
    120   REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
    121 
    122 #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
    123   static optimization_registration::OptimizationPassRegistration       \
    124       register_optimization_##ctr(                                     \
    125           grouping, phase,                                             \
    126           std::unique_ptr<GraphOptimizationPass>(new optimization))
    127 
    128 }  // namespace tensorflow
    129 
    130 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
    131