Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/const_op.h"
     17 #include "tensorflow/cc/ops/image_ops.h"
     18 #include "tensorflow/cc/ops/nn_ops.h"
     19 #include "tensorflow/cc/ops/sendrecv_ops.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/platform/test.h"
     24 #include "tensorflow/core/platform/test_benchmark.h"
     25 #include "tensorflow/core/public/session.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     27 
     28 namespace tensorflow {
     29 namespace graph_transforms {
     30 
     31 // Declare here, so we don't need a public header.
     32 Status SetDevice(const GraphDef& input_graph_def,
     33                  const TransformFuncContext& context,
     34                  GraphDef* output_graph_def);
     35 
     36 namespace {
     37 GraphDef CreateDeviceGraph() {
     38   GraphDef graph_def;
     39 
     40   NodeDef* mul_node1 = graph_def.add_node();
     41   mul_node1->set_name("mul_node1");
     42   mul_node1->set_op("Mul");
     43   mul_node1->set_device("/device:CPU:0");
     44   mul_node1->add_input("add_node2");
     45   mul_node1->add_input("add_node3");
     46 
     47   NodeDef* add_node2 = graph_def.add_node();
     48   add_node2->set_name("add_node2");
     49   add_node2->set_op("Add");
     50   add_node2->add_input("const_node1");
     51   add_node2->add_input("const_node2");
     52   add_node2->set_device("/device:GPU:1");
     53 
     54   NodeDef* add_node3 = graph_def.add_node();
     55   add_node3->set_name("add_node3");
     56   add_node3->set_op("Add");
     57   add_node3->add_input("const_node1");
     58   add_node3->add_input("const_node3");
     59 
     60   NodeDef* const_node1 = graph_def.add_node();
     61   const_node1->set_name("const_node1");
     62   const_node1->set_op("Const");
     63 
     64   NodeDef* const_node2 = graph_def.add_node();
     65   const_node2->set_name("const_node2");
     66   const_node2->set_op("Const");
     67 
     68   NodeDef* const_node3 = graph_def.add_node();
     69   const_node3->set_name("const_node3");
     70   const_node3->set_op("Const");
     71 
     72   NodeDef* add_node4 = graph_def.add_node();
     73   add_node4->set_name("add_node4");
     74   add_node4->set_op("Add");
     75   add_node4->add_input("add_node2");
     76   add_node4->add_input("add_node3");
     77 
     78   return graph_def;
     79 }
     80 }  // namespace
     81 
     82 TEST(SetDeviceTest, TestSetDevice) {
     83   GraphDef graph_def = CreateDeviceGraph();
     84   GraphDef result;
     85   TransformFuncContext context;
     86   context.input_names = {};
     87   context.output_names = {"mul_node1"};
     88   context.params.insert(std::pair<string, std::vector<string>>(
     89       {"device", {string("/device:CPU:0")}}));
     90   TF_ASSERT_OK(SetDevice(graph_def, context, &result));
     91 
     92   std::map<string, const NodeDef*> node_lookup;
     93   MapNamesToNodes(result, &node_lookup);
     94   EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
     95   EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node2")->device());
     96   EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node3")->device());
     97   EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node1")->device());
     98   EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node2")->device());
     99   EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node3")->device());
    100   EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node4")->device());
    101 }
    102 
    103 TEST(SetDeviceTest, TestSetDeviceIfDefault) {
    104   GraphDef graph_def = CreateDeviceGraph();
    105   GraphDef result;
    106   TransformFuncContext context;
    107   context.input_names = {};
    108   context.output_names = {"mul_node1"};
    109   context.params.insert(std::pair<string, std::vector<string>>(
    110       {"device", {string("/device:GPU:0")}}));
    111   context.params.insert(
    112       std::pair<string, std::vector<string>>({"if_default", {string("true")}}));
    113   TF_ASSERT_OK(SetDevice(graph_def, context, &result));
    114 
    115   std::map<string, const NodeDef*> node_lookup;
    116   MapNamesToNodes(result, &node_lookup);
    117   EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
    118   EXPECT_EQ("/device:GPU:1", node_lookup.at("add_node2")->device());
    119   EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node3")->device());
    120   EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node1")->device());
    121   EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node2")->device());
    122   EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node3")->device());
    123   EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node4")->device());
    124 }
    125 
    126 }  // namespace graph_transforms
    127 }  // namespace tensorflow
    128