Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Helpers for working with TensorFlow exports and their signatures.
     17 
     18 #ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
     19 #define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
     20 
     21 #include <string>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/platform/types.h"
     29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     30 #include "tensorflow/core/protobuf/saver.pb.h"
     31 #include "tensorflow/core/public/session.h"
     32 
     33 namespace tensorflow {
     34 namespace serving {
     35 
     36 const char kSignaturesKey[] = "serving_signatures";
     37 
     38 // Get Signatures from a MetaGraphDef.
     39 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
     40                      Signatures* signatures);
     41 
     42 // (Re)set Signatures in a MetaGraphDef.
     43 Status SetSignatures(const Signatures& signatures,
     44                      tensorflow::MetaGraphDef* meta_graph_def);
     45 
     46 // Gets a ClassificationSignature from a MetaGraphDef's default signature.
     47 // Returns an error if the default signature is not a ClassificationSignature,
     48 // or does not exist.
     49 Status GetClassificationSignature(
     50     const tensorflow::MetaGraphDef& meta_graph_def,
     51     ClassificationSignature* signature);
     52 
     53 // Gets a named ClassificationSignature from a MetaGraphDef.
     54 // Returns an error if a ClassificationSignature with the given name does
     55 // not exist.
     56 Status GetNamedClassificationSignature(
     57     const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
     58     ClassificationSignature* signature);
     59 
     60 // Gets a RegressionSignature from a MetaGraphDef's default signature.
     61 // Returns an error if the default signature is not a RegressionSignature,
     62 // or does not exist.
     63 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
     64                               RegressionSignature* signature);
     65 
     66 // Runs a classification using the provided signature and initialized Session.
     67 //   input: input batch of items to classify
     68 //   classes: output batch of classes; may be null if not needed
     69 //   scores: output batch of scores; may be null if not needed
     70 // Validates sizes of the inputs and outputs are consistent (e.g., input
     71 // batch size equals output batch sizes).
     72 // Does not do any type validation.
     73 Status RunClassification(const ClassificationSignature& signature,
     74                          const Tensor& input, Session* session, Tensor* classes,
     75                          Tensor* scores);
     76 
     77 // Runs regression using the provided signature and initialized Session.
     78 //   input: input batch of items to run the regression model against
     79 //   output: output targets
     80 // Validates sizes of the inputs and outputs are consistent (e.g., input
     81 // batch size equals output batch sizes).
     82 // Does not do any type validation.
     83 Status RunRegression(const RegressionSignature& signature, const Tensor& input,
     84                      Session* session, Tensor* output);
     85 
     86 // Gets the named GenericSignature from a MetaGraphDef.
     87 // Returns an error if a GenericSignature with the given name does not exist.
     88 Status GetGenericSignature(const string& name,
     89                            const tensorflow::MetaGraphDef& meta_graph_def,
     90                            GenericSignature* signature);
     91 
     92 // Gets the default signature from a MetaGraphDef.
     93 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
     94                            Signature* default_signature);
     95 
     96 // Gets a named Signature from a MetaGraphDef.
     97 // Returns an error if a Signature with the given name does not exist.
     98 Status GetNamedSignature(const string& name,
     99                          const tensorflow::MetaGraphDef& meta_graph_def,
    100                          Signature* default_signature);
    101 
    102 // Binds TensorFlow inputs specified by the caller using the logical names
    103 // specified at Graph export time, to the actual Graph names.
    104 // Returns an error if any of the inputs do not have a binding in the export's
    105 // MetaGraphDef.
    106 Status BindGenericInputs(const GenericSignature& signature,
    107                          const std::vector<std::pair<string, Tensor>>& inputs,
    108                          std::vector<std::pair<string, Tensor>>* bound_inputs);
    109 
    110 // Binds the input names specified by the caller using the logical names
    111 // specified at Graph export time, to the actual Graph names. This is useful
    112 // for binding names of both the TensorFlow output tensors and target nodes,
    113 // with the latter (target nodes) being optional and rarely used (if ever) at
    114 // serving time.
    115 // Returns an error if any of the input names do not have a binding in the
    116 // export's MetaGraphDef.
    117 Status BindGenericNames(const GenericSignature& signature,
    118                         const std::vector<string>& input_names,
    119                         std::vector<string>* bound_names);
    120 
    121 }  // namespace serving
    122 }  // namespace tensorflow
    123 
    124 #endif  // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
    125