Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/reduction_ops_common.h"
     17 
     18 #include "tensorflow/core/lib/strings/str_util.h"
     19 
     20 namespace tensorflow {
     21 
     22 TensorShape ReductionHelper::out_reshape() const {
     23   TensorShape shape;
     24   for (auto size : out_reshape_) shape.AddDim(size);
     25   return shape;
     26 }
     27 
     28 // The final output shape must be allocated with this shape.
     29 TensorShape ReductionHelper::out_shape() const {
     30   TensorShape shape;
     31   for (auto size : out_shape_) shape.AddDim(size);
     32   return shape;
     33 }
     34 
     35 TensorShape ReductionHelper::shuffled_shape() {
     36   const int dims = data_reshape_.size();
     37   TensorShape shape;
     38   for (int i = reduce_first_axis_; i < dims; i += 2) {
     39     shape.AddDim(data_reshape_[i]);
     40   }
     41   for (int i = !reduce_first_axis_; i < dims; i += 2) {
     42     shape.AddDim(data_reshape_[i]);
     43   }
     44   return shape;
     45 }
     46 
     47 gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
     48   const int dims = data_reshape_.size();
     49   const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
     50   gtl::InlinedVector<int32, 8> perm(dims);
     51   for (int i = 0; i < unreduced_dims; i++) {
     52     perm[i] = 2 * i + reduce_first_axis_;
     53   }
     54   for (int i = unreduced_dims; i < dims; i++) {
     55     perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
     56   }
     57   return perm;
     58 }
     59 
     60 template <typename Tperm>
     61 Status SimplifyHelper(const Tensor& data, const Tensor& axis,
     62                       gtl::InlinedVector<bool, 4>& bitmap) {
     63   auto axis_vec = axis.flat<Tperm>();
     64   for (int64 i = 0; i < axis.NumElements(); ++i) {
     65     Tperm index = axis_vec(i);
     66     if (index < -data.dims() || index >= data.dims()) {
     67       return errors::InvalidArgument("Invalid reduction dimension (", index,
     68                                      " for input with ", data.dims(),
     69                                      " dimension(s)");
     70     }
     71     index = (index + data.dims()) % data.dims();
     72     bitmap[index] = true;
     73   }
     74   return Status::OK();
     75 }
     76 
     77 Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
     78                                  const bool keep_dims) {
     79   // bitmap[i] indicates whether to reduce data along i-th axis.
     80   gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
     81   if (axis.dtype() == DT_INT32) {
     82     TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
     83   } else {
     84     TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
     85   }
     86   // Output tensor's dim sizes.
     87   out_shape_.clear();
     88   for (int i = 0; i < data.dims(); ++i) {
     89     if (!bitmap[i]) {
     90       // If we are not reducing along dimension i.
     91       out_shape_.push_back(data.dim_size(i));
     92     } else if (keep_dims) {
     93       // We are reducing along dimension i, but we want to keep the
     94       // same number of dimensions, so we set the dimension of i to
     95       // '1'.
     96       out_shape_.push_back(1);
     97     }
     98   }
     99 
    100   // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
    101   // the input data before doing the reduction on the resulting
    102   // tensor.  The shape of the reduction is a reshape of the final
    103   // output.
    104 
    105   // We'll skip the leading 1s.
    106   int dim_index = 0;
    107   for (; dim_index < data.dims(); ++dim_index) {
    108     if (data.dim_size(dim_index) != 1) break;
    109   }
    110   if (dim_index >= data.dims()) {
    111     // Special case. The input is essentially a scalar.
    112     reduce_first_axis_ = true;
    113   } else {
    114     // Starting from the (dim_index)-th dimension, dimensions
    115     // alternates between runs that need to be reduced and runs that
    116     // don't.
    117     //
    118     // NOTE: If a dimension has size 1, we group it as the current
    119     // run so that we can minimize the number of runs.
    120     //
    121     // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
    122     // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
    123     // and reduce by axes = [1] (i.e., the output is shape [6]).
    124     reduce_first_axis_ = bitmap[dim_index];
    125     data_reshape_.push_back(data.dim_size(dim_index));
    126     ++dim_index;
    127     for (; dim_index < data.dims(); ++dim_index) {
    128       const auto size = data.dim_size(dim_index);
    129       if (size == 1) {
    130         bitmap[dim_index] = bitmap[dim_index - 1];
    131       }
    132       if (bitmap[dim_index - 1] != bitmap[dim_index]) {
    133         // Starts a new run of reduce or !reduce.
    134         data_reshape_.push_back(size);
    135       } else {
    136         // Continue a run of reduce or !reduce.
    137         data_reshape_.back() *= size;
    138       }
    139     }
    140     // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
    141     // are reduced), data_reshape_[1, 3, 5, ...]  is out_reshape_,
    142     // otherwise, data_reshape_[0, 2, 4, ...] is.
    143     for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
    144          i += 2) {
    145       out_reshape_.push_back(data_reshape_[i]);
    146     }
    147   }
    148 
    149   VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
    150   VLOG(1) << "out  reshape: " << str_util::Join(out_reshape_, ",");
    151   VLOG(1) << "out    shape: " << str_util::Join(out_shape_, ",");
    152   return Status::OK();
    153 }
    154 
    155 }  // namespace tensorflow
    156