Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/types.h"
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 
     24 namespace tensorflow {
     25 
     26 namespace {
     27 template <typename T>
     28 struct mod_op {
     29   const T operator()(const T& a, const T& b) const { return a % b; }
     30 };
     31 }  // namespace
     32 
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 
     35 template <typename Tidx>
     36 class UnravelIndexOp : public OpKernel {
     37  public:
     38   explicit UnravelIndexOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     39 
     40   void Compute(OpKernelContext* ctx) override {
     41     const Tensor& indices_tensor = ctx->input(0);
     42     OP_REQUIRES(ctx,
     43                 TensorShapeUtils::IsVector(indices_tensor.shape()) ||
     44                     TensorShapeUtils::IsScalar(indices_tensor.shape()),
     45                 errors::InvalidArgument(
     46                     "The indices can only be scalar or vector, got \"",
     47                     indices_tensor.shape().DebugString(), "\""));
     48 
     49     const Tensor& dims_tensor = ctx->input(1);
     50     OP_REQUIRES(
     51         ctx, TensorShapeUtils::IsVector(dims_tensor.shape()),
     52         errors::InvalidArgument("The indices can only be 1-D, got \"",
     53                                 dims_tensor.shape().DebugString(), "\""));
     54 
     55     auto dims = dims_tensor.vec<Tidx>();
     56 
     57     Eigen::array<bool, 1> reverse({true});
     58 
     59     Tensor strides_tensor;
     60     OP_REQUIRES_OK(ctx,
     61                    ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
     62                                       TensorShape({dims_tensor.NumElements()}),
     63                                       &strides_tensor));
     64 
     65     auto strides = strides_tensor.vec<Tidx>();
     66     strides = dims.reverse(reverse)
     67                   .scan(0, Eigen::internal::ProdReducer<Tidx>(), false)
     68                   .reverse(reverse);
     69 
     70     Tensor strides_shifted_tensor;
     71     OP_REQUIRES_OK(ctx,
     72                    ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
     73                                       TensorShape({dims_tensor.NumElements()}),
     74                                       &strides_shifted_tensor));
     75 
     76     auto strides_shifted = strides_shifted_tensor.vec<Tidx>();
     77     strides_shifted = dims.reverse(reverse)
     78                           .scan(0, Eigen::internal::ProdReducer<Tidx>(), true)
     79                           .reverse(reverse);
     80 
     81     Tensor* output_tensor = nullptr;
     82     if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
     83       OP_REQUIRES_OK(
     84           ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}),
     85                                     &output_tensor));
     86 
     87       auto output = output_tensor->vec<Tidx>();
     88 
     89       output = output.constant(indices_tensor.scalar<Tidx>()());
     90       output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted;
     91     } else {
     92       OP_REQUIRES_OK(
     93           ctx, ctx->allocate_output(0,
     94                                     TensorShape({dims_tensor.NumElements(),
     95                                                  indices_tensor.NumElements()}),
     96                                     &output_tensor));
     97 
     98       auto output = output_tensor->matrix<Tidx>();
     99 
    100       Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
    101       Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
    102       Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
    103       Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
    104 
    105       output = indices_tensor.vec<Tidx>()
    106                    .reshape(indices_reshape)
    107                    .broadcast(indices_bcast);
    108       output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast),
    109                                  mod_op<Tidx>()) /
    110                strides_shifted.reshape(reshape).broadcast(bcast);
    111     }
    112   }
    113 };
    114 
    115 #define REGISTER_KERNEL(type)                                               \
    116   REGISTER_KERNEL_BUILDER(                                                  \
    117       Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint<type>("Tidx"), \
    118       UnravelIndexOp<type>);
    119 TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL)
    120 #undef REGISTER_KERNEL
    121 
    122 }  // namespace tensorflow
    123