1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 17 18 #include "tensorflow/cc/framework/ops.h" 19 #include "tensorflow/cc/ops/data_flow_ops.h" 20 #include "tensorflow/cc/ops/function_ops.h" 21 #include "tensorflow/cc/ops/standard_ops.h" 22 #include "tensorflow/compiler/tf2xla/sharding_util.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/core/status_test_util.h" 27 #include "tensorflow/core/lib/core/stringpiece.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/platform/test.h" 30 31 namespace tensorflow { 32 namespace { 33 34 void ExpectErrorContains(const Status& status, StringPiece str) { 35 EXPECT_NE(Status::OK(), status); 36 EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) 37 << "expected error: " << status.error_message() << " to contain: " << str; 38 } 39 40 TEST(ValidateConfig, Good) { 41 tf2xla::Config config; 42 tf2xla::Feed* feed = config.add_feed(); 43 feed->mutable_id()->set_node_name("foo"); 44 feed->mutable_id()->set_output_index(123); 45 feed->set_name("foo_debug"); 46 feed = config.add_feed(); 47 feed->mutable_id()->set_node_name("bar"); 48 feed->mutable_id()->set_output_index(0); 49 tf2xla::Fetch* fetch = config.add_fetch(); 50 fetch->mutable_id()->set_node_name("baz"); 51 fetch->mutable_id()->set_output_index(456); 52 fetch->set_name("baz_debug"); 53 fetch = config.add_fetch(); 54 fetch->mutable_id()->set_node_name("banana"); 55 fetch->mutable_id()->set_output_index(0); 56 TF_EXPECT_OK(ValidateConfig(config)); 57 } 58 59 TEST(ValidateConfig, BadEmpty) { 60 tf2xla::Config config; 61 ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); 62 } 63 64 TEST(ValidateConfig, BadNoFetch) { 65 tf2xla::Config config; 66 tf2xla::Feed* feed = config.add_feed(); 67 feed->mutable_id()->set_node_name("foo"); 68 ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); 69 } 70 71 TEST(ValidateConfig, BadFeedNodeName) { 72 tf2xla::Config config; 73 config.add_feed(); 74 ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); 75 } 76 77 TEST(ValidateConfig, BadFeedOutputIndex) { 78 tf2xla::Config config; 79 tf2xla::Feed* feed = config.add_feed(); 80 feed->mutable_id()->set_node_name("foo"); 81 feed->mutable_id()->set_output_index(-1); 82 ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); 83 } 84 85 TEST(ValidateConfig, BadFetchNodeName) { 86 tf2xla::Config config; 87 tf2xla::Feed* feed = config.add_feed(); 88 feed->mutable_id()->set_node_name("foo"); 89 config.add_fetch(); 90 ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); 91 } 92 93 TEST(ValidateConfig, BadFetchOutputIndex) { 94 tf2xla::Config config; 95 tf2xla::Feed* feed = config.add_feed(); 96 feed->mutable_id()->set_node_name("foo"); 97 tf2xla::Fetch* fetch = config.add_fetch(); 98 fetch->mutable_id()->set_node_name("bar"); 99 fetch->mutable_id()->set_output_index(-1); 100 ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); 101 } 102 103 TEST(ValidateConfig, DuplicateFeedName) { 104 tf2xla::Config config; 105 tf2xla::Feed* feed = config.add_feed(); 106 feed->mutable_id()->set_node_name("foo"); 107 feed->set_name("dup"); 108 feed = config.add_feed(); 109 feed->mutable_id()->set_node_name("bar"); 110 feed->set_name("dup"); 111 ExpectErrorContains(ValidateConfig(config), "duplicate feed name"); 112 } 113 114 TEST(ValidateConfig, DuplicateFetchName) { 115 tf2xla::Config config; 116 tf2xla::Feed* feed = config.add_feed(); 117 feed->mutable_id()->set_node_name("foo"); 118 tf2xla::Fetch* fetch = config.add_fetch(); 119 fetch->mutable_id()->set_node_name("bar"); 120 fetch->set_name("dup"); 121 fetch = config.add_fetch(); 122 fetch->mutable_id()->set_node_name("baz"); 123 fetch->set_name("dup"); 124 ExpectErrorContains(ValidateConfig(config), "duplicate fetch name"); 125 } 126 127 TEST(ValidateConfig, ConflictingFeedName) { 128 tf2xla::Config config; 129 tf2xla::Feed* feed = config.add_feed(); 130 feed->mutable_id()->set_node_name("foo"); 131 feed->set_name("conflict"); 132 feed = config.add_feed(); 133 feed->mutable_id()->set_node_name("bar"); 134 feed->set_name("conflict_data"); 135 ExpectErrorContains(ValidateConfig(config), "conflicting feed name"); 136 } 137 138 TEST(ValidateConfig, ConflictingFetchName) { 139 tf2xla::Config config; 140 tf2xla::Feed* feed = config.add_feed(); 141 feed->mutable_id()->set_node_name("foo"); 142 tf2xla::Fetch* fetch = config.add_fetch(); 143 fetch->mutable_id()->set_node_name("bar"); 144 fetch->set_name("conflict"); 145 fetch = config.add_fetch(); 146 fetch->mutable_id()->set_node_name("baz"); 147 fetch->set_name("conflict_data"); 148 ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); 149 } 150 151 static tf2xla::Config FetchesConfig(std::vector<string> fetches) { 152 tf2xla::Config config; 153 for (const auto& fetch_node_name : fetches) { 154 auto* fetch = config.add_fetch(); 155 fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); 156 fetch->mutable_id()->set_node_name(fetch_node_name); 157 } 158 return config; 159 } 160 161 TEST(PruneGraphDefInto, Basic) { 162 GraphDef def; 163 auto* n = def.add_node(); 164 n->set_name("a"); 165 n->add_input("b:0"); 166 n->add_input("^c"); 167 168 GraphDef copy; 169 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©), 170 "node missing needed"); 171 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), 172 "node b needed"); 173 174 n = def.add_node(); 175 n->set_name("b"); 176 ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), 177 "node c needed"); 178 n->add_input("d:1"); 179 180 n = def.add_node(); 181 n->set_name("c"); 182 n->add_input("d:1"); 183 184 n = def.add_node(); 185 n->set_name("d"); 186 187 // Graph is full, no pruning done. 188 // Graph right now has diamond from d: 189 // d --> b --> a 190 // d --> c --> a 191 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); 192 EXPECT_EQ(def.DebugString(), copy.DebugString()); 193 GraphDef pruned_a = copy; 194 195 // Add some unrelated fields that use b and c, but are not needed for a. 196 n = def.add_node(); 197 n->set_name("e"); 198 n->add_input("^d"); 199 n->add_input("b:2"); 200 copy.Clear(); 201 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); 202 EXPECT_EQ(pruned_a.DebugString(), copy.DebugString()); 203 204 // Fetch "a" and "e" to get the original graph. 205 copy.Clear(); 206 TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©)); 207 EXPECT_EQ(def.DebugString(), copy.DebugString()); 208 } 209 210 TEST(SetNodeShardingFromNeighbors, Basic) { 211 // Builds a graph that adds two Tensors. 212 Scope scope = Scope::NewRootScope().ExitOnError(); 213 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); 214 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); 215 auto c = ops::Add(scope.WithOpName("C"), a, b); 216 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); 217 TF_ASSERT_OK(scope.ToGraph(graph.get())); 218 219 Node* a_node = nullptr; 220 Node* b_node = nullptr; 221 Node* c_node = nullptr; 222 for (Node* n : graph->nodes()) { 223 if (n->name() == "A") a_node = n; 224 if (n->name() == "B") b_node = n; 225 if (n->name() == "C") c_node = n; 226 } 227 228 const int num_cores_per_replica = 4; 229 230 a_node->set_assigned_device_name("foo"); 231 EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok()); 232 233 // Test where one input to c_node has a device. 234 a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2"); 235 TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); 236 auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); 237 TF_ASSERT_OK(parse_status.status()); 238 ASSERT_TRUE(parse_status.ValueOrDie().has_value()); 239 EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0)); 240 241 // Test where two inputs to c_node have a device. 242 b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1"); 243 TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); 244 parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); 245 TF_ASSERT_OK(parse_status.status()); 246 ASSERT_TRUE(parse_status.ValueOrDie().has_value()); 247 EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); 248 249 // Test setting based on out edges. 250 TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true)); 251 parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica); 252 TF_ASSERT_OK(parse_status.status()); 253 ASSERT_TRUE(parse_status.ValueOrDie().has_value()); 254 EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); 255 } 256 257 } // namespace 258 } // namespace tensorflow 259