Home | History | Annotate | Download | only in kernels
      1 // =============================================================================
      2 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //     http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 // =============================================================================
     16 
     17 #ifndef TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
     18 #define TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
     19 
     20 #include <cmath>
     21 #include <type_traits>
     22 #include <vector>
     23 #include "tensorflow/core/framework/op.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/shape_inference.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 
     29 namespace {
     30 
     31 template <class IndexVecT, class IndexT>
     32 IndexT compute_input_index(
     33     IndexVecT* target_dimensions, const IndexT& output_index,
     34     const IndexVecT& original_dimensions, const int& adjustable_dimension,
     35     const std::vector<tensorflow::int64>& dimension_ceiling,
     36     const std::vector<tensorflow::int64>& cumulative_dimensions, IndexT* result,
     37     std::vector<IndexT>* output_indices, const int& rank) {
     38   *result = 0;
     39   output_indices->clear();
     40 
     41   // un-rasterize the output index
     42   auto last_reduced_i = output_index;
     43   for (auto r = rank - 1; r >= 0; --r) {
     44     (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r];
     45     last_reduced_i =
     46         (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r];
     47   }
     48 
     49   // rasterize the input index
     50   IndexT last_index_factor = 1;
     51   for (auto r = rank - 1; r >= 0; --r) {
     52     IndexT index = 0;
     53     if (r != adjustable_dimension)
     54       index = (*output_indices)[r] / dimension_ceiling[r];
     55     else {
     56       for (int qi = 0; qi < rank; ++qi) {
     57         if (qi == adjustable_dimension) continue;
     58         index += cumulative_dimensions[qi] *
     59                  ((*output_indices)[qi] % dimension_ceiling[qi]);
     60       }
     61       index *= (*target_dimensions)[adjustable_dimension];
     62       index += (*output_indices)[r];
     63     }
     64     *result += last_index_factor * index;
     65     last_index_factor *= original_dimensions[r];
     66   }
     67 
     68   return *result;
     69 }
     70 
     71 template <class InputDataT,
     72           class IndexVecT>  // both types are needed here b/c IndexVecT and
     73                             // InputDataT are not related
     74                             void
     75                             fill_periodic_tensor(
     76                                 tensorflow::OpKernelContext* context,
     77                                 const IndexVecT& desired_shape,
     78                                 const tensorflow::Tensor& input_tensor) {
     79   // input is a strided array (last index is fastest, C-ordered)
     80   auto input = input_tensor.flat<InputDataT>();
     81   const int rank = input_tensor.dims();
     82   // original and target dimensions
     83   std::vector<tensorflow::int64> original_dimensions(rank),
     84       target_dimensions(rank);
     85   tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1);
     86   // factors by which original_dimensions increases/decreases w.r.t.
     87   // target_dimensions
     88   std::vector<tensorflow::int64> dimension_ceiling(rank),
     89       cumulative_dimensions(rank);
     90   // index of adjustable dimension
     91   int adjustable_dimension;
     92   tensorflow::TensorShape output_shape;
     93 
     94   // requires that the rank of the input tensor and length of the desired shape
     95   // are equal
     96   OP_REQUIRES(context, rank == desired_shape.size(),
     97               tensorflow::errors::InvalidArgument(
     98                   "periodic_resample expects the rank of the input tensor, ",
     99                   rank, ", to be the same as the length of the desired shape, ",
    100                   desired_shape.size(), "."));
    101 
    102   bool found = false;
    103   const auto& input_tensor_shape = input_tensor.shape();
    104 
    105   for (int i = 0; i < rank; ++i) {
    106     // if (desired_shape(i) < 1) {
    107     if (desired_shape[i] < 1) {
    108       // only one index can be adjustable
    109       OP_REQUIRES(context, !found,
    110                   tensorflow::errors::InvalidArgument(
    111                       "periodic_resample expects only "
    112                       "one index to be marked as adjustable."));
    113       adjustable_dimension = i;
    114       found = true;
    115     } else {
    116       OP_REQUIRES(
    117           context, desired_shape[i] >= input_tensor_shape.dim_size(i),
    118           tensorflow::errors::InvalidArgument(
    119               "periodic_resample expects the size of non-adjustable "
    120               "dimensions be at least as large as size of input tensor."
    121               " Dimension ",
    122               i, " input tensor has size ", input_tensor_shape.dim_size(i),
    123               ", desired shape has size ", desired_shape[i], "."));
    124 
    125       // target_dimensions[i] = desired_shape(i);
    126       target_dimensions[i] = desired_shape[i];
    127       new_sliced_size *= target_dimensions[i];
    128     }
    129   }
    130   // at least one index needs to be adjustable
    131   OP_REQUIRES(context, found,
    132               tensorflow::errors::InvalidArgument(
    133                   "periodic_resample expects at least "
    134                   "one index to be marked as adjustable."));
    135 
    136   int count = 0;
    137   for (const auto dim_info : input_tensor.shape()) {
    138     original_dimensions[count] = dim_info.size;
    139     ++count;
    140   }
    141 
    142   target_dimensions[adjustable_dimension] = total_size / new_sliced_size;
    143 
    144   count = 0;
    145   for (int i = 0; i < input_tensor.shape().dims(); ++i) {
    146     dimension_ceiling[count] = tensorflow::int64(std::ceil(
    147         float(target_dimensions[count]) / float(original_dimensions[count])));
    148     if (count == 0)
    149       cumulative_dimensions[count] = 1;
    150     else
    151       cumulative_dimensions[count] =
    152           cumulative_dimensions[count - 1] * dimension_ceiling[count - 1];
    153     ++count;
    154   }
    155 
    156   // ensure that the new dimension is greater than zero
    157   OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0,
    158               tensorflow::errors::InvalidArgument(
    159                   "periodic_resample found that the "
    160                   "adjustable dimension, ",
    161                   adjustable_dimension, ", isn't greater than zero, ",
    162                   target_dimensions[adjustable_dimension], "."));
    163   for (int i = 0; i < rank; ++i) {
    164     output_shape.AddDim(target_dimensions[i]);
    165   }
    166   const auto new_size =
    167       new_sliced_size * target_dimensions[adjustable_dimension];
    168 
    169   // Create an output tensor and attach it to the current context
    170   tensorflow::Tensor* output_tensor = nullptr;
    171   OP_REQUIRES_OK(context,
    172                  context->allocate_output(0, output_shape, &output_tensor));
    173   auto output = output_tensor->flat<InputDataT>();
    174 
    175   // memory is allocated for these variables outside the inner loop for
    176   // efficiency (although, I could create a separate class scope for
    177   // this purpose instead)
    178   tensorflow::int64 result = 0;
    179   std::vector<tensorflow::int64> output_indices(target_dimensions.size());
    180 
    181   // Fill output tensor with periodically resampled input tensor values
    182   for (tensorflow::int64 output_index = 0; output_index < new_size;
    183        ++output_index) {
    184     output(output_index) = input(compute_input_index(
    185         &target_dimensions, output_index, original_dimensions,
    186         adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result,
    187         &output_indices, rank));
    188   }
    189 }
    190 
    191 void create_output_tensor(
    192     tensorflow::OpKernelContext* context,
    193     const tensorflow::Tensor& input_tensor,
    194     const tensorflow::DataType& input_tensor_type,
    195     const tensorflow::PartialTensorShape& desired_shape_tensor) {
    196   auto desired_shape = desired_shape_tensor.dim_sizes();
    197 
    198   // obligatory type switch
    199   switch (input_tensor_type) {
    200     case tensorflow::DataTypeToEnum<float>::value:
    201       fill_periodic_tensor<float>(context, desired_shape, input_tensor);
    202       break;
    203     case tensorflow::DataTypeToEnum<double>::value:
    204       fill_periodic_tensor<double>(context, desired_shape, input_tensor);
    205       break;
    206     case tensorflow::DataTypeToEnum<tensorflow::int32>::value:
    207       fill_periodic_tensor<tensorflow::int32>(context, desired_shape,
    208                                               input_tensor);
    209       break;
    210     case tensorflow::DataTypeToEnum<tensorflow::int64>::value:
    211       fill_periodic_tensor<tensorflow::int64>(context, desired_shape,
    212                                               input_tensor);
    213       break;
    214     default:;
    215   }
    216 }
    217 
    218 }  // namespace
    219 
    220 class PeriodicResampleOp : public tensorflow::OpKernel {
    221  public:
    222   explicit PeriodicResampleOp(tensorflow::OpKernelConstruction* context)
    223       : tensorflow::OpKernel(context) {
    224     // Get the desired shape
    225     OP_REQUIRES_OK(context, context->GetAttr("shape", &desired_shape));
    226   }
    227 
    228   void Compute(tensorflow::OpKernelContext* context) override {
    229     // Grab the input tensor
    230     const tensorflow::Tensor& input_tensor = context->input(0);
    231     const tensorflow::DataType input_tensor_type = context->input_dtype(0);
    232 
    233     create_output_tensor(context, input_tensor, input_tensor_type,
    234                          desired_shape);
    235   }
    236 
    237  private:
    238   tensorflow::PartialTensorShape desired_shape;
    239 };
    240 
    241 #endif  // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
    242