1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/image_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <vector> 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/kernels/eigen_attention.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 class ExtractGlimpseOp : public OpKernel { 33 public: 34 explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) { 35 OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_)); 36 OP_REQUIRES_OK(context, context->GetAttr("centered", ¢ered_)); 37 OP_REQUIRES_OK(context, context->GetAttr("uniform_noise", &uniform_noise_)); 38 } 39 40 // Expect input tensor of rank 4 with dimensions (batch_size, height, width, 41 // depth). 42 void Compute(OpKernelContext* context) override { 43 const Tensor& input = context->input(0); 44 const TensorShape& input_shape = input.shape(); 45 const int32 num_dims = input_shape.dims(); 46 OP_REQUIRES( 47 context, num_dims == 4, 48 errors::InvalidArgument( 49 "input must be 4-dimensional (batch_size, height, width, depth)", 50 input_shape.DebugString())); 51 52 const int64 batch_size = input_shape.dim_size(0); 53 54 const Tensor& window_size = context->input(1); 55 OP_REQUIRES(context, 56 (window_size.shape().dims() == 1) && 57 window_size.shape().dim_size(0) == 2, 58 errors::InvalidArgument( 59 "input must be a vector of size 2 (height, width)", 60 window_size.shape().DebugString())); 61 62 const int64 output_height = window_size.tensor<int, 1>()(0); 63 const int64 output_width = window_size.tensor<int, 1>()(1); 64 TensorShape output_shape = input_shape; 65 output_shape.set_dim(1, output_height); 66 output_shape.set_dim(2, output_width); 67 68 const Tensor& offsets = context->input(2); 69 OP_REQUIRES(context, offsets.shape().dims() == 2, 70 errors::InvalidArgument("input must be a matrix", 71 offsets.shape().DebugString())); 72 OP_REQUIRES(context, offsets.shape().dim_size(0) == batch_size, 73 errors::InvalidArgument("first dimension should be batch", 74 offsets.shape().DebugString())); 75 OP_REQUIRES( 76 context, offsets.shape().dim_size(1) == 2, 77 errors::InvalidArgument("second dimension should be of size 2 (y,x)", 78 offsets.shape().DebugString())); 79 80 Tensor* output = nullptr; 81 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 82 if (output->NumElements() == 0) { 83 // Nothing else to do. 84 return; 85 } 86 87 std::vector<Eigen::IndexPair<float> > offset_vec; 88 offset_vec.reserve(batch_size); 89 for (int i = 0; i < batch_size; ++i) { 90 float offset_y = offsets.tensor<float, 2>()(i, 0); 91 float offset_x = offsets.tensor<float, 2>()(i, 1); 92 // Eigen::ExtractGlimpses expects offsets as (x,y), whereas the 93 // calling TensorFlow operates with (y,x) as indices. 94 offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y)); 95 } 96 97 output->tensor<float, 4>().swap_layout().device( 98 context->eigen_cpu_device()) = 99 Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(), 100 output_width, output_height, offset_vec, 101 normalized_, centered_, uniform_noise_); 102 } 103 104 private: 105 bool normalized_; 106 bool centered_; 107 bool uniform_noise_; 108 }; 109 110 REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU), 111 ExtractGlimpseOp); 112 113 } // end namespace tensorflow 114