Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/image_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <vector>
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/eigen_attention.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace tensorflow {
     31 
     32 class ExtractGlimpseOp : public OpKernel {
     33  public:
     34   explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
     35     OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
     36     OP_REQUIRES_OK(context, context->GetAttr("centered", &centered_));
     37     OP_REQUIRES_OK(context, context->GetAttr("uniform_noise", &uniform_noise_));
     38   }
     39 
     40   // Expect input tensor of rank 4 with dimensions (batch_size, height, width,
     41   // depth).
     42   void Compute(OpKernelContext* context) override {
     43     const Tensor& input = context->input(0);
     44     const TensorShape& input_shape = input.shape();
     45     const int32 num_dims = input_shape.dims();
     46     OP_REQUIRES(
     47         context, num_dims == 4,
     48         errors::InvalidArgument(
     49             "input must be 4-dimensional (batch_size, height, width, depth)",
     50             input_shape.DebugString()));
     51 
     52     const int64 batch_size = input_shape.dim_size(0);
     53 
     54     const Tensor& window_size = context->input(1);
     55     OP_REQUIRES(context,
     56                 (window_size.shape().dims() == 1) &&
     57                     window_size.shape().dim_size(0) == 2,
     58                 errors::InvalidArgument(
     59                     "input must be a vector of size 2 (height, width)",
     60                     window_size.shape().DebugString()));
     61 
     62     const int64 output_height = window_size.tensor<int, 1>()(0);
     63     const int64 output_width = window_size.tensor<int, 1>()(1);
     64     TensorShape output_shape = input_shape;
     65     output_shape.set_dim(1, output_height);
     66     output_shape.set_dim(2, output_width);
     67 
     68     const Tensor& offsets = context->input(2);
     69     OP_REQUIRES(context, offsets.shape().dims() == 2,
     70                 errors::InvalidArgument("input must be a matrix",
     71                                         offsets.shape().DebugString()));
     72     OP_REQUIRES(context, offsets.shape().dim_size(0) == batch_size,
     73                 errors::InvalidArgument("first dimension should be batch",
     74                                         offsets.shape().DebugString()));
     75     OP_REQUIRES(
     76         context, offsets.shape().dim_size(1) == 2,
     77         errors::InvalidArgument("second dimension should be of size 2 (y,x)",
     78                                 offsets.shape().DebugString()));
     79 
     80     Tensor* output = nullptr;
     81     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
     82     if (output->NumElements() == 0) {
     83       // Nothing else to do.
     84       return;
     85     }
     86 
     87     std::vector<Eigen::IndexPair<float> > offset_vec;
     88     offset_vec.reserve(batch_size);
     89     for (int i = 0; i < batch_size; ++i) {
     90       float offset_y = offsets.tensor<float, 2>()(i, 0);
     91       float offset_x = offsets.tensor<float, 2>()(i, 1);
     92       // Eigen::ExtractGlimpses expects offsets as (x,y), whereas the
     93       // calling TensorFlow operates with (y,x) as indices.
     94       offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
     95     }
     96 
     97     output->tensor<float, 4>().swap_layout().device(
     98         context->eigen_cpu_device()) =
     99         Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
    100                                output_width, output_height, offset_vec,
    101                                normalized_, centered_, uniform_noise_);
    102   }
    103 
    104  private:
    105   bool normalized_;
    106   bool centered_;
    107   bool uniform_noise_;
    108 };
    109 
    110 REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
    111                         ExtractGlimpseOp);
    112 
    113 }  // end namespace tensorflow
    114