1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/testlib.h" 17 18 #include <vector> 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/node_def_builder.h" 21 #include "tensorflow/core/framework/node_def_util.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/framework/types.pb.h" 26 #include "tensorflow/core/graph/graph.h" 27 #include "tensorflow/core/graph/node_builder.h" 28 #include "tensorflow/core/kernels/constant_op.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 32 namespace tensorflow { 33 34 // HostConst: forced to generate output on the host. 35 // Only used by testlib; no op is registered for this kernel 36 // externally (i.e., in array_ops.cc) 37 REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp); 38 REGISTER_KERNEL_BUILDER( 39 Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp); 40 #ifdef TENSORFLOW_USE_SYCL 41 REGISTER_KERNEL_BUILDER( 42 Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp); 43 #endif // TENSORFLOW_USE_SYCL 44 45 // Register the HostConst Op 46 // Returns a constant tensor on the host. Useful for writing C++ tests 47 // and benchmarks which run on GPU but require arguments pinned to the host. 48 // Used by test::graph::HostConstant. 49 // value: Attr `value` is the tensor to return. 50 REGISTER_OP("HostConst") 51 .Output("output: dtype") 52 .Attr("value: tensor") 53 .Attr("dtype: type"); 54 55 namespace test { 56 namespace graph { 57 58 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, 59 const uint64 sender_incarnation, const string& receiver) { 60 Node* ret; 61 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send") 62 .Input(input, 0) 63 .Attr("tensor_name", tensor) 64 .Attr("send_device", sender) 65 .Attr("send_device_incarnation", 66 static_cast<int64>(sender_incarnation)) 67 .Attr("recv_device", receiver) 68 .Finalize(g, &ret)); 69 return ret; 70 } 71 72 Node* Recv(Graph* g, const string& tensor, const string& type, 73 const string& sender, const uint64 sender_incarnation, 74 const string& receiver) { 75 Node* ret; 76 DataType dtype; 77 CHECK(DataTypeFromString(type, &dtype)); 78 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv") 79 .Attr("tensor_type", dtype) 80 .Attr("tensor_name", tensor) 81 .Attr("send_device", sender) 82 .Attr("send_device_incarnation", 83 static_cast<int64>(sender_incarnation)) 84 .Attr("recv_device", receiver) 85 .Finalize(g, &ret)); 86 return ret; 87 } 88 89 Node* Constant(Graph* g, const Tensor& tensor) { 90 Node* ret; 91 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const") 92 .Attr("dtype", tensor.dtype()) 93 .Attr("value", tensor) 94 .Finalize(g, &ret)); 95 return ret; 96 } 97 98 Node* Constant(Graph* g, const Tensor& tensor, const string& name) { 99 Node* ret; 100 TF_CHECK_OK(NodeBuilder(name, "Const") 101 .Attr("dtype", tensor.dtype()) 102 .Attr("value", tensor) 103 .Finalize(g, &ret)); 104 return ret; 105 } 106 107 Node* HostConstant(Graph* g, const Tensor& tensor) { 108 return HostConstant(g, tensor, g->NewName("n")); 109 } 110 111 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) { 112 Node* ret; 113 TF_CHECK_OK(NodeBuilder(name, "HostConst") 114 .Attr("dtype", tensor.dtype()) 115 .Attr("value", tensor) 116 .Finalize(g, &ret)); 117 return ret; 118 } 119 120 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) { 121 Node* ret; 122 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable") 123 .Attr("dtype", dtype) 124 .Attr("shape", shape) 125 .Finalize(g, &ret)); 126 return ret; 127 } 128 129 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape, 130 const string& name) { 131 Node* ret; 132 TF_CHECK_OK(NodeBuilder(name, "Variable") 133 .Attr("dtype", dtype) 134 .Attr("shape", shape) 135 .Finalize(g, &ret)); 136 return ret; 137 } 138 139 Node* Assign(Graph* g, Node* var, Node* val) { 140 Node* ret; 141 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign") 142 .Input(var) 143 .Input(val) 144 .Attr("use_locking", true) 145 .Finalize(g, &ret)); 146 return ret; 147 } 148 149 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, 150 bool keep_dims) { 151 Node* ret; 152 TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry()) 153 .Input(data) 154 .Input(axes) 155 .Attr("keep_dims", keep_dims) 156 .Finalize(g, &ret)); 157 return ret; 158 } 159 160 Node* QuantizeToUINT8(Graph* g, Node* data) { 161 Node* ret; 162 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize") 163 .Input(data) 164 .Attr("T", DT_QUINT8) 165 .Attr("max_range", 1.0f) 166 .Attr("min_range", -1.0f) 167 .Finalize(g, &ret)); 168 return ret; 169 } 170 171 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, 172 bool transpose_b) { 173 Node* ret; 174 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul") 175 .Input(in0) 176 .Input(in1) 177 .Attr("transpose_a", transpose_a) 178 .Attr("transpose_b", transpose_b) 179 .Finalize(g, &ret)); 180 return ret; 181 } 182 183 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) { 184 Node* ret; 185 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul") 186 .Input(in0) 187 .Input(in1) 188 .Attr("adj_x", adj_x) 189 .Attr("adj_y", adj_y) 190 .Finalize(g, &ret)); 191 return ret; 192 } 193 194 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, 195 DataType dtype) { 196 Node* ret; 197 TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry()) 198 .Input(input) 199 .Attr("dtype", dtype) 200 .Attr("seed", 0) 201 .Finalize(g, &ret)); 202 return ret; 203 } 204 205 Node* RandomUniform(Graph* g, Node* input, DataType dtype) { 206 return RandomNumberGenerator("RandomUniform", g, input, dtype); 207 } 208 209 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) { 210 return RandomNumberGenerator("RandomStandardNormal", g, input, dtype); 211 } 212 213 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) { 214 return RandomNumberGenerator("TruncatedNormal", g, input, dtype); 215 } 216 217 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) { 218 Node* ret; 219 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma") 220 .Input(shape) 221 .Input(alpha) 222 .Attr("seed", 0) 223 .Finalize(g, &ret)); 224 return ret; 225 } 226 227 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) { 228 Node* ret; 229 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson") 230 .Input(shape) 231 .Input(lam) 232 .Attr("seed", 0) 233 .Finalize(g, &ret)); 234 return ret; 235 } 236 237 Node* Unary(Graph* g, const string& func, Node* input, int index) { 238 Node* ret; 239 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 240 .Input(input, index) 241 .Finalize(g, &ret)); 242 return ret; 243 } 244 245 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { 246 Node* ret; 247 TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) 248 .Input(in0) 249 .Input(in1) 250 .Finalize(g, &ret)); 251 return ret; 252 } 253 254 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { 255 Node* ret; 256 auto b = NodeBuilder(g->NewName("n"), func, g->op_registry()); 257 for (Node* n : ins) b = b.Input(n); 258 TF_CHECK_OK(b.Finalize(g, &ret)); 259 return ret; 260 } 261 262 Node* Identity(Graph* g, Node* input, int index) { 263 Node* ret; 264 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity") 265 .Input(input, index) 266 .Finalize(g, &ret)); 267 return ret; 268 } 269 270 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); } 271 272 Node* Reverse(Graph* g, Node* tensor, Node* axis) { 273 return Binary(g, "ReverseV2", tensor, axis); 274 } 275 276 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) { 277 Node* ret; 278 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry()) 279 .Input(input) 280 .Input(shift) 281 .Input(axis) 282 .Finalize(g, &ret)); 283 return ret; 284 } 285 286 Node* Error(Graph* g, Node* input, const string& errmsg) { 287 Node* ret; 288 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error") 289 .Input(input) 290 .Attr("message", errmsg) 291 .Finalize(g, &ret)); 292 return ret; 293 } 294 295 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) { 296 DCHECK(out_type != invalid_type); 297 Node* ret; 298 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType") 299 .Attr("TIn", out_type) 300 .Attr("TOut", invalid_type) 301 .Finalize(g, &ret)); 302 return ret; 303 } 304 305 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) { 306 Node* ret; 307 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay") 308 .Input(input) 309 .Attr("micros", delay_micros.value()) 310 .Finalize(g, &ret)); 311 return ret; 312 } 313 314 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) { 315 Node* ret; 316 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp") 317 .ControlInputs(control_inputs) 318 .Finalize(g, &ret)); 319 return ret; 320 } 321 322 Node* Switch(Graph* g, Node* in0, Node* in1) { 323 Node* ret; 324 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch") 325 .Input(in0) 326 .Input(in1) 327 .Finalize(g, &ret)); 328 return ret; 329 } 330 331 Node* Enter(Graph* g, Node* input, const string& frame_name) { 332 Node* ret; 333 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter") 334 .Input(input) 335 .Attr("frame_name", frame_name) 336 .Finalize(g, &ret)); 337 return ret; 338 } 339 340 Node* Exit(Graph* g, Node* input) { 341 Node* ret; 342 TF_CHECK_OK( 343 NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret)); 344 return ret; 345 } 346 347 Node* Merge(Graph* g, Node* in0, Node* in1) { 348 Node* ret; 349 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge") 350 .Input({in0, in1}) 351 .Finalize(g, &ret)); 352 return ret; 353 } 354 355 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) { 356 std::vector<NodeBuilder::NodeOut> inputs; 357 inputs.reserve(remaining_in.size() + 1); 358 inputs.emplace_back(in0); 359 for (const string& in_name : remaining_in) { 360 inputs.emplace_back(in_name, 0, inputs[0].dt); 361 } 362 363 Node* ret; 364 TF_CHECK_OK( 365 NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret)); 366 return ret; 367 } 368 369 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) { 370 std::vector<NodeBuilder::NodeOut> nodeouts; 371 nodeouts.reserve(tensors.size()); 372 for (auto const t : tensors) { 373 nodeouts.emplace_back(t); 374 } 375 Node* ret; 376 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat") 377 .Input(concat_dim) 378 .Input(nodeouts) 379 .Finalize(g, &ret)); 380 return ret; 381 } 382 383 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) { 384 std::vector<NodeBuilder::NodeOut> nodeouts; 385 nodeouts.reserve(tensors.size()); 386 for (auto const t : tensors) { 387 nodeouts.emplace_back(t); 388 } 389 Node* ret; 390 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2") 391 .Input(nodeouts) 392 .Input(concat_dim) 393 .Finalize(g, &ret)); 394 return ret; 395 } 396 397 Node* Next(Graph* g, const string& name, Node* input) { 398 Node* ret; 399 TF_CHECK_OK( 400 NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret)); 401 return ret; 402 } 403 404 Node* LoopCond(Graph* g, Node* input) { 405 Node* ret; 406 TF_CHECK_OK( 407 NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret)); 408 return ret; 409 } 410 411 Node* Less(Graph* g, Node* in0, Node* in1) { 412 return Binary(g, "Less", in0, in1); 413 } 414 415 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) { 416 Node* ret; 417 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select") 418 .Input(c) 419 .Input(inx) 420 .Input(iny) 421 .Finalize(g, &ret)); 422 return ret; 423 } 424 425 Node* Cast(Graph* g, Node* in, DataType dst) { 426 Node* ret; 427 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast") 428 .Input(in) 429 .Attr("DstT", dst) 430 .Finalize(g, &ret)); 431 return ret; 432 } 433 434 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) { 435 Node* ret; 436 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2") 437 .Input(in0) 438 .Input(in1) 439 .Input(axis) 440 .Finalize(g, &ret)); 441 return ret; 442 } 443 444 Node* GetSessionTensor(Graph* g, Node* in) { 445 Node* ret; 446 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor") 447 .Input(in, 0) 448 .Attr("dtype", DT_FLOAT) 449 .Finalize(g, &ret)); 450 return ret; 451 } 452 453 Node* Relu(Graph* g, Node* in) { 454 Node* ret; 455 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu") 456 .Input(in, 0) 457 .Attr("T", DT_FLOAT) 458 .Finalize(g, &ret)); 459 return ret; 460 } 461 462 Node* Relu6(Graph* g, Node* in) { 463 Node* ret; 464 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6") 465 .Input(in, 0) 466 .Attr("T", DT_FLOAT) 467 .Finalize(g, &ret)); 468 return ret; 469 } 470 471 Node* BiasAdd(Graph* g, Node* value, Node* bias) { 472 Node* ret; 473 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd") 474 .Input(value) 475 .Input(bias) 476 .Attr("T", DT_FLOAT) 477 .Finalize(g, &ret)); 478 return ret; 479 } 480 481 Node* Conv2D(Graph* g, Node* in0, Node* in1) { 482 Node* ret; 483 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D") 484 .Input(in0) 485 .Input(in1) 486 .Attr("T", DT_FLOAT) 487 .Attr("strides", {1, 1, 1, 1}) 488 .Attr("padding", "SAME") 489 .Finalize(g, &ret)); 490 return ret; 491 } 492 493 Node* Diag(Graph* g, Node* in, DataType type) { 494 Node* ret; 495 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag") 496 .Input(in) 497 .Attr("T", type) 498 .Finalize(g, &ret)); 499 return ret; 500 } 501 502 Node* DiagPart(Graph* g, Node* in, DataType type) { 503 Node* ret; 504 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart") 505 .Input(in) 506 .Attr("T", type) 507 .Finalize(g, &ret)); 508 return ret; 509 } 510 511 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } 512 513 } // end namespace graph 514 } // end namespace test 515 } // end namespace tensorflow 516