Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #include <cmath>
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/framework/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 int32 GetValue(int32 v) { return v; }
     29 
     30 template <typename T>
     31 class RangeOp : public OpKernel {
     32  public:
     33   explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {}
     34 
     35   void Compute(OpKernelContext* context) override {
     36     const Tensor& start_in = context->input(0);
     37     const Tensor& limit_in = context->input(1);
     38     const Tensor& delta_in = context->input(2);
     39     OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
     40                 errors::InvalidArgument("start must be a scalar, not shape ",
     41                                         start_in.shape().DebugString()));
     42     OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
     43                 errors::InvalidArgument("limit must be a scalar, not shape ",
     44                                         limit_in.shape().DebugString()));
     45     OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
     46                 errors::InvalidArgument("delta must be a scalar, not shape ",
     47                                         delta_in.shape().DebugString()));
     48     const T start = start_in.scalar<T>()();
     49     const T limit = limit_in.scalar<T>()();
     50     const T delta = delta_in.scalar<T>()();
     51     OP_REQUIRES(context, delta != 0,
     52                 errors::InvalidArgument("Requires delta != 0: ", delta));
     53     if (delta > 0) {
     54       OP_REQUIRES(
     55           context, start <= limit,
     56           errors::InvalidArgument(
     57               "Requires start <= limit when delta > 0: ", start, "/", limit));
     58     } else {
     59       OP_REQUIRES(
     60           context, start >= limit,
     61           errors::InvalidArgument(
     62               "Requires start >= limit when delta < 0: ", start, "/", limit));
     63     }
     64     int64 size = (std::is_integral<T>::value
     65                       ? ((std::abs(limit - start) + std::abs(delta) - 1) /
     66                          std::abs(delta))
     67                       : std::ceil(std::abs((limit - start) / delta)));
     68     Tensor* out = nullptr;
     69     OP_REQUIRES_OK(context,
     70                    context->allocate_output(0, TensorShape({size}), &out));
     71     auto flat = out->flat<T>();
     72     T val = start;
     73     for (int64 i = 0; i < size; ++i) {
     74       flat(i) = T(val);
     75       val += delta;
     76     }
     77   }
     78 };
     79 
     80 #define REGISTER_KERNEL(DEV, TYPE)                           \
     81   REGISTER_KERNEL_BUILDER(Name("Range")                      \
     82                               .Device(DEV)                   \
     83                               .HostMemory("start")           \
     84                               .HostMemory("limit")           \
     85                               .HostMemory("delta")           \
     86                               .HostMemory("output")          \
     87                               .TypeConstraint<TYPE>("Tidx"), \
     88                           RangeOp<TYPE>);
     89 
     90 #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
     91 #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
     92 #ifdef TENSORFLOW_USE_SYCL
     93 #define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
     94 TF_CALL_float(REGISTER_SYCL_KERNEL);
     95 TF_CALL_double(REGISTER_SYCL_KERNEL);
     96 TF_CALL_int32(REGISTER_SYCL_KERNEL);
     97 TF_CALL_int64(REGISTER_SYCL_KERNEL);
     98 #undef REGISTER_SYCL_KERNEL
     99 #endif  // TENSORFLOW_USE_SYCL
    100 
    101 TF_CALL_float(REGISTER_CPU_KERNEL);
    102 TF_CALL_double(REGISTER_CPU_KERNEL);
    103 TF_CALL_int32(REGISTER_CPU_KERNEL);
    104 TF_CALL_int64(REGISTER_CPU_KERNEL);
    105 
    106 #if GOOGLE_CUDA
    107 
    108 TF_CALL_float(REGISTER_GPU_KERNEL);
    109 TF_CALL_double(REGISTER_GPU_KERNEL);
    110 TF_CALL_int32(REGISTER_GPU_KERNEL);
    111 TF_CALL_int64(REGISTER_GPU_KERNEL);
    112 
    113 #endif  // GOOGLE_CUDA
    114 
    115 #undef REGISTER_KERNEL
    116 #undef REGISTER_CPU_KERNEL
    117 #undef REGISTER_GPU_KERNEL
    118 
    119 template <typename T, typename Tnum>
    120 class LinSpaceOp : public OpKernel {
    121  public:
    122   explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {}
    123 
    124   void Compute(OpKernelContext* context) override {
    125     const Tensor& start_in = context->input(0);
    126     const Tensor& stop_in = context->input(1);
    127     const Tensor& num_in = context->input(2);
    128     OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()),
    129                 errors::InvalidArgument("start must be a scalar, not shape ",
    130                                         start_in.shape().DebugString()));
    131     OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()),
    132                 errors::InvalidArgument("stop must be a scalar, not shape ",
    133                                         stop_in.shape().DebugString()));
    134     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()),
    135                 errors::InvalidArgument("num must be a scalar, not shape ",
    136                                         num_in.shape().DebugString()));
    137     const T start = start_in.scalar<T>()();
    138     const T stop = stop_in.scalar<T>()();
    139     const Tnum num = num_in.scalar<Tnum>()();
    140     OP_REQUIRES(context, num > 0,
    141                 errors::InvalidArgument("Requires num > 0: ", num));
    142     Tensor* out = nullptr;
    143     OP_REQUIRES_OK(context,
    144                    context->allocate_output(0, TensorShape({num}), &out));
    145     auto flat = out->flat<T>();
    146     if (num == 1) {
    147       flat(0) = start;
    148     } else {
    149       const T step = (stop - start) / (num - 1);
    150       for (Tnum i = 0; i < num; ++i) flat(i) = start + step * i;
    151     }
    152   }
    153 };
    154 
    155 #define REGISTER_KERNEL(DEV, T, Tidx)                       \
    156   REGISTER_KERNEL_BUILDER(Name("LinSpace")                  \
    157                               .Device(DEV)                  \
    158                               .TypeConstraint<T>("T")       \
    159                               .TypeConstraint<Tidx>("Tidx") \
    160                               .HostMemory("start")          \
    161                               .HostMemory("stop")           \
    162                               .HostMemory("num")            \
    163                               .HostMemory("output"),        \
    164                           LinSpaceOp<T, Tidx>);
    165 
    166 #define REGISTER_KERNEL_ALL_NUMS(dev, T) \
    167   REGISTER_KERNEL(dev, T, int32);        \
    168   REGISTER_KERNEL(dev, T, int64)
    169 
    170 #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T)
    171 TF_CALL_float(REGISTER_CPU_KERNEL);
    172 TF_CALL_double(REGISTER_CPU_KERNEL);
    173 
    174 // NOTE(touts): We register the op on GPU but it still runs on CPU
    175 // because its inputs and outputs are tagged as HostMemory.
    176 #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_GPU, T)
    177 TF_CALL_float(REGISTER_GPU_KERNEL);
    178 TF_CALL_double(REGISTER_GPU_KERNEL);
    179 #undef REGISTER_GPU_KERNEL
    180 
    181 #ifdef TENSORFLOW_USE_SYCL
    182 #define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_SYCL, T)
    183 TF_CALL_float(REGISTER_SYCL_KERNEL);
    184 TF_CALL_double(REGISTER_SYCL_KERNEL);
    185 #undef REGISTER_SYCL_KERNEL
    186 #endif  // TENSORFLOW_USE_SYCL
    187 
    188 #undef REGISTER_CPU_KERNEL
    189 #undef REGISTER_KERNEL_ALL_NUMS
    190 #undef REGISTER_KERNEL
    191 
    192 }  // namespace tensorflow
    193