Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op.h"
     17 #include "tensorflow/core/framework/op_kernel.h"
     18 #include "tensorflow/core/platform/macros.h"
     19 #include "tensorflow/core/platform/types.h"
     20 #include "tensorflow/core/util/bcast.h"
     21 
     22 namespace tensorflow {
     23 
     24 // Given shapes of two tensors, computes the broadcast shape.
     25 template <typename T>
     26 class BCastArgsOp : public OpKernel {
     27  public:
     28   explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     29 
     30   void Compute(OpKernelContext* ctx) override {
     31     OP_REQUIRES(
     32         ctx, ctx->num_inputs() == 2,
     33         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
     34     gtl::InlinedVector<BCast::Vec, 4> shapes;
     35     for (int i = 0; i < ctx->num_inputs(); ++i) {
     36       const Tensor& in = ctx->input(i);
     37       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
     38                   errors::InvalidArgument("In[", i, "] must be a vector.",
     39                                           in.shape().DebugString()));
     40       BCast::Vec vec;
     41       for (int64 i = 0; i < in.NumElements(); ++i) {
     42         vec.push_back(in.vec<T>()(i));
     43       }
     44       shapes.push_back(vec);
     45     }
     46     BCast bcast(shapes[0], shapes[1]);
     47     OP_REQUIRES(ctx, bcast.IsValid(),
     48                 errors::InvalidArgument(
     49                     "Incompatible shapes: [", str_util::Join(shapes[0], ","),
     50                     "] vs. [", str_util::Join(shapes[1], ","), "]"));
     51     Output(ctx, 0, bcast.output_shape());
     52   }
     53 
     54   bool IsExpensive() override { return false; }
     55 
     56  private:
     57   void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
     58     const int64 len = v.size();
     59     Tensor* o = nullptr;
     60     OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
     61     for (int64 i = 0; i < len; ++i) {
     62       o->flat<T>()(i) = static_cast<T>(v[i]);
     63     }
     64   }
     65 
     66   TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp);
     67 };
     68 
     69 // Given shapes of two tensors, computes the reduction indices for the
     70 // gradient computation.
     71 //
     72 // TODO(zhifengc):
     73 //   1. Adds support for n-ary (n >= 2).
     74 template <typename T>
     75 class BCastGradArgsOp : public OpKernel {
     76  public:
     77   explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     78 
     79   void Compute(OpKernelContext* ctx) override {
     80     OP_REQUIRES(
     81         ctx, ctx->num_inputs() == 2,
     82         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
     83     gtl::InlinedVector<BCast::Vec, 4> shapes;
     84     for (int i = 0; i < ctx->num_inputs(); ++i) {
     85       const Tensor& in = ctx->input(i);
     86       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
     87                   errors::InvalidArgument("In[", i, "] must be a vector.",
     88                                           in.shape().DebugString()));
     89       BCast::Vec vec;
     90       for (int64 i = 0; i < in.NumElements(); ++i) {
     91         vec.push_back(in.vec<T>()(i));
     92       }
     93       shapes.push_back(vec);
     94     }
     95     BCast bcast(shapes[0], shapes[1]);
     96     OP_REQUIRES(ctx, bcast.IsValid(),
     97                 errors::InvalidArgument(
     98                     "Incompatible shapes: [", str_util::Join(shapes[0], ","),
     99                     "] vs. [", str_util::Join(shapes[1], ","), "]"));
    100     Output(ctx, 0, bcast.grad_x_reduce_idx());
    101     Output(ctx, 1, bcast.grad_y_reduce_idx());
    102   }
    103 
    104   bool IsExpensive() override { return false; }
    105 
    106  private:
    107   void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
    108     const int64 len = v.size();
    109     Tensor* o = nullptr;
    110     OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
    111     for (int64 i = 0; i < len; ++i) {
    112       o->flat<T>()(i) = static_cast<T>(v[i]);
    113     }
    114   }
    115 
    116   TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
    117 };
    118 
    119 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    120                             .Device(DEVICE_CPU)
    121                             .TypeConstraint<int32>("T")
    122                             .HostMemory("s0")
    123                             .HostMemory("s1")
    124                             .HostMemory("r0"),
    125                         BCastArgsOp<int32>);
    126 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    127                             .Device(DEVICE_CPU)
    128                             .TypeConstraint<int64>("T")
    129                             .HostMemory("s0")
    130                             .HostMemory("s1")
    131                             .HostMemory("r0"),
    132                         BCastArgsOp<int64>);
    133 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    134                             .Device(DEVICE_GPU)
    135                             .TypeConstraint<int32>("T")
    136                             .HostMemory("s0")
    137                             .HostMemory("s1")
    138                             .HostMemory("r0"),
    139                         BCastArgsOp<int32>);
    140 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    141                             .Device(DEVICE_GPU)
    142                             .TypeConstraint<int64>("T")
    143                             .HostMemory("s0")
    144                             .HostMemory("s1")
    145                             .HostMemory("r0"),
    146                         BCastArgsOp<int64>);
    147 
    148 #if TENSORFLOW_USE_SYCL
    149 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    150                             .Device(DEVICE_SYCL)
    151                             .TypeConstraint<int32>("T")
    152                             .HostMemory("s0")
    153                             .HostMemory("s1")
    154                             .HostMemory("r0"),
    155                         BCastArgsOp<int32>);
    156 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs")
    157                             .Device(DEVICE_SYCL)
    158                             .TypeConstraint<int64>("T")
    159                             .HostMemory("s0")
    160                             .HostMemory("s1")
    161                             .HostMemory("r0"),
    162                         BCastArgsOp<int32>);
    163 #endif
    164 
    165 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    166                             .Device(DEVICE_CPU)
    167                             .TypeConstraint<int32>("T")
    168                             .HostMemory("s0")
    169                             .HostMemory("s1")
    170                             .HostMemory("r0")
    171                             .HostMemory("r1"),
    172                         BCastGradArgsOp<int32>);
    173 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    174                             .Device(DEVICE_CPU)
    175                             .TypeConstraint<int64>("T")
    176                             .HostMemory("s0")
    177                             .HostMemory("s1")
    178                             .HostMemory("r0")
    179                             .HostMemory("r1"),
    180                         BCastGradArgsOp<int64>);
    181 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    182                             .Device(DEVICE_GPU)
    183                             .TypeConstraint<int32>("T")
    184                             .HostMemory("s0")
    185                             .HostMemory("s1")
    186                             .HostMemory("r0")
    187                             .HostMemory("r1"),
    188                         BCastGradArgsOp<int32>);
    189 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    190                             .Device(DEVICE_GPU)
    191                             .TypeConstraint<int64>("T")
    192                             .HostMemory("s0")
    193                             .HostMemory("s1")
    194                             .HostMemory("r0")
    195                             .HostMemory("r1"),
    196                         BCastGradArgsOp<int64>);
    197 
    198 #if TENSORFLOW_USE_SYCL
    199 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    200                             .Device(DEVICE_SYCL)
    201                             .TypeConstraint<int32>("T")
    202                             .HostMemory("s0")
    203                             .HostMemory("s1")
    204                             .HostMemory("r0")
    205                             .HostMemory("r1"),
    206                         BCastGradArgsOp<int32>);
    207 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
    208                             .Device(DEVICE_SYCL)
    209                             .TypeConstraint<int64>("T")
    210                             .HostMemory("s0")
    211                             .HostMemory("s1")
    212                             .HostMemory("r0")
    213                             .HostMemory("r1"),
    214                         BCastGradArgsOp<int64>);
    215 #endif
    216 }  // end namespace tensorflow
    217