Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/signature.h"
     17 
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "google/protobuf/any.pb.h"
     23 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/platform/protobuf_internal.h"
     29 #include "tensorflow/core/platform/types.h"
     30 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     31 #include "tensorflow/core/public/session.h"
     32 
     33 namespace tensorflow {
     34 namespace serving {
     35 namespace {
     36 
     37 // Returns OK if the input and output batch sizes match.
     38 Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
     39   // Ensure the number of outputs match the number of inputs.
     40   if (input.dim_size(0) != output.dim_size(0)) {
     41     return errors::Internal(strings::StrCat(
     42         "Input batch size did not match output batch size: ", input.dim_size(0),
     43         " vs. ", output.dim_size(0)));
     44   }
     45   return Status::OK();
     46 }
     47 }  // namespace
     48 
     49 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
     50                      Signatures* signatures) {
     51   const auto& collection_def = meta_graph_def.collection_def();
     52   const auto it = collection_def.find(kSignaturesKey);
     53   if (it == collection_def.end() || it->second.any_list().value_size() != 1) {
     54     return errors::FailedPrecondition(
     55         strings::StrCat("Expected exactly one signatures proto in : ",
     56                         DebugStringIfAvailable(meta_graph_def)));
     57   }
     58   const auto& any = it->second.any_list().value(0);
     59   return ParseAny(any, signatures, "tensorflow.serving.Signatures");
     60 }
     61 
     62 Status SetSignatures(const Signatures& signatures,
     63                      tensorflow::MetaGraphDef* meta_graph_def) {
     64   auto& collection_def = *(meta_graph_def->mutable_collection_def());
     65   auto* any_list = collection_def[kSignaturesKey].mutable_any_list();
     66   any_list->mutable_value()->Clear();
     67 #ifdef TENSORFLOW_LITE_PROTOS
     68   signatures.SerializeToString(
     69       any_list->mutable_value()->Add()->mutable_value());
     70 #else
     71   any_list->mutable_value()->Add()->PackFrom(signatures);
     72 #endif
     73   return Status::OK();
     74 }
     75 
     76 Status GetClassificationSignature(
     77     const tensorflow::MetaGraphDef& meta_graph_def,
     78     ClassificationSignature* signature) {
     79   Signatures signatures;
     80   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
     81   if (!signatures.has_default_signature()) {
     82     return errors::FailedPrecondition(
     83         strings::StrCat("Expected a default signature in: ",
     84                         DebugStringIfAvailable(signatures)));
     85   }
     86   if (!signatures.default_signature().has_classification_signature()) {
     87     return errors::FailedPrecondition(strings::StrCat(
     88         "Expected a classification signature in: ",
     89         DebugStringIfAvailable(signatures.default_signature())));
     90   }
     91   *signature = signatures.default_signature().classification_signature();
     92   return Status::OK();
     93 }
     94 
     95 Status GetNamedClassificationSignature(
     96     const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
     97     ClassificationSignature* signature) {
     98   Signatures signatures;
     99   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
    100   const auto& it = signatures.named_signatures().find(name);
    101   if (it == signatures.named_signatures().end()) {
    102     return errors::NotFound(
    103         strings::StrCat("Missing signature named \"", name,
    104                         "\" in: ", DebugStringIfAvailable(signatures)));
    105   }
    106   if (!it->second.has_classification_signature()) {
    107     return errors::FailedPrecondition(
    108         strings::StrCat("Expected a classification signature for name \"", name,
    109                         "\" in: ", DebugStringIfAvailable(it->second)));
    110   }
    111   *signature = it->second.classification_signature();
    112   return Status::OK();
    113 }
    114 
    115 Status RunClassification(const ClassificationSignature& signature,
    116                          const Tensor& input, Session* session, Tensor* classes,
    117                          Tensor* scores) {
    118   std::vector<string> output_tensor_names;
    119   if (classes) {
    120     output_tensor_names.push_back(signature.classes().tensor_name());
    121   }
    122   if (scores) {
    123     output_tensor_names.push_back(signature.scores().tensor_name());
    124   }
    125   // Run the graph with our inputs and outputs.
    126   std::vector<Tensor> outputs;
    127   const Status run_status =
    128       session->Run({{signature.input().tensor_name(), input}},
    129                    output_tensor_names, {}, &outputs);
    130   if (!run_status.ok()) {
    131     return run_status;
    132   }
    133   // Ensure the output is shaped how we expect.
    134   // There should be one string Tensor of shape,
    135   //   [batch_size, num_recommendations].
    136   if (outputs.size() != output_tensor_names.size()) {
    137     return errors::Internal(
    138         strings::StrCat("Expected ", output_tensor_names.size(),
    139                         " output tensor(s).  Got: ", outputs.size()));
    140   }
    141   if (classes) {
    142     *classes = outputs[0];
    143     TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes));
    144   }
    145   if (scores) {
    146     *scores = outputs[classes ? 1 : 0];
    147     TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores));
    148   }
    149   return Status::OK();
    150 }
    151 
    152 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
    153                               RegressionSignature* signature) {
    154   Signatures signatures;
    155   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
    156   if (!signatures.has_default_signature()) {
    157     return errors::FailedPrecondition(
    158         strings::StrCat("Expected a default signature in: ",
    159                         DebugStringIfAvailable(signatures)));
    160   }
    161   if (!signatures.default_signature().has_regression_signature()) {
    162     return errors::FailedPrecondition(strings::StrCat(
    163         "Expected a regression signature in: ",
    164         DebugStringIfAvailable(signatures.default_signature())));
    165   }
    166   *signature = signatures.default_signature().regression_signature();
    167   return Status::OK();
    168 }
    169 
    170 Status RunRegression(const RegressionSignature& signature,
    171                      const Tensor& regression_input, Session* session,
    172                      Tensor* regression_output) {
    173   std::vector<string> output_tensor_names;
    174   if (regression_output) {
    175     output_tensor_names.push_back(signature.output().tensor_name());
    176   }
    177   // Run the graph with our inputs and outputs.
    178   std::vector<Tensor> outputs;
    179   const Status run_status =
    180       session->Run({{signature.input().tensor_name(), regression_input}},
    181                    output_tensor_names, {}, &outputs);
    182   if (!run_status.ok()) {
    183     return run_status;
    184   }
    185   // Ensure the regression score output is shaped how we expect.
    186   // There should be one float Tensor of shape,
    187   //   [batch_size, num_recommendations].
    188   if (outputs.size() != output_tensor_names.size()) {
    189     return errors::Internal(
    190         strings::StrCat("Expected ", output_tensor_names.size(),
    191                         " output tensor(s).  Got: ", outputs.size()));
    192   }
    193   if (regression_output) {
    194     *regression_output = outputs[0];
    195     TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output));
    196   }
    197   return Status::OK();
    198 }
    199 
    200 Status GetGenericSignature(const string& name,
    201                            const tensorflow::MetaGraphDef& meta_graph_def,
    202                            GenericSignature* signature) {
    203   Signatures signatures;
    204   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
    205   const auto& it = signatures.named_signatures().find(name);
    206   if (it == signatures.named_signatures().end()) {
    207     return errors::InvalidArgument(
    208         strings::StrCat("Missing generic signature named \"", name, "\" in ",
    209                         DebugStringIfAvailable(signatures)));
    210   }
    211   if (!it->second.has_generic_signature()) {
    212     return errors::InvalidArgument(strings::StrCat(
    213         "Expected a generic signature: ", DebugStringIfAvailable(it->second)));
    214   }
    215   *signature = it->second.generic_signature();
    216   return Status::OK();
    217 }
    218 
    219 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
    220                            Signature* default_signature) {
    221   Signatures signatures;
    222   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
    223   *default_signature = signatures.default_signature();
    224   return Status::OK();
    225 }
    226 
    227 Status GetNamedSignature(const string& name,
    228                          const tensorflow::MetaGraphDef& meta_graph_def,
    229                          Signature* signature) {
    230   Signatures signatures;
    231   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
    232   const auto& it = signatures.named_signatures().find(name);
    233   if (it == signatures.named_signatures().end()) {
    234     return errors::NotFound(
    235         strings::StrCat("Missing signature named \"", name,
    236                         "\" in: ", DebugStringIfAvailable(signatures)));
    237   }
    238   *signature = it->second;
    239   return Status::OK();
    240 }
    241 
    242 Status BindGenericInputs(const GenericSignature& signature,
    243                          const std::vector<std::pair<string, Tensor>>& inputs,
    244                          std::vector<std::pair<string, Tensor>>* bound_inputs) {
    245   const protobuf::Map<string, serving::TensorBinding>& bindings =
    246       signature.map();
    247 
    248   for (const auto& entry : inputs) {
    249     const auto mapped = bindings.find(entry.first);
    250     if (mapped == bindings.end()) {
    251       return errors::NotFound(
    252           strings::StrCat("Could not find generic binding for: ", entry.first));
    253     }
    254     bound_inputs->push_back({mapped->second.tensor_name(), entry.second});
    255   }
    256   return Status::OK();
    257 }
    258 
    259 Status BindGenericNames(const GenericSignature& signature,
    260                         const std::vector<string>& input_names,
    261                         std::vector<string>* bound_names) {
    262   const protobuf::Map<string, serving::TensorBinding>& bindings =
    263       signature.map();
    264 
    265   for (const string& entry : input_names) {
    266     const auto mapped = bindings.find(entry);
    267     if (mapped == bindings.end()) {
    268       return errors::NotFound(
    269           strings::StrCat("Could not find generic binding for: ", entry));
    270     }
    271     bound_names->push_back(mapped->second.tensor_name());
    272   }
    273   return Status::OK();
    274 }
    275 
    276 }  // namespace serving
    277 }  // namespace tensorflow
    278