Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
     18 
     19 #include <string>
     20 
     21 #include "tensorflow/compiler/xla/service/computation_layout.h"
     22 #include "tensorflow/compiler/xla/types.h"
     23 #include "tensorflow/compiler/xla/xla.pb.h"
     24 #include "tensorflow/compiler/xla/xla_data.pb.h"
     25 #include "tensorflow/core/lib/gtl/optional.h"
     26 
     27 namespace xla {
     28 
     29 // This class gathers all settings and values which affect the compiled
     30 // executable outside of the HLO code itself. This include layouts of inputs and
     31 // outputs to the module and settings such as HLO profiling. Together the
     32 // HloModule and HloModuleConfig unambiguously determine a particular
     33 // executable.
     34 class HloModuleConfig {
     35  public:
     36   // A configuration can be created either with, or without an entry
     37   // ComputationLayout. The default ctor creates it without -- in this case
     38   // accessing entry_computation_layout will CHECK-fail. The ctor accepting a
     39   // ProgramShape creates a computation layout using this shape.
     40   HloModuleConfig();
     41   explicit HloModuleConfig(const ProgramShape& program_shape);
     42 
     43   // Checks if this config has an entry computation layout already.
     44   bool has_entry_computation_layout() const {
     45     return entry_computation_layout_.has_value();
     46   }
     47 
     48   // Sets the entry computation layout for this config. If the entry computation
     49   // layout already exists, it is silently replaced.
     50   void SetDefaultComputationLayout(const ProgramShape& program_shape);
     51 
     52   // Returns a constant reference to the layout of the entry computation.
     53   // Assumes the layout was set.
     54   const ComputationLayout& entry_computation_layout() const {
     55     CHECK(entry_computation_layout_.has_value());
     56     return *entry_computation_layout_;
     57   }
     58 
     59   // Returns a mutable pointer to the layout of the entry computation. Assumes
     60   // the layout was set.
     61   ComputationLayout* mutable_entry_computation_layout() {
     62     CHECK(entry_computation_layout_.has_value());
     63     return &(*entry_computation_layout_);
     64   }
     65 
     66   // Sets/returns whether to enable HLO-level profiling.
     67   bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; }
     68   void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; }
     69 
     70   // Sets/returns the module seed set during execution.
     71   void set_seed(uint64 seed) { seed_ = seed; }
     72   uint64 seed() const { return seed_; }
     73 
     74   void set_replica_count(int64 replica_count) {
     75     replica_count_ = replica_count;
     76   }
     77   int64 replica_count() const { return replica_count_; }
     78 
     79   // Return a string which unambiguously represents all the fields of this data
     80   // structure. Used for generating a cache key for storing the compiled
     81   // executable.
     82   string compilation_cache_key() const;
     83 
     84   const DebugOptions& debug_options() const { return debug_options_; }
     85 
     86   void set_debug_options(const DebugOptions& debug_options) {
     87     debug_options_ = debug_options;
     88   }
     89 
     90   // Sets/returns the number of intra op threads for this module.
     91   void set_intra_op_parallelism_threads(
     92       const int intra_op_parallelism_threads) {
     93     intra_op_parallelism_threads_ = intra_op_parallelism_threads;
     94   }
     95   int64 intra_op_parallelism_threads() const {
     96     return intra_op_parallelism_threads_;
     97   }
     98 
     99  private:
    100   // If you add new members, be sure to update compilation_cache_key.
    101 
    102   tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
    103 
    104   // Whether to enable HLO-level profiling.
    105   bool hlo_profiling_enabled_ = false;
    106 
    107   // Module/graph-level seed handle.
    108   uint64 seed_ = 0;
    109 
    110   // The number of replicas to compile this binary for.
    111   int64 replica_count_ = 1;
    112 
    113   // The target maximum parallelism at which to partition HLOs for parallel
    114   // execution on the CPU backend.
    115   int64 intra_op_parallelism_threads_ = -1;
    116 
    117   DebugOptions debug_options_;
    118 };
    119 
    120 }  // namespace xla
    121 
    122 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
    123