Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/testlib.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/node_def_builder.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/framework/types.pb.h"
     26 #include "tensorflow/core/graph/graph.h"
     27 #include "tensorflow/core/graph/node_builder.h"
     28 #include "tensorflow/core/kernels/constant_op.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 namespace tensorflow {
     33 
     34 // HostConst: forced to generate output on the host.
     35 // Only used by testlib; no op is registered for this kernel
     36 // externally (i.e., in array_ops.cc)
     37 REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
     38 REGISTER_KERNEL_BUILDER(
     39     Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
     40 #ifdef TENSORFLOW_USE_SYCL
     41 REGISTER_KERNEL_BUILDER(
     42     Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
     43 #endif  // TENSORFLOW_USE_SYCL
     44 
     45 // Register the HostConst Op
     46 // Returns a constant tensor on the host.  Useful for writing C++ tests
     47 // and benchmarks which run on GPU but require arguments pinned to the host.
     48 // Used by test::graph::HostConstant.
     49 // value: Attr `value` is the tensor to return.
     50 REGISTER_OP("HostConst")
     51     .Output("output: dtype")
     52     .Attr("value: tensor")
     53     .Attr("dtype: type");
     54 
     55 namespace test {
     56 namespace graph {
     57 
     58 Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
     59            const uint64 sender_incarnation, const string& receiver) {
     60   Node* ret;
     61   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
     62                   .Input(input, 0)
     63                   .Attr("tensor_name", tensor)
     64                   .Attr("send_device", sender)
     65                   .Attr("send_device_incarnation",
     66                         static_cast<int64>(sender_incarnation))
     67                   .Attr("recv_device", receiver)
     68                   .Finalize(g, &ret));
     69   return ret;
     70 }
     71 
     72 Node* Recv(Graph* g, const string& tensor, const string& type,
     73            const string& sender, const uint64 sender_incarnation,
     74            const string& receiver) {
     75   Node* ret;
     76   DataType dtype;
     77   CHECK(DataTypeFromString(type, &dtype));
     78   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
     79                   .Attr("tensor_type", dtype)
     80                   .Attr("tensor_name", tensor)
     81                   .Attr("send_device", sender)
     82                   .Attr("send_device_incarnation",
     83                         static_cast<int64>(sender_incarnation))
     84                   .Attr("recv_device", receiver)
     85                   .Finalize(g, &ret));
     86   return ret;
     87 }
     88 
     89 Node* Constant(Graph* g, const Tensor& tensor) {
     90   Node* ret;
     91   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
     92                   .Attr("dtype", tensor.dtype())
     93                   .Attr("value", tensor)
     94                   .Finalize(g, &ret));
     95   return ret;
     96 }
     97 
     98 Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
     99   Node* ret;
    100   TF_CHECK_OK(NodeBuilder(name, "Const")
    101                   .Attr("dtype", tensor.dtype())
    102                   .Attr("value", tensor)
    103                   .Finalize(g, &ret));
    104   return ret;
    105 }
    106 
    107 Node* HostConstant(Graph* g, const Tensor& tensor) {
    108   return HostConstant(g, tensor, g->NewName("n"));
    109 }
    110 
    111 Node* HostConstant(Graph* g, const Tensor& tensor, const string& name) {
    112   Node* ret;
    113   TF_CHECK_OK(NodeBuilder(name, "HostConst")
    114                   .Attr("dtype", tensor.dtype())
    115                   .Attr("value", tensor)
    116                   .Finalize(g, &ret));
    117   return ret;
    118 }
    119 
    120 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
    121   Node* ret;
    122   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
    123                   .Attr("dtype", dtype)
    124                   .Attr("shape", shape)
    125                   .Finalize(g, &ret));
    126   return ret;
    127 }
    128 
    129 Node* Var(Graph* g, const DataType dtype, const TensorShape& shape,
    130           const string& name) {
    131   Node* ret;
    132   TF_CHECK_OK(NodeBuilder(name, "Variable")
    133                   .Attr("dtype", dtype)
    134                   .Attr("shape", shape)
    135                   .Finalize(g, &ret));
    136   return ret;
    137 }
    138 
    139 Node* Assign(Graph* g, Node* var, Node* val) {
    140   Node* ret;
    141   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
    142                   .Input(var)
    143                   .Input(val)
    144                   .Attr("use_locking", true)
    145                   .Finalize(g, &ret));
    146   return ret;
    147 }
    148 
    149 Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
    150              bool keep_dims) {
    151   Node* ret;
    152   TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
    153                   .Input(data)
    154                   .Input(axes)
    155                   .Attr("keep_dims", keep_dims)
    156                   .Finalize(g, &ret));
    157   return ret;
    158 }
    159 
    160 Node* QuantizeToUINT8(Graph* g, Node* data) {
    161   Node* ret;
    162   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
    163                   .Input(data)
    164                   .Attr("T", DT_QUINT8)
    165                   .Attr("max_range", 1.0f)
    166                   .Attr("min_range", -1.0f)
    167                   .Finalize(g, &ret));
    168   return ret;
    169 }
    170 
    171 Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
    172              bool transpose_b) {
    173   Node* ret;
    174   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
    175                   .Input(in0)
    176                   .Input(in1)
    177                   .Attr("transpose_a", transpose_a)
    178                   .Attr("transpose_b", transpose_b)
    179                   .Finalize(g, &ret));
    180   return ret;
    181 }
    182 
    183 Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y) {
    184   Node* ret;
    185   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BatchMatMul")
    186                   .Input(in0)
    187                   .Input(in1)
    188                   .Attr("adj_x", adj_x)
    189                   .Attr("adj_y", adj_y)
    190                   .Finalize(g, &ret));
    191   return ret;
    192 }
    193 
    194 Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
    195                             DataType dtype) {
    196   Node* ret;
    197   TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
    198                   .Input(input)
    199                   .Attr("dtype", dtype)
    200                   .Attr("seed", 0)
    201                   .Finalize(g, &ret));
    202   return ret;
    203 }
    204 
    205 Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
    206   return RandomNumberGenerator("RandomUniform", g, input, dtype);
    207 }
    208 
    209 Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
    210   return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
    211 }
    212 
    213 Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
    214   return RandomNumberGenerator("TruncatedNormal", g, input, dtype);
    215 }
    216 
    217 Node* RandomGamma(Graph* g, Node* shape, Node* alpha) {
    218   Node* ret;
    219   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomGamma")
    220                   .Input(shape)
    221                   .Input(alpha)
    222                   .Attr("seed", 0)
    223                   .Finalize(g, &ret));
    224   return ret;
    225 }
    226 
    227 Node* RandomPoisson(Graph* g, Node* shape, Node* lam) {
    228   Node* ret;
    229   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "RandomPoisson")
    230                   .Input(shape)
    231                   .Input(lam)
    232                   .Attr("seed", 0)
    233                   .Finalize(g, &ret));
    234   return ret;
    235 }
    236 
    237 Node* Unary(Graph* g, const string& func, Node* input, int index) {
    238   Node* ret;
    239   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
    240                   .Input(input, index)
    241                   .Finalize(g, &ret));
    242   return ret;
    243 }
    244 
    245 Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
    246   Node* ret;
    247   TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
    248                   .Input(in0)
    249                   .Input(in1)
    250                   .Finalize(g, &ret));
    251   return ret;
    252 }
    253 
    254 Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
    255   Node* ret;
    256   auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
    257   for (Node* n : ins) b = b.Input(n);
    258   TF_CHECK_OK(b.Finalize(g, &ret));
    259   return ret;
    260 }
    261 
    262 Node* Identity(Graph* g, Node* input, int index) {
    263   Node* ret;
    264   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
    265                   .Input(input, index)
    266                   .Finalize(g, &ret));
    267   return ret;
    268 }
    269 
    270 Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
    271 
    272 Node* Reverse(Graph* g, Node* tensor, Node* axis) {
    273   return Binary(g, "ReverseV2", tensor, axis);
    274 }
    275 
    276 Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
    277   Node* ret;
    278   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
    279                   .Input(input)
    280                   .Input(shift)
    281                   .Input(axis)
    282                   .Finalize(g, &ret));
    283   return ret;
    284 }
    285 
    286 Node* Error(Graph* g, Node* input, const string& errmsg) {
    287   Node* ret;
    288   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
    289                   .Input(input)
    290                   .Attr("message", errmsg)
    291                   .Finalize(g, &ret));
    292   return ret;
    293 }
    294 
    295 Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
    296   DCHECK(out_type != invalid_type);
    297   Node* ret;
    298   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
    299                   .Attr("TIn", out_type)
    300                   .Attr("TOut", invalid_type)
    301                   .Finalize(g, &ret));
    302   return ret;
    303 }
    304 
    305 Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
    306   Node* ret;
    307   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
    308                   .Input(input)
    309                   .Attr("micros", delay_micros.value())
    310                   .Finalize(g, &ret));
    311   return ret;
    312 }
    313 
    314 Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
    315   Node* ret;
    316   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
    317                   .ControlInputs(control_inputs)
    318                   .Finalize(g, &ret));
    319   return ret;
    320 }
    321 
    322 Node* Switch(Graph* g, Node* in0, Node* in1) {
    323   Node* ret;
    324   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
    325                   .Input(in0)
    326                   .Input(in1)
    327                   .Finalize(g, &ret));
    328   return ret;
    329 }
    330 
    331 Node* Enter(Graph* g, Node* input, const string& frame_name) {
    332   Node* ret;
    333   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
    334                   .Input(input)
    335                   .Attr("frame_name", frame_name)
    336                   .Finalize(g, &ret));
    337   return ret;
    338 }
    339 
    340 Node* Exit(Graph* g, Node* input) {
    341   Node* ret;
    342   TF_CHECK_OK(
    343       NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
    344   return ret;
    345 }
    346 
    347 Node* Merge(Graph* g, Node* in0, Node* in1) {
    348   Node* ret;
    349   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
    350                   .Input({in0, in1})
    351                   .Finalize(g, &ret));
    352   return ret;
    353 }
    354 
    355 Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
    356   std::vector<NodeBuilder::NodeOut> inputs;
    357   inputs.reserve(remaining_in.size() + 1);
    358   inputs.emplace_back(in0);
    359   for (const string& in_name : remaining_in) {
    360     inputs.emplace_back(in_name, 0, inputs[0].dt);
    361   }
    362 
    363   Node* ret;
    364   TF_CHECK_OK(
    365       NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
    366   return ret;
    367 }
    368 
    369 Node* Concat(Graph* g, Node* concat_dim, gtl::ArraySlice<Node*> tensors) {
    370   std::vector<NodeBuilder::NodeOut> nodeouts;
    371   nodeouts.reserve(tensors.size());
    372   for (auto const t : tensors) {
    373     nodeouts.emplace_back(t);
    374   }
    375   Node* ret;
    376   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
    377                   .Input(concat_dim)
    378                   .Input(nodeouts)
    379                   .Finalize(g, &ret));
    380   return ret;
    381 }
    382 
    383 Node* ConcatV2(Graph* g, gtl::ArraySlice<Node*> tensors, Node* concat_dim) {
    384   std::vector<NodeBuilder::NodeOut> nodeouts;
    385   nodeouts.reserve(tensors.size());
    386   for (auto const t : tensors) {
    387     nodeouts.emplace_back(t);
    388   }
    389   Node* ret;
    390   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "ConcatV2")
    391                   .Input(nodeouts)
    392                   .Input(concat_dim)
    393                   .Finalize(g, &ret));
    394   return ret;
    395 }
    396 
    397 Node* Next(Graph* g, const string& name, Node* input) {
    398   Node* ret;
    399   TF_CHECK_OK(
    400       NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
    401   return ret;
    402 }
    403 
    404 Node* LoopCond(Graph* g, Node* input) {
    405   Node* ret;
    406   TF_CHECK_OK(
    407       NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
    408   return ret;
    409 }
    410 
    411 Node* Less(Graph* g, Node* in0, Node* in1) {
    412   return Binary(g, "Less", in0, in1);
    413 }
    414 
    415 Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
    416   Node* ret;
    417   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
    418                   .Input(c)
    419                   .Input(inx)
    420                   .Input(iny)
    421                   .Finalize(g, &ret));
    422   return ret;
    423 }
    424 
    425 Node* Cast(Graph* g, Node* in, DataType dst) {
    426   Node* ret;
    427   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
    428                   .Input(in)
    429                   .Attr("DstT", dst)
    430                   .Finalize(g, &ret));
    431   return ret;
    432 }
    433 
    434 Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis) {
    435   Node* ret;
    436   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherV2")
    437                   .Input(in0)
    438                   .Input(in1)
    439                   .Input(axis)
    440                   .Finalize(g, &ret));
    441   return ret;
    442 }
    443 
    444 Node* GetSessionTensor(Graph* g, Node* in) {
    445   Node* ret;
    446   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GetSessionTensor")
    447                   .Input(in, 0)
    448                   .Attr("dtype", DT_FLOAT)
    449                   .Finalize(g, &ret));
    450   return ret;
    451 }
    452 
    453 Node* Relu(Graph* g, Node* in) {
    454   Node* ret;
    455   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu")
    456                   .Input(in, 0)
    457                   .Attr("T", DT_FLOAT)
    458                   .Finalize(g, &ret));
    459   return ret;
    460 }
    461 
    462 Node* Relu6(Graph* g, Node* in) {
    463   Node* ret;
    464   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Relu6")
    465                   .Input(in, 0)
    466                   .Attr("T", DT_FLOAT)
    467                   .Finalize(g, &ret));
    468   return ret;
    469 }
    470 
    471 Node* BiasAdd(Graph* g, Node* value, Node* bias) {
    472   Node* ret;
    473   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAdd")
    474                   .Input(value)
    475                   .Input(bias)
    476                   .Attr("T", DT_FLOAT)
    477                   .Finalize(g, &ret));
    478   return ret;
    479 }
    480 
    481 Node* Conv2D(Graph* g, Node* in0, Node* in1) {
    482   Node* ret;
    483   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Conv2D")
    484                   .Input(in0)
    485                   .Input(in1)
    486                   .Attr("T", DT_FLOAT)
    487                   .Attr("strides", {1, 1, 1, 1})
    488                   .Attr("padding", "SAME")
    489                   .Finalize(g, &ret));
    490   return ret;
    491 }
    492 
    493 Node* Diag(Graph* g, Node* in, DataType type) {
    494   Node* ret;
    495   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Diag")
    496                   .Input(in)
    497                   .Attr("T", type)
    498                   .Finalize(g, &ret));
    499   return ret;
    500 }
    501 
    502 Node* DiagPart(Graph* g, Node* in, DataType type) {
    503   Node* ret;
    504   TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DiagPart")
    505                   .Input(in)
    506                   .Attr("T", type)
    507                   .Finalize(g, &ret));
    508   return ret;
    509 }
    510 
    511 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
    512 
    513 }  // end namespace graph
    514 }  // end namespace test
    515 }  // end namespace tensorflow
    516