1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/batch_util.h" 17 18 #include "tensorflow/core/framework/register_types.h" 19 #include "tensorflow/core/framework/types.h" 20 #include "tensorflow/core/lib/core/errors.h" 21 22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) 23 24 namespace tensorflow { 25 namespace batch_util { 26 27 namespace { 28 29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) { 30 DCHECK_NE(parent.dim_size(0), 0); 31 DCHECK_GE(index, 0); 32 if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) { 33 TensorShape chip_shape = parent.shape(); 34 chip_shape.RemoveDim(0); 35 return errors::Internal( 36 "ValidateInput Cannot perform copy: number of elements does not match. " 37 " Shapes are: [element]: ", 38 element.shape().DebugString(), 39 ", [parent slice]: ", chip_shape.DebugString()); 40 } 41 return Status::OK(); 42 } 43 44 template <typename T> 45 Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index, 46 bool /* can_move */) { 47 parent->flat_outer_dims<T>().chip(index, 0) = element.flat<T>(); 48 return Status::OK(); 49 } 50 51 template <> 52 Status HandleElementToSlice<string>(Tensor element, Tensor* parent, int64 index, 53 bool can_move) { 54 auto parent_as_matrix = parent->flat_outer_dims<string>(); 55 auto element_flat = element.flat<string>(); 56 if (can_move) { 57 for (int64 i = 0; i < element.NumElements(); ++i) { 58 parent_as_matrix(index, i) = std::move(element_flat(i)); 59 } 60 } else { 61 parent_as_matrix.chip(index, 0) = element_flat; 62 } 63 return Status::OK(); 64 } 65 66 template <> 67 Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent, 68 int64 index, bool can_move) { 69 auto parent_as_matrix = parent->flat_outer_dims<Variant>(); 70 auto element_flat = element.flat<Variant>(); 71 if (can_move) { 72 for (int64 i = 0; i < element.NumElements(); ++i) { 73 parent_as_matrix(index, i) = std::move(element_flat(i)); 74 } 75 } else { 76 parent_as_matrix.chip(index, 0) = element_flat; 77 } 78 return Status::OK(); 79 } 80 81 // TODO(jsimsa): Add HandleElementToSlice<variant> specialization that moves 82 // the data when possible. 83 84 template <typename T> 85 static Status HandleSliceToElement(const Tensor& parent, Tensor* element, 86 int64 index) { 87 element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0); 88 return Status::OK(); 89 } 90 91 } // namespace 92 93 // Copies element into the index^th slice of parent (in the 0th dimension). 94 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { 95 TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index)); 96 97 bool can_move = element.RefCountIsOne(); 98 #define HANDLE_TYPE(T) \ 99 case DataTypeToEnum<T>::value: { \ 100 return HandleElementToSlice<T>(std::move(element), parent, index, \ 101 can_move); \ 102 } 103 104 switch (element.dtype()) { 105 TF_CALL_ALL_TYPES(HANDLE_TYPE); 106 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); 107 #undef HANDLE_TYPE 108 default: 109 return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", 110 element.dtype()); 111 } 112 } 113 114 // Copies the index^th slice of parent (in the 0th dimension) into element. 115 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { 116 TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index)); 117 118 #define HANDLE_TYPE(T) \ 119 case DataTypeToEnum<T>::value: { \ 120 return HandleSliceToElement<T>(parent, element, index); \ 121 } 122 123 switch (parent.dtype()) { 124 TF_CALL_ALL_TYPES(HANDLE_TYPE); 125 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); 126 #undef HANDLE_TYPE 127 default: 128 return errors::Unimplemented("CopySliceToElement Unhandled data type: ", 129 element->dtype()); 130 } 131 } 132 133 // The following five functions are copied from padding_fifo_queue.cc. 134 // TODO(mrry): Reconcile these functions with the similar methods in the 135 // queue implementation. 136 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { 137 DCHECK_NE(parent->dim_size(0), 0); 138 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { 139 TensorShape chip_shape = parent->shape(); 140 chip_shape.RemoveDim(0); 141 return errors::Internal( 142 "HandleElementToLargerSlice Cannot copy slice: number of entries in " 143 "element is greater than number of elements in parent slice. ", 144 "Shapes are: [element]: ", element.shape().DebugString(), 145 ", [parent slice]: ", chip_shape.DebugString()); 146 } 147 return Status::OK(); 148 } 149 150 template <typename T, int NDIMS> 151 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, 152 int index) { 153 TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); 154 if (element.NumElements() == 0) { 155 return Status::OK(); 156 } 157 auto element_t = element.tensor<T, NDIMS>(); 158 auto parent_t = parent->tensor<T, NDIMS + 1>(); 159 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; 160 slice_indices[0] = index; 161 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size; 162 slice_size[0] = 1; 163 for (size_t i = 1; i < slice_size.size(); ++i) { 164 slice_size[i] = element_t.dimension(i - 1); 165 } 166 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); 167 return Status::OK(); 168 } 169 170 template <int NDIMS> 171 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, 172 int index) { 173 #define HANDLE_TYPE(T) \ 174 case DataTypeToEnum<T>::value: { \ 175 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \ 176 } 177 178 switch (element.dtype()) { 179 TF_CALL_DATASET_TYPES(HANDLE_TYPE); 180 #undef HANDLE_TYPE 181 default: 182 return errors::Unimplemented( 183 "HandleElementToLargerSliceWithRank Unhandled data type: ", 184 element.dtype()); 185 } 186 } 187 188 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, 189 int index) { 190 if (parent->dims() != element.dims() + 1) { 191 return errors::Internal( 192 "Mismatched ranks. Element's rank is: ", element.dims(), 193 " but element is meant to be a slice in output Tensor having rank: ", 194 parent->dims(), " (should be: ", element.dims() + 1, ")"); 195 } 196 197 #define HANDLE_DIMS(NDIMS) \ 198 case NDIMS: { \ 199 TF_RETURN_IF_ERROR( \ 200 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \ 201 return Status::OK(); \ 202 } 203 204 switch (element.dims()) { 205 HANDLE_DIMS(0); 206 HANDLE_DIMS(1); 207 HANDLE_DIMS(2); 208 HANDLE_DIMS(3); 209 HANDLE_DIMS(4); 210 #undef HANDLE_DIMS 211 default: 212 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", 213 element.dims()); 214 } 215 } 216 217 Status SetElementZero(Tensor* element, const Tensor& padding) { 218 #define HANDLE_TYPE(T) \ 219 if (element->dtype() == DataTypeToEnum<T>::value) { \ 220 element->flat<T>().setConstant(padding.scalar<T>()()); \ 221 return Status::OK(); \ 222 } 223 TF_CALL_DATASET_TYPES(HANDLE_TYPE); 224 #undef HANDLE_TYPE 225 return errors::Unimplemented("SetElementZero Unhandled data type: ", 226 element->dtype()); 227 } 228 229 } // namespace batch_util 230 } // namespace tensorflow 231