1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 18 19 #include <cassert> 20 #include <string> 21 22 #include "tensorflow/compiler/xla/executable_run_options.h" 23 #include "tensorflow/core/platform/types.h" 24 25 // Forward-declare, rather than include, to reduce code size for users that 26 // never use this functionality. 27 namespace xla { 28 class ProgramShape; 29 class HloProfilePrinterData; 30 } 31 32 namespace tensorflow { 33 34 // Represents a function compiled by XLA, produced via either JIT or AOT. 35 // 36 // The Run method invokes the actual computation, with inputs read from arg 37 // buffers, and outputs written to result buffers. Each Run call may also use a 38 // set of temporary buffers for the computation. 39 // 40 // By default each instance of this class manages its own arg, result and temp 41 // buffers. The AllocMode constructor parameter may be used to modify the buffer 42 // allocation strategy. 43 // 44 // Under the default allocation strategy, this class is thread-compatible: 45 // o Calls to non-const methods require exclusive access to the object. 46 // o Concurrent calls to const methods are OK, if those calls are made while it 47 // is guaranteed that no thread may call a non-const method. 48 class XlaCompiledCpuFunction { 49 public: 50 // Type of the raw function, produced by either JIT or AOT. 51 using RawFunction = void (*)(void* result, 52 const xla::ExecutableRunOptions* run_options, 53 const void** args, void** temps, 54 int64* profile_counters); 55 56 // StaticData represents the state necessary to run an XLA-compiled 57 // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for 58 // AOT this is backed by data compiled into the object file. 59 struct StaticData { 60 // The raw function to call. 61 RawFunction raw_function; 62 63 // Cardinality and sizes of arg and temp buffers. 64 const intptr_t* arg_sizes = nullptr; 65 size_t num_args = 0; 66 const intptr_t* temp_sizes = nullptr; 67 size_t num_temps = 0; 68 69 // The 0-based index of the result tuple, in the temp buffers. 70 size_t result_index = 0; 71 72 // [Optional] Arrays of arg and result names. These are arrays of C-style 73 // strings, where the array is terminated by nullptr. 74 const char** arg_names = nullptr; 75 const char** result_names = nullptr; 76 77 // [Optional] Arg and result shapes. 78 const xla::ProgramShape* program_shape = nullptr; 79 80 // [Optional] Profile printer data. Null if profiling is disabled. 81 const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; 82 83 // [Optional] The number of profile counters expected in the profile counter 84 // buffer by the generated code and hlo_profile_printer. 0 if profiling is 85 // disabled. This information is already present in 86 // hlo_profile_printer_data but xla::HloProfilePrinterData is forward 87 // declared so we don't have access to that information here. 88 int64 profile_counters_size = 0; 89 }; 90 91 // AllocMode controls the buffer allocation mode. 92 enum class AllocMode { 93 // Allocate all buffers - args, results, profile and temps. 94 ARGS_RESULTS_PROFILES_AND_TEMPS, 95 96 // Only allocate result, profile and temp buffers. 97 // Use set_arg_data to set argument buffers before Run is called. 98 RESULTS_PROFILES_AND_TEMPS_ONLY, 99 }; 100 101 XlaCompiledCpuFunction( 102 const StaticData& static_data, 103 AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); 104 virtual ~XlaCompiledCpuFunction(); 105 106 XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; 107 XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; 108 109 // Sets the intra-op thread pool used to run individual ops concurrently. 110 void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { 111 run_options_.set_intra_op_thread_pool(pool); 112 } 113 114 // Runs the computation, with inputs read from arg buffers, and outputs 115 // written to result buffers. Returns true on success and false on failure. 116 bool Run() { 117 raw_function_(temps_[result_index_], &run_options_, 118 const_cast<const void**>(args_), temps_, profile_counters_); 119 return true; 120 } 121 122 // Returns the error message from the previous failed Run call. 123 // 124 // TODO(fschneider): For now this always returns an empty string because there 125 // is no support for error reporting in XLA. Remove this once all callers are 126 // updated. 127 string error_msg() const { return {}; } 128 129 // ------------------------------ 130 // Arg methods for managing input buffers. Buffers are in row-major order. 131 132 // Returns the underlying array of argument buffers, where args()[I] is the 133 // buffer for the positional argument at index I. 134 void** args() { return args_; } 135 const void* const* args() const { return args_; } 136 137 // Returns the buffer for the positional argument at the given `index`. 138 void* arg_data(size_t index) { return args_[index]; } 139 const void* arg_data(size_t index) const { return args_[index]; } 140 141 // Sets the buffer for the positional argument at the given `index` to `data`. 142 // Must be called before Run to have an effect. May be called under any 143 // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be 144 // called for each positional argument, in order to set the argument buffers. 145 // 146 // Allocated memory must be aligned to the size specified by 147 // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in 148 // tensorflow/compiler/aot/runtime.h to ensure correct alignment. 149 // 150 // Aliasing of argument and result buffers is not allowed, and results in 151 // undefined behavior. 152 void set_arg_data(size_t index, void* data) { args_[index] = data; } 153 154 // ------------------------------ 155 // Result methods for managing output buffers. Buffers are in row-major order. 156 // Must only be called after a successful Run call. Unlike the arg methods, 157 // there is no set_resultN_data method. The result buffers are managed 158 // internally, and may change after each call to Run. 159 160 // Returns the underlying array of result buffers, where results()[I] is the 161 // buffer for the positional result at index I. 162 void** results() { return static_cast<void**>(temps_[result_index_]); } 163 const void* const* results() const { 164 return static_cast<const void* const*>(temps_[result_index_]); 165 } 166 167 // Profile counters for this XLA computation. 168 // 169 // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in 170 // this case) these counters are non-null and are automatically populated by 171 // `Run`. The counters can then be pretty-printed using 172 // `hlo_profile_printer()`. 173 // 174 // When Hlo profiling is disabled, this accessor returns null. 175 const int64* profile_counters() const { return profile_counters_; } 176 177 // Returns the buffer for the positional result at the given `index`. 178 void* result_data(size_t index) { return results()[index]; } 179 const void* result_data(size_t index) const { return results()[index]; } 180 181 // ------------------------------ 182 // Methods for extracting optional metadata. 183 184 // Returns true iff data is available for the Lookup{Arg,Result}Index methods. 185 // E.g. the data might not be compiled into the binary for AOT. 186 bool HasNameIndices() const { 187 return arg_names_ != nullptr && result_names_ != nullptr; 188 } 189 190 // Returns the 0-based index for the argument with the given `name`. 191 // Returns -1 if the name wasn't found, or data isn't available. 192 // 193 // The index remains constant for every instance of XlaCompiledCpuFunction 194 // generated from the same static data, and might not be cheap to determine. 195 // Recommended usage is to capture this in a variable for re-use. 196 int LookupArgIndex(const string& name) const; 197 198 // Returns the 0-based index for the result with the given `name`. 199 // Returns -1 if the name wasn't found, or data isn't available. 200 // 201 // The index remains constant for every instance of XlaCompiledCpuFunction 202 // generated from the same static data, and might not be cheap to determine. 203 // Recommended usage is to capture this in a variable for re-use. 204 int LookupResultIndex(const string& name) const; 205 206 // Returns the shape of the args and results. May return nullptr if the 207 // program shape isn't available. 208 const xla::ProgramShape* ProgramShape() const { return program_shape_; } 209 210 bool hlo_profiling_enabled() const { 211 return hlo_profile_printer_data_ != nullptr; 212 } 213 const xla::HloProfilePrinterData& hlo_profile_printer_data() const { 214 assert(hlo_profiling_enabled()); 215 return *hlo_profile_printer_data_; 216 } 217 218 private: 219 const RawFunction raw_function_; 220 const size_t result_index_; 221 222 // Arrays of argument and temp buffers; entries in args_ may be overwritten by 223 // the user. 224 void** args_ = nullptr; 225 void** temps_ = nullptr; 226 227 // Backing memory for individual arg and temp buffers. 228 void* alloc_args_ = nullptr; 229 void* alloc_temps_ = nullptr; 230 231 // Backing memory for profiling counters. 232 int64* profile_counters_ = nullptr; 233 234 // Options and context passed to the compiled function. 235 xla::ExecutableRunOptions run_options_; 236 237 // Optional metadata. 238 const char** arg_names_ = nullptr; 239 const char** result_names_ = nullptr; 240 const xla::ProgramShape* program_shape_ = nullptr; 241 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; 242 }; 243 244 } // namespace tensorflow 245 246 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 247