1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 namespace Eigen { 22 23 /** ExtractGlimpses 24 * \ingroup CXX11_NeuralNetworks_Module 25 * 26 * \brief Extract glimpses from an input tensor. 27 * 28 * The input parameter is expected to be a col-major tensor with a rank of 4 29 * (depth, x, y, and batch). The width and height parameters specify the 30 * extension of the returned glimpses. The offsets parameter specifies the x, y 31 * locations of the center of the glimpses relative to the center of the input 32 * image. The vector is expected to contain one IndexPair for each image in the 33 * batch dimension. The normalized boolean indicates if incoming coordinates are 34 * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each 35 * height and width dimension. The centered boolean indicates if incoming 36 * coordinates are centered relative to the image, in which case -1.0 and 1.0 37 * correspond to minimum and maximum of each dimension while 0.0 corresponds to 38 * the center. 39 * 40 * The result can be assigned to a tensor of rank equal to that of the input. 41 * The result will be laid out in col-major order (depth, x, y, batch). The 42 * dimensions of the result will be equal to the dimensions of the input except 43 * for width and height which will be equal to the requested glimpse size. 44 */ 45 namespace { 46 template <typename Index> 47 struct GlimpseExtractionOp { 48 GlimpseExtractionOp(const Index width, const Index height, 49 const std::vector<IndexPair<float> >& offsets, 50 const bool normalized, const bool centered, 51 const bool uniform_noise) 52 : width_(width), 53 height_(height), 54 offsets_(offsets), 55 normalized_(normalized), 56 centered_(centered), 57 uniform_noise_(uniform_noise) {} 58 59 template <typename Input> 60 DSizes<Index, 4> dimensions(const Input& input) const { 61 typedef typename internal::traits<Input>::Index IndexType; 62 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, 63 internal::traits<Input>::Layout, IndexType> > 64 Ref; 65 Ref in(input); 66 67 DSizes<Index, 4> dims = in.dimensions(); 68 69 dims[0] = in.dimension(0); 70 dims[1] = width_; 71 dims[2] = height_; 72 dims[3] = in.dimension(3); 73 return dims; 74 } 75 76 template <typename Input, typename Output, typename Device> 77 EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output, 78 const Device& device) const { 79 typedef typename internal::traits<Input>::Index IndexType; 80 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, 81 internal::traits<Input>::Layout, IndexType> > 82 Ref; 83 Ref in(input); 84 const Index num_channels = in.dimension(0); 85 const Index input_width = in.dimension(1); 86 const Index input_height = in.dimension(2); 87 const Index batch_size = in.dimension(3); 88 eigen_assert(input_width > 0); 89 eigen_assert(input_height > 0); 90 internal::NormalRandomGenerator<float> gen; 91 internal::UniformRandomGenerator<float> unigen; 92 93 for (Index i = 0; i < batch_size; ++i) { 94 float x = offsets_[i].first, y = offsets_[i].second; 95 96 // Un-normalize coordinates back to pixel space if normalized. 97 if (normalized_) { 98 x *= input_width; 99 y *= input_height; 100 } 101 // Un-center if coordinates are centered on the image center. 102 if (centered_) { 103 x /= 2.0f; 104 y /= 2.0f; 105 x += input_width / 2.0f; 106 y += input_height / 2.0f; 107 } 108 // Remove half of the glimpse window. 109 x -= width_ / 2.0f; 110 y -= height_ / 2.0f; 111 112 const Index offset_x = (Index)x; 113 const Index offset_y = (Index)y; 114 Index glimpse_width = width_; 115 Index glimpse_height = height_; 116 bool partial_overlap = false; 117 DSizes<Index, 3> slice_offset(0, offset_x, offset_y); 118 DSizes<Index, 3> slice_extent(num_channels, width_, height_); 119 DSizes<Index, 3> base_offset(0, 0, 0); 120 121 if (offset_x < 0) { 122 slice_offset[1] = 0; 123 glimpse_width = (std::max<Index>)(0, width_ + offset_x); 124 slice_extent[1] = glimpse_width; 125 base_offset[1] = width_ - glimpse_width; 126 partial_overlap = true; 127 } else if (offset_x + width_ >= input_width) { 128 glimpse_width = (std::max<Index>)(0, input_width - offset_x); 129 slice_extent[1] = glimpse_width; 130 partial_overlap = true; 131 } 132 if (offset_y < 0) { 133 slice_offset[2] = 0; 134 glimpse_height = (std::max<Index>)(0, height_ + offset_y); 135 slice_extent[2] = glimpse_height; 136 base_offset[2] = height_ - glimpse_height; 137 partial_overlap = true; 138 } else if (offset_y + height_ >= input_height) { 139 glimpse_height = (std::max<Index>)(0, input_height - offset_y); 140 slice_extent[2] = glimpse_height; 141 partial_overlap = true; 142 } 143 slice_extent[1] = std::min<Index>(input_width, slice_extent[1]); 144 slice_extent[2] = std::min<Index>(input_height, slice_extent[2]); 145 146 if (partial_overlap) { 147 if (uniform_noise_) { 148 // Initialize the glimpse with uniform noise. 149 typedef typename internal::remove_const< 150 typename internal::traits<Input>::Scalar>::type Scalar; 151 TensorFixedSize<Scalar, Sizes<> > mini; 152 mini.device(device) = input.template chip<3>(i).minimum(); 153 TensorFixedSize<float, Sizes<> > range; 154 range.device(device) = (input.template chip<3>(i).maximum() - mini) 155 .template cast<float>(); 156 157 DSizes<Index, 3> glimpse_size(num_channels, width_, height_); 158 TensorMap<Tensor<float, 3> > tmp(NULL, glimpse_size); 159 output.template chip<3>(i).device(device) = 160 mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) + 161 (tmp.random(unigen) * 162 range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size)) 163 .template cast<Scalar>(); 164 } else { 165 // Initialize the glimpse with white noise: compute the mean and sigma 166 // of each channel, and use them to shape the gaussian. 167 DSizes<Index, 2> glimpse_size(width_, height_); 168 DSizes<Index, 2> input_size(input_width, input_height); 169 typedef typename internal::remove_const< 170 typename internal::traits<Input>::Scalar>::type Scalar; 171 172 for (int j = 0; j < num_channels; ++j) { 173 TensorFixedSize<Scalar, Sizes<> > mean; 174 mean.device(device) = input.template chip<3>(i) 175 .template chip<0>(j) 176 .template cast<float>() 177 .mean(); 178 TensorFixedSize<float, Sizes<> > sigma; 179 sigma.device(device) = 180 (input.template chip<3>(i) 181 .template chip<0>(j) 182 .template cast<float>() - 183 mean.reshape(Sizes<1, 1>()).broadcast(input_size)) 184 .square() 185 .mean() 186 .sqrt(); 187 TensorFixedSize<Scalar, Sizes<> > mini; 188 mini.device(device) = 189 input.template chip<3>(i).template chip<0>(j).minimum(); 190 TensorFixedSize<float, Sizes<> > maxi; 191 maxi.device(device) = 192 input.template chip<3>(i).template chip<0>(j).maximum(); 193 194 TensorMap<Tensor<float, 2> > tmp(NULL, glimpse_size); 195 output.template chip<3>(i).template chip<0>(j).device(device) = 196 (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) + 197 (tmp.random(gen) * 198 sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) 199 .template cast<Scalar>()) 200 .cwiseMin( 201 maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) 202 .cwiseMax( 203 mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size)); 204 } 205 } 206 207 // Copy the part of the glimpse that cover the input image if any. 208 if (glimpse_width == 0 || glimpse_height == 0) { 209 continue; 210 } 211 output.template chip<3>(i) 212 .slice(base_offset, slice_extent) 213 .device(device) = 214 input.template chip<3>(i).slice(slice_offset, slice_extent); 215 } else { 216 output.template chip<3>(i).device(device) = 217 input.template chip<3>(i).slice(slice_offset, slice_extent); 218 } 219 } 220 } 221 222 private: 223 const Index width_; 224 const Index height_; 225 const std::vector<IndexPair<float> > offsets_; 226 const bool normalized_; 227 const bool centered_; 228 const bool uniform_noise_; 229 }; 230 } // namespace 231 232 template <typename Input> 233 EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp< 234 const GlimpseExtractionOp<typename internal::traits<Input>::Index>, 235 const Input> 236 ExtractGlimpses(const Input& input, 237 const typename internal::traits<Input>::Index width, 238 const typename internal::traits<Input>::Index height, 239 const std::vector<IndexPair<float> >& offsets, 240 const bool normalized = true, const bool centered = true, 241 const bool uniform_noise = true) { 242 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, 243 YOU_MADE_A_PROGRAMMING_MISTAKE); 244 EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, 245 YOU_MADE_A_PROGRAMMING_MISTAKE); 246 247 typedef typename internal::traits<Input>::Index Index; 248 const GlimpseExtractionOp<Index> op(width, height, offsets, normalized, 249 centered, uniform_noise); 250 return input.customOp(op); 251 } 252 253 } // end namespace Eigen 254 255 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ 256