Home | History | Annotate | Download | only in optimizers

Lines Matching refs:node_

472         node_(opt_cxt.node),
519 if (node_->attr().find("data_format") != node_->attr().end()) {
520 if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
528 auto outputs = node_map_->GetOutputs(node_->name());
541 return nodes_to_preserve_.find(node_->name()) != nodes_to_preserve_.end();
546 if (node_->device().empty()) {
547 device_name = virtual_placer_.get_canonical_device_name(*node_);
549 device_name = node_->device();
562 return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
567 if (node_->attr().find("_output_shapes") != node_->attr().end()) {
569 auto shape = node_->mutable_attr()
586 auto input_node = node_map_->GetNode(node_->input(input_index));
590 // to ensure added_node is in the same frame with node_.
593 string base_name = strings::StrCat(node_->name(), "-", input_index);
596 *node_->mutable_input(input_index) = node_name;
598 node_map_->AddOutput(node_name, node_->name());
616 strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
618 graph_properties_.GetInputProperties(node_->name())[pos].dtype();
619 auto input_node = node_map_->GetNode(node_->input(pos));
623 ParseNodeName(node_->input(pos), &output_pos);
625 node_name, node_->input(pos), const_name, dtype,
628 node_map_->UpdateOutput(NodeName(node_->input(pos)), node_->name(),
630 node_map_->AddOutput(node_name, node_->name());
631 *node_->mutable_input(pos) = node_name;
637 auto outputs = node_map_->GetOutputs(node_->name());
648 if (input_name == node_->name()) {
653 strings::StrCat(node_->name(), "-", output_count, "-", i);
656 graph_properties_.GetOutputProperties(node_->name())[input_port]
661 TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
664 node_->attr().at("_output_shapes").list().shape(input_port),
674 node_map_->AddOutput(node_->name(), added_node_name);
680 node_map_->RemoveOutput(node_->name(), output->name());
695 auto param_node = node_map_->GetNode(node_->input(param_index));
705 NodeDef* node_;
710 if (node_->attr().find("ksize") != node_->attr().end()) {
711 auto list = node_->mutable_attr()->at("ksize").mutable_list();
717 if (node_->attr().find("strides") != node_->attr().end()) {
718 auto list = node_->mutable_attr()->at("strides").mutable_list();
724 if (node_->attr().find("data_format") != node_->attr().end()) {
725 if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
727 node_->mutable_attr()->at("data_format").mutable_s();
798 node->set_device(node_->device());
851 string base_name = strings::StrCat(node_->name(), "-", pos);
852 string input = NodeName(node_->input(pos));
861 AddNodePermNHWCToNCHW(base_name, depended_node, node_->device());
873 AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device());
897 added_node->set_device(node_->device());
917 strings::StrCat(node_->name(), "-", input_pos, "-", suffix));
919 AddNodeDataFormatOp(name, node_->input(input_pos), op, dtype, true);
920 *node_->mutable_input(input_pos) = added_node->name();
921 node_map_->UpdateOutput(NodeName(added_node->input(0)), node_->name(),
923 node_map_->AddOutput(added_node->name(), node_->name());
952 auto input = node_map_->GetNode(node_->input(0));
955 ParseNodeName(node_->input(0), &port);
973 return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
990 if (node_->attr().find("strides") != node_->attr().end()) {
991 auto list = node_->attr().at("strides").list();
998 if (node_->attr().find("padding") != node_->attr().end()) {
999 auto padding = node_->attr().at("padding").s();
1027 auto filter_shape = GetShape(node_->input(1));
1028 auto input_shape = GetShape(node_->input(0));
1042 auto filter_shape = GetShape(node_->name());
1043 auto input_shape = GetShape(node_->input(0));
1063 auto filter_shape = GetShape(node_->input(1));
1064 auto input_shape = GetShape(node_->name());
1089 if (node_->attr().find("is_training") != node_->attr().end()) {
1090 if (node_->attr().at("is_training").b()) {
1132 auto data_input = node_map_->GetNode(node_->input(0));
1134 ParseNodeName(node_->input(0), &port);
1155 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1191 bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
1201 return NonControlInputs(*node_);
1212 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1222 auto input0 = node_map_->GetNode(node_->input(0));
1223 auto input1 = node_map_->GetNode(node_->input(1));
1225 ParseNodeName(node_->input(0), &input0_port);
1227 ParseNodeName(node_->input(1), &input1_port);
1238 auto input0 = node_map_->GetNode(node_->input(0));
1239 auto input1 = node_map_->GetNode(node_->input(1));
1241 ParseNodeName(node_->input(0), &input0_port);
1243 ParseNodeName(node_->input(1), &input1_port);
1261 node->set_device(node_->device());
1291 node->set_device(node_->device());
1311 string base_name = strings::StrCat(node_->name(), "-", vector_index);
1316 auto input_node = node_map_->GetNode(node_->input(vector_index));
1319 ParseNodeName(node_->input(vector_index), &port);
1327 NodeName(node_->input(vector_index)));
1328 TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
1329 AddNodeReshape(reshape_node_name, node_->input(vector_index),
1330 shape_const_node_name, node_->attr().at("T").type());
1332 node_map_->UpdateOutput(NodeName(node_->input(vector_index)),
1333 node_->name(), reshape_node_name);
1334 node_map_->AddOutput(reshape_node_name, node_->name());
1335 *node_->mutable_input(vector_index) = reshape_node_name;
1348 int n = node_->attr().at("N").i();
1349 axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
1354 return DataInputPosConcat(*node_);
1359 (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
1376 DataType dtype = node_->attr().at("index_type").type();
1388 auto input1 = node_map_->GetNode(node_->input(1));
1390 ParseNodeName(node_->input(1), &port);
1405 for (int i = 0; i < node_->input_size(); i++) {
1406 auto input = node_map_->GetNode(node_->input(i));
1408 ParseNodeName(node_->input(i), &port);
1462 return !MustPreserve() && IsPortZeroDimsFour(*node_) && HasOutputs() &&
1468 int n = node_->attr().at("N").i();
1479 for (const auto& input : node_->input()) {
1503 DataType dtype = node_->attr().at("Tpaddings").type();
1515 DataType dtype = node_->attr().at("Tidx").type();
1532 if (HasAttribute(*node_, "num_split").ok()) {
1533 for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
1575 auto input0 = node_map_->GetNode(node_->input(0));
1577 ParseNodeName(node_->input(0), &input0_port);
1585 auto input0 = node_map_->GetNode(node_->input(0));
1587 ParseNodeName(node_->input(0), &input0_port);
1613 // Note that we can't use node_->input_size() here because there
1621 DataType dtype = node_->attr().at("Index").type();
1656 return node_->attr().at(mask).i() == 0;
1665 int i = node_->attr().at(mask).i();
1696 node_->mutable_attr()->at(mask).set_i(i);
1720 bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
1721 (IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
1729 TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
1730 auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
1744 auto input = node_map_->GetNode(node_->input(0));
1745 ParseNodeName(node_->input(0), &input_port);
1763 if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
1764 auto list = node_->attr().at("squeeze_dims").list();
1788 auto input0 = node_map_->GetNode(node_->input(0));
1790 ParseNodeName(node_->input(0), &port);
1798 DataType dtype = node_->attr().at("Tidx").type();
1820 auto axis_node = node_map_->GetNode(node_->input(1));
1851 bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
1870 DataType dtype = node_->attr().at("Tmultiples").type();