Home | History | Annotate | Download | only in client
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The "client library" instantiates a local (in-process) XLA service for
     17 // use by this process, and connects to it with a singleton XLA local
     18 // client. ClientLibrary::GetOrCreateLocalClient will spawn a local service,
     19 // and return a client that's connected to it and ready to run XLA
     20 // computations.
     21 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
     22 #define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
     23 
     24 #include <functional>
     25 #include <memory>
     26 #include <string>
     27 #include <vector>
     28 
     29 #include "tensorflow/compiler/xla/client/compile_only_client.h"
     30 #include "tensorflow/compiler/xla/client/local_client.h"
     31 #include "tensorflow/compiler/xla/service/compile_only_service.h"
     32 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
     33 #include "tensorflow/compiler/xla/service/local_service.h"
     34 #include "tensorflow/compiler/xla/statusor.h"
     35 #include "tensorflow/compiler/xla/types.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
     39 #include "tensorflow/core/platform/thread_annotations.h"
     40 
     41 namespace xla {
     42 
     43 // Options to configure the local client when it is created.
     44 class LocalClientOptions {
     45  public:
     46   LocalClientOptions(perftools::gputools::Platform* platform = nullptr,
     47                      int number_of_replicas = 1,
     48                      int intra_op_parallelism_threads = -1);
     49 
     50   // Set the platform backing the service, or nullptr for the default platform.
     51   LocalClientOptions& set_platform(perftools::gputools::Platform* platform);
     52   perftools::gputools::Platform* platform() const;
     53 
     54   // Set the number of replicas to use when compiling replicated
     55   // programs.
     56   LocalClientOptions& set_number_of_replicas(int number_of_replicas);
     57   int number_of_replicas() const;
     58 
     59   // Sets the thread pool size for parallel execution of an individual operator.
     60   LocalClientOptions& set_intra_op_parallelism_threads(int num_threads);
     61   int intra_op_parallelism_threads() const;
     62 
     63  private:
     64   perftools::gputools::Platform* platform_;
     65   int number_of_replicas_;
     66   int intra_op_parallelism_threads_;
     67 };
     68 
     69 class ClientLibrary {
     70  public:
     71   // Singleton constructor-or-accessor -- returns a client for the application
     72   // to issue XLA commands on. Arguments:
     73   //
     74   //   platform : The platform the underlying XLA service should target. If
     75   //     null then default platform is used.
     76   static StatusOr<LocalClient*> GetOrCreateLocalClient(
     77       perftools::gputools::Platform* platform = nullptr);
     78   static StatusOr<LocalClient*> GetOrCreateLocalClient(
     79       const LocalClientOptions& options);
     80 
     81   // Convenience "or-die" wrapper around the above which returns the existing
     82   // client library or creates one with default platform and allocator.
     83   static LocalClient* LocalClientOrDie();
     84 
     85   // Returns the service from the service thread. Only used in unit tests to
     86   // access user computations from client.
     87   static LocalService* GetXlaService(perftools::gputools::Platform* platform);
     88 
     89   // Singleton constructor-or-accessor for compile-only clients. Arguments:
     90   //
     91   //   platform : The platform the underlying XLA service should target. If
     92   //     null then default platform is used.
     93   static StatusOr<CompileOnlyClient*> GetOrCreateCompileOnlyClient(
     94       perftools::gputools::Platform* platform = nullptr);
     95 
     96   // Clears the local instance and compile only instance caches. The client
     97   // pointers returned by the previous GetOrCreateLocalClient() or
     98   // GetOrCreateCompileOnlyClient() invocations are not valid anymore.
     99   static void DestroyLocalInstances();
    100 
    101  private:
    102   // Returns the singleton instance of ClientLibrary.
    103   static ClientLibrary& Singleton();
    104 
    105   ClientLibrary();
    106   ~ClientLibrary();
    107 
    108   struct LocalInstance {
    109     // Service that is wrapped by the singleton client object.
    110     std::unique_ptr<LocalService> service;
    111     // Singleton client object.
    112     std::unique_ptr<LocalClient> client;
    113   };
    114 
    115   struct CompileOnlyInstance {
    116     // Service that is wrapped by the singleton client object.
    117     std::unique_ptr<CompileOnlyService> service;
    118     // Singleton client object.
    119     std::unique_ptr<CompileOnlyClient> client;
    120   };
    121 
    122   tensorflow::mutex service_mutex_;  // Guards the singleton creation state.
    123   std::unordered_map<perftools::gputools::Platform::Id,
    124                      std::unique_ptr<LocalInstance>>
    125       local_instances_ GUARDED_BY(service_mutex_);
    126 
    127   std::unordered_map<perftools::gputools::Platform::Id,
    128                      std::unique_ptr<CompileOnlyInstance>>
    129       compile_only_instances_ GUARDED_BY(service_mutex_);
    130 
    131   TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
    132 };
    133 
    134 }  // namespace xla
    135 
    136 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
    137