Home | History | Annotate | Download | only in c
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_
     17 #define TENSORFLOW_C_C_API_INTERNAL_H_
     18 
     19 #include "tensorflow/c/c_api.h"
     20 
     21 #include <list>
     22 #include <set>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <vector>
     26 
     27 #ifndef __ANDROID__
     28 #include "tensorflow/core/framework/op_gen_lib.h"
     29 #endif
     30 #include "tensorflow/core/common_runtime/shape_refiner.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/graph/graph.h"
     34 #include "tensorflow/core/graph/graph_constructor.h"
     35 #include "tensorflow/core/graph/node_builder.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/platform/mutex.h"
     38 #include "tensorflow/core/platform/types.h"
     39 #include "tensorflow/core/public/session.h"
     40 
     41 namespace tensorflow {
     42 class Device;
     43 class DeviceMgr;
     44 }  // namespace tensorflow
     45 
     46 // Internal structures used by the C API. These are likely to change and should
     47 // not be depended on.
     48 
     49 struct TF_Status {
     50   tensorflow::Status status;
     51 };
     52 
     53 struct TF_Tensor {
     54   ~TF_Tensor();
     55 
     56   TF_DataType dtype;
     57   tensorflow::TensorShape shape;
     58   tensorflow::TensorBuffer* buffer;
     59 };
     60 
     61 struct TF_SessionOptions {
     62   tensorflow::SessionOptions options;
     63 };
     64 
     65 struct TF_DeprecatedSession {
     66   tensorflow::Session* session;
     67 };
     68 
     69 struct TF_Library {
     70   void* lib_handle;
     71   TF_Buffer op_list;
     72 };
     73 
     74 struct TF_Graph {
     75   TF_Graph();
     76 
     77   tensorflow::mutex mu;
     78   tensorflow::Graph graph GUARDED_BY(mu);
     79 
     80   // Runs shape inference.
     81   tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
     82 
     83   // Maps from name of an operation to the Node* in 'graph'.
     84   std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
     85       GUARDED_BY(mu);
     86 
     87   // The keys of this map are all the active sessions using this graph.
     88   // Each value is the current "runnability" status of the corresponding
     89   // session. Under normal conditions all statuses are Status::OK(), but
     90   // if some operation is mutated after it was run by a session (this
     91   // is detected in RecordMutation function), that session is no longer
     92   // safe to run. Its status will contain the error that will be returned
     93   // to the user, should she try running this session.
     94   //
     95   // Sessions are added to this map in TF_NewSession, and removed in
     96   // TF_DeleteSession.
     97   // TF_Graph may only / must be deleted when
     98   //   sessions.size() == 0 && delete_requested == true
     99   tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions
    100       GUARDED_BY(mu);
    101   bool delete_requested GUARDED_BY(mu);  // set true by TF_DeleteGraph
    102 
    103   // Used to link graphs contained in TF_WhileParams to the parent graph that
    104   // will eventually contain the full while loop.
    105   TF_Graph* parent;
    106   TF_Output* parent_inputs;
    107 };
    108 
    109 struct TF_OperationDescription {
    110   TF_OperationDescription(TF_Graph* g, const char* op_type,
    111                           const char* node_name)
    112       : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
    113 
    114   tensorflow::NodeBuilder node_builder;
    115   TF_Graph* graph;
    116   std::set<tensorflow::string> colocation_constraints;
    117 };
    118 
    119 struct TF_Operation {
    120   tensorflow::Node node;
    121 };
    122 
    123 struct TF_Session {
    124   TF_Session(tensorflow::Session* s, TF_Graph* g);
    125 
    126   tensorflow::Session* session;
    127   TF_Graph* graph;
    128 
    129   tensorflow::mutex mu;
    130   int last_num_graph_nodes;
    131 
    132   // NOTE(ashankar): Experimental fields to help keep the
    133   // buffers of a TF_Tensor pinned in device memory.
    134   const tensorflow::DeviceMgr* device_mgr;   // Owned by session.
    135   std::vector<tensorflow::Device*> devices;  // Owned by device_mgr.
    136 };
    137 
    138 struct TF_ImportGraphDefOptions {
    139   tensorflow::ImportGraphDefOptions opts;
    140 
    141   // Backing memory for TensorId fields in opts.
    142   // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
    143   std::list<tensorflow::string> tensor_id_data;
    144 };
    145 
    146 struct TF_ImportGraphDefResults {
    147   std::vector<TF_Output> return_tensors;
    148   std::vector<TF_Operation*> return_nodes;
    149   std::vector<const char*> missing_unused_key_names;
    150   std::vector<int> missing_unused_key_indexes;
    151 
    152   // Backing memory for missing_unused_key_names values.
    153   std::list<tensorflow::string> missing_unused_key_names_data;
    154 };
    155 
    156 struct TF_DeviceList {
    157   std::vector<tensorflow::DeviceAttributes> response;
    158 };
    159 
    160 struct TF_Function {
    161   tensorflow::FunctionDef fdef;
    162 };
    163 
    164 struct TF_ApiDefMap {
    165   explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
    166       :
    167 #ifndef __ANDROID__
    168         api_def_map(op_list),
    169 #endif
    170         update_docs_called(false) {
    171   }
    172 
    173 #ifndef __ANDROID__
    174   tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
    175 #endif
    176   bool update_docs_called GUARDED_BY(lock);
    177   tensorflow::mutex lock;
    178 };
    179 
    180 namespace tensorflow {
    181 
    182 class TensorCApi {
    183  public:
    184   static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
    185   static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
    186                            TensorBuffer* buf) {
    187     return Tensor(static_cast<DataType>(type), shape, buf);
    188   }
    189 };
    190 
    191 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
    192 
    193 TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
    194 
    195 Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
    196 
    197 // Set the shapes and types of the output's handle.
    198 //
    199 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
    200 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
    201 // rank is known), then it must be equal to the length of `shapes[i]`; if
    202 // `ranks[i] == 1`, then `shapes[i]` may be nullptr.
    203 //
    204 // TODO(akshayka): Implement a corresponding getter method.
    205 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
    206                                            int num_shapes_and_types,
    207                                            const int64_t** shapes,
    208                                            const int* ranks,
    209                                            const TF_DataType* types,
    210                                            TF_Status* status);
    211 
    212 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
    213                     const char* mutation_type);
    214 
    215 }  // end namespace tensorflow
    216 
    217 #endif  // TENSORFLOW_C_C_API_INTERNAL_H_
    218