Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
     17 #define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 
     21 namespace Eigen {
     22 
     23 /** ExtractGlimpses
     24  * \ingroup CXX11_NeuralNetworks_Module
     25  *
     26  * \brief Extract glimpses from an input tensor.
     27  *
     28  * The input parameter is expected to be a col-major tensor with a rank of 4
     29  * (depth, x, y, and batch). The width and height parameters specify the
     30  * extension of the returned glimpses. The offsets parameter specifies the x, y
     31  * locations of the center of the glimpses relative to the center of the input
     32  * image. The vector is expected to contain one IndexPair for each image in the
     33  * batch dimension. The normalized boolean indicates if incoming coordinates are
     34  * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each
     35  * height and width dimension. The centered boolean indicates if incoming
     36  * coordinates are centered relative to the image, in which case -1.0 and 1.0
     37  * correspond to minimum and maximum of each dimension while 0.0 corresponds to
     38  * the center.
     39  *
     40  * The result can be assigned to a tensor of rank equal to that of the input.
     41  * The result will be laid out in col-major order (depth, x, y, batch). The
     42  * dimensions of the result will be equal to the dimensions of the input except
     43  * for width and height which will be equal to the requested glimpse size.
     44  */
     45 namespace {
     46 template <typename Index>
     47 struct GlimpseExtractionOp {
     48   GlimpseExtractionOp(const Index width, const Index height,
     49                       const std::vector<IndexPair<float> >& offsets,
     50                       const bool normalized, const bool centered,
     51                       const bool uniform_noise)
     52       : width_(width),
     53         height_(height),
     54         offsets_(offsets),
     55         normalized_(normalized),
     56         centered_(centered),
     57         uniform_noise_(uniform_noise) {}
     58 
     59   template <typename Input>
     60   DSizes<Index, 4> dimensions(const Input& input) const {
     61     typedef typename internal::traits<Input>::Index IndexType;
     62     typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
     63                              internal::traits<Input>::Layout, IndexType> >
     64         Ref;
     65     Ref in(input);
     66 
     67     DSizes<Index, 4> dims = in.dimensions();
     68 
     69     dims[0] = in.dimension(0);
     70     dims[1] = width_;
     71     dims[2] = height_;
     72     dims[3] = in.dimension(3);
     73     return dims;
     74   }
     75 
     76   template <typename Input, typename Output, typename Device>
     77   EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output,
     78                               const Device& device) const {
     79     typedef typename internal::traits<Input>::Index IndexType;
     80     typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
     81                              internal::traits<Input>::Layout, IndexType> >
     82         Ref;
     83     Ref in(input);
     84     const Index num_channels = in.dimension(0);
     85     const Index input_width = in.dimension(1);
     86     const Index input_height = in.dimension(2);
     87     const Index batch_size = in.dimension(3);
     88     eigen_assert(input_width > 0);
     89     eigen_assert(input_height > 0);
     90     internal::NormalRandomGenerator<float> gen;
     91     internal::UniformRandomGenerator<float> unigen;
     92 
     93     for (Index i = 0; i < batch_size; ++i) {
     94       float x = offsets_[i].first, y = offsets_[i].second;
     95 
     96       // Un-normalize coordinates back to pixel space if normalized.
     97       if (normalized_) {
     98         x *= input_width;
     99         y *= input_height;
    100       }
    101       // Un-center if coordinates are centered on the image center.
    102       if (centered_) {
    103         x /= 2.0f;
    104         y /= 2.0f;
    105         x += input_width / 2.0f;
    106         y += input_height / 2.0f;
    107       }
    108       // Remove half of the glimpse window.
    109       x -= width_ / 2.0f;
    110       y -= height_ / 2.0f;
    111 
    112       const Index offset_x = (Index)x;
    113       const Index offset_y = (Index)y;
    114       Index glimpse_width = width_;
    115       Index glimpse_height = height_;
    116       bool partial_overlap = false;
    117       DSizes<Index, 3> slice_offset(0, offset_x, offset_y);
    118       DSizes<Index, 3> slice_extent(num_channels, width_, height_);
    119       DSizes<Index, 3> base_offset(0, 0, 0);
    120 
    121       if (offset_x < 0) {
    122         slice_offset[1] = 0;
    123         glimpse_width = (std::max<Index>)(0, width_ + offset_x);
    124         slice_extent[1] = glimpse_width;
    125         base_offset[1] = width_ - glimpse_width;
    126         partial_overlap = true;
    127       } else if (offset_x + width_ >= input_width) {
    128         glimpse_width = (std::max<Index>)(0, input_width - offset_x);
    129         slice_extent[1] = glimpse_width;
    130         partial_overlap = true;
    131       }
    132       if (offset_y < 0) {
    133         slice_offset[2] = 0;
    134         glimpse_height = (std::max<Index>)(0, height_ + offset_y);
    135         slice_extent[2] = glimpse_height;
    136         base_offset[2] = height_ - glimpse_height;
    137         partial_overlap = true;
    138       } else if (offset_y + height_ >= input_height) {
    139         glimpse_height = (std::max<Index>)(0, input_height - offset_y);
    140         slice_extent[2] = glimpse_height;
    141         partial_overlap = true;
    142       }
    143       slice_extent[1] = std::min<Index>(input_width, slice_extent[1]);
    144       slice_extent[2] = std::min<Index>(input_height, slice_extent[2]);
    145 
    146       if (partial_overlap) {
    147         if (uniform_noise_) {
    148           // Initialize the glimpse with uniform noise.
    149           typedef typename internal::remove_const<
    150               typename internal::traits<Input>::Scalar>::type Scalar;
    151           TensorFixedSize<Scalar, Sizes<> > mini;
    152           mini.device(device) = input.template chip<3>(i).minimum();
    153           TensorFixedSize<float, Sizes<> > range;
    154           range.device(device) = (input.template chip<3>(i).maximum() - mini)
    155                                      .template cast<float>();
    156 
    157           DSizes<Index, 3> glimpse_size(num_channels, width_, height_);
    158           TensorMap<Tensor<float, 3> > tmp(NULL, glimpse_size);
    159           output.template chip<3>(i).device(device) =
    160               mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) +
    161               (tmp.random(unigen) *
    162                range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size))
    163                   .template cast<Scalar>();
    164         } else {
    165           // Initialize the glimpse with white noise: compute the mean and sigma
    166           // of each channel, and use them to shape the gaussian.
    167           DSizes<Index, 2> glimpse_size(width_, height_);
    168           DSizes<Index, 2> input_size(input_width, input_height);
    169           typedef typename internal::remove_const<
    170               typename internal::traits<Input>::Scalar>::type Scalar;
    171 
    172           for (int j = 0; j < num_channels; ++j) {
    173             TensorFixedSize<Scalar, Sizes<> > mean;
    174             mean.device(device) = input.template chip<3>(i)
    175                                       .template chip<0>(j)
    176                                       .template cast<float>()
    177                                       .mean();
    178             TensorFixedSize<float, Sizes<> > sigma;
    179             sigma.device(device) =
    180                 (input.template chip<3>(i)
    181                      .template chip<0>(j)
    182                      .template cast<float>() -
    183                  mean.reshape(Sizes<1, 1>()).broadcast(input_size))
    184                     .square()
    185                     .mean()
    186                     .sqrt();
    187             TensorFixedSize<Scalar, Sizes<> > mini;
    188             mini.device(device) =
    189                 input.template chip<3>(i).template chip<0>(j).minimum();
    190             TensorFixedSize<float, Sizes<> > maxi;
    191             maxi.device(device) =
    192                 input.template chip<3>(i).template chip<0>(j).maximum();
    193 
    194             TensorMap<Tensor<float, 2> > tmp(NULL, glimpse_size);
    195             output.template chip<3>(i).template chip<0>(j).device(device) =
    196                 (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) +
    197                  (tmp.random(gen) *
    198                   sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
    199                      .template cast<Scalar>())
    200                     .cwiseMin(
    201                         maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
    202                     .cwiseMax(
    203                         mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size));
    204           }
    205         }
    206 
    207         // Copy the part of the glimpse that cover the input image if any.
    208         if (glimpse_width == 0 || glimpse_height == 0) {
    209           continue;
    210         }
    211         output.template chip<3>(i)
    212             .slice(base_offset, slice_extent)
    213             .device(device) =
    214             input.template chip<3>(i).slice(slice_offset, slice_extent);
    215       } else {
    216         output.template chip<3>(i).device(device) =
    217             input.template chip<3>(i).slice(slice_offset, slice_extent);
    218       }
    219     }
    220   }
    221 
    222  private:
    223   const Index width_;
    224   const Index height_;
    225   const std::vector<IndexPair<float> > offsets_;
    226   const bool normalized_;
    227   const bool centered_;
    228   const bool uniform_noise_;
    229 };
    230 }  // namespace
    231 
    232 template <typename Input>
    233 EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp<
    234     const GlimpseExtractionOp<typename internal::traits<Input>::Index>,
    235     const Input>
    236 ExtractGlimpses(const Input& input,
    237                 const typename internal::traits<Input>::Index width,
    238                 const typename internal::traits<Input>::Index height,
    239                 const std::vector<IndexPair<float> >& offsets,
    240                 const bool normalized = true, const bool centered = true,
    241                 const bool uniform_noise = true) {
    242   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
    243                       YOU_MADE_A_PROGRAMMING_MISTAKE);
    244   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
    245                       YOU_MADE_A_PROGRAMMING_MISTAKE);
    246 
    247   typedef typename internal::traits<Input>::Index Index;
    248   const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
    249                                       centered, uniform_noise);
    250   return input.customOp(op);
    251 }
    252 
    253 }  // end namespace Eigen
    254 
    255 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
    256