Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
     18 
     19 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
     20 #include "tensorflow/compiler/xla/service/hlo_module.h"
     21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     22 
     23 namespace xla {
     24 namespace cpu {
     25 
     26 // Simple interface for different parallel cost model implementations.
     27 class ParallelCostModel {
     28  public:
     29   virtual ~ParallelCostModel() = default;
     30   virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
     31 };
     32 
     33 // ParallelTaskAssignment computes parallel task counts for HLOs in 'module'.
     34 class ParallelTaskAssignment {
     35  public:
     36   // 'max_parallelism': the maximum parallel task count per instruction.
     37   // 'shape_size': shape size function used by HloCostAnalysis during parallel
     38   //               task assignment.
     39   // 'module': the containing HloModule.
     40   ParallelTaskAssignment(const int64 max_parallelism,
     41                          const HloCostAnalysis::ShapeSizeFunction& shape_size,
     42                          HloModule* module);
     43   ~ParallelTaskAssignment() {}
     44 
     45   // Computes and returns the target parallel task count for 'instruction'.
     46   int64 GetTargetParallelTaskCount(HloInstruction* instruction);
     47 
     48  private:
     49   std::unique_ptr<ParallelCostModel> cost_model_;
     50 };
     51 
     52 // ParallelTaskAssigner computes target parallel task counts for all HLOs
     53 // in the module, then assigns parallel task counts to HLOs in the entry
     54 // computation, or to HLOs in embedded computations invoked by (potentially
     55 // nested) kWhile or kCall instructions.
     56 // Each HLO which is assigned parallel task counts is outlined into its
     57 // own embedded computation, which is compiled as a parallel compute function,
     58 // and which is invoked from a kCall instruction that is lowered in codegen to
     59 // a runtime parallel fork/join call.
     60 class ParallelTaskAssigner : public HloPassInterface {
     61  public:
     62   // 'max_parallelism': the maximum parallel task count per instruction.
     63   // 'shape_size': shape size function used by HloCostAnalysis during parallel
     64   //               task assignment.
     65   ParallelTaskAssigner(const int64 max_parallelism,
     66                        const HloCostAnalysis::ShapeSizeFunction& shape_size)
     67       : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {}
     68   ~ParallelTaskAssigner() override {}
     69 
     70   tensorflow::StringPiece name() const override {
     71     return "cpu-parallel-task-assigner";
     72   }
     73 
     74   // Run parallel task assigner on 'module'.
     75   // Returns true if the computation was changed, false otherwise.
     76   StatusOr<bool> Run(HloModule* module) override;
     77 
     78  private:
     79   using HloToParallelTasks = std::unordered_map<const HloInstruction*, int64>;
     80 
     81   // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in
     82   // 'module'.
     83   // Returns true if the computation was changed, false otherwise.
     84   bool AssignParallelTasks(HloModule* module,
     85                            const HloToParallelTasks& hlo_to_parallel_tasks);
     86   bool AssignParallelTasksHelper(
     87       HloModule* module, HloComputation* computation,
     88       const HloToParallelTasks& hlo_to_parallel_tasks);
     89 
     90   // Computes target parallel task counts (returned in 'parallel_task_counts')
     91   // for parallelizable instructions in 'module'.
     92   void ComputeTargetParallelTasks(HloModule* module,
     93                                   HloToParallelTasks* hlo_to_parallel_tasks);
     94 
     95   int64 max_parallelism_;
     96   HloCostAnalysis::ShapeSizeFunction shape_size_function_;
     97 };
     98 
     99 }  // namespace cpu
    100 }  // namespace xla
    101 
    102 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
    103