1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/op_kernel.h" 18 #include "tensorflow/core/platform/macros.h" 19 #include "tensorflow/core/platform/types.h" 20 #include "tensorflow/core/util/bcast.h" 21 22 namespace tensorflow { 23 24 // Given shapes of two tensors, computes the broadcast shape. 25 template <typename T> 26 class BCastArgsOp : public OpKernel { 27 public: 28 explicit BCastArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 29 30 void Compute(OpKernelContext* ctx) override { 31 OP_REQUIRES( 32 ctx, ctx->num_inputs() == 2, 33 errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); 34 gtl::InlinedVector<BCast::Vec, 4> shapes; 35 for (int i = 0; i < ctx->num_inputs(); ++i) { 36 const Tensor& in = ctx->input(i); 37 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), 38 errors::InvalidArgument("In[", i, "] must be a vector.", 39 in.shape().DebugString())); 40 BCast::Vec vec; 41 for (int64 i = 0; i < in.NumElements(); ++i) { 42 vec.push_back(in.vec<T>()(i)); 43 } 44 shapes.push_back(vec); 45 } 46 BCast bcast(shapes[0], shapes[1]); 47 OP_REQUIRES(ctx, bcast.IsValid(), 48 errors::InvalidArgument( 49 "Incompatible shapes: [", str_util::Join(shapes[0], ","), 50 "] vs. [", str_util::Join(shapes[1], ","), "]")); 51 Output(ctx, 0, bcast.output_shape()); 52 } 53 54 bool IsExpensive() override { return false; } 55 56 private: 57 void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { 58 const int64 len = v.size(); 59 Tensor* o = nullptr; 60 OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); 61 for (int64 i = 0; i < len; ++i) { 62 o->flat<T>()(i) = static_cast<T>(v[i]); 63 } 64 } 65 66 TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); 67 }; 68 69 // Given shapes of two tensors, computes the reduction indices for the 70 // gradient computation. 71 // 72 // TODO(zhifengc): 73 // 1. Adds support for n-ary (n >= 2). 74 template <typename T> 75 class BCastGradArgsOp : public OpKernel { 76 public: 77 explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 78 79 void Compute(OpKernelContext* ctx) override { 80 OP_REQUIRES( 81 ctx, ctx->num_inputs() == 2, 82 errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); 83 gtl::InlinedVector<BCast::Vec, 4> shapes; 84 for (int i = 0; i < ctx->num_inputs(); ++i) { 85 const Tensor& in = ctx->input(i); 86 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), 87 errors::InvalidArgument("In[", i, "] must be a vector.", 88 in.shape().DebugString())); 89 BCast::Vec vec; 90 for (int64 i = 0; i < in.NumElements(); ++i) { 91 vec.push_back(in.vec<T>()(i)); 92 } 93 shapes.push_back(vec); 94 } 95 BCast bcast(shapes[0], shapes[1]); 96 OP_REQUIRES(ctx, bcast.IsValid(), 97 errors::InvalidArgument( 98 "Incompatible shapes: [", str_util::Join(shapes[0], ","), 99 "] vs. [", str_util::Join(shapes[1], ","), "]")); 100 Output(ctx, 0, bcast.grad_x_reduce_idx()); 101 Output(ctx, 1, bcast.grad_y_reduce_idx()); 102 } 103 104 bool IsExpensive() override { return false; } 105 106 private: 107 void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { 108 const int64 len = v.size(); 109 Tensor* o = nullptr; 110 OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); 111 for (int64 i = 0; i < len; ++i) { 112 o->flat<T>()(i) = static_cast<T>(v[i]); 113 } 114 } 115 116 TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); 117 }; 118 119 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 120 .Device(DEVICE_CPU) 121 .TypeConstraint<int32>("T") 122 .HostMemory("s0") 123 .HostMemory("s1") 124 .HostMemory("r0"), 125 BCastArgsOp<int32>); 126 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 127 .Device(DEVICE_CPU) 128 .TypeConstraint<int64>("T") 129 .HostMemory("s0") 130 .HostMemory("s1") 131 .HostMemory("r0"), 132 BCastArgsOp<int64>); 133 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 134 .Device(DEVICE_GPU) 135 .TypeConstraint<int32>("T") 136 .HostMemory("s0") 137 .HostMemory("s1") 138 .HostMemory("r0"), 139 BCastArgsOp<int32>); 140 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 141 .Device(DEVICE_GPU) 142 .TypeConstraint<int64>("T") 143 .HostMemory("s0") 144 .HostMemory("s1") 145 .HostMemory("r0"), 146 BCastArgsOp<int64>); 147 148 #if TENSORFLOW_USE_SYCL 149 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 150 .Device(DEVICE_SYCL) 151 .TypeConstraint<int32>("T") 152 .HostMemory("s0") 153 .HostMemory("s1") 154 .HostMemory("r0"), 155 BCastArgsOp<int32>); 156 REGISTER_KERNEL_BUILDER(Name("BroadcastArgs") 157 .Device(DEVICE_SYCL) 158 .TypeConstraint<int64>("T") 159 .HostMemory("s0") 160 .HostMemory("s1") 161 .HostMemory("r0"), 162 BCastArgsOp<int32>); 163 #endif 164 165 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 166 .Device(DEVICE_CPU) 167 .TypeConstraint<int32>("T") 168 .HostMemory("s0") 169 .HostMemory("s1") 170 .HostMemory("r0") 171 .HostMemory("r1"), 172 BCastGradArgsOp<int32>); 173 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 174 .Device(DEVICE_CPU) 175 .TypeConstraint<int64>("T") 176 .HostMemory("s0") 177 .HostMemory("s1") 178 .HostMemory("r0") 179 .HostMemory("r1"), 180 BCastGradArgsOp<int64>); 181 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 182 .Device(DEVICE_GPU) 183 .TypeConstraint<int32>("T") 184 .HostMemory("s0") 185 .HostMemory("s1") 186 .HostMemory("r0") 187 .HostMemory("r1"), 188 BCastGradArgsOp<int32>); 189 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 190 .Device(DEVICE_GPU) 191 .TypeConstraint<int64>("T") 192 .HostMemory("s0") 193 .HostMemory("s1") 194 .HostMemory("r0") 195 .HostMemory("r1"), 196 BCastGradArgsOp<int64>); 197 198 #if TENSORFLOW_USE_SYCL 199 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 200 .Device(DEVICE_SYCL) 201 .TypeConstraint<int32>("T") 202 .HostMemory("s0") 203 .HostMemory("s1") 204 .HostMemory("r0") 205 .HostMemory("r1"), 206 BCastGradArgsOp<int32>); 207 REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") 208 .Device(DEVICE_SYCL) 209 .TypeConstraint<int64>("T") 210 .HostMemory("s0") 211 .HostMemory("s1") 212 .HostMemory("r0") 213 .HostMemory("r1"), 214 BCastGradArgsOp<int64>); 215 #endif 216 } // end namespace tensorflow 217