Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
     17 
     18 #include "tensorflow/cc/framework/ops.h"
     19 #include "tensorflow/cc/ops/data_flow_ops.h"
     20 #include "tensorflow/cc/ops/function_ops.h"
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 #include "tensorflow/compiler/tf2xla/sharding_util.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/graph/graph.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 #include "tensorflow/core/lib/core/stringpiece.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/platform/test.h"
     30 
     31 namespace tensorflow {
     32 namespace {
     33 
     34 void ExpectErrorContains(const Status& status, StringPiece str) {
     35   EXPECT_NE(Status::OK(), status);
     36   EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
     37       << "expected error: " << status.error_message() << " to contain: " << str;
     38 }
     39 
     40 TEST(ValidateConfig, Good) {
     41   tf2xla::Config config;
     42   tf2xla::Feed* feed = config.add_feed();
     43   feed->mutable_id()->set_node_name("foo");
     44   feed->mutable_id()->set_output_index(123);
     45   feed->set_name("foo_debug");
     46   feed = config.add_feed();
     47   feed->mutable_id()->set_node_name("bar");
     48   feed->mutable_id()->set_output_index(0);
     49   tf2xla::Fetch* fetch = config.add_fetch();
     50   fetch->mutable_id()->set_node_name("baz");
     51   fetch->mutable_id()->set_output_index(456);
     52   fetch->set_name("baz_debug");
     53   fetch = config.add_fetch();
     54   fetch->mutable_id()->set_node_name("banana");
     55   fetch->mutable_id()->set_output_index(0);
     56   TF_EXPECT_OK(ValidateConfig(config));
     57 }
     58 
     59 TEST(ValidateConfig, BadEmpty) {
     60   tf2xla::Config config;
     61   ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
     62 }
     63 
     64 TEST(ValidateConfig, BadNoFetch) {
     65   tf2xla::Config config;
     66   tf2xla::Feed* feed = config.add_feed();
     67   feed->mutable_id()->set_node_name("foo");
     68   ExpectErrorContains(ValidateConfig(config), "fetches must be specified");
     69 }
     70 
     71 TEST(ValidateConfig, BadFeedNodeName) {
     72   tf2xla::Config config;
     73   config.add_feed();
     74   ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
     75 }
     76 
     77 TEST(ValidateConfig, BadFeedOutputIndex) {
     78   tf2xla::Config config;
     79   tf2xla::Feed* feed = config.add_feed();
     80   feed->mutable_id()->set_node_name("foo");
     81   feed->mutable_id()->set_output_index(-1);
     82   ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
     83 }
     84 
     85 TEST(ValidateConfig, BadFetchNodeName) {
     86   tf2xla::Config config;
     87   tf2xla::Feed* feed = config.add_feed();
     88   feed->mutable_id()->set_node_name("foo");
     89   config.add_fetch();
     90   ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
     91 }
     92 
     93 TEST(ValidateConfig, BadFetchOutputIndex) {
     94   tf2xla::Config config;
     95   tf2xla::Feed* feed = config.add_feed();
     96   feed->mutable_id()->set_node_name("foo");
     97   tf2xla::Fetch* fetch = config.add_fetch();
     98   fetch->mutable_id()->set_node_name("bar");
     99   fetch->mutable_id()->set_output_index(-1);
    100   ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
    101 }
    102 
    103 TEST(ValidateConfig, DuplicateFeedName) {
    104   tf2xla::Config config;
    105   tf2xla::Feed* feed = config.add_feed();
    106   feed->mutable_id()->set_node_name("foo");
    107   feed->set_name("dup");
    108   feed = config.add_feed();
    109   feed->mutable_id()->set_node_name("bar");
    110   feed->set_name("dup");
    111   ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
    112 }
    113 
    114 TEST(ValidateConfig, DuplicateFetchName) {
    115   tf2xla::Config config;
    116   tf2xla::Feed* feed = config.add_feed();
    117   feed->mutable_id()->set_node_name("foo");
    118   tf2xla::Fetch* fetch = config.add_fetch();
    119   fetch->mutable_id()->set_node_name("bar");
    120   fetch->set_name("dup");
    121   fetch = config.add_fetch();
    122   fetch->mutable_id()->set_node_name("baz");
    123   fetch->set_name("dup");
    124   ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
    125 }
    126 
    127 TEST(ValidateConfig, ConflictingFeedName) {
    128   tf2xla::Config config;
    129   tf2xla::Feed* feed = config.add_feed();
    130   feed->mutable_id()->set_node_name("foo");
    131   feed->set_name("conflict");
    132   feed = config.add_feed();
    133   feed->mutable_id()->set_node_name("bar");
    134   feed->set_name("conflict_data");
    135   ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
    136 }
    137 
    138 TEST(ValidateConfig, ConflictingFetchName) {
    139   tf2xla::Config config;
    140   tf2xla::Feed* feed = config.add_feed();
    141   feed->mutable_id()->set_node_name("foo");
    142   tf2xla::Fetch* fetch = config.add_fetch();
    143   fetch->mutable_id()->set_node_name("bar");
    144   fetch->set_name("conflict");
    145   fetch = config.add_fetch();
    146   fetch->mutable_id()->set_node_name("baz");
    147   fetch->set_name("conflict_data");
    148   ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
    149 }
    150 
    151 static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
    152   tf2xla::Config config;
    153   for (const auto& fetch_node_name : fetches) {
    154     auto* fetch = config.add_fetch();
    155     fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
    156     fetch->mutable_id()->set_node_name(fetch_node_name);
    157   }
    158   return config;
    159 }
    160 
    161 TEST(PruneGraphDefInto, Basic) {
    162   GraphDef def;
    163   auto* n = def.add_node();
    164   n->set_name("a");
    165   n->add_input("b:0");
    166   n->add_input("^c");
    167 
    168   GraphDef copy;
    169   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, &copy),
    170                       "node missing needed");
    171   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy),
    172                       "node b needed");
    173 
    174   n = def.add_node();
    175   n->set_name("b");
    176   ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy),
    177                       "node c needed");
    178   n->add_input("d:1");
    179 
    180   n = def.add_node();
    181   n->set_name("c");
    182   n->add_input("d:1");
    183 
    184   n = def.add_node();
    185   n->set_name("d");
    186 
    187   // Graph is full, no pruning done.
    188   // Graph right now has diamond from d:
    189   //   d --> b --> a
    190   //   d --> c --> a
    191   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy));
    192   EXPECT_EQ(def.DebugString(), copy.DebugString());
    193   GraphDef pruned_a = copy;
    194 
    195   // Add some unrelated fields that use b and c, but are not needed for a.
    196   n = def.add_node();
    197   n->set_name("e");
    198   n->add_input("^d");
    199   n->add_input("b:2");
    200   copy.Clear();
    201   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, &copy));
    202   EXPECT_EQ(pruned_a.DebugString(), copy.DebugString());
    203 
    204   // Fetch "a" and "e" to get the original graph.
    205   copy.Clear();
    206   TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, &copy));
    207   EXPECT_EQ(def.DebugString(), copy.DebugString());
    208 }
    209 
    210 TEST(SetNodeShardingFromNeighbors, Basic) {
    211   // Builds a graph that adds two Tensors.
    212   Scope scope = Scope::NewRootScope().ExitOnError();
    213   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
    214   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
    215   auto c = ops::Add(scope.WithOpName("C"), a, b);
    216   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
    217   TF_ASSERT_OK(scope.ToGraph(graph.get()));
    218 
    219   Node* a_node = nullptr;
    220   Node* b_node = nullptr;
    221   Node* c_node = nullptr;
    222   for (Node* n : graph->nodes()) {
    223     if (n->name() == "A") a_node = n;
    224     if (n->name() == "B") b_node = n;
    225     if (n->name() == "C") c_node = n;
    226   }
    227 
    228   const int num_cores_per_replica = 4;
    229 
    230   a_node->set_assigned_device_name("foo");
    231   EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok());
    232 
    233   // Test where one input to c_node has a device.
    234   a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2");
    235   TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
    236   auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
    237   TF_ASSERT_OK(parse_status.status());
    238   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
    239   EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0));
    240 
    241   // Test where two inputs to c_node have a device.
    242   b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1");
    243   TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false));
    244   parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica);
    245   TF_ASSERT_OK(parse_status.status());
    246   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
    247   EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
    248 
    249   // Test setting based on out edges.
    250   TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true));
    251   parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica);
    252   TF_ASSERT_OK(parse_status.status());
    253   ASSERT_TRUE(parse_status.ValueOrDie().has_value());
    254   EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0));
    255 }
    256 
    257 }  // namespace
    258 }  // namespace tensorflow
    259