1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/shape_inference_testutil.h" 16 17 #include "tensorflow/core/framework/node_def_util.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/lib/gtl/map_util.h" 20 #include "tensorflow/core/lib/strings/numbers.h" 21 #include "tensorflow/core/lib/strings/scanner.h" 22 #include "tensorflow/core/lib/strings/str_util.h" 23 24 namespace tensorflow { 25 namespace shape_inference { 26 27 using errors::Unknown; 28 29 Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, 30 const string& ins, 31 const string& expected_outs) { 32 const OpRegistrationData* op_reg_data; 33 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); 34 35 std::vector<string> ins_v = str_util::Split(ins, ';'); 36 std::unique_ptr<const NodeDef> new_node_def; 37 38 InferenceContext::ShapeManager manager; 39 std::vector<ShapeHandle> in_shapes; 40 for (const string& spec : ins_v) { 41 ShapeHandle shape; 42 TF_RETURN_IF_ERROR(MakeShapeFromString(&manager, spec, &shape)); 43 in_shapes.push_back(shape); 44 } 45 46 std::vector<std::unique_ptr<std::vector<shape_inference::ShapeAndType>>> 47 input_resource_handle_shapes_and_types; 48 for (const auto p : op.input_resource_handle_shapes_and_types) { 49 if (p == nullptr) { 50 input_resource_handle_shapes_and_types.push_back(nullptr); 51 } else { 52 std::unique_ptr<std::vector<ShapeAndType>> v( 53 new std::vector<ShapeAndType>()); 54 for (const auto& shape_and_type : *p) { 55 ShapeHandle shape; 56 TF_RETURN_IF_ERROR( 57 MakeShapeFromString(&manager, shape_and_type.first, &shape)); 58 v->emplace_back(shape, shape_and_type.second); 59 } 60 input_resource_handle_shapes_and_types.emplace_back(v.release()); 61 } 62 } 63 shape_inference::InferenceContext c( 64 op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes, 65 op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types)); 66 TF_RETURN_IF_ERROR(c.construction_status()); 67 if (op_reg_data->shape_inference_fn == nullptr) { 68 return errors::InvalidArgument( 69 "No shape inference function exists for op '", op.name, 70 "', did you forget to define it?"); 71 } 72 73 TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); 74 75 const int num_outputs = c.num_outputs(); 76 77 if (expected_outs == "e") { 78 return Unknown("Shape inference should have returned error"); 79 } 80 81 // Verify the output shape. 82 std::vector<string> expected_outs_v = str_util::Split(expected_outs, ';'); 83 if (num_outputs != expected_outs_v.size()) { 84 return Unknown("The expected output string lists the wrong number of ", 85 "outputs. It lists ", expected_outs_v.size(), 86 " but should list ", num_outputs); 87 } 88 for (int i = 0; i < num_outputs; ++i) { 89 StringPiece expected(expected_outs_v[i]); 90 shape_inference::ShapeHandle out = c.output(i); 91 92 string err_prefix = strings::StrCat("Output ", i); 93 string err_suffix = 94 strings::StrCat(". Output shape was ", c.DebugString(out)); 95 96 int in_index = -1; 97 for (int i = 0; i < c.num_inputs(); ++i) { 98 if (c.input(i).SameHandle(out)) { 99 in_index = i; 100 } 101 } 102 103 if (expected.starts_with("in")) { 104 if (in_index == -1) { 105 return Unknown(err_prefix, 106 " should have matched an input shape by " 107 "handle, but matched no input shape. This means the ", 108 "shape function was expected to pass an input " 109 "ShapeHandle through for this output, but did not", 110 err_suffix); 111 } 112 auto v = str_util::Split(expected, '|'); 113 if (std::find(v.begin(), v.end(), strings::StrCat("in", in_index)) == 114 v.end()) { 115 return Unknown( 116 err_prefix, " matched input ", in_index, 117 " by handle, but should have matched one of (", expected, 118 ") instead. This means the shape function passed the ShapeHandle ", 119 "for input ", in_index, 120 " to the output, but should have passed a different input ", 121 "ShapeHandle through", err_suffix); 122 } 123 continue; 124 } 125 if (in_index != -1) { 126 return Unknown(err_prefix, " matched input ", in_index, 127 " by ShapeHandle, but was expected to not match an input ", 128 "shape by handle", err_suffix); 129 } 130 if (expected == "?") { 131 if (c.RankKnown(out)) { 132 return Unknown(err_prefix, " expected to be unknown", err_suffix); 133 } 134 continue; 135 } 136 137 // Verify the dimensions. 138 CHECK(expected.starts_with("[") && expected.ends_with("]")) << expected; 139 expected.remove_prefix(1); 140 expected.remove_suffix(1); 141 142 // Split expected as a dimension. 143 auto expected_dims = str_util::Split(expected, ','); 144 if (!c.RankKnown(out)) { 145 return Unknown(err_prefix, " expected rank ", expected_dims.size(), 146 " but was ?", err_suffix); 147 } 148 if (c.Rank(out) != expected_dims.size()) { 149 return Unknown(err_prefix, " expected rank ", expected_dims.size(), 150 " but was ", c.Rank(out), err_suffix); 151 } 152 for (int j = 0; j < expected_dims.size(); ++j) { 153 err_prefix = strings::StrCat("Output dim ", i, ",", j); 154 StringPiece expected_dim(expected_dims[j]); 155 DimensionHandle out_dim = c.Dim(out, j); 156 157 std::pair<int, int> in_dim_idx(-1, -1); 158 for (int i = 0; i < c.num_inputs(); ++i) { 159 auto in = c.input(i); 160 for (int j = 0; j < c.Rank(in); ++j) { 161 if (c.Dim(in, j).SameHandle(out_dim)) { 162 in_dim_idx = std::make_pair(i, j); 163 } 164 } 165 } 166 167 if (expected_dim == "?") { 168 if (in_dim_idx.first != -1) { 169 return Unknown(err_prefix, 170 " expected to be an unknown but matched input d", 171 in_dim_idx.first, "_", in_dim_idx.second, 172 ". The shape function passed through ", 173 "a DimensionHandle from an input instead of making ", 174 "a new unknown dimension", err_suffix); 175 } else if (c.ValueKnown(out_dim)) { 176 return Unknown(err_prefix, " expected to be unknown but was ", 177 c.Value(out_dim), err_suffix); 178 } 179 } else if (expected_dim.starts_with("d")) { 180 // Compare the dimension values. 181 auto v = str_util::Split(expected_dim, '|'); 182 if (in_dim_idx.first == -1) { 183 return Unknown( 184 err_prefix, " was expected to match the dimension of an input, ", 185 "but did not match any input dimension. The shape ", 186 "function was expected to pass through a ", 187 "DimensionHandle for an input, but did not", err_suffix); 188 } 189 if (std::find(v.begin(), v.end(), 190 strings::StrCat("d", in_dim_idx.first, "_", 191 in_dim_idx.second)) == v.end()) { 192 return Unknown(err_prefix, " matched input d", in_dim_idx.first, "_", 193 in_dim_idx.second, 194 ", but should have matched one of (", expected_dim, 195 "). The shape function passed through " 196 "the DimensionHandle for an input, but ", 197 "was expected to pass a different one", err_suffix); 198 } 199 } else { 200 // Parse it as a value. 201 int64 value = -1; 202 if (!strings::safe_strto64(expected_dim, &value)) { 203 return Unknown(err_prefix, ": the expected dimension value '", 204 expected_dim, "' failed to parse as int64", 205 err_suffix); 206 } 207 if (in_dim_idx.first != -1) { 208 return Unknown( // 209 err_prefix, " expected to be ", value, " but matched input d", 210 in_dim_idx.first, "_", in_dim_idx.second, 211 ". The shape function was not expected to pass a DimensionHandle " 212 "from the input to the output, but did. Note that even if the " 213 "passed through output has the same dimension value as the " 214 "expected value, this is considered a failure for the test; " 215 "switch to using d#_# syntax if passing through the " 216 "DimensionHandle should be the expected behavior", 217 err_suffix); 218 } else if (value != c.Value(out_dim)) { 219 return Unknown(err_prefix, " expected to be ", value, " but was ", 220 c.DebugString(out_dim), err_suffix); 221 } 222 } 223 } 224 } 225 return Status::OK(); 226 } 227 228 // static 229 Status ShapeInferenceTestutil::MakeShapeFromString( 230 InferenceContext::ShapeManager* manager, const string& spec, 231 ShapeHandle* output) { 232 if (spec == "?") { 233 *output = manager->UnknownShape(); 234 return Status::OK(); 235 } 236 237 std::vector<DimensionHandle> dims; 238 strings::Scanner scanner(spec); 239 scanner.OneLiteral("["); 240 while (scanner.Peek() != ']') { 241 if (scanner.Peek() == '?') { 242 scanner.OneLiteral("?"); 243 dims.push_back(manager->MakeDim(InferenceContext::kUnknownDim)); 244 } else { 245 scanner.RestartCapture().Many(strings::Scanner::DIGIT); 246 StringPiece match; 247 int64 dim_size = 0; 248 249 if (!scanner.GetResult(nullptr, &match) || 250 !strings::safe_strto64(match, &dim_size)) { 251 return errors::InvalidArgument("Could not parse number in ", spec); 252 } 253 254 dims.push_back(manager->MakeDim(dim_size)); 255 } 256 257 if (scanner.Peek() == ',') { 258 scanner.OneLiteral(","); 259 } else if (scanner.Peek() != ']') { 260 return errors::InvalidArgument( 261 "Invalid input spec (] not found in dim shape): ", spec); 262 } 263 } 264 if (!scanner.OneLiteral("]").Eos().GetResult()) { 265 return errors::InvalidArgument("Malformed shape spec: did not end in ']'."); 266 } 267 *output = manager->MakeShape(dims); 268 269 return Status::OK(); 270 } 271 272 } // namespace shape_inference 273 } // namespace tensorflow 274