Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/function_testlib.h"
     17 
     18 #include "tensorflow/core/framework/function.h"
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/framework/versions.pb.h"
     22 #include "tensorflow/core/lib/core/threadpool.h"
     23 #include "tensorflow/core/public/version.h"
     24 
     25 namespace tensorflow {
     26 namespace test {
     27 namespace function {
     28 
     29 typedef FunctionDefHelper FDH;
     30 
     31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
     32               gtl::ArraySlice<FunctionDef> funcs) {
     33   GraphDef g;
     34   VersionDef* versions = g.mutable_versions();
     35   versions->set_producer(TF_GRAPH_DEF_VERSION);
     36   versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
     37   for (const auto& n : nodes) {
     38     *(g.add_node()) = n;
     39   }
     40   auto lib = g.mutable_library();
     41   for (const auto& f : funcs) {
     42     *(lib->add_function()) = f;
     43   }
     44   return g;
     45 }
     46 
     47 // Helper to construct a NodeDef.
     48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
     49              gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
     50              const string& device) {
     51   NodeDef n;
     52   n.set_name(string(name));
     53   n.set_op(string(op));
     54   for (const auto& in : inputs) n.add_input(in);
     55   n.set_device(device);
     56   for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
     57   return n;
     58 }
     59 
     60 FunctionDef NonZero() {
     61   return FDH::Define(
     62       // Name
     63       "NonZero",
     64       // Args
     65       {"x:T"},
     66       // Return values
     67       {"y:T"},
     68       // Attr def
     69       {"T:{float, double, int32, int64, string}"},
     70       // Nodes
     71       {
     72           {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
     73       });
     74 }
     75 
     76 FunctionDef IsZero() {
     77   const Tensor kZero = test::AsScalar<int64>(0);
     78   return FDH::Define(
     79       // Name
     80       "IsZero",
     81       // Args
     82       {"x: T"},
     83       // Return values
     84       {"equal: T"},
     85       // Attr def
     86       {"T:{float, double, int32, int64, string}"},
     87       {
     88           {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}},
     89           {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
     90           {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}},
     91       });
     92 }
     93 
     94 FunctionDef RandomUniform() {
     95   const Tensor kZero = test::AsScalar<int64>(0);
     96 
     97   return FDH::Define(
     98       // Name
     99       "RandomUniform",
    100       // Args
    101       {"x: T"},
    102       // Return values
    103       {"random_uniform: int64"},
    104       // Attr def
    105       {"T:{float, double, int32, int64, string}"},
    106       {{{"random_uniform/shape"},
    107         "Const",
    108         {},
    109         {{"value", kZero}, {"dtype", DT_INT64}}},
    110        {{"random_uniform"},
    111         "RandomUniform",
    112         {"random_uniform/shape"},
    113         {{"T", DT_INT32},
    114          {"Tout", DT_FLOAT},
    115          {"seed", 87654321},
    116          {"seed2", 42}}}});
    117 }
    118 
    119 FunctionDef XTimesTwo() {
    120   const Tensor kTwo = test::AsScalar<int64>(2);
    121   return FDH::Define(
    122       // Name
    123       "XTimesTwo",
    124       // Args
    125       {"x: T"},
    126       // Return values
    127       {"y: T"},
    128       // Attr def
    129       {"T: {float, double, int32, int64}"},
    130       // Nodes
    131       {
    132           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
    133           {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    134           {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
    135       });
    136 }
    137 
    138 FunctionDef TwoDeviceMult() {
    139   const Tensor kTwo = test::AsScalar<int64>(2);
    140   const Tensor kThree = test::AsScalar<int64>(3);
    141   return FDH::Create(
    142       // Name
    143       "TwoDeviceMult",
    144       // Args
    145       {"x: T"},
    146       // Return values
    147       {"y_cpu: T", "y_gpu: T"},
    148       // Attr def
    149       {"T: {float, double, int32, int64}"},
    150       // Nodes
    151       {
    152           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
    153           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}},
    154           {{"factor_2"},
    155            "Cast",
    156            {"num_2:output:0"},
    157            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    158           {{"factor_3"},
    159            "Cast",
    160            {"num_3:output:0"},
    161            {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    162           {{"y_cpu"},
    163            "Mul",
    164            {"x", "factor_2:y:0"},
    165            {{"T", "$T"}},
    166            {},
    167            "/device:CPU:0"},
    168           {{"y_gpu"},
    169            "Mul",
    170            {"x", "factor_3:y:0"},
    171            {{"T", "$T"}},
    172            {},
    173            "/device:GPU:0"},
    174       },
    175       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
    176 }
    177 
    178 FunctionDef TwoDeviceInputOutput() {
    179   const Tensor kTwo = test::AsScalar<float>(2);
    180   const Tensor kThree = test::AsScalar<float>(3);
    181   return FDH::Create(
    182       // Name
    183       "TwoDeviceInputOutput",
    184       // Args
    185       {"x1: T", "x2: T"},
    186       // Return values
    187       {"y_cpu: T", "y_gpu: T"},
    188       // Attr def
    189       {"T: {float}"},
    190       // Nodes
    191       {
    192           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
    193           {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}},
    194           {{"y_cpu"},
    195            "Mul",
    196            {"x1", "num_2:output:0"},
    197            {{"T", "$T"}},
    198            {},
    199            "/device:CPU:0"},
    200           {{"y_gpu"},
    201            "Mul",
    202            {"x2", "num_3:output:0"},
    203            {{"T", "$T"}},
    204            {},
    205            "/device:GPU:0"},
    206       },
    207       {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}});
    208 }
    209 
    210 FunctionDef FuncWithListInput() {
    211   const Tensor kTwo = test::AsScalar<float>(2);
    212   return FDH::Create(
    213       // Name
    214       "FuncWithListInput",
    215       // Args
    216       {"x1: N * T"},
    217       // Return values
    218       {},
    219       // Attr def
    220       {"T: {float}", "N: int >= 1"},
    221       // Nodes
    222       {
    223           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
    224       },
    225       {});
    226 }
    227 
    228 FunctionDef FuncWithListOutput() {
    229   const Tensor kTwo = test::AsScalar<float>(2);
    230   return FDH::Create(
    231       // Name
    232       "FuncWithListOutput",
    233       // Args
    234       {},
    235       // Return values
    236       {"y: N * T"},
    237       // Attr def
    238       {"T: {float}", "N: int >= 1"},
    239       // Nodes
    240       {
    241           {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
    242       },
    243       {{"y", "num_2:output:0"}});
    244 }
    245 
    246 FunctionDef XAddX() {
    247   return FDH::Define(
    248       // Name
    249       "XAddX",
    250       // Args
    251       {"x: T"},
    252       // Return values
    253       {"y: T"},
    254       // Attr def
    255       {"T: {float, double, int32, int64}"},
    256       // Nodes
    257       {
    258           {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
    259       });
    260 }
    261 
    262 FunctionDef XTimesTwoInt32() {
    263   const Tensor kTwo = test::AsScalar<int64>(2);
    264   return FDH::Define(
    265       // Name
    266       "XTimesTwoInt32",
    267       // Args
    268       {"x: int32"},
    269       // Return values
    270       {"y: int32"}, {},
    271       // Nodes
    272       {
    273           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
    274           {{"scale"},
    275            "Cast",
    276            {"two"},
    277            {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}},
    278           {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}},
    279       });
    280 }
    281 
    282 FunctionDef XTimesFour() {
    283   return FDH::Create(
    284       // Name
    285       "XTimesFour",
    286       // Args
    287       {"x: T"},
    288       // Return values
    289       {"y: T"},
    290       // Attr def
    291       {"T: {float, double, int32, int64}"},
    292       // Nodes
    293       {
    294           {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
    295           {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
    296       },
    297       {{"y", "y:y:0"}});
    298 }
    299 
    300 FunctionDef XTimes16() {
    301   return FDH::Create(
    302       // Name
    303       "XTimes16",
    304       // Args
    305       {"x: T"},
    306       // Return values
    307       {"y: T"},
    308       // Attr def
    309       {"T: {float, double, int32, int64}"},
    310       // Nodes
    311       {
    312           {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
    313           {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
    314       },
    315       {{"y", "y:y:0"}});
    316 }
    317 
    318 FunctionDef WXPlusB() {
    319   return FDH::Define(
    320       // Name
    321       "WXPlusB",
    322       // Args
    323       {"w: T", "x: T", "b: T"},
    324       // Return values
    325       {"y: T"},
    326       // Attr def
    327       {"T: {float, double}"},
    328       // Nodes
    329       {{{"mm"},
    330         "MatMul",
    331         {"w", "x"},
    332         {{"T", "$T"},
    333          {"transpose_a", false},
    334          {"transpose_b", false},
    335          {"_kernel", "eigen"}}},
    336        {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
    337 }
    338 
    339 FunctionDef Swap() {
    340   return FDH::Define(
    341       // Name
    342       "Swap",
    343       // Args
    344       {"i0: T", "i1: T"},
    345       // Return values
    346       {"o0: T", "o1: T"},
    347       // Attr def
    348       {"T: {float, double}"},
    349       // Nodes
    350       {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
    351        {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
    352 }
    353 
    354 FunctionDef EmptyBodySwap() {
    355   return FDH::Create(
    356       // Name
    357       "EmptyBodySwap",
    358       // Args
    359       {"i0: T", "i1: T"},
    360       // Return values
    361       {"o0: T", "o1: T"},
    362       // Attr def
    363       {"T: {float, double}"},
    364       // Nodes
    365       {},
    366       // Output mapping
    367       {{"o0", "i1"}, {"o1", "i0"}});
    368 }
    369 
    370 FunctionDef ResourceOutput() {
    371   const Tensor kTwo = test::AsScalar<float>(2);
    372   return FDH::Create(
    373       // Name
    374       "ResourceOutput",
    375       // Args
    376       {"x: float", "y: resource"},
    377       // Return values
    378       {"y_out: resource", "two_x: float"},
    379       // Attr def
    380       {},
    381       // Nodes
    382       {
    383           {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
    384           {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}},
    385       },
    386       {{"y_out", "y"}, {"two_x", "mul:z:0"}});
    387 }
    388 
    389 FunctionDef ReadResourceVariable() {
    390   return FDH::Create(
    391       // Name
    392       "ReadResourceVariable",
    393       // Args
    394       {"x: resource"},
    395       // Return values
    396       {"y: float"},
    397       // Attr def
    398       {},
    399       // Nodes
    400       {
    401           {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}},
    402       },
    403       {{"y", "read:value:0"}});
    404 }
    405 
    406 FunctionDef InvalidControlFlow() {
    407   return FDH::Create(
    408       // Name
    409       "InvalidControlFlow",
    410       // Args
    411       {"i: int32"},
    412       // Return values
    413       {"o: int32"},
    414       // Attr def
    415       {},
    416       // Nodes
    417       {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}},
    418        {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}},
    419       // Output mapping
    420       {{"o", "add:z"}});
    421 }
    422 
    423 FunctionDef LessThanOrEqualToN(int64 N) {
    424   const Tensor kN = test::AsScalar<int64>(N);
    425   return FDH::Define(
    426       // Name
    427       "LessThanOrEqualToN",
    428       // Args
    429       {"x: T"},
    430       // Return values
    431       {"z: bool"},
    432       // Attr def
    433       {"T: {float, double, int32, int64}"},
    434       // Nodes
    435       {
    436           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
    437           {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    438           {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
    439       });
    440 }
    441 
    442 FunctionDef XPlusOneXTimesY() {
    443   const Tensor kOne = test::AsScalar<int64>(1);
    444   return FDH::Define(
    445       // Name
    446       "XPlusOneXTimesY",
    447       // Args
    448       {"x: T", "y: T"},
    449       // Return values
    450       {"s: T", "t: T"},
    451       // Attr def
    452       {"T: {float, double, int32, int64}"},
    453       // Nodes
    454       {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
    455        {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    456        {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
    457        {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
    458 }
    459 
    460 FunctionDef XYXLessThanOrEqualToN(int64 N) {
    461   const Tensor kN = test::AsScalar<int64>(N);
    462   return FDH::Define(
    463       // Name
    464       "XYXLessThanOrEqualToN",
    465       // Args
    466       {"x: T", "y: T"},
    467       // Return values
    468       {"z: bool"},
    469       // Attr def
    470       {"T: {float, double, int32, int64}"},
    471       // Nodes
    472       {
    473           {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
    474           {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
    475           {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
    476       });
    477 }
    478 
    479 void FunctionTestSchedClosure(std::function<void()> fn) {
    480   static thread::ThreadPool* w =
    481       new thread::ThreadPool(Env::Default(), "Test", 8);
    482   w->Schedule(std::move(fn));
    483 }
    484 
    485 }  // end namespace function
    486 }  // end namespace test
    487 }  // end namespace tensorflow
    488