Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 
     22 #include "absl/container/node_hash_map.h"
     23 #include "absl/memory/memory.h"
     24 #include "absl/types/span.h"
     25 #include "tensorflow/compiler/xla/array2d.h"
     26 #include "tensorflow/compiler/xla/literal.h"
     27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     28 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
     29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     31 #include "tensorflow/compiler/xla/service/hlo_module.h"
     32 #include "tensorflow/compiler/xla/service/shape_inference.h"
     33 #include "tensorflow/compiler/xla/statusor.h"
     34 #include "tensorflow/compiler/xla/util.h"
     35 #include "tensorflow/compiler/xla/xla_data.pb.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 
     38 namespace xla {
     39 
     40 // Responsible for evaluating HLO and obtain literal as the evaluation results.
     41 //
     42 // This class is not thread-safe.
     43 class HloEvaluator : public DfsHloVisitorWithDefault {
     44  public:
     45   // Only evaluate up to max_loop_iterations per while-loop execution if
     46   // specified.
     47   explicit HloEvaluator(int64 max_loop_iterations = -1);
     48 
     49   // Evaluates an HLO module and an array of pointers to literals.  Returns the
     50   // evaluated result as a literal if successful.
     51   //
     52   // Precondition: The indices of arg_literals correspond to the parameter
     53   // numbers of the HLO parameters in the computation. See comment below for an
     54   // example.
     55   //
     56   // (Dummy template arg is to reduce the overloading priority of one overload
     57   // so that Evaluate(module, {}) resolves unambiguously.)
     58   StatusOr<Literal> Evaluate(const HloModule& module,
     59                              absl::Span<const Literal* const> arg_literals) {
     60     return Evaluate(*module.entry_computation(), arg_literals);
     61   }
     62   template <typename Dummy = void>
     63   StatusOr<Literal> Evaluate(const HloModule& module,
     64                              absl::Span<const Literal> arg_literals) {
     65     return Evaluate(*module.entry_computation(), arg_literals);
     66   }
     67 
     68   // Evaluates an HLO computation and an array of pointers to literals.
     69   // Returns the evaluated result as a literal if successful.
     70   // Precondition: The indices of arg_literals correspond to the parameter
     71   // numbers of the HLO parameters in the computation. For e.g., consider the
     72   // following graph:
     73   //
     74   //                *
     75   //            /       \
     76   //            +     Parameter1
     77   //        /      \
     78   //       /        \
     79   //    Parameter0  Constant
     80   //
     81   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
     82   // 1 in this computation. The input literals array will then have its first
     83   // literal map to Parameter0 and the second map to Parameter1.
     84   //
     85   // (Dummy template arg is to reduce the overloading priority of one overload
     86   // so that Evaluate(module, {}) resolves unambiguously.)
     87   StatusOr<Literal> Evaluate(const HloComputation& computation,
     88                              absl::Span<const Literal* const> arg_literals);
     89   template <typename Dummy = void>
     90   StatusOr<Literal> Evaluate(const HloComputation& computation,
     91                              absl::Span<const Literal> arg_literals) {
     92     std::vector<const Literal*> arg_literal_ptrs;
     93     for (const auto& l : arg_literals) {
     94       arg_literal_ptrs.push_back(&l);
     95     }
     96     return Evaluate(computation, arg_literal_ptrs);
     97   }
     98 
     99   // Gets the value of running a single HLO instruction.
    100   //
    101   // All of the operands to this instruction must be constants.
    102   StatusOr<Literal> Evaluate(HloInstruction* instruction);
    103 
    104   // Same as Evaluate, except returning false on error and accepts an output
    105   // pointer.
    106   bool TryEvaluate(HloInstruction* instruction, Literal* result);
    107 
    108   // Evaluates a single HLO instruction, substituting the given literals for
    109   // some of the instruction's operands.
    110   //
    111   // For example, given instruction = op(A, B, C) and the map
    112   // {A = x, C = y}, this evaluates op(x, B, y).
    113   StatusOr<Literal> EvaluateWithSubstitutions(
    114       const HloInstruction* instruction,
    115       const std::unordered_map<const HloInstruction*, const Literal*>&
    116           substitutions);
    117 
    118   StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
    119                                                 const Literal& lhs,
    120                                                 const Literal& rhs);
    121 
    122   StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
    123                                                const Literal& operand);
    124 
    125   StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
    126                                   const PrecisionConfig& precision_config,
    127                                   const Literal& lhs, const Literal& rhs);
    128 
    129   void set_dynamic_dimension_inference(
    130       DynamicDimensionInference* dynamic_dimension_inference) {
    131     dynamic_dimension_inference_ = dynamic_dimension_inference;
    132   }
    133 
    134   // Enable the fast path for certain operations like dot or convolution.
    135   void set_use_fast_path(bool value) { use_fast_path_ = value; }
    136 
    137   // Handles evaluation of a custom-call op.
    138   // Operand literals are provided in |operands| and implementations must
    139   // populate |output| before returning.
    140   using CustomCallHandler = std::function<StatusOr<Literal>(
    141       HloInstruction* custom_call, absl::Span<const Literal*> operands)>;
    142 
    143   // Sets a handler that is called during evaluation for custom-call ops.
    144   // If no handler is defined the default error behavior will occur. The handler
    145   // will be provided evaluated literals for all operands and is expected to
    146   // return an output literal of the appropriate shape.
    147   void set_custom_call_handler(
    148       std::function<StatusOr<Literal>(HloInstruction* custom_call,
    149                                       absl::Span<const Literal*> operands)>
    150           handler) {
    151     custom_call_handler_ = std::move(handler);
    152   }
    153 
    154   // Returns the result of a matrix multiply `lhs x rhs`.
    155   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
    156       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
    157   static std::unique_ptr<Array2D<float>> MatmulArray2D(
    158       const Array2D<float>& lhs, const Array2D<float>& rhs);
    159   static std::unique_ptr<Array2D<double>> MatmulArray2D(
    160       const Array2D<double>& lhs, const Array2D<double>& rhs);
    161 
    162  protected:
    163   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
    164   // class.
    165   //
    166   // A straightforward implementation would be to make it a nested class
    167   // declared and defined in hlo_evaluator.cc.  Instead HloEvaluatorTypedVisitor
    168   // lives as a separate class with its own header because its template gets
    169   // instantiated many times and we want to use extern templates to shard out
    170   // the compilation of those instantiations across multiple cc files.
    171   template <typename ReturnT, typename ElementwiseT>
    172   friend class HloEvaluatorTypedVisitor;
    173 
    174   // Wraps around instruction handling to infer types before dispatching to
    175   // the corresponding typed Visitor.
    176   Status DefaultAction(HloInstruction* hlo) override {
    177     return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
    178   }
    179 
    180   Status Preprocess(HloInstruction* hlo) override;
    181 
    182   Status Postprocess(HloInstruction* hlo) override;
    183 
    184   // Operations that are type-agnostic or always return a specific type, such as
    185   // HandleIsFinite where boolean is always returned.
    186   //
    187   Status HandleBitcast(HloInstruction* bitcast) override;
    188 
    189   Status HandleGetDimensionSize(HloInstruction* get_dimension_size) override;
    190 
    191   Status HandleParameter(HloInstruction* parameter) override;
    192 
    193   Status HandleConstant(HloInstruction* constant) override;
    194 
    195   Status HandleConcatenate(HloInstruction* concatenate) override;
    196 
    197   Status HandleReshape(HloInstruction* reshape) override;
    198 
    199   Status HandleTranspose(HloInstruction* transpose) override;
    200 
    201   Status HandleIsFinite(HloInstruction* is_finite) override;
    202 
    203   Status HandleCompare(HloInstruction* compare) override;
    204 
    205   Status HandleTuple(HloInstruction* tuple) override;
    206 
    207   Status HandleGather(HloInstruction* gather) override;
    208 
    209   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    210 
    211   Status HandleCopy(HloInstruction* copy) override;
    212 
    213   Status HandleConditional(HloInstruction* conditional) override;
    214 
    215   Status HandleCall(HloInstruction* call) override;
    216 
    217   Status HandleFusion(HloInstruction* fusion) override;
    218 
    219   Status HandleWhile(HloInstruction* while_hlo) override;
    220 
    221   Status HandleSelect(HloInstruction* select) override;
    222 
    223   Status HandleTupleSelect(HloInstruction* tuple_select) override;
    224 
    225   Status HandleBroadcast(HloInstruction* broadcast) override;
    226 
    227   Status HandleAfterAll(HloInstruction* after_all) override;
    228 
    229   Status HandleAddDependency(HloInstruction* add_dependency) override;
    230 
    231   Status HandleSort(HloInstruction* sort) override;
    232 
    233   Status HandleReal(HloInstruction* real) override;
    234 
    235   Status HandleImag(HloInstruction* imag) override;
    236 
    237   Status HandleComplex(HloInstruction* complex) override;
    238 
    239   Status HandleReduce(HloInstruction* reduce) override;
    240 
    241   Status HandleCustomCall(HloInstruction* custom_call) override;
    242 
    243   // Unsupported HLOs, note some of them (such as BatchNorm*) are typically
    244   // expanded in a semantic-preserving way into other HLOs by adding exanpsion
    245   // HLO pass to the HLO optimization pass during compilation, which can then be
    246   // handled by the evaluator.
    247   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override {
    248     return Unimplemented("BatchNormGrad HLO is unsupported by the evaluator.");
    249   };
    250   Status HandleBatchNormInference(
    251       HloInstruction* batch_norm_inference) override {
    252     return Unimplemented(
    253         "BatchNormInference HLO is unsupported by the evaluator.");
    254   };
    255   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
    256     return Unimplemented(
    257         "BatchNormTraining HLO is unsupported by the evaluator.");
    258   };
    259   Status HandleInfeed(HloInstruction* infeed) override {
    260     return Unimplemented("Infeed HLO is unsupported by the evaluator.");
    261   };
    262   Status HandleOutfeed(HloInstruction* outfeed) override {
    263     return Unimplemented("Outfeed HLO is unsupported by the evaluator.");
    264   };
    265 
    266   // Returns the already-evaluated literal result for the instruction.
    267   //
    268   // A Constant instruction is considered evaluated and its literal will be
    269   // returned directly without looking up the cache.
    270   //
    271   // Similarly, a Parameter instruction is considered evaluated and its literal
    272   // is looked up in arg_literals.
    273   //
    274   // Crash with log if the given instruction has not been evaluated previously.
    275   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
    276     if (hlo->IsConstant()) {
    277       return hlo->literal();
    278     }
    279     if (hlo->opcode() == HloOpcode::kParameter) {
    280       return *arg_literals_.at(hlo->parameter_number());
    281     }
    282     auto it = evaluated_.find(hlo);
    283     CHECK(it != evaluated_.end())
    284         << "could not find evaluated value for: " << hlo->ToString();
    285     return it->second;
    286   }
    287 
    288   // Tracks the HLO instruction and its evaluated literal result.
    289   //
    290   // Parameters and constants aren't stored here, see implementation of
    291   // GetEvaluatedLiteralFor.
    292   //
    293   // TODO(b/35950897): have better memory management here to free instructions
    294   // that are no longer a parent for any other subsequent instruction in
    295   // post-orderring.
    296   //
    297   // Must be cleared for each evaluation.
    298   //
    299   // Storing Literal in place requires the container to have pointer stability
    300   // so we cannot use flat_hash_map any more.
    301   absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
    302 
    303   // Use fast path that uses eigen in the evaluator.
    304   bool use_fast_path_ = false;
    305 
    306  private:
    307   template <typename ReturnT, typename NativeT>
    308   static StatusOr<Literal> ElementWiseUnaryOpImpl(
    309       HloInstruction* instruction,
    310       const std::function<ReturnT(NativeT)>& unary_op,
    311       const Literal& operand_literal) {
    312     const auto shape = instruction->shape();
    313     const auto* operand = instruction->operand(0);
    314     TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape()));
    315 
    316     Literal result(shape);
    317     TF_RETURN_IF_ERROR(
    318         result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
    319           return unary_op(operand_literal.Get<NativeT>(multi_index));
    320         }));
    321     return std::move(result);
    322   }
    323 
    324   // Map from a primitive type to its associated (templated) DfsHloVisitor.
    325   std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
    326 
    327   // Caches pointers to input literals, assuming they are in post-order.
    328   // Literals are not owned by this class, and they must outlive the lifetime of
    329   // each invocation to the Evaluate* method.
    330   // Must be cleared for each evaluation.
    331   std::vector<const Literal*> arg_literals_;
    332 
    333   // Max loop iterations to execute with no maximum if negative.
    334   int64 max_loop_iterations_ = 0;
    335 
    336   // Module-level seed handle.
    337   uint64 seed_ = 0;
    338   // RNG engine.
    339   std::minstd_rand0 engine_;
    340 
    341   // DynamicDimensionInference is used to evaluate GetDimensionSize, which
    342   // returns the dynamic dimension size of its operand.
    343   DynamicDimensionInference* dynamic_dimension_inference_ = nullptr;
    344 
    345   // Optional handler for custom_call ops.
    346   std::function<StatusOr<Literal>(HloInstruction* custom_call,
    347                                   absl::Span<const Literal*> operands)>
    348       custom_call_handler_;
    349 
    350   TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
    351 };
    352 
    353 std::unique_ptr<Array2D<float>> MatmulArray2D(const Array2D<float>& lhs,
    354                                               const Array2D<float>& rhs);
    355 }  // namespace xla
    356 
    357 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
    358