1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/function_testlib.h" 17 18 #include "tensorflow/core/framework/function.h" 19 #include "tensorflow/core/framework/node_def.pb.h" 20 #include "tensorflow/core/framework/tensor_testutil.h" 21 #include "tensorflow/core/framework/versions.pb.h" 22 #include "tensorflow/core/lib/core/threadpool.h" 23 #include "tensorflow/core/public/version.h" 24 25 namespace tensorflow { 26 namespace test { 27 namespace function { 28 29 typedef FunctionDefHelper FDH; 30 31 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, 32 gtl::ArraySlice<FunctionDef> funcs) { 33 GraphDef g; 34 VersionDef* versions = g.mutable_versions(); 35 versions->set_producer(TF_GRAPH_DEF_VERSION); 36 versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); 37 for (const auto& n : nodes) { 38 *(g.add_node()) = n; 39 } 40 auto lib = g.mutable_library(); 41 for (const auto& f : funcs) { 42 *(lib->add_function()) = f; 43 } 44 return g; 45 } 46 47 // Helper to construct a NodeDef. 48 NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs, 49 gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs, 50 const string& device) { 51 NodeDef n; 52 n.set_name(string(name)); 53 n.set_op(string(op)); 54 for (const auto& in : inputs) n.add_input(in); 55 n.set_device(device); 56 for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto}); 57 return n; 58 } 59 60 FunctionDef NonZero() { 61 return FDH::Define( 62 // Name 63 "NonZero", 64 // Args 65 {"x:T"}, 66 // Return values 67 {"y:T"}, 68 // Attr def 69 {"T:{float, double, int32, int64, string}"}, 70 // Nodes 71 { 72 {{"y"}, "Identity", {"x"}, {{"T", "$T"}}}, 73 }); 74 } 75 76 FunctionDef IsZero() { 77 const Tensor kZero = test::AsScalar<int64>(0); 78 return FDH::Define( 79 // Name 80 "IsZero", 81 // Args 82 {"x: T"}, 83 // Return values 84 {"equal: T"}, 85 // Attr def 86 {"T:{float, double, int32, int64, string}"}, 87 { 88 {{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}}, 89 {{"cast"}, "Cast", {"zero"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 90 {{"equal"}, "Equal", {"x", "cast"}, {{"T", "$T"}}}, 91 }); 92 } 93 94 FunctionDef RandomUniform() { 95 const Tensor kZero = test::AsScalar<int64>(0); 96 97 return FDH::Define( 98 // Name 99 "RandomUniform", 100 // Args 101 {"x: T"}, 102 // Return values 103 {"random_uniform: int64"}, 104 // Attr def 105 {"T:{float, double, int32, int64, string}"}, 106 {{{"random_uniform/shape"}, 107 "Const", 108 {}, 109 {{"value", kZero}, {"dtype", DT_INT64}}}, 110 {{"random_uniform"}, 111 "RandomUniform", 112 {"random_uniform/shape"}, 113 {{"T", DT_INT32}, 114 {"Tout", DT_FLOAT}, 115 {"seed", 87654321}, 116 {"seed2", 42}}}}); 117 } 118 119 FunctionDef XTimesTwo() { 120 const Tensor kTwo = test::AsScalar<int64>(2); 121 return FDH::Define( 122 // Name 123 "XTimesTwo", 124 // Args 125 {"x: T"}, 126 // Return values 127 {"y: T"}, 128 // Attr def 129 {"T: {float, double, int32, int64}"}, 130 // Nodes 131 { 132 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, 133 {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 134 {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, 135 }); 136 } 137 138 FunctionDef TwoDeviceMult() { 139 const Tensor kTwo = test::AsScalar<int64>(2); 140 const Tensor kThree = test::AsScalar<int64>(3); 141 return FDH::Create( 142 // Name 143 "TwoDeviceMult", 144 // Args 145 {"x: T"}, 146 // Return values 147 {"y_cpu: T", "y_gpu: T"}, 148 // Attr def 149 {"T: {float, double, int32, int64}"}, 150 // Nodes 151 { 152 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, 153 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_INT64}}}, 154 {{"factor_2"}, 155 "Cast", 156 {"num_2:output:0"}, 157 {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 158 {{"factor_3"}, 159 "Cast", 160 {"num_3:output:0"}, 161 {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 162 {{"y_cpu"}, 163 "Mul", 164 {"x", "factor_2:y:0"}, 165 {{"T", "$T"}}, 166 {}, 167 "/device:CPU:0"}, 168 {{"y_gpu"}, 169 "Mul", 170 {"x", "factor_3:y:0"}, 171 {{"T", "$T"}}, 172 {}, 173 "/device:GPU:0"}, 174 }, 175 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}}); 176 } 177 178 FunctionDef TwoDeviceInputOutput() { 179 const Tensor kTwo = test::AsScalar<float>(2); 180 const Tensor kThree = test::AsScalar<float>(3); 181 return FDH::Create( 182 // Name 183 "TwoDeviceInputOutput", 184 // Args 185 {"x1: T", "x2: T"}, 186 // Return values 187 {"y_cpu: T", "y_gpu: T"}, 188 // Attr def 189 {"T: {float}"}, 190 // Nodes 191 { 192 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}}, 193 {{"num_3"}, "Const", {}, {{"value", kThree}, {"dtype", DT_FLOAT}}}, 194 {{"y_cpu"}, 195 "Mul", 196 {"x1", "num_2:output:0"}, 197 {{"T", "$T"}}, 198 {}, 199 "/device:CPU:0"}, 200 {{"y_gpu"}, 201 "Mul", 202 {"x2", "num_3:output:0"}, 203 {{"T", "$T"}}, 204 {}, 205 "/device:GPU:0"}, 206 }, 207 {{"y_cpu", "y_cpu:z:0"}, {"y_gpu", "y_gpu:z:0"}}); 208 } 209 210 FunctionDef FuncWithListInput() { 211 const Tensor kTwo = test::AsScalar<float>(2); 212 return FDH::Create( 213 // Name 214 "FuncWithListInput", 215 // Args 216 {"x1: N * T"}, 217 // Return values 218 {}, 219 // Attr def 220 {"T: {float}", "N: int >= 1"}, 221 // Nodes 222 { 223 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}}, 224 }, 225 {}); 226 } 227 228 FunctionDef FuncWithListOutput() { 229 const Tensor kTwo = test::AsScalar<float>(2); 230 return FDH::Create( 231 // Name 232 "FuncWithListOutput", 233 // Args 234 {}, 235 // Return values 236 {"y: N * T"}, 237 // Attr def 238 {"T: {float}", "N: int >= 1"}, 239 // Nodes 240 { 241 {{"num_2"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}}, 242 }, 243 {{"y", "num_2:output:0"}}); 244 } 245 246 FunctionDef XAddX() { 247 return FDH::Define( 248 // Name 249 "XAddX", 250 // Args 251 {"x: T"}, 252 // Return values 253 {"y: T"}, 254 // Attr def 255 {"T: {float, double, int32, int64}"}, 256 // Nodes 257 { 258 {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}}, 259 }); 260 } 261 262 FunctionDef XTimesTwoInt32() { 263 const Tensor kTwo = test::AsScalar<int64>(2); 264 return FDH::Define( 265 // Name 266 "XTimesTwoInt32", 267 // Args 268 {"x: int32"}, 269 // Return values 270 {"y: int32"}, {}, 271 // Nodes 272 { 273 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, 274 {{"scale"}, 275 "Cast", 276 {"two"}, 277 {{"SrcT", DT_INT64}, {"DstT", DT_INT32}}}, 278 {{"y"}, "Mul", {"x", "scale"}, {{"T", DT_INT32}}}, 279 }); 280 } 281 282 FunctionDef XTimesFour() { 283 return FDH::Create( 284 // Name 285 "XTimesFour", 286 // Args 287 {"x: T"}, 288 // Return values 289 {"y: T"}, 290 // Attr def 291 {"T: {float, double, int32, int64}"}, 292 // Nodes 293 { 294 {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, 295 {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}}, 296 }, 297 {{"y", "y:y:0"}}); 298 } 299 300 FunctionDef XTimes16() { 301 return FDH::Create( 302 // Name 303 "XTimes16", 304 // Args 305 {"x: T"}, 306 // Return values 307 {"y: T"}, 308 // Attr def 309 {"T: {float, double, int32, int64}"}, 310 // Nodes 311 { 312 {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, 313 {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}}, 314 }, 315 {{"y", "y:y:0"}}); 316 } 317 318 FunctionDef WXPlusB() { 319 return FDH::Define( 320 // Name 321 "WXPlusB", 322 // Args 323 {"w: T", "x: T", "b: T"}, 324 // Return values 325 {"y: T"}, 326 // Attr def 327 {"T: {float, double}"}, 328 // Nodes 329 {{{"mm"}, 330 "MatMul", 331 {"w", "x"}, 332 {{"T", "$T"}, 333 {"transpose_a", false}, 334 {"transpose_b", false}, 335 {"_kernel", "eigen"}}}, 336 {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}}); 337 } 338 339 FunctionDef Swap() { 340 return FDH::Define( 341 // Name 342 "Swap", 343 // Args 344 {"i0: T", "i1: T"}, 345 // Return values 346 {"o0: T", "o1: T"}, 347 // Attr def 348 {"T: {float, double}"}, 349 // Nodes 350 {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, 351 {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); 352 } 353 354 FunctionDef EmptyBodySwap() { 355 return FDH::Create( 356 // Name 357 "EmptyBodySwap", 358 // Args 359 {"i0: T", "i1: T"}, 360 // Return values 361 {"o0: T", "o1: T"}, 362 // Attr def 363 {"T: {float, double}"}, 364 // Nodes 365 {}, 366 // Output mapping 367 {{"o0", "i1"}, {"o1", "i0"}}); 368 } 369 370 FunctionDef ResourceOutput() { 371 const Tensor kTwo = test::AsScalar<float>(2); 372 return FDH::Create( 373 // Name 374 "ResourceOutput", 375 // Args 376 {"x: float", "y: resource"}, 377 // Return values 378 {"y_out: resource", "two_x: float"}, 379 // Attr def 380 {}, 381 // Nodes 382 { 383 {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}}, 384 {{"mul"}, "Mul", {"x", "two:output:0"}, {{"T", DT_FLOAT}}, {}}, 385 }, 386 {{"y_out", "y"}, {"two_x", "mul:z:0"}}); 387 } 388 389 FunctionDef ReadResourceVariable() { 390 return FDH::Create( 391 // Name 392 "ReadResourceVariable", 393 // Args 394 {"x: resource"}, 395 // Return values 396 {"y: float"}, 397 // Attr def 398 {}, 399 // Nodes 400 { 401 {{"read"}, "ReadVariableOp", {"x"}, {{"dtype", DT_FLOAT}}, {}}, 402 }, 403 {{"y", "read:value:0"}}); 404 } 405 406 FunctionDef InvalidControlFlow() { 407 return FDH::Create( 408 // Name 409 "InvalidControlFlow", 410 // Args 411 {"i: int32"}, 412 // Return values 413 {"o: int32"}, 414 // Attr def 415 {}, 416 // Nodes 417 {{{"enter"}, "Enter", {"i"}, {{"T", DT_INT32}, {"frame_name", "while"}}}, 418 {{"add"}, "Add", {"enter:output", "i"}, {{"T", DT_INT32}}}}, 419 // Output mapping 420 {{"o", "add:z"}}); 421 } 422 423 FunctionDef LessThanOrEqualToN(int64 N) { 424 const Tensor kN = test::AsScalar<int64>(N); 425 return FDH::Define( 426 // Name 427 "LessThanOrEqualToN", 428 // Args 429 {"x: T"}, 430 // Return values 431 {"z: bool"}, 432 // Attr def 433 {"T: {float, double, int32, int64}"}, 434 // Nodes 435 { 436 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}}, 437 {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 438 {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}}, 439 }); 440 } 441 442 FunctionDef XPlusOneXTimesY() { 443 const Tensor kOne = test::AsScalar<int64>(1); 444 return FDH::Define( 445 // Name 446 "XPlusOneXTimesY", 447 // Args 448 {"x: T", "y: T"}, 449 // Return values 450 {"s: T", "t: T"}, 451 // Attr def 452 {"T: {float, double, int32, int64}"}, 453 // Nodes 454 {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}}, 455 {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 456 {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}}, 457 {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}); 458 } 459 460 FunctionDef XYXLessThanOrEqualToN(int64 N) { 461 const Tensor kN = test::AsScalar<int64>(N); 462 return FDH::Define( 463 // Name 464 "XYXLessThanOrEqualToN", 465 // Args 466 {"x: T", "y: T"}, 467 // Return values 468 {"z: bool"}, 469 // Attr def 470 {"T: {float, double, int32, int64}"}, 471 // Nodes 472 { 473 {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}}, 474 {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, 475 {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}}, 476 }); 477 } 478 479 void FunctionTestSchedClosure(std::function<void()> fn) { 480 static thread::ThreadPool* w = 481 new thread::ThreadPool(Env::Default(), "Test", 8); 482 w->Schedule(std::move(fn)); 483 } 484 485 } // end namespace function 486 } // end namespace test 487 } // end namespace tensorflow 488