1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/interpreter/compiler.h" 17 18 #include <string> 19 #include <utility> 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 23 #include "tensorflow/compiler/xla/service/cholesky_expander.h" 24 #include "tensorflow/compiler/xla/service/computation_placer.h" 25 #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" 26 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" 27 #include "tensorflow/compiler/xla/service/flatten_call_graph.h" 28 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" 29 #include "tensorflow/compiler/xla/service/hlo_cse.h" 30 #include "tensorflow/compiler/xla/service/hlo_dce.h" 31 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 32 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" 33 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" 34 #include "tensorflow/compiler/xla/service/interpreter/executable.h" 35 #include "tensorflow/compiler/xla/service/layout_assignment.h" 36 #include "tensorflow/compiler/xla/service/map_inliner.h" 37 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" 38 #include "tensorflow/compiler/xla/service/reshape_mover.h" 39 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" 40 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" 41 #include "tensorflow/compiler/xla/status_macros.h" 42 #include "tensorflow/core/lib/core/errors.h" 43 #include "tensorflow/core/platform/types.h" 44 45 namespace xla { 46 namespace interpreter { 47 48 namespace { 49 50 // Handles custom_call ops during evaluation by routing them through the global 51 // CPU registry used by other CPU-based backends. 52 StatusOr<Literal> HandleEvaluatorCustomCall( 53 HloInstruction* custom_call, absl::Span<const Literal*> operands) { 54 // Find the target C function in the global registry. 55 auto* registry = xla::cpu::CustomCallTargetRegistry::Global(); 56 void* target_fn = registry->Lookup(custom_call->custom_call_target()); 57 if (!target_fn) { 58 return NotFound("Custom call target '%s' was not registered", 59 custom_call->custom_call_target()); 60 } 61 62 // Populate pointers to operand and output literal data. 63 std::vector<const void*> operand_data; 64 operand_data.reserve(operands.size()); 65 for (const auto* literal : operands) { 66 operand_data.push_back(literal->untyped_data()); 67 } 68 auto output = Literal::CreateFromShape(custom_call->shape()); 69 void* output_data = output.untyped_data(); 70 71 // Call the target function matching the C ABI used by the CPU backends. 72 auto* typed_fn = reinterpret_cast<void (*)(void*, const void**)>(target_fn); 73 (*typed_fn)(output_data, operand_data.data()); 74 75 return std::move(output); 76 } 77 78 } // namespace 79 80 Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { 81 HloPassPipeline pipeline("Interpreter"); 82 83 pipeline.AddPass<DynamicIndexSplitter>(); 84 pipeline.AddPass<CholeskyExpander>(); 85 pipeline.AddPass<TriangularSolveExpander>(); 86 pipeline.AddPass<LayoutAssignment>( 87 hlo_module->mutable_entry_computation_layout(), 88 LayoutAssignment::InstructionCanChangeLayout); 89 90 ReducePrecisionInsertion::AddPasses( 91 &pipeline, hlo_module->config().debug_options(), 92 ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); 93 94 return pipeline.Run(hlo_module).status(); 95 } 96 97 StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses( 98 std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/, 99 DeviceMemoryAllocator* /*device_allocator*/) { 100 VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); 101 TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); 102 return std::move(hlo_module); 103 } 104 105 Status InterpreterCompiler::RunHloPassesOnModuleGroup( 106 HloModuleGroup* module_group, 107 absl::Span<se::StreamExecutor* const> executors, 108 DeviceMemoryAllocator* device_allocator) { 109 return Unimplemented("Module group compilation not supported on Interpreter"); 110 } 111 112 StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend( 113 std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec, 114 DeviceMemoryAllocator* /*device_allocator*/) { 115 TF_RET_CHECK(stream_exec != nullptr); 116 117 VLOG(1) << "Run backend " << hlo_module->name(); 118 119 // Typically you would visit the HLO graph, building up a compiled equivalent 120 // In this case we are using an HloEvaluator at execution time, so we don't 121 // need to compile anything 122 123 auto evaluator = absl::make_unique<HloEvaluator>(); 124 evaluator->set_use_fast_path( 125 hlo_module->config().debug_options().xla_hlo_evaluator_use_fast_path()); 126 evaluator->set_custom_call_handler(HandleEvaluatorCustomCall); 127 128 // Create executable from only the Hlo module. 129 std::unique_ptr<Executable> executable = 130 absl::make_unique<InterpreterExecutable>(std::move(hlo_module), 131 std::move(evaluator)); 132 133 return std::move(executable); 134 } 135 136 StatusOr<std::vector<std::unique_ptr<Executable>>> 137 InterpreterCompiler::RunBackendOnModuleGroup( 138 std::unique_ptr<HloModuleGroup> module_group, 139 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 140 DeviceMemoryAllocator* device_allocator) { 141 return Unimplemented( 142 "Module group compilation is not supported on Interpreter."); 143 } 144 145 StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile( 146 std::unique_ptr<HloModuleGroup> module_group, 147 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 148 DeviceMemoryAllocator* device_allocator) { 149 if (module_group->empty()) { 150 return std::vector<std::unique_ptr<Executable>>(); 151 } 152 if (module_group->size() > 1) { 153 return tensorflow::errors::Unimplemented( 154 "Compilation of multiple HLO modules is not supported on Interpreter."); 155 } 156 if (stream_exec.size() != 1 || stream_exec[0].size() != 1) { 157 return tensorflow::errors::Unimplemented( 158 "Unexpected number of StreamExecutor's."); 159 } 160 auto hlo_modules = module_group->ConsumeModules(); 161 TF_ASSIGN_OR_RETURN(auto module, 162 RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0], 163 device_allocator)); 164 TF_ASSIGN_OR_RETURN( 165 auto executable, 166 RunBackend(std::move(module), stream_exec[0][0], device_allocator)); 167 std::vector<std::unique_ptr<Executable>> ret; 168 ret.push_back(std::move(executable)); 169 return std::move(ret); 170 } 171 172 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 173 InterpreterCompiler::CompileAheadOfTime( 174 std::unique_ptr<HloModuleGroup> module_group, 175 const AotCompilationOptions& aot_options) { 176 return tensorflow::errors::InvalidArgument( 177 "AOT compilation not supported on Interpreter"); 178 } 179 180 se::Platform::Id InterpreterCompiler::PlatformId() const { 181 return se::interpreter::kXlaInterpreterPlatformId; 182 } 183 184 HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() 185 const { 186 return InterpreterExecutable::ShapeSizeBytes; 187 } 188 189 static bool InitModule() { 190 xla::Compiler::RegisterCompilerFactory( 191 se::interpreter::kXlaInterpreterPlatformId, []() { 192 return absl::make_unique<xla::interpreter::InterpreterCompiler>(); 193 }); 194 xla::ComputationPlacer::RegisterComputationPlacer( 195 se::interpreter::kXlaInterpreterPlatformId, 196 []() { return absl::make_unique<xla::ComputationPlacer>(); }); 197 return true; 198 } 199 200 static bool module_initialized = InitModule(); 201 202 } // namespace interpreter 203 } // namespace xla 204