Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
     18 
     19 #include <cassert>
     20 #include <string>
     21 
     22 #include "tensorflow/compiler/xla/executable_run_options.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 // Forward-declare, rather than include, to reduce code size for users that
     26 // never use this functionality.
     27 namespace xla {
     28 class ProgramShape;
     29 class HloProfilePrinterData;
     30 }
     31 
     32 namespace tensorflow {
     33 
     34 // Represents a function compiled by XLA, produced via either JIT or AOT.
     35 //
     36 // The Run method invokes the actual computation, with inputs read from arg
     37 // buffers, and outputs written to result buffers. Each Run call may also use a
     38 // set of temporary buffers for the computation.
     39 //
     40 // By default each instance of this class manages its own arg, result and temp
     41 // buffers. The AllocMode constructor parameter may be used to modify the buffer
     42 // allocation strategy.
     43 //
     44 // Under the default allocation strategy, this class is thread-compatible:
     45 // o Calls to non-const methods require exclusive access to the object.
     46 // o Concurrent calls to const methods are OK, if those calls are made while it
     47 //   is guaranteed that no thread may call a non-const method.
     48 class XlaCompiledCpuFunction {
     49  public:
     50   // Type of the raw function, produced by either JIT or AOT.
     51   using RawFunction = void (*)(void* result,
     52                                const xla::ExecutableRunOptions* run_options,
     53                                const void** args, void** temps,
     54                                int64* profile_counters);
     55 
     56   // StaticData represents the state necessary to run an XLA-compiled
     57   // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
     58   // AOT this is backed by data compiled into the object file.
     59   struct StaticData {
     60     // The raw function to call.
     61     RawFunction raw_function;
     62 
     63     // Cardinality and sizes of arg and temp buffers.
     64     const intptr_t* arg_sizes = nullptr;
     65     size_t num_args = 0;
     66     const intptr_t* temp_sizes = nullptr;
     67     size_t num_temps = 0;
     68 
     69     // The 0-based index of the result tuple, in the temp buffers.
     70     size_t result_index = 0;
     71 
     72     // [Optional] Arrays of arg and result names. These are arrays of C-style
     73     // strings, where the array is terminated by nullptr.
     74     const char** arg_names = nullptr;
     75     const char** result_names = nullptr;
     76 
     77     // [Optional] Arg and result shapes.
     78     const xla::ProgramShape* program_shape = nullptr;
     79 
     80     // [Optional] Profile printer data.  Null if profiling is disabled.
     81     const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
     82 
     83     // [Optional] The number of profile counters expected in the profile counter
     84     // buffer by the generated code and hlo_profile_printer.  0 if profiling is
     85     // disabled.  This information is already present in
     86     // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
     87     // declared so we don't have access to that information here.
     88     int64 profile_counters_size = 0;
     89   };
     90 
     91   // AllocMode controls the buffer allocation mode.
     92   enum class AllocMode {
     93     // Allocate all buffers - args, results, profile and temps.
     94     ARGS_RESULTS_PROFILES_AND_TEMPS,
     95 
     96     // Only allocate result, profile and temp buffers.
     97     // Use set_arg_data to set argument buffers before Run is called.
     98     RESULTS_PROFILES_AND_TEMPS_ONLY,
     99   };
    100 
    101   XlaCompiledCpuFunction(
    102       const StaticData& static_data,
    103       AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS);
    104   virtual ~XlaCompiledCpuFunction();
    105 
    106   XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
    107   XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete;
    108 
    109   // Sets the intra-op thread pool used to run individual ops concurrently.
    110   void set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
    111     run_options_.set_intra_op_thread_pool(pool);
    112   }
    113 
    114   // Runs the computation, with inputs read from arg buffers, and outputs
    115   // written to result buffers. Returns true on success and false on failure.
    116   bool Run() {
    117     raw_function_(temps_[result_index_], &run_options_,
    118                   const_cast<const void**>(args_), temps_, profile_counters_);
    119     return true;
    120   }
    121 
    122   // Returns the error message from the previous failed Run call.
    123   //
    124   // TODO(fschneider): For now this always returns an empty string because there
    125   // is no support for error reporting in XLA. Remove this once all callers are
    126   // updated.
    127   string error_msg() const { return {}; }
    128 
    129   // ------------------------------
    130   // Arg methods for managing input buffers. Buffers are in row-major order.
    131 
    132   // Returns the underlying array of argument buffers, where args()[I] is the
    133   // buffer for the positional argument at index I.
    134   void** args() { return args_; }
    135   const void* const* args() const { return args_; }
    136 
    137   // Returns the buffer for the positional argument at the given `index`.
    138   void* arg_data(size_t index) { return args_[index]; }
    139   const void* arg_data(size_t index) const { return args_[index]; }
    140 
    141   // Sets the buffer for the positional argument at the given `index` to `data`.
    142   // Must be called before Run to have an effect. May be called under any
    143   // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
    144   // called for each positional argument, in order to set the argument buffers.
    145   //
    146   // Allocated memory must be aligned to the size specified by
    147   // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in
    148   // tensorflow/compiler/aot/runtime.h to ensure correct alignment.
    149   //
    150   // Aliasing of argument and result buffers is not allowed, and results in
    151   // undefined behavior.
    152   void set_arg_data(size_t index, void* data) { args_[index] = data; }
    153 
    154   // ------------------------------
    155   // Result methods for managing output buffers. Buffers are in row-major order.
    156   // Must only be called after a successful Run call. Unlike the arg methods,
    157   // there is no set_resultN_data method. The result buffers are managed
    158   // internally, and may change after each call to Run.
    159 
    160   // Returns the underlying array of result buffers, where results()[I] is the
    161   // buffer for the positional result at index I.
    162   void** results() { return static_cast<void**>(temps_[result_index_]); }
    163   const void* const* results() const {
    164     return static_cast<const void* const*>(temps_[result_index_]);
    165   }
    166 
    167   // Profile counters for this XLA computation.
    168   //
    169   // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in
    170   // this case) these counters are non-null and are automatically populated by
    171   // `Run`.  The counters can then be pretty-printed using
    172   // `hlo_profile_printer()`.
    173   //
    174   // When Hlo profiling is disabled, this accessor returns null.
    175   const int64* profile_counters() const { return profile_counters_; }
    176 
    177   // Returns the buffer for the positional result at the given `index`.
    178   void* result_data(size_t index) { return results()[index]; }
    179   const void* result_data(size_t index) const { return results()[index]; }
    180 
    181   // ------------------------------
    182   // Methods for extracting optional metadata.
    183 
    184   // Returns true iff data is available for the Lookup{Arg,Result}Index methods.
    185   // E.g. the data might not be compiled into the binary for AOT.
    186   bool HasNameIndices() const {
    187     return arg_names_ != nullptr && result_names_ != nullptr;
    188   }
    189 
    190   // Returns the 0-based index for the argument with the given `name`.
    191   // Returns -1 if the name wasn't found, or data isn't available.
    192   //
    193   // The index remains constant for every instance of XlaCompiledCpuFunction
    194   // generated from the same static data, and might not be cheap to determine.
    195   // Recommended usage is to capture this in a variable for re-use.
    196   int LookupArgIndex(const string& name) const;
    197 
    198   // Returns the 0-based index for the result with the given `name`.
    199   // Returns -1 if the name wasn't found, or data isn't available.
    200   //
    201   // The index remains constant for every instance of XlaCompiledCpuFunction
    202   // generated from the same static data, and might not be cheap to determine.
    203   // Recommended usage is to capture this in a variable for re-use.
    204   int LookupResultIndex(const string& name) const;
    205 
    206   // Returns the shape of the args and results. May return nullptr if the
    207   // program shape isn't available.
    208   const xla::ProgramShape* ProgramShape() const { return program_shape_; }
    209 
    210   bool hlo_profiling_enabled() const {
    211     return hlo_profile_printer_data_ != nullptr;
    212   }
    213   const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
    214     assert(hlo_profiling_enabled());
    215     return *hlo_profile_printer_data_;
    216   }
    217 
    218  private:
    219   const RawFunction raw_function_;
    220   const size_t result_index_;
    221 
    222   // Arrays of argument and temp buffers; entries in args_ may be overwritten by
    223   // the user.
    224   void** args_ = nullptr;
    225   void** temps_ = nullptr;
    226 
    227   // Backing memory for individual arg and temp buffers.
    228   void* alloc_args_ = nullptr;
    229   void* alloc_temps_ = nullptr;
    230 
    231   // Backing memory for profiling counters.
    232   int64* profile_counters_ = nullptr;
    233 
    234   // Options and context passed to the compiled function.
    235   xla::ExecutableRunOptions run_options_;
    236 
    237   // Optional metadata.
    238   const char** arg_names_ = nullptr;
    239   const char** result_names_ = nullptr;
    240   const xla::ProgramShape* program_shape_ = nullptr;
    241   const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
    242 };
    243 
    244 }  // namespace tensorflow
    245 
    246 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
    247