Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/python/framework/cpp_shape_inference.h"
     17 
     18 #include "tensorflow/core/framework/node_def.pb.h"
     19 #include "tensorflow/core/framework/op.h"
     20 #include "tensorflow/core/framework/shape_inference.h"
     21 #include "tensorflow/core/framework/tensor_shape.pb.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/lib/strings/strcat.h"
     24 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
     25 #include "tensorflow/python/lib/core/py_func.h"
     26 
     27 namespace tensorflow {
     28 
     29 namespace swig {
     30 namespace {
     31 
     32 void ProtoFromShapeHandle(tensorflow::shape_inference::ShapeHandle s,
     33                           tensorflow::shape_inference::InferenceContext* c,
     34                           TensorShapeProto* out) {
     35   if (c->RankKnown(s)) {
     36     const int32 rank = c->Rank(s);
     37     for (int i = 0; i < rank; ++i) {
     38       shape_inference::DimensionHandle d = c->Dim(s, i);
     39       auto* out_dim = out->add_dim();
     40       if (c->ValueKnown(d)) {
     41         out_dim->set_size(c->Value(d));
     42       } else {
     43         out_dim->set_size(-1);
     44       }
     45     }
     46   } else {
     47     out->set_unknown_rank(true);
     48   }
     49 }
     50 
     51 Status RunCppShapeInferenceImpl(
     52     int graph_def_version, const string& serialized_node_def,
     53     const std::vector<string>& input_serialized_shapes,
     54     const std::vector<PyObject*>& input_constant_tensor_values,
     55     const std::vector<string>& input_constant_tensor_as_shape_values,
     56     std::vector<string>* output_tensor_shape_protos,
     57     string* input_tensors_needed_out) {
     58   tensorflow::NodeDef node;
     59   if (!node.ParseFromString(serialized_node_def)) {
     60     return errors::InvalidArgument(
     61         "Error parsing node_def during cpp shape inference");
     62   }
     63   DCHECK_EQ(output_tensor_shape_protos->size(), 0);
     64 
     65   const OpRegistrationData* op_reg_data;
     66   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node.op(), &op_reg_data));
     67 
     68   if (op_reg_data->shape_inference_fn == nullptr) {
     69     return errors::InvalidArgument(
     70         "No shape inference function exists for op '", node.op(),
     71         "', did you forget to define it?");
     72   }
     73 
     74   // Convert input shapes.
     75   std::vector<TensorShapeProto> input_shapes;
     76   std::vector<
     77       std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
     78       input_handle_shapes_and_types;
     79   input_shapes.resize(input_serialized_shapes.size());
     80   input_handle_shapes_and_types.resize(input_serialized_shapes.size());
     81   CppShapeInferenceResult tmp;
     82   for (int i = 0; i < input_serialized_shapes.size(); ++i) {
     83     tmp.Clear();
     84     if (!tmp.ParseFromString(input_serialized_shapes[i])) {
     85       return errors::InvalidArgument(
     86           "Error parsing shape proto during cpp shape inference");
     87     }
     88 
     89     input_shapes[i].Swap(tmp.mutable_shape());
     90 
     91     if (tmp.handle_data().is_set()) {
     92       input_handle_shapes_and_types[i].reset(
     93           new std::vector<std::pair<TensorShapeProto, DataType>>);
     94       auto& v = *input_handle_shapes_and_types[i];
     95       for (const auto& x : tmp.handle_data().shape_and_type()) {
     96         v.emplace_back(x.shape(), x.dtype());
     97       }
     98     }
     99   }
    100 
    101   // Convert input tensor values;
    102   std::vector<Tensor> input_tensor_values(input_constant_tensor_values.size());
    103   std::vector<const Tensor*> input_tensors;
    104   for (int i = 0; i < input_constant_tensor_values.size(); ++i) {
    105     auto* py_val = input_constant_tensor_values[i];
    106     if (py_val == Py_None) {
    107       input_tensors.push_back(nullptr);
    108     } else {
    109       TF_RETURN_IF_ERROR(
    110           ConvertNdarrayToTensor(py_val, &input_tensor_values[i]));
    111       input_tensors.push_back(&input_tensor_values[i]);
    112     }
    113   }
    114 
    115   // Convert input tensor-as-shape values;
    116   std::vector<TensorShapeProto> input_tensor_as_shapes_protos(
    117       input_constant_tensor_as_shape_values.size());
    118   for (int i = 0; i < input_constant_tensor_as_shape_values.size(); ++i) {
    119     if (!input_tensor_as_shapes_protos[i].ParseFromString(
    120             input_constant_tensor_as_shape_values[i])) {
    121       return errors::InvalidArgument(
    122           "Error parsing shape proto during cpp shape inference");
    123     }
    124   }
    125 
    126   // Run shape inference.
    127   tensorflow::shape_inference::InferenceContext c(
    128       graph_def_version, &node, op_reg_data->op_def, input_shapes,
    129       input_tensors, input_tensor_as_shapes_protos,
    130       input_handle_shapes_and_types);
    131   TF_RETURN_IF_ERROR(c.construction_status());
    132 
    133   TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn));
    134 
    135   // Convert output shapes.
    136   output_tensor_shape_protos->resize(c.num_outputs());
    137   CppShapeInferenceResult out;
    138   for (int i = 0; i < c.num_outputs(); ++i) {
    139     out.Clear();
    140     ProtoFromShapeHandle(c.output(i), &c, out.mutable_shape());
    141 
    142     const auto* shapes_and_types = c.output_handle_shapes_and_types(i);
    143     if (shapes_and_types != nullptr) {
    144       auto* out_handle_data = out.mutable_handle_data();
    145       out_handle_data->set_is_set(true);
    146       for (const auto& p : *shapes_and_types) {
    147         auto* out_shape_and_type = out_handle_data->add_shape_and_type();
    148         ProtoFromShapeHandle(p.shape, &c, out_shape_and_type->mutable_shape());
    149         out_shape_and_type->set_dtype(p.dtype);
    150       }
    151     }
    152 
    153     CHECK(out.AppendToString(&(*output_tensor_shape_protos)[i]));
    154   }
    155 
    156   // Add info about requested inputs.
    157   CppShapeInferenceInputsNeeded needed;
    158   for (int i = 0; i < c.num_inputs(); ++i) {
    159     if (c.requested_input_tensor(i)) {
    160       needed.add_input_tensors_needed(i);
    161     }
    162     if (c.requested_input_tensor_as_partial_shape(i)) {
    163       needed.add_input_tensors_as_shapes_needed(i);
    164     }
    165   }
    166   *input_tensors_needed_out = needed.SerializeAsString();
    167 
    168   return Status::OK();
    169 }
    170 
    171 }  // namespace
    172 
    173 std::vector<string> RunCppShapeInference(
    174     int graph_def_version, const string& serialized_node_def,
    175     const std::vector<string>& input_serialized_shapes,
    176     PyObject* input_constant_tensor_values,
    177     const std::vector<string>& input_constant_tensor_as_shape_values,
    178     TF_Status* out_status) {
    179   if (!PyList_Check(input_constant_tensor_values)) {
    180     TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value");
    181     return std::vector<string>();
    182   }
    183 
    184   std::vector<PyObject*> input_constant_tensor_values_v;
    185   int cnt = PyList_Size(input_constant_tensor_values);
    186   input_constant_tensor_values_v.reserve(cnt);
    187   for (int i = 0; i < cnt; ++i) {
    188     input_constant_tensor_values_v.push_back(
    189         PyList_GetItem(input_constant_tensor_values, i));
    190   }
    191 
    192   std::vector<string> output;
    193   string input_tensors_needed_out;
    194   tensorflow::Status status = RunCppShapeInferenceImpl(
    195       graph_def_version, serialized_node_def, input_serialized_shapes,
    196       input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
    197       &output, &input_tensors_needed_out);
    198 
    199   Set_TF_Status_from_Status(out_status, status);
    200   if (!status.ok()) {
    201     return std::vector<string>();
    202   }
    203   output.push_back(input_tensors_needed_out);
    204   return output;
    205 }
    206 
    207 }  // namespace swig
    208 }  // namespace tensorflow
    209