1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Helpers for working with TensorFlow exports and their signatures. 17 18 #ifndef TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ 19 #define TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ 20 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/contrib/session_bundle/manifest.pb.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/protobuf/meta_graph.pb.h" 30 #include "tensorflow/core/protobuf/saver.pb.h" 31 #include "tensorflow/core/public/session.h" 32 33 namespace tensorflow { 34 namespace serving { 35 36 const char kSignaturesKey[] = "serving_signatures"; 37 38 // Get Signatures from a MetaGraphDef. 39 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def, 40 Signatures* signatures); 41 42 // (Re)set Signatures in a MetaGraphDef. 43 Status SetSignatures(const Signatures& signatures, 44 tensorflow::MetaGraphDef* meta_graph_def); 45 46 // Gets a ClassificationSignature from a MetaGraphDef's default signature. 47 // Returns an error if the default signature is not a ClassificationSignature, 48 // or does not exist. 49 Status GetClassificationSignature( 50 const tensorflow::MetaGraphDef& meta_graph_def, 51 ClassificationSignature* signature); 52 53 // Gets a named ClassificationSignature from a MetaGraphDef. 54 // Returns an error if a ClassificationSignature with the given name does 55 // not exist. 56 Status GetNamedClassificationSignature( 57 const string& name, const tensorflow::MetaGraphDef& meta_graph_def, 58 ClassificationSignature* signature); 59 60 // Gets a RegressionSignature from a MetaGraphDef's default signature. 61 // Returns an error if the default signature is not a RegressionSignature, 62 // or does not exist. 63 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def, 64 RegressionSignature* signature); 65 66 // Runs a classification using the provided signature and initialized Session. 67 // input: input batch of items to classify 68 // classes: output batch of classes; may be null if not needed 69 // scores: output batch of scores; may be null if not needed 70 // Validates sizes of the inputs and outputs are consistent (e.g., input 71 // batch size equals output batch sizes). 72 // Does not do any type validation. 73 Status RunClassification(const ClassificationSignature& signature, 74 const Tensor& input, Session* session, Tensor* classes, 75 Tensor* scores); 76 77 // Runs regression using the provided signature and initialized Session. 78 // input: input batch of items to run the regression model against 79 // output: output targets 80 // Validates sizes of the inputs and outputs are consistent (e.g., input 81 // batch size equals output batch sizes). 82 // Does not do any type validation. 83 Status RunRegression(const RegressionSignature& signature, const Tensor& input, 84 Session* session, Tensor* output); 85 86 // Gets the named GenericSignature from a MetaGraphDef. 87 // Returns an error if a GenericSignature with the given name does not exist. 88 Status GetGenericSignature(const string& name, 89 const tensorflow::MetaGraphDef& meta_graph_def, 90 GenericSignature* signature); 91 92 // Gets the default signature from a MetaGraphDef. 93 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def, 94 Signature* default_signature); 95 96 // Gets a named Signature from a MetaGraphDef. 97 // Returns an error if a Signature with the given name does not exist. 98 Status GetNamedSignature(const string& name, 99 const tensorflow::MetaGraphDef& meta_graph_def, 100 Signature* default_signature); 101 102 // Binds TensorFlow inputs specified by the caller using the logical names 103 // specified at Graph export time, to the actual Graph names. 104 // Returns an error if any of the inputs do not have a binding in the export's 105 // MetaGraphDef. 106 Status BindGenericInputs(const GenericSignature& signature, 107 const std::vector<std::pair<string, Tensor>>& inputs, 108 std::vector<std::pair<string, Tensor>>* bound_inputs); 109 110 // Binds the input names specified by the caller using the logical names 111 // specified at Graph export time, to the actual Graph names. This is useful 112 // for binding names of both the TensorFlow output tensors and target nodes, 113 // with the latter (target nodes) being optional and rarely used (if ever) at 114 // serving time. 115 // Returns an error if any of the input names do not have a binding in the 116 // export's MetaGraphDef. 117 Status BindGenericNames(const GenericSignature& signature, 118 const std::vector<string>& input_names, 119 std::vector<string>* bound_names); 120 121 } // namespace serving 122 } // namespace tensorflow 123 124 #endif // TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ 125