1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/node_def_util.h" 17 18 #include <algorithm> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/framework/attr_value_util.h" 23 #include "tensorflow/core/framework/graph.pb_text.h" 24 #include "tensorflow/core/framework/node_def.pb.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_def.pb_text.h" 27 #include "tensorflow/core/framework/op_def_util.h" 28 #include "tensorflow/core/framework/tensor.pb_text.h" 29 #include "tensorflow/core/framework/tensor_shape.pb.h" 30 #include "tensorflow/core/graph/graph.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/gtl/map_util.h" 33 #include "tensorflow/core/lib/strings/scanner.h" 34 #include "tensorflow/core/lib/strings/strcat.h" 35 #include "tensorflow/core/platform/protobuf.h" 36 37 namespace tensorflow { 38 39 const char* const kColocationAttrName = "_class"; 40 const char* const kColocationGroupPrefix = "loc:@"; 41 42 AttrSlice::AttrSlice() : ndef_(nullptr) { 43 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap; 44 attrs_ = kEmptyAttrValueMap; 45 } 46 47 AttrSlice::AttrSlice(const NodeDef& node_def) 48 : ndef_(&node_def), attrs_(&ndef_->attr()) {} 49 50 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} 51 52 static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { 53 string ret; 54 55 // We sort the attrs so the output is deterministic. 56 std::vector<string> attr_names; 57 attr_names.reserve(attrs.size()); 58 for (const auto& attr : attrs) { 59 attr_names.push_back(attr.first); 60 } 61 std::sort(attr_names.begin(), attr_names.end()); 62 bool first = true; 63 for (const string& attr_name : attr_names) { 64 if (!first) strings::StrAppend(&ret, ", "); 65 first = false; 66 strings::StrAppend(&ret, attr_name, "=", 67 SummarizeAttrValue(*attrs.Find(attr_name))); 68 } 69 70 // Consider the device to be a final attr with name "_device". 71 if (!device.empty()) { 72 if (!first) strings::StrAppend(&ret, ", "); 73 first = false; 74 strings::StrAppend(&ret, "_device=\"", device, "\""); 75 } 76 return ret; 77 } 78 79 string AttrSlice::SummarizeNode() const { 80 return ndef_ ? SummarizeNodeDef(*ndef_) 81 : strings::StrCat( 82 "[", SummarizeAttrsHelper(*this, StringPiece()), "]"); 83 } 84 85 string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); } 86 87 string SummarizeNodeDef(const NodeDef& node_def) { 88 string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); 89 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); 90 strings::StrAppend(&ret, "]("); 91 92 // Output inputs, including control inputs, verbatim. 93 bool first = true; 94 for (const string& input : node_def.input()) { 95 if (!first) strings::StrAppend(&ret, ", "); 96 first = false; 97 strings::StrAppend(&ret, input); 98 } 99 strings::StrAppend(&ret, ")"); 100 return ret; 101 } 102 103 const AttrValue* AttrSlice::Find(StringPiece attr_name) const { 104 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map) 105 // requires that the keys used for lookups have type 'const string&'. Because 106 // this method takes a StringPiece, it is necessary to allocate a temporary 107 // string, copy attr_name to it, and then use that temporary string for the 108 // lookup. This causes an excessive number of short-lived allocations, and for 109 // large graphs, this can be a significant cost. 110 // 111 // Because most nodes have a small number of attributes, a simple linear scan 112 // is generally more efficient than a hashed lookup. If google::protobuf::Map 113 // changes so that it supports efficient lookups using StringPiece instead of 114 // const string&, then this code could be changed to use attrs_->find() again. 115 116 for (const auto& attr : *attrs_) { 117 if (attr.first == attr_name) { 118 return &attr.second; 119 } 120 } 121 return nullptr; 122 } 123 124 Status AttrSlice::Find(StringPiece attr_name, 125 const AttrValue** attr_value) const { 126 *attr_value = Find(attr_name); 127 if (*attr_value != nullptr) { 128 return Status::OK(); 129 } 130 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); 131 // Skip AttachDef for internal attrs since it is a little bit 132 // expensive and it is common for them to correctly not be included 133 // in a NodeDef. 134 if (!attr_name.starts_with("_") && ndef_ != nullptr) { 135 s = AttachDef(s, *ndef_); 136 } 137 return s; 138 } 139 140 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { 141 if (size() != other.size()) return false; 142 143 for (const auto& attr : *other.attrs_) { 144 auto iter = attrs_->find(attr.first); 145 if (iter == attrs_->end()) return false; 146 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since 147 // TensorProto is a nonunique representation of Tensor. This bug will go 148 // away once AttrSlice switches over to NodeInfo. 149 iter->second.SerializeToString(&scratch->a); 150 attr.second.SerializeToString(&scratch->b); 151 if (scratch->a != scratch->b) return false; 152 } 153 return true; 154 } 155 156 // The ... is to allow the caller to inject some value validation code. Use 157 // just ; if no additional validation code is needed. 158 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ 159 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ 160 TYPE* value) { \ 161 const AttrValue* attr_value; \ 162 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ 163 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \ 164 const auto& v = attr_value->FIELD(); \ 165 __VA_ARGS__; \ 166 *value = CAST; \ 167 return Status::OK(); \ 168 } \ 169 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ 170 std::vector<TYPE>* value) { \ 171 const AttrValue* attr_value; \ 172 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ 173 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \ 174 for (const auto& v : attr_value->list().FIELD()) { \ 175 __VA_ARGS__; \ 176 value->APPEND_OP(CAST); \ 177 } \ 178 return Status::OK(); \ 179 } 180 181 #define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ 182 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \ 183 TYPE* value) { \ 184 const AttrValue* attr_value = attrs.Find(attr_name); \ 185 if (attr_value == nullptr) { \ 186 return false; \ 187 } \ 188 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \ 189 if (!s.ok()) { \ 190 return false; \ 191 } \ 192 const auto& v = attr_value->FIELD(); \ 193 __VA_ARGS__; \ 194 *value = CAST; \ 195 return true; \ 196 } \ 197 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \ 198 std::vector<TYPE>* value) { \ 199 const AttrValue* attr_value = attrs.Find(attr_name); \ 200 if (attr_value == nullptr) { \ 201 return false; \ 202 } \ 203 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \ 204 if (!s.ok()) { \ 205 return false; \ 206 } \ 207 for (const auto& v : attr_value->list().FIELD()) { \ 208 __VA_ARGS__; \ 209 value->APPEND_OP(CAST); \ 210 } \ 211 return true; \ 212 } 213 214 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;) 215 DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;) 216 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;) 217 DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v), 218 if (static_cast<int64>(static_cast<int32>(v)) != v) { 219 return errors::InvalidArgument("Attr ", attr_name, 220 " has value ", v, 221 " out of range for an int32"); 222 }) 223 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;) 224 // std::vector<bool> specialization does not have emplace_back until 225 // c++14, so we have to use push_back (see 226 // http://en.cppreference.com/w/cpp/container/vector/emplace_back) 227 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;) 228 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v), 229 ;) 230 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;) 231 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), 232 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));) 233 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back, 234 PartialTensorShape(v), 235 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));) 236 DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; 237 if (!t.FromProto(v)) { 238 return errors::InvalidArgument( 239 "Attr ", attr_name, " has value ", 240 ProtoShortDebugString(v), 241 " that can't be converted to a Tensor"); 242 }) 243 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); 244 #undef DEFINE_GET_ATTR 245 246 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { 247 return node_def.attr().find(attr_name.ToString()) != node_def.attr().end(); 248 } 249 250 static const string& kEmptyString = *new string(); 251 252 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { 253 const AttrValue* attr_value = attrs.Find(attr_name); 254 if (attr_value == nullptr) { 255 return kEmptyString; 256 } 257 Status s = AttrValueHasType(*attr_value, "string"); 258 if (!s.ok()) { 259 return kEmptyString; 260 } 261 return attr_value->s(); 262 } 263 264 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, 265 DataTypeVector* value) { 266 const AttrValue* attr_value; 267 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); 268 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)")); 269 for (const auto& v : attr_value->list().type()) { 270 value->push_back(static_cast<DataType>(v)); 271 } 272 return Status::OK(); 273 } 274 275 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, 276 const TensorProto** value) { 277 const AttrValue* attr_value; 278 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); 279 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); 280 *value = &attr_value->tensor(); 281 return Status::OK(); 282 } 283 284 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, 285 const NameAttrList** value) { 286 const AttrValue* attr_value; 287 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); 288 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); 289 *value = &attr_value->func(); 290 return Status::OK(); 291 } 292 293 namespace { // Helper for InOutTypesForNode(). 294 295 Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def, 296 DataTypeVector* sig) { 297 const int original_size = sig->size(); 298 if (!arg_def.number_attr().empty()) { 299 // Same type repeated "repeats" times. 300 int32 repeats = -1; 301 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats)); 302 if (repeats < 0) { 303 return errors::InvalidArgument("Value for number_attr() ", repeats, 304 " < 0"); 305 } 306 307 if (!arg_def.type_attr().empty()) { 308 DataType dtype; 309 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype)); 310 for (int i = 0; i < repeats; ++i) { 311 sig->push_back(dtype); 312 } 313 } else if (arg_def.type() != DT_INVALID) { 314 for (int i = 0; i < repeats; ++i) { 315 sig->push_back(arg_def.type()); 316 } 317 } else { 318 return errors::InvalidArgument("Missing type or type_attr field in ", 319 ProtoShortDebugString(arg_def)); 320 } 321 } else if (!arg_def.type_attr().empty()) { 322 const AttrValue* attr_value; 323 TF_RETURN_IF_ERROR( 324 AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value)); 325 sig->push_back(attr_value->type()); 326 } else if (!arg_def.type_list_attr().empty()) { 327 const AttrValue* attr_value; 328 TF_RETURN_IF_ERROR( 329 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); 330 for (int dtype : attr_value->list().type()) { 331 sig->push_back(static_cast<DataType>(dtype)); 332 } 333 } else if (arg_def.type() != DT_INVALID) { 334 sig->push_back(arg_def.type()); 335 } else { 336 return errors::InvalidArgument("No type fields in ", 337 ProtoShortDebugString(arg_def)); 338 } 339 if (arg_def.is_ref()) { 340 // For all types that were added by this function call, make them refs. 341 for (size_t i = original_size; i < sig->size(); ++i) { 342 (*sig)[i] = MakeRefType((*sig)[i]); 343 } 344 } 345 return Status::OK(); 346 } 347 348 } // namespace 349 350 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, 351 int input_port, DataType* input_type) { 352 DataTypeVector input_types; 353 for (const auto& arg : op_def.input_arg()) { 354 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types)); 355 if (input_types.size() > input_port) { 356 const DataType dtype = input_types[input_port]; 357 *input_type = dtype; 358 return Status::OK(); 359 } 360 } 361 return errors::InvalidArgument("Input ", input_port, " not found for node ", 362 node_def.name()); 363 } 364 365 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, 366 int output_port, DataType* output_type) { 367 DataTypeVector output_types; 368 for (const auto& arg : op_def.output_arg()) { 369 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types)); 370 if (output_types.size() > output_port) { 371 const DataType dtype = output_types[output_port]; 372 *output_type = dtype; 373 return Status::OK(); 374 } 375 } 376 return errors::InvalidArgument("Output ", output_port, " not found for node ", 377 node_def.name()); 378 } 379 380 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, 381 DataTypeVector* inputs, DataTypeVector* outputs) { 382 for (const auto& arg : op_def.input_arg()) { 383 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); 384 } 385 for (const auto& arg : op_def.output_arg()) { 386 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); 387 } 388 return Status::OK(); 389 } 390 391 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { 392 if (node_def.op() != op_def.name()) { 393 return errors::InvalidArgument("NodeDef op '", node_def.op(), 394 "' does not match ", SummarizeOpDef(op_def), 395 "; NodeDef: ", SummarizeNodeDef(node_def)); 396 } 397 398 bool seen_control = false; 399 size_t num_inputs = 0; 400 // TODO(josh11b): Unify the input field validation. 401 for (const string& input : node_def.input()) { 402 if (StringPiece(input).starts_with("^")) { 403 seen_control = true; 404 if (input.find(':') != string::npos) { 405 return errors::InvalidArgument( 406 "Control input '", input, 407 "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def)); 408 } 409 } else if (seen_control) { 410 return errors::InvalidArgument( 411 "Non-control input '", input, 412 "' after control input in NodeDef: ", SummarizeNodeDef(node_def)); 413 } else { 414 ++num_inputs; 415 } 416 } 417 418 std::unordered_map<string, const OpDef::AttrDef*> op_attrs; 419 for (const auto& attr : op_def.attr()) { 420 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { 421 return errors::InvalidArgument("OpDef has duplicate attr name '", 422 attr.name(), 423 "': ", SummarizeOpDef(op_def)); 424 } 425 } 426 for (const auto& attr : node_def.attr()) { 427 // Allow internal optional attributes with names starting with "_". 428 if (StringPiece(attr.first).starts_with("_")) { 429 continue; 430 } 431 auto iter = op_attrs.find(attr.first); 432 if (iter == op_attrs.end()) { 433 // A common cause of this error is that TensorFlow has made a 434 // backwards-compatible change to the NodeDef (e.g., adding a 435 // new attr with a default value), but the binary consuming the 436 // NodeDef does not know about the new attribute; the solution 437 // in these cases is to ensure that the binary consuming the 438 // NodeDef is built with a version of TensorFlow no earlier than 439 // the binary producing it. 440 return errors::InvalidArgument( 441 "NodeDef mentions attr '", attr.first, "' not in ", 442 SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def), 443 ". (Check whether your GraphDef-interpreting binary is up to date " 444 "with your GraphDef-generating binary.)."); 445 } 446 TF_RETURN_WITH_CONTEXT_IF_ERROR( 447 ValidateAttrValue(attr.second, *iter->second), 448 "; NodeDef: ", SummarizeNodeDef(node_def), "; ", 449 SummarizeOpDef(op_def)); 450 // Keep track of which attr names have (not) been found in the NodeDef. 451 op_attrs.erase(iter); 452 } 453 454 // Were all attrs in the OpDef found in the NodeDef? 455 if (!op_attrs.empty()) { 456 string attrs; 457 for (const auto& attr_pair : op_attrs) { 458 if (!attrs.empty()) strings::StrAppend(&attrs, "', '"); 459 strings::StrAppend(&attrs, attr_pair.first); 460 } 461 return errors::InvalidArgument("NodeDef missing attr", 462 op_attrs.size() == 1 ? " '" : "s '", attrs, 463 "' from ", SummarizeOpDef(op_def), 464 "; NodeDef: ", SummarizeNodeDef(node_def)); 465 } 466 467 // Validate the number of inputs. 468 DataTypeVector inputs, outputs; 469 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs)); 470 471 if (num_inputs != inputs.size()) { 472 return errors::InvalidArgument( 473 "NodeDef expected inputs '", DataTypeVectorString(inputs), 474 "' do not match ", num_inputs, " inputs specified; ", 475 SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def)); 476 } 477 478 return Status::OK(); 479 } 480 481 namespace { // Helpers for NameRangesForNode() 482 483 Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, 484 const OpDef& op_def, int* num) { 485 if (!arg_def.number_attr().empty()) { 486 // Same type repeated "num" times. 487 return GetNodeAttr(node_def, arg_def.number_attr(), num); 488 } else if (!arg_def.type_list_attr().empty()) { 489 const AttrValue* attr_value; 490 TF_RETURN_IF_ERROR( 491 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); 492 *num = attr_value->list().type_size(); 493 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { 494 *num = 1; 495 } else { 496 return errors::InvalidArgument( 497 "Argument '", arg_def.name(), 498 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); 499 } 500 return Status::OK(); 501 } 502 503 Status NameRangesHelper(const NodeDef& node_def, 504 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, 505 const OpDef& op_def, NameRangeMap* result) { 506 int start = 0; 507 int num; 508 for (const auto& arg : args) { 509 TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); 510 (*result)[arg.name()] = std::make_pair(start, start + num); 511 start += num; 512 } 513 return Status::OK(); 514 } 515 516 } // namespace 517 518 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, 519 NameRangeMap* inputs, NameRangeMap* outputs) { 520 if (inputs != nullptr) { 521 TF_RETURN_IF_ERROR( 522 NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs)); 523 } 524 if (outputs != nullptr) { 525 return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); 526 } 527 return Status::OK(); 528 } 529 530 Status NameRangesForNode(const Node& node, const OpDef& op_def, 531 NameRangeMap* inputs, NameRangeMap* outputs) { 532 return NameRangesForNode(node.def(), op_def, inputs, outputs); 533 } 534 535 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { 536 for (const auto& attr_def : op_def.attr()) { 537 AttrSlice attrs(*node_def); 538 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) { 539 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def); 540 } 541 } 542 } 543 544 namespace { 545 546 using ::tensorflow::strings::Scanner; 547 548 bool IsValidOpName(StringPiece sp) { 549 return Scanner(sp) 550 .One(Scanner::LETTER_DIGIT_DOT) 551 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 552 .Eos() 553 .GetResult(); 554 } 555 556 bool IsValidDataInputName(StringPiece sp) { 557 // Data inputs are op_name, op_name:0, or op_name:12345. 558 Scanner scan(sp); 559 scan.One(Scanner::LETTER_DIGIT_DOT) 560 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); 561 if (scan.Peek() == ':') { 562 scan.OneLiteral(":"); 563 if (scan.Peek() == '0') { 564 scan.OneLiteral("0"); // :0 565 } else { 566 scan.Many(Scanner::DIGIT); // :[1-9][0-9]* 567 } 568 } 569 scan.Eos(); 570 571 return scan.GetResult(); 572 } 573 574 bool IsValidControlInputName(StringPiece sp) { 575 return Scanner(sp) 576 .OneLiteral("^") 577 .One(Scanner::LETTER_DIGIT_DOT) 578 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 579 .Eos() 580 .GetResult(); 581 } 582 583 } // namespace 584 585 Status ValidateOpInput(const string& input_name, bool* is_control_input) { 586 *is_control_input = false; 587 if (IsValidDataInputName(input_name)) { 588 return Status::OK(); 589 } else if (IsValidControlInputName(input_name)) { 590 *is_control_input = true; 591 return Status::OK(); 592 } else { 593 return errors::InvalidArgument("Illegal op input name '", input_name, "'"); 594 } 595 } 596 597 Status ValidateOpName(const string& op_name) { 598 if (IsValidOpName(op_name)) { 599 return Status::OK(); 600 } else { 601 return errors::InvalidArgument("Illegal op name '", op_name, "'"); 602 } 603 } 604 605 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { 606 Status s = ValidateOpName(node_def.name()); 607 if (!s.ok()) { 608 return AttachDef(s, node_def); 609 } 610 bool in_control_inputs = false; 611 for (const string& input_name : node_def.input()) { 612 bool is_control_input; 613 s = ValidateOpInput(input_name, &is_control_input); 614 if (!s.ok()) { 615 return AttachDef(s, node_def); 616 } 617 618 if (in_control_inputs && !is_control_input) { 619 return AttachDef(errors::InvalidArgument( 620 "All control inputs must follow all data inputs"), 621 node_def); 622 } 623 in_control_inputs = is_control_input; 624 } 625 return Status::OK(); 626 } 627 628 Status AttachDef(const Status& status, const NodeDef& node_def) { 629 Status ret = status; 630 errors::AppendToMessage( 631 &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]")); 632 return ret; 633 } 634 635 Status AttachDef(const Status& status, const Node& node) { 636 return AttachDef(status, node.def()); 637 } 638 639 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { 640 node_def->mutable_attr()->insert( 641 AttrValueMap::value_type(name.ToString(), value)); 642 } 643 644 #define ADD_NODE_ATTR(T) \ 645 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ 646 AttrValue attr_value; \ 647 SetAttrValue(value, &attr_value); \ 648 AddNodeAttr(name, attr_value, node_def); \ 649 } 650 ADD_NODE_ATTR(StringPiece) 651 ADD_NODE_ATTR(const char*) 652 ADD_NODE_ATTR(int32) 653 ADD_NODE_ATTR(int64) 654 ADD_NODE_ATTR(float) 655 ADD_NODE_ATTR(double) 656 ADD_NODE_ATTR(bool) 657 ADD_NODE_ATTR(DataType) 658 ADD_NODE_ATTR(const PartialTensorShape&) 659 ADD_NODE_ATTR(const Tensor&) 660 ADD_NODE_ATTR(const TensorProto&) 661 ADD_NODE_ATTR(const NameAttrList&) 662 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>) 663 ADD_NODE_ATTR(gtl::ArraySlice<const char*>) 664 ADD_NODE_ATTR(gtl::ArraySlice<string>) 665 ADD_NODE_ATTR(gtl::ArraySlice<int32>) 666 ADD_NODE_ATTR(gtl::ArraySlice<int64>) 667 ADD_NODE_ATTR(gtl::ArraySlice<float>) 668 ADD_NODE_ATTR(gtl::ArraySlice<bool>) 669 ADD_NODE_ATTR(const std::vector<bool>&) 670 ADD_NODE_ATTR(gtl::ArraySlice<DataType>) 671 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>) 672 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>) 673 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>) 674 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>) 675 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>) 676 #undef ADD_NODE_ATTR 677 678 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { 679 map->insert(AttrValueMap::value_type(name.ToString(), value)); 680 } 681 682 #define ADD_ATTR(T) \ 683 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \ 684 AttrValue attr_value; \ 685 SetAttrValue(value, &attr_value); \ 686 AddAttr(name, attr_value, map); \ 687 } 688 ADD_ATTR(bool) 689 #undef ADD_ATTR 690 691 } // namespace tensorflow 692