Home | History | Annotate | Download | only in toco

Lines Matching refs:node

66 bool HasAttr(const NodeDef& node, const string& attr_name) {
67 return node.attr().count(attr_name) > 0;
70 const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
71 CHECK(HasAttr(node, attr_name));
72 const auto& attr = node.attr().at(attr_name);
77 int GetIntAttr(const NodeDef& node, const string& attr_name) {
78 CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
79 << node.DebugString();
80 const auto& attr = node.attr().at(attr_name);
85 float GetFloatAttr(const NodeDef& node, const string& attr_name) {
86 CHECK(HasAttr(node, attr_name));
87 const auto& attr = node.attr().at(attr_name);
92 bool GetBoolAttr(const NodeDef& node, const string& attr_name) {
93 CHECK(HasAttr(node, attr_name));
94 const auto& attr = node.attr().at(attr_name);
99 tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
101 CHECK(HasAttr(node, attr_name));
102 const auto& attr = node.attr().at(attr_name);
107 const TensorShapeProto& GetShapeAttr(const NodeDef& node,
109 CHECK(HasAttr(node, attr_name));
110 const auto& attr = node.attr().at(attr_name);
115 const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) {
116 CHECK(HasAttr(node, attr_name));
117 const auto& attr = node.attr().at(attr_name);
122 const AttrValue::ListValue& GetListAttr(const NodeDef& node,
124 CHECK(HasAttr(node, attr_name));
125 const auto& attr = node.attr().at(attr_name);
296 // Count the number of inputs of a given node. If
299 int GetInputsCount(const NodeDef& node,
302 for (size_t i = 0; i < node.input_size(); ++i) {
303 if (node.input(i)[0] == '^') {
307 return node.input_size();
309 return node.input_size();
313 void CheckInputsCount(const NodeDef& node,
316 QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count)
317 << node.op() << " node expects " << expected_input_count
318 << " input(s) other than control dependencies: " << node.DebugString();
321 void ConvertConstOperator(const NodeDef& node,
324 CHECK_EQ(node.op(), "Const");
325 const auto& tensor = GetTensorAttr(node, "value");
326 const auto dtype = GetDataTypeAttr(node, "dtype");
328 auto& array = model->GetOrCreateArray(node.name());
360 void ConvertConvOperator(const NodeDef& node,
363 CHECK_EQ(node.op(), "Conv2D");
364 CheckInputsCount(node, tf_import_flags, 2);
368 if (node.attr().count("data_format")) {
369 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
371 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
373 const auto& input_name = node.input(0);
374 const auto& weights_name = node.input(1);
395 conv->outputs = {node.name()};
396 const auto& strides = GetListAttr(node, "strides");
402 const auto& padding = GetStringAttr(node, "padding");
413 void ConvertDepthwiseConvOperator(const NodeDef& node,
416 CHECK_EQ(node.op(), "DepthwiseConv2dNative");
417 CheckInputsCount(node, tf_import_flags, 2);
421 if (node.attr().count("data_format")) {
422 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
424 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
426 node.input(0);
427 const auto& weights_name = node.input(1);
448 conv->outputs = {node.name()};
449 const auto& strides = GetListAttr(node, "strides");
455 const auto& padding = GetStringAttr(node, "padding");
466 void ConvertDepthToSpaceOperator(const NodeDef& node,
469 CHECK_EQ(node.op(), "DepthToSpace");
470 CheckInputsCount(node, tf_import_flags, 1);
472 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
474 op->inputs.push_back(node.input(0));
475 op->outputs.push_back(node.name());
476 op->block_size = GetIntAttr(node, "block_size");
481 void ConvertSpaceToDepthOperator(const NodeDef& node,
484 CHECK_EQ(node.op(), "SpaceToDepth");
485 CheckInputsCount(node, tf_import_flags, 1);
487 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
489 op->inputs.push_back(node.input(0));
490 op->outputs.push_back(node.name());
491 op->block_size = GetIntAttr(node, "block_size");
496 void ConvertBiasAddOperator(const NodeDef& node,
499 CHECK_EQ(node.op(), "BiasAdd");
500 CheckInputsCount(node, tf_import_flags, 2);
502 const auto& input_name = node.input(0);
503 const auto& bias_name = node.input(1);
504 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
508 biasadd->outputs.push_back(node.name());
512 void ConvertReluOperator(const NodeDef& node,
515 CHECK_EQ(node.op(), "Relu");
516 CheckInputsCount(node, tf_import_flags, 1);
517 const auto& input_name = node.input(0);
520 relu->outputs.push_back(node.name());
524 void ConvertRelu6Operator(const NodeDef& node,
527 CHECK_EQ(node.op(), "Relu6");
528 CheckInputsCount(node, tf_import_flags, 1);
530 const auto& input_name = node.input(0);
533 op->outputs.push_back(node.name());
537 void ConvertLogisticOperator(const NodeDef& node,
540 CHECK_EQ(node.op(), "Sigmoid");
541 CheckInputsCount(node, tf_import_flags, 1);
543 const auto& input_name = node.input(0);
546 op->outputs.push_back(node.name());
550 void ConvertTanhOperator(const NodeDef& node,
553 CHECK_EQ(node.op(), "Tanh");
554 CheckInputsCount(node, tf_import_flags, 1);
556 const auto& input_name = node.input(0);
559 op->outputs.push_back(node.name());
563 void ConvertDivOperator(const NodeDef& node,
566 CHECK(node.op() == "Div" || node.op() == "RealDiv");
567 CheckInputsCount(node, tf_import_flags, 2);
569 op->inputs.push_back(node.input(0));
570 op->inputs.push_back(node.input(1));
571 op->outputs.push_back(node.name());
575 void ConvertIdentityOperator(const NodeDef& node,
578 CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
579 node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
586 QCHECK_GE(node.input_size(), 1)
587 << node.op()
588 << " node expects at least 1 input other than control dependencies: "
589 << node.DebugString();
590 const auto& input_name = node.input(0);
592 op->outputs.push_back(node.name());
597 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
599 CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
600 CheckInputsCount(node, tf_import_flags, 1);
602 op->inputs.push_back(node.input(0));
605 minmax.min = GetFloatAttr(node, "min");
606 minmax.max = GetFloatAttr(node, "max");
607 op->outputs.push_back(node.name());
612 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
614 CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
615 const int num_inputs = GetInputsCount(node, tf_import_flags);
617 << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
619 << node.DebugString();
622 op->inputs.push_back(node.input(i));
624 op->outputs.push_back(node.name());
628 void ConvertNegOperator(const NodeDef& node,
631 CHECK_EQ(node.op(), "Neg");
632 CheckInputsCount(node, tf_import_flags, 1);
634 op->inputs.push_back(node.input(0));
635 op->outputs.push_back(node.name());
639 void ConvertRsqrtOperator(const NodeDef& node,
642 CHECK_EQ(node.op(), "Rsqrt");
643 CheckInputsCount(node, tf_import_flags, 1);
645 op->inputs.push_back(node.input(0));
646 op->outputs.push_back(node.name());
650 void ConvertSqrtOperator(const NodeDef& node,
653 CHECK_EQ(node.op(), "Sqrt");
654 CheckInputsCount(node, tf_import_flags, 1);
656 op->inputs.push_back(node.input(0));
657 op->outputs.push_back(node.name());
661 void ConvertSqueezeOperator(const NodeDef& node,
664 CHECK_EQ(node.op(), "Squeeze");
665 CheckInputsCount(node, tf_import_flags, 1);
667 op->inputs.push_back(node.input(0));
668 op->outputs.push_back(node.name());
670 const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
678 void ConvertSquareOperator(const NodeDef& node,
681 CHECK_EQ(node.op(), "Square");
682 CheckInputsCount(node, tf_import_flags, 1);
684 op->inputs.push_back(node.input(0));
685 op->outputs.push_back(node.name());
689 void ConvertAddOperator(const NodeDef& node,
692 CHECK_EQ(node.op(), "Add");
693 CheckInputsCount(node, tf_import_flags, 2);
695 op->inputs.push_back(node.input(0));
696 op->inputs.push_back(node.input(1));
697 op->outputs.push_back(node.name());
701 void ConvertAddNOperator(const NodeDef& node,
704 CHECK_EQ(node.op(), "AddN");
705 const int num_inputs = GetInputsCount(node, tf_import_flags);
708 op->inputs.push_back(node.input(i));
710 op->outputs.push_back(node.name());
714 void ConvertMulOperator(const NodeDef& node,
717 CHECK_EQ(node.op(), "Mul");
718 CheckInputsCount(node, tf_import_flags, 2);
720 op->inputs.push_back(node.input(0));
721 op->inputs.push_back(node.input(1));
722 op->outputs.push_back(node.name());
726 void ConvertSubOperator(const NodeDef& node,
729 CHECK_EQ(node.op(), "Sub");
730 CheckInputsCount(node, tf_import_flags, 2);
732 op->inputs.push_back(node.input(0));
733 op->inputs.push_back(node.input(1));
734 op->outputs.push_back(node.name());
738 void ConvertSumOperator(const NodeDef& node,
741 CHECK_EQ(node.op(), "Sum");
742 CheckInputsCount(node, tf_import_flags, 2);
744 op->inputs.push_back(node.input(0));
745 op->inputs.push_back(node.input(1));
746 op->outputs.push_back(node.name());
748 if (HasAttr(node, "keep_dims")) {
749 op->keep_dims = GetBoolAttr(node, "keep_dims");
753 void ConvertTileOperator(const NodeDef& node,
756 CHECK_EQ(node.op(), "Tile");
757 CheckInputsCount(node, tf_import_flags, 2);
759 op->inputs.push_back(node.input(0));
760 op->inputs.push_back(node.input(1));
761 op->outputs.push_back(node.name());
765 void ConvertSliceOperator(const NodeDef& node,
768 CHECK_EQ(node.op(), "Slice");
769 CheckInputsCount(node, tf_import_flags, 3);
772 op->inputs.push_back(node.input(i));
774 op->outputs.push_back(node.name());
778 void ConvertPadOperator(const NodeDef& node,
781 CHECK_EQ(node.op(), "Pad");
782 CheckInputsCount(node, tf_import_flags, 2);
784 op->inputs.push_back(node.input(0));
785 op->inputs.push_back(node.input(1));
786 op->outputs.push_back(node.name());
790 void ConvertShapeOperator(const NodeDef& node,
793 CHECK_EQ(node.op(), "Shape");
794 CheckInputsCount(node, tf_import_flags, 1);
796 op->inputs.push_back(node.input(0));
797 op->outputs.push_back(node.name());
801 void ConvertSplitOperator(const NodeDef& node,
804 CHECK_EQ(node.op(), "Split");
805 CheckInputsCount(node, tf_import_flags, 2);
807 op->inputs.push_back(node.input(0));
808 op->inputs.push_back(node.input(1));
809 const int num_split = GetIntAttr(node, "num_split");
810 op->outputs.push_back(node.name());
812 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
818 void ConvertMergeOperator(const NodeDef& node,
821 CHECK_EQ(node.op(), "Merge");
822 CheckInputsCount(node, tf_import_flags, 2);
824 op->inputs.push_back(node.input(0));
825 op->inputs.push_back(node.input(1));
826 op->outputs.push_back(node.name());
830 void ConvertSwitchOperator(const NodeDef& node,
833 CHECK_EQ(node.op(), "Switch");
834 CheckInputsCount(node, tf_import_flags, 2);
836 op->inputs.push_back(node.input(0));
837 op->inputs.push_back(node.input(1));
838 op->outputs.push_back(node.name());
840 op->outputs.push_back(node.name() + ":1");
844 void ConvertSoftmaxOperator(const NodeDef& node,
847 CHECK_EQ(node.op(), "Softmax");
848 CheckInputsCount(node, tf_import_flags, 1);
849 const auto& input_name = node.input(0);
852 softmax->outputs.push_back(node.name());
854 CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
859 void ConvertLogSoftmaxOperator(const NodeDef& node,
862 CHECK_EQ(node.op(), "LogSoftmax");
863 CheckInputsCount(node, tf_import_flags, 1);
864 const auto& input_name = node.input(0);
867 log_softmax->outputs.push_back(node.name());
871 void ConvertLRNOperator(const NodeDef& node,
874 CHECK_EQ(node.op(), "LRN");
875 CheckInputsCount(node, tf_import_flags, 1);
876 const auto& input_name = node.input(0);
879 lrn->outputs.push_back(node.name());
880 lrn->range = GetIntAttr(node, "depth_radius");
881 lrn->bias = GetFloatAttr(node, "bias");
882 lrn->alpha = GetFloatAttr(node, "alpha");
883 lrn->beta = GetFloatAttr(node, "beta");
887 void ConvertMaxPoolOperator(const NodeDef& node,
890 CHECK_EQ(node.op(), "MaxPool");
891 CheckInputsCount(node, tf_import_flags, 1);
892 const auto& input_name = node.input(0);
895 if (node.attr().count("data_format")) {
896 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
898 if (HasAttr(node, "T")) {
899 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
905 maxpool->outputs.push_back(node.name());
906 const auto& strides = GetListAttr(node, "strides");
912 const auto& ksize = GetListAttr(node, "ksize");
918 const auto& padding = GetStringAttr(node, "padding");
929 void ConvertAvgPoolOperator(const NodeDef& node,
932 CHECK_EQ(node.op(), "AvgPool");
933 CheckInputsCount(node, tf_import_flags, 1);
934 const auto& input_name = node.input(0);
937 if (node.attr().count("data_format")) {
938 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
940 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
943 avgpool->outputs.push_back(node.name());
944 const auto& strides = GetListAttr(node, "strides");
950 const auto& ksize = GetListAttr(node, "ksize");
956 const auto& padding = GetStringAttr(node, "padding");
967 void ConvertReshapeOperator(const NodeDef& node,
970 CHECK_EQ(node.op(), "Reshape");
971 CheckInputsCount(node, tf_import_flags, 2);
973 op->inputs.push_back(node.input(0));
974 op->inputs.push_back(node.input(1));
975 op->outputs.push_back(node.name());
979 void ConvertBatchMatMulOperator(const NodeDef& node,
982 CheckInputsCount(node, tf_import_flags, 2);
985 CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false));
986 CHECK(!HasAttr(node, "adj_b") || (GetBoolAttr(node, "adj_b") == false));
989 batch_matmul->inputs = {node.input(0), node.input(1)};
990 batch_matmul->outputs = {node.name()};
994 void ConvertMatMulOperator(const NodeDef& node,
997 CheckInputsCount(node, tf_import_flags, 2);
1001 CHECK_EQ(GetBoolAttr(node, "transpose_a"), false);
1002 CHECK_EQ(GetBoolAttr(node, "transpose_b"), false);
1003 CHECK(!HasAttr(node, "adjoint_a") ||
1004 (GetBoolAttr(node, "adjoint_a") == false));
1005 CHECK(!HasAttr(node, "adjoint_b") ||
1006 (GetBoolAttr(node, "adjoint_b") == false));
1009 matmul->inputs = {node.input(0), node.input(1)};
1010 matmul->outputs = {node.name()};
1014 void ConvertConcatOperator(const NodeDef& node,
1018 if (node.op() == "Concat") {
1020 } else if (node.op() == "ConcatV2") {
1025 const int num_inputs = GetInputsCount(node, tf_import_flags);
1027 << node.op()
1028 << " node expects at least 2 inputs other than control dependencies: "
1029 << node.DebugString();
1030 CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1032 op->inputs.push_back(node.input(i));
1034 op->outputs.push_back(node.name());
1038 void ConvertAllOperator(const NodeDef& node,
1041 CHECK_EQ(node.op(), "All");
1043 const int num_inputs = GetInputsCount(node, tf_import_flags);
1045 op->inputs.push_back(node.input(i));
1047 op->outputs.push_back(node.name());
1051 void ConvertAssertOperator(const NodeDef& node,
1054 CHECK_EQ(node.op(), "Assert");
1056 const int num_inputs = GetInputsCount(node, tf_import_flags);
1058 op->inputs.push_back(node.input(i));
1060 op->outputs.push_back(node.name());
1064 void ConvertLessOperator(const NodeDef& node,
1067 CHECK_EQ(node.op(), "Less");
1069 const int num_inputs = GetInputsCount(node, tf_import_flags);
1071 op->inputs.push_back(node.input(i));
1073 op->outputs.push_back(node.name());
1077 void ConvertLessEqualOperator(const NodeDef& node,
1080 CHECK_EQ(node.op(), "LessEqual");
1082 const int num_inputs = GetInputsCount(node, tf_import_flags);
1084 op->inputs.push_back(node.input(i));
1086 op->outputs.push_back(node.name());
1090 void ConvertGreaterOperator(const NodeDef& node,
1093 CHECK_EQ(node.op(), "Greater");
1095 const int num_inputs = GetInputsCount(node, tf_import_flags);
1097 op->inputs.push_back(node.input(i));
1099 op->outputs.push_back(node.name());
1103 void ConvertGreaterEqualOperator(const NodeDef& node,
1106 CHECK_EQ(node.op(), "GreaterEqual");
1108 const int num_inputs = GetInputsCount(node, tf_import_flags);
1110 op->inputs.push_back(node.input(i));
1112 op->outputs.push_back(node.name());
1116 void ConvertMaxOperator(const NodeDef& node,
1119 CHECK_EQ(node.op(), "Max");
1120 CheckInputsCount(node, tf_import_flags, 2);
1122 op->inputs.push_back(node.input(0));
1123 op->inputs.push_back(node.input(1));
1124 op->outputs.push_back(node.name());
1126 if (HasAttr(node, "keep_dims")) {
1127 op->keep_dims = GetBoolAttr(node, "keep_dims");
1131 void ConvertMinOperator(const NodeDef& node,
1134 CHECK_EQ(node.op(), "Min");
1135 CheckInputsCount(node, tf_import_flags, 2);
1137 op->inputs.push_back(node.input(0));
1138 op->inputs.push_back(node.input(1));
1139 op->outputs.push_back(node.name());
1141 if (HasAttr(node, "keep_dims")) {
1142 op->keep_dims = GetBoolAttr(node, "keep_dims");
1146 void ConvertMaximumOperator(const NodeDef& node,
1149 CHECK_EQ(node.op(), "Maximum");
1150 CheckInputsCount(node, tf_import_flags, 2);
1152 op->inputs.push_back(node.input(0));
1153 op->inputs.push_back(node.input(1));
1154 op->outputs.push_back(node.name());
1158 void ConvertMinimumOperator(const NodeDef& node,
1161 CHECK_EQ(node.op(), "Minimum");
1162 CheckInputsCount(node, tf_import_flags, 2);
1164 op->inputs.push_back(node.input(0));
1165 op->inputs.push_back(node.input(1));
1166 op->outputs.push_back(node.name());
1170 void ConvertUnsupportedOperator(const NodeDef& node,
1173 LOG(INFO) << "Converting unsupported operation: " << node.op();
1175 const int num_inputs = GetInputsCount(node, tf_import_flags);
1177 op->inputs.push_back(node.input(i));
1179 op->outputs.push_back(node.name());
1180 op->tensorflow_op = node.op();
1181 node.SerializeToString(&op->tensorflow_node_def);
1183 if (HasAttr(node, "_output_quantized")) {
1184 op->quantized = GetBoolAttr(node, "_output_quantized");
1186 if (HasAttr(node, "_output_types")) {
1187 const auto& output_types = GetListAttr(node, "_output_types");
1194 void ConvertStridedSliceOperator(const NodeDef& node,
1197 CHECK_EQ(node.op(), "StridedSlice");
1200 CheckInputsCount(node, tf_import_flags, 4);
1203 for (const auto& input : node.input()) {
1206 op->outputs.push_back(node.name());
1208 op->begin_mask = GetIntAttr(node, "begin_mask");
1209 op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask");
1210 op->end_mask = GetIntAttr(node, "end_mask");
1211 op->new_axis_mask = GetIntAttr(node, "new_axis_mask");
1212 op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask");
1216 void ConvertPlaceholderOperator(const NodeDef& node,
1219 CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1220 if (node.op() == "Placeholder") {
1221 CheckInputsCount(node, tf_import_flags, 0);
1223 auto& array = model->GetOrCreateArray(node.name());
1224 if (node.attr().count("dtype")) {
1225 array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1227 if (node.attr().count("shape")) {
1228 const auto& shape = GetShapeAttr(node, "shape");
1248 void ConvertNoOpOperator(const NodeDef& node,
1252 void ConvertCastOperator(const NodeDef& node,
1255 CHECK_EQ(node.op(), "Cast");
1256 CheckInputsCount(node, tf_import_flags, 1);
1257 const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1258 const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1262 op->inputs.push_back(node.input(0));
1263 op->outputs.push_back(node.name());
1267 void ConvertFloorOperator(const NodeDef& node,
1270 CHECK_EQ(node.op(), "Floor");
1271 CheckInputsCount(node, tf_import_flags, 1);
1272 const auto data_type = GetDataTypeAttr(node, "T");
1275 op->inputs.push_back(node.input(0));
1276 op->outputs.push_back(node.name());
1280 void ConvertGatherOperator(const NodeDef& node,
1283 CHECK_EQ(node.op(), "Gather");
1284 CheckInputsCount(node, tf_import_flags, 2);
1285 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1288 op->inputs.push_back(node.input(0));
1289 op->inputs.push_back(node.input(1));
1290 op->outputs.push_back(node.name());
1294 void ConvertArgMaxOperator(const NodeDef& node,
1297 CHECK_EQ(node.op(), "ArgMax");
1298 CheckInputsCount(node, tf_import_flags, 2);
1299 const auto axis_data_type = GetDataTypeAttr(node, "Tidx");
1300 const auto output_type = GetDataTypeAttr(node, "output_type");
1305 op->inputs.push_back(node.input(0));
1306 op->inputs.push_back(node.input(1));
1307 op->outputs.push_back(node.name());
1311 void ConvertResizeBilinearOperator(const NodeDef& node,
1314 CHECK_EQ(node.op(), "ResizeBilinear");
1315 CheckInputsCount(node, tf_import_flags, 2);
1319 if (HasAttr(node, "align_corners")) {
1320 op->align_corners = GetBoolAttr(node, "align_corners");
1323 op->inputs.push_back(node.input(0));
1324 op->inputs.push_back(node.input(1));
1325 op->outputs.push_back(node.name());
1330 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1332 CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1333 CheckInputsCount(node, tf_import_flags, 5);
1337 // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1339 string multiplier = node.name() + "_mul";
1340 if (GetBoolAttr(node, "scale_after_normalization")) {
1345 string rsqrt = node.name() + "_rsqrt";
1348 rsqrt_op->inputs.push_back(node.input(2));
1354 mul_op->inputs.push_back(node.input(4));
1361 rsqrt_op->inputs.push_back(node.input(2));
1369 op->inputs.push_back(node.input(0));
1370 op->inputs.push_back(node.input(1));
1372 op->inputs.push_back(node.input(3));
1373 op->outputs.push_back(node.name());
1378 void ConvertFusedBatchNormOperator(const NodeDef& node,
1381 CHECK_EQ(node.op(), "FusedBatchNorm");
1382 CheckInputsCount(node, tf_import_flags, 5);
1385 const string& gamma_input = node.input(1);
1386 const string& beta_input = node.input(2);
1387 const string& moving_mean_input = node.input(3);
1388 const string& moving_variance_input = node.input(4);
1391 const string epsilon_array_name = node.name() + "_epsilon_array";
1396 GetFloatAttr(node, "epsilon"));
1399 const string epsilon_add_op_name = node.name() + "_epsilon";
1407 const string rsqrt_op_name = node.name() + "_rsqrt";
1414 const string multiplier = node.name() + "_mul";
1425 op->inputs.push_back(node.input(0));
1429 op->outputs.push_back(node.name());
1434 void ConvertSpaceToBatchNDOperator(const NodeDef& node,
1437 CHECK_EQ(node.op(), "SpaceToBatchND");
1438 CheckInputsCount(node, tf_import_flags, 3);
1439 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1440 CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1442 op->inputs.push_back(node.input(0));
1443 op->inputs.push_back(node.input(1));
1444 op->inputs.push_back(node.input(2));
1445 op->outputs.push_back(node.name());
1449 void ConvertBatchToSpaceNDOperator(const NodeDef& node,
1452 CHECK_EQ(node.op(), "BatchToSpaceND");
1453 CheckInputsCount(node, tf_import_flags, 3);
1454 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1455 CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1457 op->inputs.push_back(node.input(0));
1458 op->inputs.push_back(node.input(1));
1459 op->inputs.push_back(node.input(2));
1460 op->outputs.push_back(node.name());
1464 void ConvertExpOperator(const NodeDef& node,
1467 CHECK_EQ(node.op(), "Exp");
1468 CheckInputsCount(node, tf_import_flags, 1);
1470 op->inputs.push_back(node.input(0));
1471 op->outputs.push_back(node.name());
1475 void ConvertMeanOperator(const NodeDef& node,
1478 CHECK_EQ(node.op(), "Mean");
1479 CheckInputsCount(node, tf_import_flags, 2);
1481 op->inputs.push_back(node.input(0));
1482 op->inputs.push_back(node.input(1));
1483 op->outputs.push_back(node.name());
1485 if (HasAttr(node, "keep_dims")) {
1486 op->keep_dims = GetBoolAttr(node, "keep_dims");
1490 void ConvertSvdfOperator(const NodeDef& node,
1493 CHECK_EQ(node.op(), "Svdf");
1494 const int input_size = GetInputsCount(node, tf_import_flags);
1496 << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1497 << node.DebugString();
1500 op->inputs.push_back(node.input(0));
1501 op->inputs.push_back(node.input(1));
1502 op->inputs.push_back(node.input(2));
1504 op->inputs.push_back(node.input(3));
1506 op->outputs.push_back(node.name() + "_state");
1507 op->outputs.push_back(node.name());
1508 if (node.attr().at("ActivationFunction").s() == "Relu") {
1513 op->rank = node.attr().at("Rank").i();
1518 void ConvertTransposeConvOperator(const NodeDef& node,
1521 CHECK_EQ(node.op(), "Conv2DBackpropInput");
1522 CheckInputsCount(node, tf_import_flags, 3);
1524 op->inputs.push_back(node.input(2));
1525 op->inputs.push_back(node.input(1));
1526 op->inputs.push_back(node.input(0));
1527 op->outputs.push_back(node.name());
1528 const auto& strides = GetListAttr(node, "strides");
1534 auto const& padding = GetStringAttr(node, "padding");
1546 void ConvertExpandDimsOperator(const NodeDef& node,
1549 CHECK_EQ(node.op(), "ExpandDims");
1550 CheckInputsCount(node, tf_import_flags, 2);
1552 op->inputs.push_back(node.input(0));
1553 op->inputs.push_back(node.input(1));
1554 op->outputs.push_back(node.name());
1558 void ConvertFillOperator(const NodeDef& node,
1561 CHECK_EQ(node.op(), "Fill");
1562 CheckInputsCount(node, tf_import_flags, 2);
1564 op->inputs.push_back(node.input(0));
1565 op->inputs.push_back(node.input(1));
1566 op->outputs.push_back(node.name());
1570 void ConvertFloorDivOperator(const NodeDef& node,
1573 CHECK_EQ(node.op(), "FloorDiv");
1574 CheckInputsCount(node, tf_import_flags, 2);
1576 op->inputs.push_back(node.input(0));
1577 op->inputs.push_back(node.input(1));
1578 op->outputs.push_back(node.name());
1582 void ConvertFloorModOperator(const NodeDef& node,
1585 CHECK_EQ(node.op(), "FloorMod");
1586 CheckInputsCount(node, tf_import_flags, 2);
1588 op->inputs.push_back(node.input(0));
1589 op->inputs.push_back(node.input(1));
1590 op->outputs.push_back(node.name());
1594 void ConvertRangeOperator(const NodeDef& node,
1597 CHECK_EQ(node.op(), "Range");
1598 CheckInputsCount(node, tf_import_flags, 3);
1600 if (HasAttr(node, "Tidx")) {
1601 const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
1606 op->inputs.push_back(node.input(0));
1607 op->inputs.push_back(node.input(1));
1608 op->inputs.push_back(node.input(2));
1609 op->outputs.push_back(node.name());
1613 void ConvertRankOperator(const NodeDef& node,
1616 CHECK_EQ(node.op(), "Rank");
1617 CheckInputsCount(node, tf_import_flags, 1);
1619 op->inputs.push_back(node.input(0));
1620 op->outputs.push_back(node.name());
1624 void ConvertStackOperator(const NodeDef& node,
1627 CHECK((node.op() == "Stack") || (node.op() == "Pack"));
1629 const int num_inputs = GetInputsCount(node, tf_import_flags);
1631 << node.op()
1632 << " node expects at least 1 input other than control dependencies: "
1633 << node.DebugString();
1634 CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
1636 op->inputs.push_back(node.input(i));
1639 op->axis = GetIntAttr(node, "axis");
1640 op->outputs.push_back(node.name());
1644 void ConvertTransposeOperator(const NodeDef& node,
1647 CHECK_EQ(node.op(), "Transpose");
1648 CheckInputsCount(node, tf_import_flags, 2);
1650 op->inputs.push_back(node.input(0));
1651 op->inputs.push_back(node.input(1));
1652 op->outputs.push_back(node.name());
1667 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1671 CHECK_EQ(node.op(), "NextIteration");
1672 CHECK_EQ(node.input_size(), 1);
1677 rnn_state->set_state_array(node.name());
1678 rnn_state->set_back_edge_source_array(node.input(0));
1697 void StripZeroOutputIndexFromInputs(NodeDef* node) {
1698 for (auto& input : *node->mutable_input()) {
1703 // In TensorFlow GraphDef, when a node has multiple outputs, they are named
1705 // where 'name' is the node's name(). Just 'name' is an equivalent shorthand
1707 // A TensorFlow GraphDef does not explicitly list all the outputs of each node
1708 // (unlike inputs), it being implied by the node's name and operator type
1713 // at least each node lists explicitly its inputs, so after we've loaded
1738 // output of a node with multiple outputs, so nothing to do here.
1747 // has its name. We use that to identify the producer node.
1752 // Add extra outputs to that producer node, all the way to the
1821 void ConvertTopKV2Operator(const NodeDef& node,
1824 CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
1826 op->inputs.push_back(node.input(0));
1828 if (HasAttr(node, "k")) {
1830 const string array_name = node.name() + "k";
1839 output_int_data[0] = GetIntAttr(node, "k");
1843 CheckInputsCount(node, tf_import_flags, 2);
1844 op->inputs.push_back(node.input(1));
1847 op->outputs.push_back(node.name() + ":0");
1848 op->outputs.push_back(node.name() + ":1");
1876 for (auto node : inlined_graph.node()) {
1877 StripZeroOutputIndexFromInputs(&node);
1878 if (node.op() == "Const") {
1879 ConvertConstOperator(node, tf_import_flags, model);
1880 } else if (node.op() == "Conv2D") {
1881 ConvertConvOperator(node, tf_import_flags, model);
1882 } else if (node.op() == "Conv2DBackpropInput") {
1883 ConvertTransposeConvOperator(node, tf_import_flags, model);
1884 } else if (node.op() == "DepthwiseConv2dNative") {
1885 ConvertDepthwiseConvOperator(node, tf_import_flags, model);
1886 } else if (node.op() == "DepthToSpace") {
1887 ConvertDepthToSpaceOperator(node, tf_import_flags, model);
1888 } else if (node.op() == "SpaceToDepth") {
1889 ConvertSpaceToDepthOperator(node, tf_import_flags, model);
1890 } else if (node.op() == "BiasAdd") {
1891 ConvertBiasAddOperator(node, tf_import_flags, model);
1892 } else if (node.op() == "Relu") {
1893 ConvertReluOperator(node, tf_import_flags, model);
1894 } else if (node.op() == "Relu6") {
1895 ConvertRelu6Operator(node, tf_import_flags, model);
1896 } else if (node.op() == "Sigmoid") {
1897 ConvertLogisticOperator(node, tf_import_flags, model);
1898 } else if (node.op() == "Tanh") {
1899 ConvertTanhOperator(node, tf_import_flags, model);
1900 } else if (node.op() == "MaxPool") {
1901 ConvertMaxPoolOperator(node, tf_import_flags, model);
1902 } else if (node.op() == "AvgPool") {
1903 ConvertAvgPoolOperator(node, tf_import_flags, model);
1904 } else if (node.op() == "Reshape") {
1905 ConvertReshapeOperator(node, tf_import_flags, model);
1906 } else if (node.op() == "BatchMatMul") {
1907 ConvertBatchMatMulOperator(node, tf_import_flags, model);
1908 } else if (node.op() == "MatMul") {
1909 ConvertMatMulOperator(node, tf_import_flags, model);
1910 } else if (node.op() == "Div" || node.op() == "RealDiv") {
1911 ConvertDivOperator(node, tf_import_flags, model);
1912 } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
1913 node.op() == "StopGradient") {
1914 ConvertIdentityOperator(node, tf_import_flags, model);
1915 } else if (node.op() == "FakeQuantWithMinMaxVars") {
1916 ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
1917 } else if (node.op() == "FakeQuantWithMinMaxArgs") {
1918 ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
1919 } else if (node.op() == "Neg") {
1920 ConvertNegOperator(node, tf_import_flags, model);
1921 } else if (node.op() == "Rsqrt") {
1922 ConvertRsqrtOperator(node, tf_import_flags, model);
1923 } else if (node.op() == "Squeeze") {
1924 ConvertSqueezeOperator(node, tf_import_flags, model);
1925 } else if (node.op() == "Sqrt") {
1926 ConvertSqrtOperator(node, tf_import_flags, model);
1927 } else if (node.op() == "Square") {
1928 ConvertSquareOperator(node, tf_import_flags, model);
1929 } else if (node.op() == "Add") {
1930 ConvertAddOperator(node, tf_import_flags, model);
1931 } else if (node.op() == "AddN") {
1932 ConvertAddNOperator(node, tf_import_flags, model);
1933 } else if (node.op() == "Mul") {
1934 ConvertMulOperator(node, tf_import_flags, model);
1935 } else if (node.op() == "Sub") {
1936 ConvertSubOperator(node, tf_import_flags, model);
1937 } else if (node.op() == "Sum") {
1938 ConvertSumOperator(node, tf_import_flags, model);
1939 } else if (node.op() == "Tile") {
1940 ConvertTileOperator(node, tf_import_flags, model);
1941 } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
1942 ConvertConcatOperator(node, tf_import_flags, model);
1943 } else if (node.op() == "LRN") {
1944 ConvertLRNOperator(node, tf_import_flags, model);
1945 } else if (node.op() == "Softmax") {
1946 ConvertSoftmaxOperator(node, tf_import_flags, model);
1947 } else if (node.op() == "LogSoftmax") {
1948 ConvertLogSoftmaxOperator(node, tf_import_flags, model);
1949 } else if (node.op() == "All") {
1950 ConvertAllOperator(node, tf_import_flags, model);
1951 } else if (node.op() == "Assert") {
1952 ConvertAssertOperator(node, tf_import_flags, model);
1953 } else if (node.op() == "Less") {
1954 ConvertLessOperator(node, tf_import_flags, model);
1955 } else if (node.op() == "LessEqual") {
1956 ConvertLessEqualOperator(node, tf_import_flags, model);
1957 } else if (node.op() == "Greater") {
1958 ConvertGreaterOperator(node, tf_import_flags, model);
1959 } else if (node.op() == "GreaterEqual") {
1960 ConvertGreaterEqualOperator(node, tf_import_flags, model);
1961 } else if (node.op() == "Max") {
1962 ConvertMaxOperator(node, tf_import_flags, model);
1963 } else if (node.op() == "Min") {
1964 ConvertMinOperator(node, tf_import_flags, model);
1965 } else if (node.op() == "Maximum") {
1966 ConvertMaximumOperator(node, tf_import_flags, model);
1967 } else if (node.op() == "Minimum") {
1968 ConvertMinimumOperator(node, tf_import_flags, model);
1969 } else if (node.op() == "Merge") {
1970 ConvertMergeOperator(node, tf_import_flags, model);
1971 } else if (node.op() == "Pad") {
1972 ConvertPadOperator(node, tf_import_flags, model);
1973 } else if (node.op() == "StridedSlice") {
1974 ConvertStridedSliceOperator(node, tf_import_flags, model);
1975 } else if (node.op() == "Shape") {
1976 ConvertShapeOperator(node, tf_import_flags, model);
1977 } else if (node.op() == "Slice") {
1978 ConvertSliceOperator(node, tf_import_flags, model);
1979 } else if (node.op() == "Split") {
1980 ConvertSplitOperator(node, tf_import_flags, model);
1981 } else if (node.op() == "Switch") {
1982 ConvertSwitchOperator(node, tf_import_flags, model);
1983 } else if (node.op() == "Placeholder") {
1984 ConvertPlaceholderOperator(node, tf_import_flags, model);
1985 } else if (node.op() == "PlaceholderWithDefault") {
1986 ConvertIdentityOperator(node, tf_import_flags, model);
1987 } else if (node.op() == "LegacyFedInput") {
1988 ConvertPlaceholderOperator(node, tf_import_flags, model);
1989 } else if (node.op() == "NoOp") {
1990 ConvertNoOpOperator(node, tf_import_flags, model);
1991 } else if (node.op() == "Cast") {
1992 ConvertCastOperator(node, tf_import_flags, model);
1993 } else if (node.op() == "Floor") {
1994 ConvertFloorOperator(node, tf_import_flags, model);
1995 } else if (node.op() == "Gather") {
1996 ConvertGatherOperator(node, tf_import_flags, model);
1997 } else if (node.op() == "ResizeBilinear") {
1998 ConvertResizeBilinearOperator(node, tf_import_flags, model);
1999 } else if (node.op() == "BatchNormWithGlobalNormalization") {
2000 ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
2002 } else if (node.op() == "FusedBatchNorm") {
2003 ConvertFusedBatchNormOperator(node, tf_import_flags, model);
2004 } else if (node.op() == "SpaceToBatchND") {
2005 ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
2006 } else if (node.op() == "BatchToSpaceND") {
2007 ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
2008 } else if (node.op() == "Mean") {
2009 ConvertMeanOperator(node, tf_import_flags, model);
2010 } else if (node.op() == "Svdf") {
2011 ConvertSvdfOperator(node, tf_import_flags, model);
2012 } else if (node.op() == "NextIteration") {
2013 ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
2014 } else if (node.op() == "ExpandDims") {
2015 ConvertExpandDimsOperator(node, tf_import_flags, model);
2016 } else if (node.op() == "Fill") {
2017 ConvertFillOperator(node, tf_import_flags, model);
2018 } else if (node.op() == "FloorDiv") {
2019 ConvertFloorDivOperator(node, tf_import_flags, model);
2020 } else if (node.op() == "FloorMod") {
2021 ConvertFloorModOperator(node, tf_import_flags, model);
2022 } else if (node.op() == "Range") {
2023 ConvertRangeOperator(node, tf_import_flags, model);
2024 } else if (node.op() == "Rank") {
2025 ConvertRankOperator(node, tf_import_flags, model);
2026 } else if (node.op() == "Stack" || node.op() == "Pack") {
2027 ConvertStackOperator(node, tf_import_flags, model);
2028 } else if (node.op() == "Transpose") {
2029 ConvertTransposeOperator(node, tf_import_flags, model);
2030 } else if (node.op() == "ArgMax") {
2031 ConvertArgMaxOperator(node, tf_import_flags, model);
2032 } else if (node.op() == "Exp") {
2033 ConvertExpOperator(node, tf_import_flags, model);
2034 } else if (node.op() == "TopK" || node.op() == "TopKV2") {
2035 ConvertTopKV2Operator(node, tf_import_flags, model);
2037 ConvertUnsupportedOperator(node, tf_import_flags, model);