1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/interpreter/executor.h" 17 18 #include <cstring> 19 20 #include "tensorflow/compiler/xla/status_macros.h" 21 22 namespace perftools { 23 namespace gputools { 24 namespace interpreter { 25 26 host::HostStream *AsExecutorStream(Stream *stream) { 27 DCHECK(stream != nullptr); 28 return dynamic_cast<host::HostStream *>(stream->implementation()); 29 } 30 31 InterpreterExecutor::InterpreterExecutor(const PluginConfig &plugin_config) 32 : plugin_config_(plugin_config) {} 33 34 InterpreterExecutor::~InterpreterExecutor() {} 35 36 void *InterpreterExecutor::Allocate(uint64 size) { return new char[size]; } 37 38 void *InterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, 39 uint64 offset_bytes, 40 uint64 /*size_bytes*/) { 41 return parent + offset_bytes; 42 } 43 44 void InterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { 45 if (!mem->is_sub_buffer()) { 46 delete[] static_cast<char *>(mem->opaque()); 47 } 48 } 49 50 bool InterpreterExecutor::Memcpy(Stream *stream, void *host_dst, 51 const DeviceMemoryBase &dev_src, uint64 size) { 52 AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { 53 port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); 54 }); 55 return true; 56 } 57 58 bool InterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, 59 const void *host_src, uint64 size) { 60 AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { 61 port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); 62 }); 63 return true; 64 } 65 66 port::Status InterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, 67 const void *host_src, 68 uint64 size) { 69 memcpy(dev_dst->opaque(), host_src, size); 70 return port::Status::OK(); 71 } 72 73 port::Status InterpreterExecutor::SynchronousMemcpy( 74 void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) { 75 memcpy(host_dst, dev_src.opaque(), size); 76 return port::Status::OK(); 77 } 78 79 bool InterpreterExecutor::HostCallback(Stream *stream, 80 std::function<void()> callback) { 81 AsExecutorStream(stream)->EnqueueTask(callback); 82 return true; 83 } 84 85 bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, 86 Stream *other) { 87 AsExecutorStream(dependent)->EnqueueTask( 88 [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); 89 AsExecutorStream(dependent)->BlockUntilDone(); 90 return true; 91 } 92 93 bool InterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { 94 dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream); 95 return true; 96 } 97 98 bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { 99 dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream); 100 return true; 101 } 102 103 port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { 104 AsExecutorStream(stream)->BlockUntilDone(); 105 return port::Status::OK(); 106 } 107 108 DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { 109 internal::DeviceDescriptionBuilder builder; 110 111 builder.set_device_address_bits(64); 112 113 builder.set_name("Interpreter"); 114 builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024); 115 builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9); 116 117 return builder.Build().release(); 118 } 119 120 } // namespace interpreter 121 } // namespace gputools 122 } // namespace perftools 123