Home | History | Annotate | Download | only in clusters
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
     17 #define TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
     18 
     19 #include <string>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/grappler/grappler_item.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/protobuf/device_properties.pb.h"
     29 #include "tensorflow/core/public/session_options.h"
     30 
     31 namespace tensorflow {
     32 namespace grappler {
     33 
     34 // A cluster represents of collection of hardware resources available to run
     35 // the TensorFlow model.
     36 // A process can only create a single cluster at a time.
     37 class Cluster {
     38  public:
     39   explicit Cluster(int timeout_s);
     40   virtual ~Cluster();
     41 
     42   // Returns a string that represent the type of cluster that was instantiated.
     43   virtual string type() const = 0;
     44 
     45   // Provision the hardware resources needed to run TensorFlow and start a
     46   // TensorFlow session that can take advantage of these resources.
     47   // The actual resources that are leveraged depend on the type of cluster
     48   // instantiated.
     49   // Returns OK iff all the requested resources could be reserved and a
     50   // TensorFlow session successfully created. Returns an error otherwise.
     51   // There is no graceful degradation to handle the case where only a subset
     52   // of the requested resources are available.
     53   virtual Status Provision() = 0;
     54 
     55   // Attempts to shutdown the cluster.
     56   // Returns OK iff there are no pending calls to the Run() method and all the
     57   // resources used by the cluster could be released. Returns an error
     58   // otherwise.
     59   virtual Status Shutdown() { return Status::OK(); }
     60 
     61   // Whether soft placement is allowed. If allow_soft_placement is true,
     62   // an op will be placed on CPU if there's no GPU implementation for the OP
     63   // or if no GPU devices are known or registered or if we need to co-locate
     64   // with reftype input(s) which are from CPU.
     65   void AllowSoftPlacement(bool soft_placement_state);
     66 
     67   // Set the number of steps required to warmup TensorFlow. Must be called
     68   // before Provision().
     69   void SetNumWarmupSteps(int num_steps);
     70 
     71   // Returns the number of warmup steps.
     72   int NumWarmupSteps() const;
     73 
     74   // Disable the collection of detailed statistics. Must be called
     75   // before Provision().
     76   void DisableDetailedStats(bool disable);
     77 
     78   // Returns true iff the collection of detailed statistics is enabled.
     79   bool DetailedStatsEnabled() const;
     80 
     81   // Disable the TensorFlow optimizer. This ensures that the graph that TF
     82   // executes is similar to the input graph. Must be called before Provision().
     83   void DisableOptimizer(bool disable);
     84 
     85   // Return the list of TensorFlow devices that are available to execute a
     86   // graph. This is empty until provision() is called.
     87   const std::unordered_map<string, DeviceProperties>& GetDevices() const {
     88     return devices_;
     89   }
     90 
     91   // Convenience method that returns the set of device names. These names are
     92   // sorted alphabetically.
     93   const std::vector<string> GetDeviceNames() const;
     94 
     95   // Enables collecting the allocator stats. Call with enable=true must be made
     96   // before Provision().
     97   virtual Status EnablePeakMemoryStats(bool enable) {
     98     return errors::Unimplemented(strings ::StrCat(
     99         "Peak Memory Stats are not supported on ", type(), " clusters"));
    100   }
    101 
    102   // Returns peak memory of all devices during the session creation and session
    103   // runs.
    104   virtual Status GetPeakMemoryUsage(
    105       std::unordered_map<string, uint64>* device_peak_memory) const {
    106     return errors::Unimplemented(
    107         "GetPeakMemoryUsage is not implemented for this type of cluster.");
    108   }
    109 
    110   // Prepare the session to run the specified grappler item. This include
    111   // initializing all the model variables.
    112   virtual Status Initialize(const GrapplerItem& item) = 0;
    113 
    114   // Run the specified graph_def and return the corresponding metadata.
    115   virtual Status Run(const GraphDef& graph_def,
    116                      const std::vector<std::pair<string, Tensor>>& feed,
    117                      const std::vector<string>& fetch,
    118                      RunMetadata* metadata) = 0;
    119 
    120  protected:
    121   std::unordered_map<string, DeviceProperties> devices_;
    122   const int timeout_s_;
    123   SessionOptions options_;
    124   RunOptions run_options_;
    125 };
    126 
    127 }  // end namespace grappler
    128 }  // end namespace tensorflow
    129 
    130 #endif  // TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
    131