Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
     18 
     19 #include <atomic>
     20 #include <list>
     21 #include <memory>
     22 #include <random>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <vector>
     26 
     27 #include "tensorflow/compiler/xla/iterator_util.h"
     28 #include "tensorflow/compiler/xla/service/hlo.pb.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
     33 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/core/lib/gtl/array_slice.h"
     36 #include "tensorflow/core/lib/gtl/iterator_range.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/platform/mutex.h"
     39 
     40 namespace xla {
     41 
     42 // Describes a compilation unit at the HLO level.
     43 //
     44 // A HLO module contains one or more HLO computations. The module contains one
     45 // "entry" computation which produces the result. The module also includes any
     46 // embedded computations used by instructions such as "map" and "reduce". All
     47 // computations are owned by the module.
     48 class HloModule {
     49  public:
     50   HloModule(const string& name,
     51             const VersionedComputationHandle& entry_computation_handle,
     52             const HloModuleConfig& config);
     53 
     54   // Constructor without a versioned computation handle. This constructor should
     55   // only be used for HloModules used outside of the XLA service (eg
     56   // tests). The versioned handle is used by the service in the compilation
     57   // cache. A default configuration is created for this module.
     58   explicit HloModule(const string& name);
     59   explicit HloModule(const string& name, const HloModuleConfig& config);
     60 
     61   // Adds an entry computation to the module. A module can only have one entry
     62   // computation. Returns a pointer to the newly added computation.
     63   HloComputation* AddEntryComputation(
     64       std::unique_ptr<HloComputation> computation);
     65 
     66   // Adds an embedded computation to the module.
     67   HloComputation* AddEmbeddedComputation(
     68       std::unique_ptr<HloComputation> computation);
     69 
     70   // Removes an embedded computation.
     71   Status RemoveEmbeddedComputation(HloComputation* to_remove);
     72 
     73   // Replaces all uses of computations that are keys of 'replacements' with
     74   // the corresponding values in 'replacements'. Replaces the entry computation,
     75   // if applicable.
     76   //
     77   // This function iterates over all instructions in the module to find
     78   // computations to replace. We could speed it up by keeping track of users of
     79   // computations.
     80   void ReplaceComputations(
     81       const std::unordered_map<HloComputation*, HloComputation*>& replacements);
     82 
     83   const string& name() const { return name_; }
     84 
     85   // Returns a deep copy of this module including all computations.
     86   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
     87 
     88   // Performs a deep clone of the computation, by recursively cloning all
     89   // the called computations as well.
     90   HloComputation* DeepCloneComputation(HloComputation* computation);
     91 
     92   // Return a pointer to the entry computation of the module..
     93   const HloComputation* entry_computation() const {
     94     CHECK_NE(nullptr, entry_computation_);
     95     return entry_computation_;
     96   }
     97   HloComputation* entry_computation() {
     98     CHECK_NE(nullptr, entry_computation_);
     99     return entry_computation_;
    100   }
    101 
    102   ComputationLayout* mutable_entry_computation_layout() {
    103     return config_.mutable_entry_computation_layout();
    104   }
    105 
    106   ComputationLayout entry_computation_layout() const {
    107     return config_.entry_computation_layout();
    108   }
    109 
    110   const VersionedComputationHandle& entry_computation_handle() const {
    111     return entry_computation_handle_;
    112   }
    113 
    114   // Gets the computations in this module.
    115   //
    116   // Returns a view of HloComputation*s, so you can iterate over this in the
    117   // natural way:
    118   //
    119   //   for (HloComputation* c : module->computations()) { ... }
    120   //
    121   tensorflow::gtl::iterator_range<UnwrappingIterator<
    122       std::vector<std::unique_ptr<HloComputation>>::const_iterator>>
    123   computations() const {
    124     return {MakeUnwrappingIterator(computations_.begin()),
    125             MakeUnwrappingIterator(computations_.end())};
    126   }
    127   tensorflow::gtl::iterator_range<UnwrappingIterator<
    128       std::vector<std::unique_ptr<HloComputation>>::iterator>>
    129   computations() {
    130     return {MakeUnwrappingIterator(computations_.begin()),
    131             MakeUnwrappingIterator(computations_.end())};
    132   }
    133 
    134   // Gets the number of computations in this module.
    135   int64 computation_count() const { return computations_.size(); }
    136 
    137   // Gets the number of instructions in this module.
    138   int64 instruction_count() const;
    139 
    140   // Compute and return a post order of all computations in the module. The sort
    141   // is defined like so: if computation A has an instruction which calls
    142   // computation B, then A will appear after B in the sort.
    143   std::list<HloComputation*> MakeComputationPostOrder() const;
    144 
    145   // Gets the computations in this module which aren't for fusion nodes.
    146   //
    147   // Postcondition: All computations in the returned list have
    148   // !IsFusionComputation().
    149   //
    150   // Note: Callers can and do rely on the return value here being a *snapshot*
    151   // of the module's non-fusion computations -- that is, it's OK to add or
    152   // remove computations from a module while iterating over
    153   // MakeNonfusionComputations().
    154   std::vector<HloComputation*> MakeNonfusionComputations() const;
    155 
    156   const HloModuleConfig& config() const { return config_; }
    157 
    158   // Return a string representation of the module.
    159   //
    160   // (We express the default options using an overload rather than a default
    161   // param because gdb ignores default params, but does resolve overloads.)
    162   string ToString() const { return ToString(HloPrintOptions()); }
    163   string ToString(const HloPrintOptions& options) const;
    164 
    165   // Convert an HloModule to or from a proto.
    166   HloModuleProto ToProto() const;
    167   static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
    168       const HloModuleProto& proto, const HloModuleConfig& module_config,
    169       const VersionedComputationHandle& entry_computation_handle =
    170           VersionedComputationHandle());
    171 
    172   // Creates and returns an HloModuleConfig with an appropriate program shape
    173   // for the HLO module in the given proto.
    174   static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
    175       const HloModuleProto& module);
    176 
    177   // Outlines the given expression from the given computation.
    178   // instructions_to_outline contains the instructions that form the expression.
    179   //
    180   // Precondition: instructions in instructions_to_outline are in topological
    181   // order (root of outlined instructions last). TODO(jingyue): takes a set of
    182   // instructions and topologically sorts them.
    183   HloInstruction* OutlineExpressionFromComputation(
    184       tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
    185       const string& outlined_computation_name, HloComputation* computation);
    186 
    187   // Returns a randomly generated uint64.
    188   uint64 RandomNew64() const;
    189 
    190   // Returns the unique name for a computation in this module.
    191   string GetUniqueCompuationName(const string& prefix) {
    192     return computation_name_uniquer_.GetUniqueName(prefix);
    193   }
    194 
    195   // Returns the NameUniquer for uniquing instruction names in this module.
    196   NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; }
    197 
    198   // Assign a new unique dense id for an instruction
    199   int NewUniqueInstructionId() {
    200     int result = next_unique_id_;
    201     next_unique_id_++;
    202     return result;
    203   }
    204 
    205   // Returns the number of unique intruction ids given out.  All ids up to
    206   // this point are guaranteed to be in the range [0..NumUniqueInstructionIds())
    207   int NumUniqueInstructionIds() const { return next_unique_id_; }
    208 
    209   // Returns an id that is unique to this module across all modules created over
    210   // the lifetime of this process.
    211   int unique_id() const { return unique_id_; }
    212 
    213  private:
    214   HloComputation* AddComputationInternal(
    215       std::unique_ptr<HloComputation> computation, bool is_entry,
    216       bool uniquify_names);
    217 
    218   const string name_;
    219   HloModuleConfig config_;
    220   HloComputation* entry_computation_ = nullptr;
    221   std::vector<std::unique_ptr<HloComputation>> computations_;
    222 
    223   // Random number generator engine to use when generating random numbers per
    224   // HloModule compilation.
    225   // TODO(b/25995601): Replace with better seed setting or dev/random for
    226   // where we don't need deterministic execution.
    227   mutable std::mt19937_64 rng_{42};
    228   mutable tensorflow::mutex rng_mutex_;
    229 
    230   // Versioned handle of the entry computation of the module.
    231   bool has_entry_computation_handle_ = false;
    232   VersionedComputationHandle entry_computation_handle_;
    233 
    234   // Unique name generator for computation and instruction names, which are
    235   // unique per module.
    236   NameUniquer computation_name_uniquer_{/*separator=*/"."};
    237   NameUniquer instruction_name_uniquer_{/*separator=*/"."};
    238   int next_unique_id_ = 0;
    239 
    240   // Used to keep track of the next unique module id that should be assigned.
    241   static std::atomic<int> next_unique_module_id_;
    242   // A unique id to label modules with.
    243   int unique_id_;
    244 };
    245 
    246 }  // namespace xla
    247 
    248 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
    249