1 // ============================================================================= 2 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // ============================================================================= 16 17 #ifndef TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ 18 #define TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ 19 20 #include <cmath> 21 #include <type_traits> 22 #include <vector> 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/shape_inference.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/lib/core/status.h" 28 29 namespace { 30 31 template <class IndexVecT, class IndexT> 32 IndexT compute_input_index( 33 IndexVecT* target_dimensions, const IndexT& output_index, 34 const IndexVecT& original_dimensions, const int& adjustable_dimension, 35 const std::vector<tensorflow::int64>& dimension_ceiling, 36 const std::vector<tensorflow::int64>& cumulative_dimensions, IndexT* result, 37 std::vector<IndexT>* output_indices, const int& rank) { 38 *result = 0; 39 output_indices->clear(); 40 41 // un-rasterize the output index 42 auto last_reduced_i = output_index; 43 for (auto r = rank - 1; r >= 0; --r) { 44 (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; 45 last_reduced_i = 46 (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; 47 } 48 49 // rasterize the input index 50 IndexT last_index_factor = 1; 51 for (auto r = rank - 1; r >= 0; --r) { 52 IndexT index = 0; 53 if (r != adjustable_dimension) 54 index = (*output_indices)[r] / dimension_ceiling[r]; 55 else { 56 for (int qi = 0; qi < rank; ++qi) { 57 if (qi == adjustable_dimension) continue; 58 index += cumulative_dimensions[qi] * 59 ((*output_indices)[qi] % dimension_ceiling[qi]); 60 } 61 index *= (*target_dimensions)[adjustable_dimension]; 62 index += (*output_indices)[r]; 63 } 64 *result += last_index_factor * index; 65 last_index_factor *= original_dimensions[r]; 66 } 67 68 return *result; 69 } 70 71 template <class InputDataT, 72 class IndexVecT> // both types are needed here b/c IndexVecT and 73 // InputDataT are not related 74 void 75 fill_periodic_tensor( 76 tensorflow::OpKernelContext* context, 77 const IndexVecT& desired_shape, 78 const tensorflow::Tensor& input_tensor) { 79 // input is a strided array (last index is fastest, C-ordered) 80 auto input = input_tensor.flat<InputDataT>(); 81 const int rank = input_tensor.dims(); 82 // original and target dimensions 83 std::vector<tensorflow::int64> original_dimensions(rank), 84 target_dimensions(rank); 85 tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); 86 // factors by which original_dimensions increases/decreases w.r.t. 87 // target_dimensions 88 std::vector<tensorflow::int64> dimension_ceiling(rank), 89 cumulative_dimensions(rank); 90 // index of adjustable dimension 91 int adjustable_dimension; 92 tensorflow::TensorShape output_shape; 93 94 // requires that the rank of the input tensor and length of the desired shape 95 // are equal 96 OP_REQUIRES(context, rank == desired_shape.size(), 97 tensorflow::errors::InvalidArgument( 98 "periodic_resample expects the rank of the input tensor, ", 99 rank, ", to be the same as the length of the desired shape, ", 100 desired_shape.size(), ".")); 101 102 bool found = false; 103 const auto& input_tensor_shape = input_tensor.shape(); 104 105 for (int i = 0; i < rank; ++i) { 106 // if (desired_shape(i) < 1) { 107 if (desired_shape[i] < 1) { 108 // only one index can be adjustable 109 OP_REQUIRES(context, !found, 110 tensorflow::errors::InvalidArgument( 111 "periodic_resample expects only " 112 "one index to be marked as adjustable.")); 113 adjustable_dimension = i; 114 found = true; 115 } else { 116 OP_REQUIRES( 117 context, desired_shape[i] >= input_tensor_shape.dim_size(i), 118 tensorflow::errors::InvalidArgument( 119 "periodic_resample expects the size of non-adjustable " 120 "dimensions be at least as large as size of input tensor." 121 " Dimension ", 122 i, " input tensor has size ", input_tensor_shape.dim_size(i), 123 ", desired shape has size ", desired_shape[i], ".")); 124 125 // target_dimensions[i] = desired_shape(i); 126 target_dimensions[i] = desired_shape[i]; 127 new_sliced_size *= target_dimensions[i]; 128 } 129 } 130 // at least one index needs to be adjustable 131 OP_REQUIRES(context, found, 132 tensorflow::errors::InvalidArgument( 133 "periodic_resample expects at least " 134 "one index to be marked as adjustable.")); 135 136 int count = 0; 137 for (const auto dim_info : input_tensor.shape()) { 138 original_dimensions[count] = dim_info.size; 139 ++count; 140 } 141 142 target_dimensions[adjustable_dimension] = total_size / new_sliced_size; 143 144 count = 0; 145 for (int i = 0; i < input_tensor.shape().dims(); ++i) { 146 dimension_ceiling[count] = tensorflow::int64(std::ceil( 147 float(target_dimensions[count]) / float(original_dimensions[count]))); 148 if (count == 0) 149 cumulative_dimensions[count] = 1; 150 else 151 cumulative_dimensions[count] = 152 cumulative_dimensions[count - 1] * dimension_ceiling[count - 1]; 153 ++count; 154 } 155 156 // ensure that the new dimension is greater than zero 157 OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, 158 tensorflow::errors::InvalidArgument( 159 "periodic_resample found that the " 160 "adjustable dimension, ", 161 adjustable_dimension, ", isn't greater than zero, ", 162 target_dimensions[adjustable_dimension], ".")); 163 for (int i = 0; i < rank; ++i) { 164 output_shape.AddDim(target_dimensions[i]); 165 } 166 const auto new_size = 167 new_sliced_size * target_dimensions[adjustable_dimension]; 168 169 // Create an output tensor and attach it to the current context 170 tensorflow::Tensor* output_tensor = nullptr; 171 OP_REQUIRES_OK(context, 172 context->allocate_output(0, output_shape, &output_tensor)); 173 auto output = output_tensor->flat<InputDataT>(); 174 175 // memory is allocated for these variables outside the inner loop for 176 // efficiency (although, I could create a separate class scope for 177 // this purpose instead) 178 tensorflow::int64 result = 0; 179 std::vector<tensorflow::int64> output_indices(target_dimensions.size()); 180 181 // Fill output tensor with periodically resampled input tensor values 182 for (tensorflow::int64 output_index = 0; output_index < new_size; 183 ++output_index) { 184 output(output_index) = input(compute_input_index( 185 &target_dimensions, output_index, original_dimensions, 186 adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, 187 &output_indices, rank)); 188 } 189 } 190 191 void create_output_tensor( 192 tensorflow::OpKernelContext* context, 193 const tensorflow::Tensor& input_tensor, 194 const tensorflow::DataType& input_tensor_type, 195 const tensorflow::PartialTensorShape& desired_shape_tensor) { 196 auto desired_shape = desired_shape_tensor.dim_sizes(); 197 198 // obligatory type switch 199 switch (input_tensor_type) { 200 case tensorflow::DataTypeToEnum<float>::value: 201 fill_periodic_tensor<float>(context, desired_shape, input_tensor); 202 break; 203 case tensorflow::DataTypeToEnum<double>::value: 204 fill_periodic_tensor<double>(context, desired_shape, input_tensor); 205 break; 206 case tensorflow::DataTypeToEnum<tensorflow::int32>::value: 207 fill_periodic_tensor<tensorflow::int32>(context, desired_shape, 208 input_tensor); 209 break; 210 case tensorflow::DataTypeToEnum<tensorflow::int64>::value: 211 fill_periodic_tensor<tensorflow::int64>(context, desired_shape, 212 input_tensor); 213 break; 214 default:; 215 } 216 } 217 218 } // namespace 219 220 class PeriodicResampleOp : public tensorflow::OpKernel { 221 public: 222 explicit PeriodicResampleOp(tensorflow::OpKernelConstruction* context) 223 : tensorflow::OpKernel(context) { 224 // Get the desired shape 225 OP_REQUIRES_OK(context, context->GetAttr("shape", &desired_shape)); 226 } 227 228 void Compute(tensorflow::OpKernelContext* context) override { 229 // Grab the input tensor 230 const tensorflow::Tensor& input_tensor = context->input(0); 231 const tensorflow::DataType input_tensor_type = context->input_dtype(0); 232 233 create_output_tensor(context, input_tensor, input_tensor_type, 234 desired_shape); 235 } 236 237 private: 238 tensorflow::PartialTensorShape desired_shape; 239 }; 240 241 #endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ 242