Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
     18 
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/service/hlo_module.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/util.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/lib/gtl/array_slice.h"
     29 #include "tensorflow/core/lib/gtl/flatmap.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 
     32 namespace xla {
     33 
     34 // Responsible for evaluating HLO and obtain literal as the evaluation results.
     35 //
     36 // This class is not thread-safe.
     37 class HloEvaluator : public DfsHloVisitorWithDefault {
     38  public:
     39   HloEvaluator();
     40   // Evaluates an HLO module and an array of pointers to literals.
     41   // Returns the evaluated result as a literal if successful.
     42   // Precondition: The indices of arg_literals correspond to the parameter
     43   // numbers of the HLO parameters in the computation. See comment below for an
     44   // example.
     45   // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
     46   // type.
     47   template <typename LiteralPtr>
     48   StatusOr<std::unique_ptr<Literal>> Evaluate(
     49       const HloModule& module,
     50       tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
     51 
     52   // Evaluates an HLO computation and an array of pointers to literals.
     53   // Returns the evaluated result as a literal if successful.
     54   // Precondition: The indices of arg_literals correspond to the parameter
     55   // numbers of the HLO parameters in the computation. For e.g., consider the
     56   // following graph:
     57   //
     58   //                *
     59   //            /       \
     60   //            +     Parameter1
     61   //        /      \
     62   //       /        \
     63   //    Parameter0  Constant
     64   //
     65   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
     66   // 1 in this computation. The input literals array will then have its first
     67   // literal map to Parameter0 and the second map to Parameter1.
     68   // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
     69   // type.
     70   template <typename LiteralPtr>
     71   StatusOr<std::unique_ptr<Literal>> Evaluate(
     72       const HloComputation& computation,
     73       tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
     74 
     75   // Evaluates a single HLO instruction and an array of pointers to literals.
     76   // Return the evaluated result as literal if successful.
     77   // Precondition:
     78   // 1. argument literals correspond to the input instruction's parameters in
     79   // their post-ordering.
     80   // 2. the instruction's operands must be of either Parameter or Constant type.
     81   // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
     82   // type.
     83   template <typename LiteralPtr>
     84   StatusOr<std::unique_ptr<Literal>> Evaluate(
     85       HloInstruction* instruction,
     86       tensorflow::gtl::ArraySlice<LiteralPtr> arg_literals);
     87 
     88   // Evaluates a single HLO instruction with constant operands.
     89   // Returns the evaluated result as literal if successful.
     90   // Precondition:
     91   // 1. all operands of the input instruction are constants.
     92   // 2. the instruction is not a Parameter operation.
     93   StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
     94 
     95   // Same as Evaluate, except returning nullptr on error.
     96   std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
     97 
     98   // Evaluates a single HLO instruction, substituting the given literals for
     99   // some of the instruction's operands.
    100   //
    101   // For example, given instruction = op(A, B, C) and the map
    102   // {A = x, C = y}, this evaluates op(x, B, y).
    103   StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
    104       const HloInstruction* instruction,
    105       const std::unordered_map<const HloInstruction*, const Literal*>&
    106           substitutions);
    107 
    108  protected:
    109   // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
    110   // literal type of each evaluated Handle* method of a TypedVisitor.
    111   // There are however a few notable exceptions to this rule, notably:
    112   // - HandleCompare and HandleIsFinite: where the resulting literal type is
    113   // always boolean.
    114   // These operations are handled outside of the parent HloEvaluator handlers
    115   // instead of from within TypedVisitor.
    116   //
    117   // Type params:
    118   //   - ReturnT: The type of input and output of each operation.
    119   //   - ElementwiseT: The type in which internal computation are done.
    120   template <typename ReturnT, typename ElementwiseT = ReturnT>
    121   class TypedVisitor;
    122 
    123   // Wraps around instruction handling to infer types before dispatching to
    124   // the corresponding typed Visitor.
    125   Status DefaultAction(HloInstruction* hlo) override {
    126     return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
    127   }
    128 
    129   Status Preprocess(HloInstruction* hlo) override;
    130 
    131   Status Postprocess(HloInstruction* hlo) override;
    132 
    133   // Operations that are type-agnostic or always return a specific type, such as
    134   // HandleIsFinite where boolean is always returned.
    135   //
    136   Status HandleParameter(HloInstruction* parameter) override;
    137 
    138   Status HandleConstant(HloInstruction* constant) override;
    139 
    140   Status HandleConcatenate(HloInstruction* concatenate) override;
    141 
    142   Status HandleReshape(HloInstruction* reshape) override;
    143 
    144   Status HandleTranspose(HloInstruction* transpose) override;
    145 
    146   Status HandleIsFinite(HloInstruction* is_finite) override;
    147 
    148   Status HandleCompare(HloInstruction* compare) override;
    149 
    150   Status HandleTuple(HloInstruction* tuple) override;
    151 
    152   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
    153 
    154   Status HandleCopy(HloInstruction* copy) override;
    155 
    156  private:
    157   // Returns the already-evaluated literal result for the instruction.
    158   // A Constant instruction is considered evaluated and its literal will be
    159   // returned directly without looking up the cache.
    160   // Crash with log if the given instruction has not been evaluated previously.
    161   const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
    162     if (hlo->IsConstant()) {
    163       return hlo->literal();
    164     }
    165     auto it = evaluated_.find(hlo);
    166     CHECK(it != evaluated_.end())
    167         << "could not find evaluated value for: " << hlo->ToString();
    168     return *(it->second);
    169   }
    170 
    171   // Map from a primitive type to its associated (templated) DfsHloVisitor.
    172   // Note: the hash function here is only needed because current gcc std::hash
    173   // does not specialize for enum types. This should however be fixed in the
    174   // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5
    175   tensorflow::gtl::FlatMap<PrimitiveType, std::unique_ptr<DfsHloVisitor>,
    176                            std::hash<int>>
    177       typed_visitors_;
    178 
    179   // Tracks the HLO instruction and its evaluated literal result.
    180   // TODO(b/35950897): have better memory management here to free instructions
    181   // that are no longer a parent for any other subsequent instruction in
    182   // post-orderring.
    183   // Must be cleared for each evaluation.
    184   tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
    185       evaluated_;
    186 
    187   // Caches pointers to input literals, assuming they are in post-order.
    188   // Literals are not owned by this class, and they must outlive the lifetime of
    189   // each invocation to the Evaluate* method.
    190   // Must be cleared for each evaluation.
    191   std::vector<const Literal*> arg_literals_;
    192 
    193   TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
    194 };
    195 
    196 }  // namespace xla
    197 
    198 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_
    199