1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ 18 19 #include <map> 20 #include <memory> 21 #include <string> 22 23 #include "tensorflow/compiler/xla/service/executable.h" 24 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 25 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/thread_annotations.h" 30 31 namespace xla { 32 33 // A cache which stores Executables indexed by computation handle and version. 34 class CompilationCache { 35 public: 36 CompilationCache() {} 37 38 // Insert the given Executable into the cache. Return a bare Executable 39 // pointer for the caller to use. Note: the returned pointer will *not* be the 40 // same as the given unique pointer if the computation already exists in the 41 // cache. See comments in the .cc implementation for details of this case. 42 // 43 // module_config is provided by the caller, instead of being taken from the 44 // executable, so that we can insert keys into the compilation cache that are 45 // devoid of layout (where XLA gets to choose what layout to compile). 46 // 47 // A shared_ptr is returned so the caller can keep the Executable from being 48 // destructed in the event that the Executable is evicted from the 49 // computation cache (and the cache's shared_ptr to the Executable is 50 // destructed). 51 std::shared_ptr<Executable> Insert(std::unique_ptr<Executable> executable, 52 const HloModuleConfig& module_config); 53 54 // Lookup the Executable for the specified versioned computation in the cache. 55 // Return a shared_ptr to the Executable if it exists in the cache. Return 56 // nullptr otherwise. 57 std::shared_ptr<Executable> LookUp( 58 const VersionedComputationHandle& versioned_handle, 59 const HloModuleConfig& module_config) const; 60 61 protected: 62 mutable tensorflow::mutex mutex_; 63 64 // Map from versioned handle with program layout to Executable built 65 // for that computation version and program layout. 66 using CacheKey = string; 67 68 CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, 69 const HloModuleConfig& module_config) const; 70 std::map<CacheKey, std::shared_ptr<Executable>> cache_ GUARDED_BY(mutex_); 71 72 private: 73 TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); 74 }; 75 76 } // namespace xla 77 78 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ 79