1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_util.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/gtl/inlined_vector.h" 30 #include "tensorflow/core/util/sparse/sparse_tensor.h" 31 32 namespace tensorflow { 33 34 void Reshape(OpKernelContext *context, const Tensor &input_indices_in, 35 const Tensor &input_shape_in, const Tensor &target_shape_in, 36 int output_indices_idx, int output_shape_idx) { 37 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()), 38 errors::InvalidArgument( 39 "Input indices should be a matrix but received shape ", 40 input_indices_in.shape().DebugString())); 41 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), 42 errors::InvalidArgument( 43 "Input shape should be a vector but received shape ", 44 input_shape_in.shape().DebugString())); 45 OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()), 46 errors::InvalidArgument( 47 "Target shape should be a vector but received shape ", 48 target_shape_in.shape().DebugString())); 49 50 const int64 input_rank = input_shape_in.NumElements(); 51 const int64 output_rank = target_shape_in.NumElements(); 52 const TensorShape input_shape(input_shape_in.vec<int64>()); 53 const int64 dense_size = input_shape.num_elements(); 54 const int64 nnz = input_indices_in.shape().dim_size(0); 55 56 // Compute the output shape. Determine product of specified dimensions, and 57 // find the index of the unspecified one. 58 TensorShape output_shape; 59 int64 product = 1; 60 int unknown_index = -1; 61 auto target_shape = target_shape_in.vec<int64>(); 62 for (int d = 0; d < output_rank; ++d) { 63 const int64 size = target_shape(d); 64 if (size == -1) { 65 OP_REQUIRES( 66 context, unknown_index == -1, 67 errors::InvalidArgument("only one output dimension may be -1, " 68 "not both ", 69 unknown_index, " and ", d)); 70 unknown_index = d; 71 output_shape.AddDim(1); 72 } else { 73 OP_REQUIRES(context, size >= 0, 74 errors::InvalidArgument("size ", d, 75 " must be non-negative, not ", size)); 76 product *= size; 77 output_shape.AddDim(size); 78 } 79 } 80 if (unknown_index != -1) { 81 OP_REQUIRES( 82 context, product > 0, 83 errors::InvalidArgument("reshape cannot infer the missing " 84 "input size for an empty tensor unless all " 85 "specified input sizes are non-zero")); 86 const int64 missing = dense_size / product; 87 OP_REQUIRES( 88 context, product * missing == dense_size, 89 errors::InvalidArgument( 90 "Input to reshape is a SparseTensor with ", dense_size, 91 " dense values, but the requested shape requires a multiple of ", 92 product)); 93 output_shape.set_dim(unknown_index, missing); 94 } 95 96 OP_REQUIRES( 97 context, output_shape.num_elements() == dense_size, 98 errors::InvalidArgument("Input to reshape is a tensor with ", dense_size, 99 " dense values, but the requested shape has ", 100 output_shape.num_elements())); 101 102 // Optimize for reshaping to the same shape. 103 if (input_shape == output_shape) { 104 context->set_output(output_indices_idx, input_indices_in); 105 context->set_output(output_shape_idx, input_shape_in); 106 return; 107 } 108 109 gtl::InlinedVector<int64, 8> input_strides(input_rank); 110 input_strides[input_rank - 1] = 1; 111 for (int d = input_rank - 2; d >= 0; --d) { 112 input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); 113 } 114 115 gtl::InlinedVector<int64, 8> output_strides(output_rank); 116 output_strides[output_rank - 1] = 1; 117 for (int d = output_rank - 2; d >= 0; --d) { 118 output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); 119 } 120 121 Tensor *result_indices = nullptr; 122 OP_REQUIRES_OK(context, 123 context->allocate_output(output_indices_idx, 124 TensorShape({nnz, output_rank}), 125 &result_indices)); 126 auto input_ind = input_indices_in.matrix<int64>(); 127 auto output_ind = result_indices->matrix<int64>(); 128 for (int i = 0; i < nnz; ++i) { 129 int64 id = 0; 130 for (int j = 0; j < input_rank; ++j) { 131 id += input_ind(i, j) * input_strides[j]; 132 } 133 for (int j = 0; j < output_rank; ++j) { 134 output_ind(i, j) = id / output_strides[j]; 135 id %= output_strides[j]; 136 } 137 } 138 139 Tensor *result_shape = nullptr; 140 OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, 141 TensorShape({output_rank}), 142 &result_shape)); 143 auto output_shape_vec = result_shape->vec<int64>(); 144 for (int j = 0; j < output_shape.dims(); ++j) { 145 output_shape_vec(j) = output_shape.dim_size(j); 146 } 147 } 148 149 } // namespace tensorflow 150