Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
     17 #define TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
     18 
     19 #include <type_traits>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_types.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/platform/types.h"
     26 
     27 namespace tensorflow {
     28 
     29 // Maximum number of non-collapsible blocked dimensions supported by the
     30 // {SpaceToBatch,BatchToSpace}ND operation.  To change the limit, modify this
     31 // constant and the TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS macro definition
     32 // below.
     33 constexpr int kMaxSpaceToBatchBlockDims = 4;
     34 
     35 // Expands to:
     36 //   MACRO(1, ## __VA_ARGS__)
     37 //   ...
     38 //   MACRO(kMaxSpaceToBatchBlockDims, ## __VA_ARGS__)
     39 //
     40 // Note: The space between the number and the comma is necessary for proper GCC
     41 // comma handling: https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html
     42 #define TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(MACRO, ...) \
     43   MACRO(1 /**/, ##__VA_ARGS__)                              \
     44   MACRO(2 /**/, ##__VA_ARGS__)                              \
     45   MACRO(3 /**/, ##__VA_ARGS__)                              \
     46   MACRO(4 /**/, ##__VA_ARGS__)                              \
     47   /**/
     48 
     49 namespace internal {
     50 namespace spacetobatch {
     51 
     52 template <typename InputType, typename OutputType>
     53 void SubtleMustCopyFlatHelper(const Tensor& t, OutputType* output) {
     54   const int64 num_elements = t.shape().num_elements();
     55   output->resize(num_elements);
     56   auto eigen_vec = t.flat<InputType>();
     57   for (int64 i = 0; i < num_elements; ++i) {
     58     (*output)[i] = SubtleMustCopy(eigen_vec(i));
     59   }
     60 }
     61 
     62 // Copies flat contents of `t` to std::vector-like `*output`, which is resized
     63 // as needed.  `OutputType` may be either `std::vector<int64>` or
     64 // `gtl::InlinedVector<int64>`.
     65 //
     66 // Precondition: t.dtype() must be either DT_INT32 or DT_INT64.
     67 template <typename OutputType>
     68 void SubtleMustCopyFlat(const Tensor& t, OutputType* output) {
     69   if (t.dtype() == DT_INT32) {
     70     SubtleMustCopyFlatHelper<int32, OutputType>(t, output);
     71   } else {
     72     SubtleMustCopyFlatHelper<int64, OutputType>(t, output);
     73   }
     74 }
     75 
     76 }  // namespace spacetobatch
     77 }  // namespace internal
     78 
     79 namespace functor {
     80 
     81 // Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion.
     82 //
     83 // If B2S is false, then this performs the space-to-batch conversion.  If S2B is
     84 // true, then this performs the inverse batch-to-space conversion.
     85 template <typename Device, typename T, int NUM_BLOCK_DIMS, bool B2S = false>
     86 struct SpaceToBatchFunctor {
     87   using InputT = typename std::conditional<B2S, T, const T>::type;
     88   using OutputT = typename std::conditional<B2S, const T, T>::type;
     89   // Implements the space to batch conversion.
     90   //
     91   // space_tensor: input tensor of space-to-batch operation.  If B2S = false,
     92   //     then this is the input to the conversion.  If B2S = true, then this
     93   //     is the output of the conversion.
     94   // block_size: array of shape [NUM_BLOCK_DIMS] specifying the block sizes for
     95   //     dimensions 1 through NUM_BLOCK_DIMS.
     96   // paddings: row-major array of shape [NUM_BLOCK_DIMS, 2] specifying the
     97   //     start and end padding for dimensions 1 through NUM_BLOCK_DIMS.
     98   // batch_tensor: output tensor of the space-to-batch operation.  If
     99   //     B2S = false, then this is the output of the conversion.  If B2S = true,
    100   //     then this is the input to the conversion.
    101   //
    102   // The caller must ensure that the dimensions of the tensors are correct.
    103   Status operator()(
    104       const Device& d,
    105       typename TTypes<InputT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
    106       const int64 block_shape[NUM_BLOCK_DIMS],
    107       const int64 paddings[NUM_BLOCK_DIMS * 2],
    108       typename TTypes<OutputT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor);
    109 };
    110 
    111 }  // namespace functor
    112 }  // namespace tensorflow
    113 
    114 #endif  // TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
    115