Home | History | Annotate | Download | only in client
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
     17 #define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <unordered_map>
     22 #include <vector>
     23 
     24 #include "tensorflow/cc/framework/ops.h"
     25 #include "tensorflow/cc/framework/scope.h"
     26 #include "tensorflow/core/public/session_options.h"
     27 
     28 namespace tensorflow {
     29 
     30 /// @addtogroup core
     31 /// @{
     32 
     33 /// A `ClientSession` object lets the caller drive the evaluation of the
     34 /// TensorFlow graph constructed with the C++ API.
     35 ///
     36 /// Example:
     37 ///
     38 ///     Scope root = Scope::NewRootScope();
     39 ///     auto a = Placeholder(root, DT_INT32);
     40 ///     auto c = Add(root, a, {41});
     41 ///
     42 ///     ClientSession session(root);
     43 ///     std::vector<Tensor> outputs;
     44 ///
     45 ///     Status s = session.Run({ {a, {1}} }, {c}, &outputs);
     46 ///     if (!s.ok()) { ... }
     47 class ClientSession {
     48  public:
     49   /// A data type to represent feeds to a Run call.
     50   ///
     51   /// This is a map of `Output` objects returned by op-constructors to the value
     52   /// to feed them with. See `Input::Initializer` for details on what can be
     53   /// used as feed values.
     54   typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;
     55 
     56   /// Create a new session to evaluate the graph contained in `scope` by
     57   /// connecting to the TensorFlow runtime specified by `target`.
     58   ClientSession(const Scope& scope, const string& target);
     59 
     60   /// Same as above, but use the empty string ("") as the target specification.
     61   ClientSession(const Scope& scope);
     62 
     63   /// Create a new session, configuring it with `session_options`.
     64   ClientSession(const Scope& scope, const SessionOptions& session_options);
     65 
     66   ~ClientSession();
     67 
     68   /// Evaluate the tensors in `fetch_outputs`. The values are returned as
     69   /// `Tensor` objects in `outputs`. The number and order of `outputs` will
     70   /// match `fetch_outputs`.
     71   Status Run(const std::vector<Output>& fetch_outputs,
     72              std::vector<Tensor>* outputs) const;
     73 
     74   /// Same as above, but use the mapping in `inputs` as feeds.
     75   Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
     76              std::vector<Tensor>* outputs) const;
     77 
     78   /// Same as above. Additionally runs the operations ins `run_outputs`.
     79   Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
     80              const std::vector<Operation>& run_outputs,
     81              std::vector<Tensor>* outputs) const;
     82 
     83   /// Use `run_options` to turn on performance profiling. `run_metadata`, if not
     84   /// null, is filled in with the profiling results.
     85   Status Run(const RunOptions& run_options, const FeedType& inputs,
     86              const std::vector<Output>& fetch_outputs,
     87              const std::vector<Operation>& run_outputs,
     88              std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
     89 
     90   // TODO(keveman): Add support for partial run.
     91 
     92  private:
     93   class Impl;
     94   std::unique_ptr<Impl> impl_;
     95   Impl* impl() { return impl_.get(); }
     96   const Impl* impl() const { return impl_.get(); }
     97 };
     98 
     99 /// @}
    100 
    101 }  // end namespace tensorflow
    102 
    103 #endif  // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
    104