Home | History | Annotate | Download | only in c
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_C_C_API_INTERNAL_H_
     17 #define TENSORFLOW_C_C_API_INTERNAL_H_
     18 
     19 #include "tensorflow/c/c_api.h"
     20 
     21 #include <list>
     22 #include <set>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <vector>
     26 
     27 // Required for IS_MOBILE_PLATFORM
     28 #include "tensorflow/core/platform/platform.h"  // NO_LINT
     29 
     30 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
     31 #include "tensorflow/core/framework/op_gen_lib.h"
     32 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
     33 #include "tensorflow/core/common_runtime/shape_refiner.h"
     34 #include "tensorflow/core/framework/tensor.h"
     35 #include "tensorflow/core/framework/tensor_shape.h"
     36 #include "tensorflow/core/graph/graph.h"
     37 #include "tensorflow/core/graph/graph_constructor.h"
     38 #include "tensorflow/core/graph/node_builder.h"
     39 #include "tensorflow/core/lib/core/status.h"
     40 #include "tensorflow/core/platform/mutex.h"
     41 #include "tensorflow/core/platform/types.h"
     42 #include "tensorflow/core/public/session.h"
     43 
     44 namespace tensorflow {
     45 class Device;
     46 class DeviceMgr;
     47 class ServerInterface;
     48 }  // namespace tensorflow
     49 
     50 // Internal structures used by the C API. These are likely to change and should
     51 // not be depended on.
     52 
     53 struct TF_Status {
     54   tensorflow::Status status;
     55 };
     56 
     57 struct TF_Tensor {
     58   ~TF_Tensor();
     59 
     60   TF_DataType dtype;
     61   tensorflow::TensorShape shape;
     62   tensorflow::TensorBuffer* buffer;
     63 };
     64 
     65 struct TF_SessionOptions {
     66   tensorflow::SessionOptions options;
     67 };
     68 
     69 struct TF_DeprecatedSession {
     70   tensorflow::Session* session;
     71 };
     72 
     73 struct TF_Library {
     74   void* lib_handle;
     75   TF_Buffer op_list;
     76 };
     77 
     78 struct TF_Graph {
     79   TF_Graph();
     80 
     81   tensorflow::mutex mu;
     82   tensorflow::Graph graph GUARDED_BY(mu);
     83 
     84   // Runs shape inference.
     85   tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
     86 
     87   // Maps from name of an operation to the Node* in 'graph'.
     88   std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
     89       GUARDED_BY(mu);
     90 
     91   // The keys of this map are all the active sessions using this graph. Each
     92   // value records whether the graph has been mutated since the corresponding
     93   // session has been run (this is detected in RecordMutation function). If the
     94   // string is empty, no mutation has occurred. Otherwise the string is a
     95   // description of the mutation suitable for returning to the user.
     96   //
     97   // Sessions are added to this map in TF_NewSession, and removed in
     98   // TF_DeleteSession.
     99   // TF_Graph may only / must be deleted when
    100   //   sessions.size() == 0 && delete_requested == true
    101   //
    102   // TODO(b/74949947): mutations currently trigger a warning instead of a bad
    103   // status, this should be reverted when possible.
    104   tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
    105       GUARDED_BY(mu);
    106   bool delete_requested GUARDED_BY(mu);  // set true by TF_DeleteGraph
    107 
    108   // Used to link graphs contained in TF_WhileParams to the parent graph that
    109   // will eventually contain the full while loop.
    110   TF_Graph* parent;
    111   TF_Output* parent_inputs;
    112 };
    113 
    114 struct TF_OperationDescription {
    115   TF_OperationDescription(TF_Graph* g, const char* op_type,
    116                           const char* node_name)
    117       : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
    118 
    119   tensorflow::NodeBuilder node_builder;
    120   TF_Graph* graph;
    121   std::set<tensorflow::string> colocation_constraints;
    122 };
    123 
    124 struct TF_Operation {
    125   tensorflow::Node node;
    126 };
    127 
    128 struct TF_Session {
    129   TF_Session(tensorflow::Session* s, TF_Graph* g);
    130 
    131   tensorflow::Session* session;
    132   TF_Graph* const graph;
    133 
    134   tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
    135   int last_num_graph_nodes;
    136 
    137   // If true, TF_SessionRun and similar methods will call
    138   // ExtendSessionGraphHelper before running the graph (this is the default
    139   // public behavior). Can be set to false if the caller needs to call
    140   // ExtendSessionGraphHelper manually.
    141   std::atomic<bool> extend_before_run;
    142 };
    143 
    144 struct TF_ImportGraphDefOptions {
    145   tensorflow::ImportGraphDefOptions opts;
    146 
    147   // Backing memory for TensorId fields in opts.
    148   // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
    149   std::list<tensorflow::string> tensor_id_data;
    150 };
    151 
    152 struct TF_ImportGraphDefResults {
    153   std::vector<TF_Output> return_tensors;
    154   std::vector<TF_Operation*> return_nodes;
    155   std::vector<const char*> missing_unused_key_names;
    156   std::vector<int> missing_unused_key_indexes;
    157 
    158   // Backing memory for missing_unused_key_names values.
    159   std::list<tensorflow::string> missing_unused_key_names_data;
    160 };
    161 
    162 struct TF_DeviceList {
    163   std::vector<tensorflow::DeviceAttributes> response;
    164 };
    165 
    166 struct TF_Function {
    167   tensorflow::FunctionDef fdef;
    168 };
    169 
    170 struct TF_ApiDefMap {
    171   explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
    172       :
    173 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    174         api_def_map(op_list),
    175 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    176         update_docs_called(false) {
    177   }
    178 
    179 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    180   tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
    181 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    182   bool update_docs_called GUARDED_BY(lock);
    183   tensorflow::mutex lock;
    184 };
    185 
    186 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    187 struct TF_Server {
    188   TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
    189 
    190   const tensorflow::string target;
    191   std::unique_ptr<tensorflow::ServerInterface> server;
    192 };
    193 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
    194 
    195 namespace tensorflow {
    196 
    197 class TensorCApi {
    198  public:
    199   static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
    200   static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
    201                            TensorBuffer* buf) {
    202     return Tensor(static_cast<DataType>(type), shape, buf);
    203   }
    204 };
    205 
    206 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
    207 
    208 TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
    209 
    210 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
    211                        TF_Buffer* out);
    212 
    213 // Set the shapes and types of the output's handle.
    214 //
    215 // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
    216 // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
    217 // rank is known), then it must be equal to the length of `shapes[i]`; if
    218 // `ranks[i] == 1`, then `shapes[i]` may be nullptr.
    219 //
    220 // TODO(akshayka): Implement a corresponding getter method.
    221 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
    222                                            int num_shapes_and_types,
    223                                            const int64_t** shapes,
    224                                            const int* ranks,
    225                                            const TF_DataType* types,
    226                                            TF_Status* status);
    227 
    228 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
    229                     const char* mutation_type)
    230     EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
    231 
    232 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
    233     LOCKS_EXCLUDED(session->graph->mu, session->mu);
    234 
    235 std::string getTF_OutputDebugString(TF_Output node);
    236 
    237 }  // end namespace tensorflow
    238 
    239 #endif  // TENSORFLOW_C_C_API_INTERNAL_H_
    240