1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ 17 #define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ 18 19 // See docs in ../ops/math_ops.cc. 20 21 #define EIGEN_USE_THREADS 22 23 #ifdef TENSORFLOW_USE_SYCL 24 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h" 25 #endif 26 27 #include "tensorflow/core/kernels/cwise_ops.h" 28 #include "tensorflow/core/kernels/cwise_ops_gradients.h" 29 30 #include "tensorflow/core/framework/op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/variant_op_registry.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/util/bcast.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 #ifdef TENSORFLOW_USE_SYCL 42 typedef Eigen::SyclDevice SYCLDevice; 43 #endif 44 45 class BinaryOpShared : public OpKernel { 46 public: 47 explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in); 48 49 protected: 50 struct BinaryOpState { 51 // Sets up bcast with the shape of in0 and in1, ensures that the bcast 52 // is valid, and if so, set out, either by allocating a new buffer using 53 // ctx->output(...) or by creating an alias for an owned input buffer for 54 // in-place computation. 55 // Caller must check ctx->status() upon return for non-ok status. 56 // If ctx->status().ok() is true, then out is guaranteed to be allocated. 57 BinaryOpState(OpKernelContext* ctx); 58 59 const Tensor& in0; 60 const Tensor& in1; 61 62 BCast bcast; 63 Tensor* out = nullptr; 64 int64 out_num_elements; 65 66 int64 in0_num_elements; 67 int64 in1_num_elements; 68 69 int ndims; 70 }; 71 72 void SetUnimplementedError(OpKernelContext* ctx); 73 void SetComputeError(OpKernelContext* ctx); 74 }; 75 76 // Coefficient-wise binary operations: 77 // Device: E.g., CPUDevice, GPUDevice. 78 // Functor: defined in cwise_ops.h. E.g., functor::add. 79 template <typename Device, typename Functor> 80 class BinaryOp : public BinaryOpShared { 81 public: 82 typedef typename Functor::in_type Tin; // Input scalar data type. 83 typedef typename Functor::out_type Tout; // Output scalar data type. 84 85 explicit BinaryOp(OpKernelConstruction* ctx) 86 : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(), 87 DataTypeToEnum<Tin>::v()) {} 88 89 void Compute(OpKernelContext* ctx) override { 90 // 'state': Shared helper not dependent on T to reduce code size 91 BinaryOpState state(ctx); 92 if (!ctx->status().ok()) return; 93 Tensor* out = state.out; 94 BCast* bcast = &state.bcast; 95 auto& in0 = state.in0; 96 auto& in1 = state.in1; 97 if (state.out_num_elements == 0) { 98 return; 99 } 100 const int ndims = state.ndims; 101 const Device& eigen_device = ctx->eigen_device<Device>(); 102 bool error = false; 103 bool* const error_ptr = Functor::has_errors ? &error : nullptr; 104 if (ndims <= 1) { 105 auto out_flat = out->flat<Tout>(); 106 if (state.in1_num_elements == 1) { 107 // tensor op scalar 108 functor::BinaryFunctor<Device, Functor, 1>().Right( 109 eigen_device, out_flat, in0.template flat<Tin>(), 110 in1.template scalar<Tin>(), error_ptr); 111 } else if (state.in0_num_elements == 1) { 112 // scalar op tensor 113 functor::BinaryFunctor<Device, Functor, 1>().Left( 114 eigen_device, out_flat, in0.template scalar<Tin>(), 115 in1.template flat<Tin>(), error_ptr); 116 } else { 117 functor::BinaryFunctor<Device, Functor, 1>()( 118 eigen_device, out_flat, in0.template flat<Tin>(), 119 in1.template flat<Tin>(), error_ptr); 120 } 121 } else if (ndims == 2) { 122 functor::BinaryFunctor<Device, Functor, 2>().BCast( 123 eigen_device, out->shaped<Tout, 2>(bcast->result_shape()), 124 in0.template shaped<Tin, 2>(bcast->x_reshape()), 125 BCast::ToIndexArray<2>(bcast->x_bcast()), 126 in1.template shaped<Tin, 2>(bcast->y_reshape()), 127 BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr); 128 } else if (ndims == 3) { 129 functor::BinaryFunctor<Device, Functor, 3>().BCast( 130 eigen_device, out->shaped<Tout, 3>(bcast->result_shape()), 131 in0.template shaped<Tin, 3>(bcast->x_reshape()), 132 BCast::ToIndexArray<3>(bcast->x_bcast()), 133 in1.template shaped<Tin, 3>(bcast->y_reshape()), 134 BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr); 135 } else if (ndims == 4) { 136 functor::BinaryFunctor<Device, Functor, 4>().BCast( 137 eigen_device, out->shaped<Tout, 4>(bcast->result_shape()), 138 in0.template shaped<Tin, 4>(bcast->x_reshape()), 139 BCast::ToIndexArray<4>(bcast->x_bcast()), 140 in1.template shaped<Tin, 4>(bcast->y_reshape()), 141 BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr); 142 } else if (ndims == 5) { 143 functor::BinaryFunctor<Device, Functor, 5>().BCast( 144 eigen_device, out->shaped<Tout, 5>(bcast->result_shape()), 145 in0.template shaped<Tin, 5>(bcast->x_reshape()), 146 BCast::ToIndexArray<5>(bcast->x_bcast()), 147 in1.template shaped<Tin, 5>(bcast->y_reshape()), 148 BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr); 149 } else { 150 SetUnimplementedError(ctx); 151 } 152 if (Functor::has_errors && error) { 153 SetComputeError(ctx); 154 } 155 } 156 }; 157 158 template <typename Device, typename T> 159 class ApproximateEqualOp : public OpKernel { 160 public: 161 explicit ApproximateEqualOp(OpKernelConstruction* context) 162 : OpKernel(context) { 163 float tolerance; 164 OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance)); 165 tolerance_ = T(tolerance); 166 } 167 void Compute(OpKernelContext* context) override { 168 const Tensor& x_input = context->input(0); 169 const Tensor& y_input = context->input(1); 170 OP_REQUIRES( 171 context, x_input.shape() == y_input.shape(), 172 errors::InvalidArgument("x and y must be of the same shape. ", 173 "x shape: ", x_input.shape().DebugString(), 174 ". y shape: ", y_input.shape().DebugString())); 175 Tensor* z_output = nullptr; 176 OP_REQUIRES_OK(context, 177 context->allocate_output(0, x_input.shape(), &z_output)); 178 const Device& d = context->eigen_device<Device>(); 179 typename TTypes<T>::ConstFlat x(x_input.flat<T>()); 180 typename TTypes<T>::ConstFlat y(y_input.flat<T>()); 181 typename TTypes<bool>::Flat z(z_output->flat<bool>()); 182 functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z); 183 } 184 185 private: 186 T tolerance_; 187 }; 188 189 // Basic coefficient-wise binary operations that are known to not require 190 // any broadcasting. This is the case for example of the gradients of 191 // unary operations. 192 // Device: E.g., CPUDevice, GPUDevice. 193 // Functor: defined above. E.g., functor::tanh_grad. 194 template <typename Device, typename Functor> 195 class SimpleBinaryOp : public OpKernel { 196 public: 197 typedef typename Functor::in_type Tin; // Input scalar data type. 198 typedef typename Functor::out_type Tout; // Output scalar data type. 199 200 explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 201 202 void Compute(OpKernelContext* ctx) override { 203 const Tensor& in0 = ctx->input(0); 204 const Tensor& in1 = ctx->input(1); 205 auto in0_flat = in0.flat<Tin>(); 206 auto in1_flat = in1.flat<Tin>(); 207 const Device& eigen_device = ctx->eigen_device<Device>(); 208 209 Tensor* out = nullptr; 210 if (std::is_same<Tin, Tout>::value) { 211 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 212 {0, 1}, 0, in0.shape(), &out)); 213 } else { 214 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out)); 215 } 216 auto out_flat = out->flat<Tout>(); 217 functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat, 218 in0_flat, in1_flat); 219 } 220 }; 221 222 // Coefficient-wise unary operations: 223 // Device: E.g., CPUDevice, GPUDevice. 224 // Functor: defined in cwise_ops.h. E.g., functor::sqrt. 225 template <typename Device, typename Functor> 226 class UnaryOp : public OpKernel { 227 public: 228 typedef typename Functor::in_type Tin; // Input scalar data type. 229 typedef typename Functor::out_type Tout; // Output scalar data type. 230 // Tin may be different from Tout. E.g., abs: complex64 -> float 231 232 explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 233 auto in = DataTypeToEnum<Tin>::v(); 234 auto out = DataTypeToEnum<Tout>::v(); 235 OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out})); 236 } 237 238 void Compute(OpKernelContext* ctx) override { 239 const Tensor& inp = ctx->input(0); 240 Tensor* out = nullptr; 241 if (std::is_same<Tin, Tout>::value) { 242 OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( 243 {0}, 0, inp.shape(), &out)); 244 } else { 245 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); 246 } 247 functor::UnaryFunctor<Device, Functor>()( 248 ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>()); 249 } 250 }; 251 252 template <typename Device, VariantUnaryOp OpEnum> 253 class UnaryVariantOp : public OpKernel { 254 public: 255 explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 256 257 void Compute(OpKernelContext* ctx) override { 258 const Tensor& inp = ctx->input(0); 259 OP_REQUIRES( 260 ctx, TensorShapeUtils::IsScalar(inp.shape()), 261 errors::InvalidArgument("Non-scalar variants are not supported.")); 262 const Variant& v = inp.scalar<Variant>()(); 263 Variant v_out; 264 OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out)); 265 Tensor out(cpu_allocator(), DT_VARIANT, TensorShape()); 266 out.scalar<Variant>()() = std::move(v_out); 267 ctx->set_output(0, std::move(out)); 268 } 269 }; 270 271 namespace functor { 272 273 template <typename D, typename Out, typename Rhs> 274 void Assign(const D& d, Out out, Rhs rhs) { 275 out.device(d) = rhs; 276 } 277 278 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS> 279 // for functors with with no error checking. 280 template <typename Functor, int NDIMS> 281 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> { 282 void operator()(const CPUDevice& d, typename Functor::tout_type out, 283 typename Functor::tin_type in0, 284 typename Functor::tin_type in1, bool* error) { 285 Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); 286 } 287 288 void Left(const CPUDevice& d, typename Functor::tout_type out, 289 typename Functor::tscalar_type scalar, 290 typename Functor::tin_type in, bool* error) { 291 typedef typename Functor::out_type Tout; 292 typedef typename Functor::in_type Tin; 293 typedef typename Functor::func Binary; 294 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; 295 Assign(d, out, in.unaryExpr(Unary(scalar.data()))); 296 } 297 298 void Right(const CPUDevice& d, typename Functor::tout_type out, 299 typename Functor::tin_type in, 300 typename Functor::tscalar_type scalar, bool* error) { 301 typedef typename Functor::out_type Tout; 302 typedef typename Functor::in_type Tin; 303 typedef typename Functor::func Binary; 304 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; 305 Assign(d, out, in.unaryExpr(Unary(scalar.data()))); 306 } 307 308 void BCast(const CPUDevice& dev, 309 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 310 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 311 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 312 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 313 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 314 bool* error) { 315 typename Functor::func func; 316 if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) { 317 Assign(dev, out, in0.binaryExpr(in1, func)); 318 } else if (AllOne<NDIMS>(bcast0)) { 319 auto rhs = in1.broadcast(bcast1); 320 Assign(dev, out, in0.binaryExpr(rhs, func)); 321 } else if (AllOne<NDIMS>(bcast1)) { 322 auto lhs = in0.broadcast(bcast0); 323 Assign(dev, out, lhs.binaryExpr(in1, func)); 324 } else { 325 auto lhs = in0.broadcast(bcast0); 326 auto rhs = in1.broadcast(bcast1); 327 Assign(dev, out, lhs.binaryExpr(rhs, func)); 328 } 329 } 330 }; 331 332 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2> 333 // for functors with with no error checking. 334 template <typename Functor> 335 struct BinaryFunctor<CPUDevice, Functor, 2, false> { 336 enum { NDIMS = 2 }; 337 338 void operator()(const CPUDevice& d, typename Functor::tout_type out, 339 typename Functor::tin_type in0, 340 typename Functor::tin_type in1, bool* error) { 341 Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); 342 } 343 344 void Left(const CPUDevice& d, typename Functor::tout_type out, 345 typename Functor::tscalar_type scalar, 346 typename Functor::tin_type in, bool* error) { 347 typedef typename Functor::out_type Tout; 348 typedef typename Functor::in_type Tin; 349 typedef typename Functor::func Binary; 350 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; 351 Assign(d, out, in.unaryExpr(Unary(scalar.data()))); 352 } 353 354 void Right(const CPUDevice& d, typename Functor::tout_type out, 355 typename Functor::tin_type in, 356 typename Functor::tscalar_type scalar, bool* error) { 357 typedef typename Functor::out_type Tout; 358 typedef typename Functor::in_type Tin; 359 typedef typename Functor::func Binary; 360 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; 361 Assign(d, out, in.unaryExpr(Unary(scalar.data()))); 362 } 363 364 #if !defined(EIGEN_HAS_INDEX_LIST) 365 inline Eigen::DSizes<int, 2> NByOne(int n) { 366 return Eigen::DSizes<int, 2>(n, 1); 367 } 368 inline Eigen::DSizes<int, 2> OneByM(int m) { 369 return Eigen::DSizes<int, 2>(1, m); 370 } 371 #else 372 inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) { 373 Eigen::IndexList<int, Eigen::type2index<1>> ret; 374 ret.set(0, n); 375 return ret; 376 } 377 inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) { 378 Eigen::IndexList<Eigen::type2index<1>, int> ret; 379 ret.set(1, m); 380 return ret; 381 } 382 #endif 383 384 void BCast(const CPUDevice& dev, 385 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 386 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 387 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 388 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 389 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 390 bool* error) { 391 typedef typename Functor::in_type T; 392 typename Functor::func func; 393 if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) { 394 // Optimize for speed by using Eigen::type2index and avoid 395 // .broadcast() when we know its a no-op. 396 // 397 // Here, we need to handle 6 cases depending on how many "1" 398 // exist in in0 and in1's shapes (4 numbers in total). It's not 399 // possible that two shapes have more than 2 1s because those 400 // are simplified to NDIMS==1 case. 401 // 402 // Because this optimization increases the binary size for each 403 // Functor (+, -, *, /, <, <=, etc.), type and ndim combination. 404 // we only apply such optimization for selected ops/types/ndims. 405 // 406 // Because NDIMS, Functor::use_broadcast_optimization and 407 // use_broadcast_optimization<T> are compile-time constant, gcc 408 // does a decent job avoiding generating code when conditions 409 // are not met. 410 const int a = in0.dimension(0); // in0 is shape [a, b] 411 const int b = in0.dimension(1); 412 const int c = in1.dimension(0); // in1 is shape [c, d] 413 const int d = in1.dimension(1); 414 if ((a == 1) && (d == 1)) { 415 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); 416 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); 417 Assign(dev, out, lhs.binaryExpr(rhs, func)); 418 return; 419 } 420 if ((b == 1) && (c == 1)) { 421 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); 422 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); 423 Assign(dev, out, lhs.binaryExpr(rhs, func)); 424 return; 425 } 426 if (a == 1) { 427 auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); 428 auto rhs = in1; 429 Assign(dev, out, lhs.binaryExpr(rhs, func)); 430 return; 431 } 432 if (b == 1) { 433 auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); 434 auto rhs = in1; 435 Assign(dev, out, lhs.binaryExpr(rhs, func)); 436 return; 437 } 438 if (c == 1) { 439 auto lhs = in0; 440 auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); 441 Assign(dev, out, lhs.binaryExpr(rhs, func)); 442 return; 443 } 444 if (d == 1) { 445 auto lhs = in0; 446 auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); 447 Assign(dev, out, lhs.binaryExpr(rhs, func)); 448 return; 449 } 450 451 const bool bcast0_all_one = AllOne<NDIMS>(bcast0); 452 const bool bcast1_all_one = AllOne<NDIMS>(bcast1); 453 if (bcast0_all_one && !bcast1_all_one) { 454 auto lhs = in0; // No need to do broadcast for in0 455 auto rhs = in1.broadcast(bcast1); 456 Assign(dev, out, lhs.binaryExpr(rhs, func)); 457 return; 458 } 459 460 if (!bcast0_all_one && bcast1_all_one) { 461 auto lhs = in0.broadcast(bcast0); 462 auto rhs = in1; // No need to do broadcast for in1 463 Assign(dev, out, lhs.binaryExpr(rhs, func)); 464 return; 465 } 466 } 467 468 // Fallback path. Always works and probably slower. 469 auto lhs = in0.broadcast(bcast0); 470 auto rhs = in1.broadcast(bcast1); 471 Assign(dev, out, lhs.binaryExpr(rhs, func)); 472 } 473 }; 474 475 // Version of BinaryFunctor with error handling. 476 template <typename Functor, int NDIMS> 477 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> { 478 void operator()(const CPUDevice& d, typename Functor::tout_type out, 479 typename Functor::tin_type in0, 480 typename Functor::tin_type in1, bool* error) { 481 Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error))); 482 } 483 484 void Left(const CPUDevice& d, typename Functor::tout_type out, 485 typename Functor::tscalar_type scalar, 486 typename Functor::tin_type in, bool* error) { 487 typedef typename Functor::out_type Tout; 488 typedef typename Functor::in_type Tin; 489 typedef typename Functor::func Binary; 490 typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary; 491 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error))); 492 } 493 494 void Right(const CPUDevice& d, typename Functor::tout_type out, 495 typename Functor::tin_type in, 496 typename Functor::tscalar_type scalar, bool* error) { 497 typedef typename Functor::out_type Tout; 498 typedef typename Functor::in_type Tin; 499 typedef typename Functor::func Binary; 500 typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary; 501 Assign(d, out, in.unaryExpr(Unary(scalar.data(), error))); 502 } 503 504 void BCast(const CPUDevice& dev, 505 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 506 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 507 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 508 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 509 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 510 bool* error) { 511 typename Functor::func func(error); 512 auto lhs = in0.broadcast(bcast0); 513 auto rhs = in1.broadcast(bcast1); 514 Assign(dev, out, lhs.binaryExpr(rhs, func)); 515 } 516 }; 517 518 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>. 519 template <typename Functor> 520 struct UnaryFunctor<CPUDevice, Functor> { 521 void operator()(const CPUDevice& d, typename Functor::tout_type out, 522 typename Functor::tin_type in) { 523 Assign(d, out, in.unaryExpr(typename Functor::func())); 524 } 525 }; 526 527 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>. 528 template <typename T> 529 struct ApproximateEqual<CPUDevice, T> { 530 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x, 531 typename TTypes<T>::ConstFlat y, T tolerance, 532 typename TTypes<bool>::Flat z) { 533 auto diff = x - y; 534 z.device(d) = diff.abs() <= tolerance; 535 } 536 }; 537 538 } // end namespace functor 539 540 #define REGISTER(OP, D, N, F, T) \ 541 REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \ 542 OP<D##Device, F<T>>); 543 544 #define REGISTER_VARIANT(OP, D, N, ENUM) \ 545 REGISTER_KERNEL_BUILDER( \ 546 Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \ 547 OP<D##Device, ENUM>); 548 549 // Macros to register kernels for multiple types (T0, T1, etc.) on 550 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using 551 // the functor "F" (e.g., functor::sqrt). 552 553 #if defined(__ANDROID_TYPES_SLIM__) 554 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files. 555 // Normally Android TensorFlow is built with a reduced number of types (float). 556 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__" 557 // to generate a library with full type support with a consequent increase in 558 // code size. 559 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0) 560 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0) 561 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0) 562 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0) 563 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) 564 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ 565 REGISTER(OP, D, N, F, T0) 566 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ 567 REGISTER(OP, D, N, F, T0) 568 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ 569 REGISTER(OP, D, N, F, T0) 570 #else // !defined(__ANDROID_TYPES_SLIM__) 571 #define REGISTER2(OP, D, N, F, T0, T1) \ 572 REGISTER(OP, D, N, F, T0) \ 573 REGISTER(OP, D, N, F, T1) 574 #define REGISTER3(OP, D, N, F, T0, T1, T2) \ 575 REGISTER2(OP, D, N, F, T0, T1) \ 576 REGISTER(OP, D, N, F, T2) 577 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ 578 REGISTER2(OP, D, N, F, T0, T1) \ 579 REGISTER2(OP, D, N, F, T2, T3) 580 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ 581 REGISTER3(OP, D, N, F, T0, T1, T2) \ 582 REGISTER2(OP, D, N, F, T3, T4) 583 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \ 584 REGISTER3(OP, D, N, F, T0, T1, T2) \ 585 REGISTER3(OP, D, N, F, T3, T4, T5) 586 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ 587 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ 588 REGISTER3(OP, D, N, F, T4, T5, T6) 589 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ 590 REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ 591 REGISTER4(OP, D, N, F, T4, T5, T6, T7) 592 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ 593 REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ 594 REGISTER4(OP, D, N, F, T5, T6, T7, T8) 595 596 // Instead of adding REGISTER10, etc., shard the .cc files - see 597 // cwise_op_equal_to_*.cc for an example. 598 599 #endif // defined(__ANDROID_TYPES_SLIM__) 600 601 } // end namespace tensorflow 602 603 #endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ 604