Home | History | Annotate | Download | only in tests
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
     18 
     19 #include <vector>
     20 
     21 #include "absl/memory/memory.h"
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/compiler/xla/client/local_client.h"
     24 #include "tensorflow/compiler/xla/client/xla_computation.h"
     25 #include "tensorflow/compiler/xla/map_util.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/status_macros.h"
     28 #include "tensorflow/compiler/xla/test_helpers.h"
     29 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
     30 #include "tensorflow/core/lib/core/threadpool.h"
     31 #include "tensorflow/core/platform/byte_order.h"
     32 #include "tensorflow/core/platform/env.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 
     35 namespace xla {
     36 
     37 /* static */ TestAllocator* LocalClientTestBase::allocator_;
     38 
     39 StatusOr<OwningDeviceMemory> TestAllocator::Allocate(int device_ordinal,
     40                                                      uint64 size,
     41                                                      bool retry_on_failure) {
     42   VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
     43   {
     44     tensorflow::mutex_lock lock(count_mutex_);
     45     allocation_count_++;
     46     device_allocation_count_[device_ordinal]++;
     47   }
     48   return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size,
     49                                                  retry_on_failure);
     50 }
     51 
     52 Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
     53   VLOG(2) << "Deallocate(" << device_ordinal << ")";
     54   {
     55     tensorflow::mutex_lock lock(count_mutex_);
     56     deallocation_count_++;
     57     device_deallocation_count_[device_ordinal]++;
     58   }
     59   return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
     60 }
     61 
     62 int64 TestAllocator::allocation_count() const {
     63   tensorflow::mutex_lock lock(count_mutex_);
     64   return allocation_count_;
     65 }
     66 
     67 int64 TestAllocator::allocation_count(int device_ordinal) const {
     68   tensorflow::mutex_lock lock(count_mutex_);
     69   auto it = device_allocation_count_.find(device_ordinal);
     70   if (it == device_allocation_count_.end()) {
     71     return 0;
     72   } else {
     73     return it->second;
     74   }
     75 }
     76 
     77 int64 TestAllocator::deallocation_count() const {
     78   tensorflow::mutex_lock lock(count_mutex_);
     79   return deallocation_count_;
     80 }
     81 
     82 int64 TestAllocator::deallocation_count(int device_ordinal) const {
     83   tensorflow::mutex_lock lock(count_mutex_);
     84   auto it = device_deallocation_count_.find(device_ordinal);
     85   if (it == device_deallocation_count_.end()) {
     86     return 0;
     87   } else {
     88     return it->second;
     89   }
     90 }
     91 
     92 /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
     93     se::Platform* platform) {
     94   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
     95   tensorflow::mutex_lock lock(mu);
     96 
     97   if (allocator_ == nullptr) {
     98     allocator_ = new TestAllocator(
     99         platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
    100                             : platform);
    101   }
    102   return allocator_;
    103 }
    104 
    105 // Define this in .cc file to avoid having to include eigen or forward declare
    106 // these types in the header.
    107 struct LocalClientTestBase::EigenThreadPoolWrapper {
    108   explicit EigenThreadPoolWrapper()
    109       : pool(new tensorflow::thread::ThreadPool(
    110             tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)),
    111         wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
    112         device(new Eigen::ThreadPoolDevice(wrapper.get(),
    113                                            wrapper->NumThreads())) {}
    114 
    115   std::unique_ptr<tensorflow::thread::ThreadPool> pool;
    116   std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
    117   std::unique_ptr<Eigen::ThreadPoolDevice> device;
    118 };
    119 
    120 LocalClientTestBase::LocalClientTestBase(se::Platform* platform)
    121     : local_client_(
    122           ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()),
    123       thread_pool_wrapper_(new EigenThreadPoolWrapper()) {
    124   stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
    125                          .ValueOrDie()[local_client_->default_device_ordinal()];
    126   transfer_manager_ =
    127       TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
    128 }
    129 
    130 LocalClientTestBase::~LocalClientTestBase() {}
    131 
    132 ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer(
    133     const Literal& literal) {
    134   return local_client_
    135       ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal())
    136       .ConsumeValueOrDie();
    137 }
    138 
    139 Literal LocalClientTestBase::ShapedBufferToLiteral(
    140     const ShapedBuffer& shaped_buffer) {
    141   return local_client_->ShapedBufferToLiteral(shaped_buffer)
    142       .ConsumeValueOrDie();
    143 }
    144 
    145 ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
    146     const {
    147   return ExecutableBuildOptions();
    148 }
    149 
    150 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
    151   ExecutableRunOptions run_options;
    152   run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get());
    153   run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
    154   return run_options;
    155 }
    156 
    157 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
    158     const XlaComputation& computation,
    159     absl::Span<const ShapedBuffer* const> arguments) {
    160   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
    161                         DefaultExecutableRunOptions())
    162       .ConsumeValueOrDie();
    163 }
    164 
    165 ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
    166     const XlaComputation& computation,
    167     absl::Span<const ShapedBuffer* const> arguments,
    168     const ExecutableBuildOptions& build_options,
    169     const ExecutableRunOptions& run_options) {
    170   return ExecuteLocally(computation, arguments, build_options, run_options)
    171       .ConsumeValueOrDie();
    172 }
    173 
    174 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
    175     const XlaComputation& computation,
    176     absl::Span<const ShapedBuffer* const> arguments) {
    177   return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
    178                         DefaultExecutableRunOptions());
    179 }
    180 
    181 StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
    182     const XlaComputation& computation,
    183     absl::Span<const ShapedBuffer* const> arguments,
    184     const ExecutableBuildOptions& build_options,
    185     const ExecutableRunOptions& run_options) {
    186   std::vector<const Shape*> argument_layouts(arguments.size());
    187   for (int i = 0; i < arguments.size(); ++i) {
    188     argument_layouts[i] = &arguments[i]->on_host_shape();
    189   }
    190   TF_ASSIGN_OR_RETURN(
    191       std::unique_ptr<LocalExecutable> executable,
    192       local_client_->Compile(computation, argument_layouts, build_options));
    193   TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options));
    194 
    195   auto device_ordinal =
    196       build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal();
    197   auto* stream = run_options.stream();
    198   if (!stream) {
    199     stream = local_client_->mutable_backend()
    200                  ->BorrowStream(device_ordinal)
    201                  .ValueOrDie()
    202                  .get();
    203   }
    204   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
    205   return std::move(ret);
    206 }
    207 
    208 }  // namespace xla
    209