1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ 17 #define TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ 18 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/grappler/grappler_item.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/protobuf/device_properties.pb.h" 29 #include "tensorflow/core/public/session_options.h" 30 31 namespace tensorflow { 32 namespace grappler { 33 34 // A cluster represents of collection of hardware resources available to run 35 // the TensorFlow model. 36 // A process can only create a single cluster at a time. 37 class Cluster { 38 public: 39 explicit Cluster(int timeout_s); 40 virtual ~Cluster(); 41 42 // Returns a string that represent the type of cluster that was instantiated. 43 virtual string type() const = 0; 44 45 // Provision the hardware resources needed to run TensorFlow and start a 46 // TensorFlow session that can take advantage of these resources. 47 // The actual resources that are leveraged depend on the type of cluster 48 // instantiated. 49 // Returns OK iff all the requested resources could be reserved and a 50 // TensorFlow session successfully created. Returns an error otherwise. 51 // There is no graceful degradation to handle the case where only a subset 52 // of the requested resources are available. 53 virtual Status Provision() = 0; 54 55 // Attempts to shutdown the cluster. 56 // Returns OK iff there are no pending calls to the Run() method and all the 57 // resources used by the cluster could be released. Returns an error 58 // otherwise. 59 virtual Status Shutdown() { return Status::OK(); } 60 61 // Whether soft placement is allowed. If allow_soft_placement is true, 62 // an op will be placed on CPU if there's no GPU implementation for the OP 63 // or if no GPU devices are known or registered or if we need to co-locate 64 // with reftype input(s) which are from CPU. 65 void AllowSoftPlacement(bool soft_placement_state); 66 67 // Set the number of steps required to warmup TensorFlow. Must be called 68 // before Provision(). 69 void SetNumWarmupSteps(int num_steps); 70 71 // Returns the number of warmup steps. 72 int NumWarmupSteps() const; 73 74 // Disable the collection of detailed statistics. Must be called 75 // before Provision(). 76 void DisableDetailedStats(bool disable); 77 78 // Returns true iff the collection of detailed statistics is enabled. 79 bool DetailedStatsEnabled() const; 80 81 // Disable the TensorFlow optimizer. This ensures that the graph that TF 82 // executes is similar to the input graph. Must be called before Provision(). 83 void DisableOptimizer(bool disable); 84 85 // Return the list of TensorFlow devices that are available to execute a 86 // graph. This is empty until provision() is called. 87 const std::unordered_map<string, DeviceProperties>& GetDevices() const { 88 return devices_; 89 } 90 91 // Convenience method that returns the set of device names. These names are 92 // sorted alphabetically. 93 const std::vector<string> GetDeviceNames() const; 94 95 // Enables collecting the allocator stats. Call with enable=true must be made 96 // before Provision(). 97 virtual Status EnablePeakMemoryStats(bool enable) { 98 return errors::Unimplemented(strings ::StrCat( 99 "Peak Memory Stats are not supported on ", type(), " clusters")); 100 } 101 102 // Returns peak memory of all devices during the session creation and session 103 // runs. 104 virtual Status GetPeakMemoryUsage( 105 std::unordered_map<string, uint64>* device_peak_memory) const { 106 return errors::Unimplemented( 107 "GetPeakMemoryUsage is not implemented for this type of cluster."); 108 } 109 110 // Prepare the session to run the specified grappler item. This include 111 // initializing all the model variables. 112 virtual Status Initialize(const GrapplerItem& item) = 0; 113 114 // Run the specified graph_def and return the corresponding metadata. 115 virtual Status Run(const GraphDef& graph_def, 116 const std::vector<std::pair<string, Tensor>>& feed, 117 const std::vector<string>& fetch, 118 RunMetadata* metadata) = 0; 119 120 protected: 121 std::unordered_map<string, DeviceProperties> devices_; 122 const int timeout_s_; 123 SessionOptions options_; 124 RunOptions run_options_; 125 }; 126 127 } // end namespace grappler 128 } // end namespace tensorflow 129 130 #endif // TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ 131