1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/tensor_util.h" 17 18 #include <cmath> 19 #include <vector> 20 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/type_traits.h" 23 #include "tensorflow/core/framework/variant.h" 24 #include "tensorflow/core/lib/core/stringpiece.h" 25 #include "tensorflow/core/platform/protobuf.h" 26 #include "tensorflow/core/platform/tensor_coding.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace tensor { 31 32 Tensor DeepCopy(const Tensor& other) { 33 Tensor tmp = Tensor(other.dtype(), other.shape()); 34 if (DataTypeCanUseMemcpy(other.dtype())) { 35 if (other.NumElements() > 0) { 36 StringPiece other_data = other.tensor_data(); 37 38 // We use StringPiece as a convenient map over the tensor buffer, 39 // but we cast the type to get to the underlying buffer to do the 40 // copy. 41 StringPiece tmp_data = tmp.tensor_data(); 42 memcpy(const_cast<char*>(tmp_data.data()), other_data.data(), 43 other_data.size()); 44 } 45 } else if (other.dtype() == DT_STRING) { 46 tmp.unaligned_flat<string>() = other.unaligned_flat<string>(); 47 } else { 48 CHECK_EQ(DT_VARIANT, other.dtype()); 49 tmp.unaligned_flat<Variant>() = other.unaligned_flat<Variant>(); 50 } 51 return tmp; 52 } 53 54 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) { 55 if (tensors.empty()) { 56 return errors::InvalidArgument("Cannot concatenate zero tensors"); 57 } 58 int64 total_dim0_size = 0; 59 for (const Tensor& tensor : tensors) { 60 if (tensor.dims() == 0) { 61 return errors::InvalidArgument( 62 "Cannot concatenate a zero-dimensional tensor"); 63 } 64 total_dim0_size += tensor.dim_size(0); 65 } 66 TensorShape shape = tensors[0].shape(); 67 shape.set_dim(0, total_dim0_size); 68 69 const DataType dtype = tensors[0].dtype(); 70 for (int i = 1; i < tensors.size(); ++i) { 71 if (tensors[i].dtype() != dtype) { 72 return errors::InvalidArgument( 73 "Cannot concatenate tensors that have different data types"); 74 } 75 } 76 *result = Tensor(dtype, shape); 77 78 // We use StringPiece as a convenient map over the tensor buffer, 79 // but we cast the type to get to the underlying buffer to do the 80 // copy. 81 StringPiece to_data = result->tensor_data(); 82 83 if (DataTypeCanUseMemcpy(dtype)) { 84 int64 offset = 0; 85 for (const Tensor& tensor : tensors) { 86 StringPiece from_data = tensor.tensor_data(); 87 CHECK_LE(offset + from_data.size(), to_data.size()); 88 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(), 89 from_data.size()); 90 91 offset += from_data.size(); 92 } 93 } else { 94 if (dtype != DT_STRING) { 95 return errors::Internal("Unexpected data type"); 96 } 97 string* to_strings = 98 reinterpret_cast<string*>(const_cast<char*>(to_data.data())); 99 100 int64 offset = 0; 101 for (const Tensor& tensor : tensors) { 102 auto from_strings = tensor.flat<string>(); 103 CHECK_LE(offset + tensor.NumElements(), result->NumElements()); 104 for (int i = 0; i < tensor.NumElements(); ++i) { 105 to_strings[offset + i] = from_strings(i); 106 } 107 108 offset += tensor.NumElements(); 109 } 110 } 111 112 return Status::OK(); 113 } 114 115 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes, 116 std::vector<Tensor>* result) { 117 if (tensor.dims() == 0) { 118 return errors::InvalidArgument("Cannot split a zero-dimensional tensor"); 119 } 120 int64 total_size = 0; 121 for (int64 size : sizes) { 122 total_size += size; 123 } 124 if (total_size != tensor.dim_size(0)) { 125 return errors::InvalidArgument( 126 "The values in 'sizes' do not sum to the zeroth-dimension size of " 127 "'tensor'"); 128 } 129 130 StringPiece from_data = tensor.tensor_data(); 131 132 if (DataTypeCanUseMemcpy(tensor.dtype())) { 133 int64 offset = 0; 134 for (int64 size : sizes) { 135 TensorShape shape = tensor.shape(); 136 shape.set_dim(0, size); 137 result->emplace_back(tensor.dtype(), shape); 138 Tensor* split = &(*result)[result->size() - 1]; 139 140 // We use StringPiece as a convenient map over the tensor buffer, 141 // but we cast the type to get to the underlying buffer to do the 142 // copy. 143 StringPiece to_data = split->tensor_data(); 144 CHECK_LE(offset + to_data.size(), from_data.size()); 145 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset, 146 to_data.size()); 147 148 offset += to_data.size(); 149 } 150 } else { 151 if (tensor.dtype() != DT_STRING) { 152 return errors::Internal("Unexpected data type"); 153 } 154 auto from_strings = tensor.flat<string>(); 155 156 int64 offset = 0; 157 for (int64 size : sizes) { 158 TensorShape shape = tensor.shape(); 159 shape.set_dim(0, size); 160 result->emplace_back(tensor.dtype(), shape); 161 Tensor& split = (*result)[result->size() - 1]; 162 string* to_strings = reinterpret_cast<string*>( 163 const_cast<char*>(split.tensor_data().data())); 164 165 CHECK_LE(offset + split.NumElements(), tensor.NumElements()); 166 for (int i = 0; i < split.NumElements(); ++i) { 167 to_strings[i] = from_strings(offset + i); 168 } 169 170 offset += split.NumElements(); 171 } 172 } 173 174 return Status::OK(); 175 } 176 177 namespace internal { 178 void SetTensorProtoShape(std::vector<size_t> shape, 179 TensorShapeProto* shape_proto) { 180 for (auto dim : shape) { 181 shape_proto->mutable_dim()->Add()->set_size(dim); 182 } 183 } 184 185 template <typename T> 186 bool CompressTensorContent(float min_compression_ratio, 187 const TensorShape& shape, TensorProto* tensor) { 188 using TypeHelper = internal::TensorProtoHelper<T>; 189 using FieldType = typename internal::TensorProtoHelper<T>::FieldType; 190 const int64 num_tensor_values = shape.num_elements(); 191 const int64 num_bytes = tensor->tensor_content().size(); 192 const int64 num_raw_values = num_bytes / sizeof(T); 193 if (num_raw_values != num_tensor_values) { 194 // Invalid or too small. 195 return false; 196 } 197 int64 last_offset = num_bytes - 1; 198 int64 prev_offset = last_offset - sizeof(T); 199 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements, 200 // starting from the end, to find the last pair of elements that are not 201 // identical. 202 while (prev_offset >= 0) { 203 if (tensor->tensor_content()[prev_offset] != 204 tensor->tensor_content()[last_offset]) { 205 break; 206 } 207 --last_offset; 208 --prev_offset; 209 } 210 // Round up to the next whole number of element of type T. 211 const int64 new_num_values = last_offset / sizeof(T) + 1; 212 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) > 213 static_cast<int64>(num_bytes / min_compression_ratio)) { 214 return false; 215 } 216 // Copy values to truncated repeated field. 217 if (sizeof(FieldType) == sizeof(T)) { 218 FieldType* dst_ptr = 219 TypeHelper::AppendUninitialized(new_num_values, tensor); 220 port::CopySubrangeToArray(tensor->tensor_content(), 0, 221 new_num_values * sizeof(T), 222 reinterpret_cast<char*>(dst_ptr)); 223 tensor->clear_tensor_content(); 224 } else if (sizeof(T) > 1) { 225 // Copy raw bytes to temp array first, then cast. 226 gtl::InlinedVector<T, 64> tmp(new_num_values); 227 port::CopySubrangeToArray(tensor->tensor_content(), 0, 228 new_num_values * sizeof(T), 229 reinterpret_cast<char*>(tmp.data())); 230 tensor->clear_tensor_content(); 231 const T* begin = tmp.begin(); 232 const T* end = tmp.end(); 233 TypeHelper::AddValues(begin, end, tensor); 234 } else { 235 // Copy and cast, one byte at a time. 236 for (int64 i = 0; i < new_num_values; ++i) { 237 char c = tensor->tensor_content()[i]; 238 TypeHelper::AddValue(static_cast<T>(c), tensor); 239 } 240 tensor->clear_tensor_content(); 241 } 242 return true; 243 } 244 245 template <typename T> 246 inline bool PackedValuesNotEqual(T a, T b) { 247 return a != b; 248 } 249 template <> 250 inline bool PackedValuesNotEqual(float a, float b) { 251 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b); 252 } 253 template <> 254 inline bool PackedValuesNotEqual(double a, double b) { 255 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b); 256 } 257 template <typename RealType> 258 inline bool PackedValuesNotEqual(const std::complex<RealType>& a, 259 const std::complex<RealType>& b) { 260 return PackedValuesNotEqual(a.real(), b.real()) || 261 PackedValuesNotEqual(a.imag(), b.imag()); 262 } 263 264 template <typename T> 265 bool CompressRepeatedField(float min_compression_ratio, 266 const TensorShape& shape, TensorProto* tensor) { 267 using TypeHelper = internal::TensorProtoHelper<T>; 268 using FieldType = typename internal::TensorProtoHelper<T>::FieldType; 269 const int64 num_tensor_values = shape.num_elements(); 270 // Notice that for complex types the tensor is stored as an array of up to 271 // 2 * num_tensor_values real values (real and imaginary parts), possibly 272 // truncated. 273 const int64 num_proto_values = TypeHelper::NumValues(*tensor); 274 if (num_proto_values != num_tensor_values) { 275 // Already compressed or invalid. 276 return false; 277 } 278 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor); 279 int64 last_index = 0; 280 for (int64 i = num_proto_values - 2; i >= 0 && last_index == 0; --i) { 281 const T cur_value = TypeHelper::GetValue(i, *tensor); 282 if (PackedValuesNotEqual(cur_value, last_value)) { 283 last_index = i + 1; 284 } 285 } 286 const int64 num_truncated_proto_values = last_index + 1; 287 const int64 num_bytes_as_field = 288 num_truncated_proto_values * sizeof(FieldType); 289 const int64 num_bytes_as_tensor_content = num_tensor_values * sizeof(T); 290 const int64 num_bytes_before = num_proto_values * sizeof(FieldType); 291 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) > 292 static_cast<int64>(num_bytes_before / min_compression_ratio)) { 293 return false; 294 } 295 if (num_bytes_as_field <= num_bytes_as_tensor_content) { 296 TypeHelper::Truncate(num_truncated_proto_values, tensor); 297 } else { 298 gtl::InlinedVector<T, 64> tmp(num_tensor_values); 299 TypeHelper::CopyValues(tmp.begin(), *tensor); 300 TypeHelper::Truncate(0, tensor); 301 port::CopyFromArray(tensor->mutable_tensor_content(), 302 reinterpret_cast<const char*>(tmp.data()), 303 num_bytes_as_tensor_content); 304 } 305 return true; 306 } 307 308 template <typename T> 309 bool CompressTensorProtoInPlaceImpl(int64 min_num_elements, 310 float min_compression_ratio, 311 TensorProto* tensor) { 312 const TensorShape shape(tensor->tensor_shape()); 313 const int64 num_tensor_values = shape.num_elements(); 314 if (num_tensor_values < min_num_elements) { 315 return false; 316 } 317 if (tensor->tensor_content().empty()) { 318 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor); 319 } else { 320 return CompressTensorContent<T>(min_compression_ratio, shape, tensor); 321 } 322 return true; 323 } 324 325 } // namespace internal 326 327 #define HANDLE_COMPRESS_CASE(TF_TYPE) \ 328 case TF_TYPE: \ 329 return internal::CompressTensorProtoInPlaceImpl< \ 330 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \ 331 min_compression_ratio, tensor); \ 332 break 333 334 bool CompressTensorProtoInPlace(int64 min_num_elements, 335 float min_compression_ratio, 336 TensorProto* tensor) { 337 switch (tensor->dtype()) { 338 HANDLE_COMPRESS_CASE(DT_FLOAT); 339 HANDLE_COMPRESS_CASE(DT_DOUBLE); 340 HANDLE_COMPRESS_CASE(DT_COMPLEX64); 341 HANDLE_COMPRESS_CASE(DT_COMPLEX128); 342 HANDLE_COMPRESS_CASE(DT_UINT8); 343 HANDLE_COMPRESS_CASE(DT_INT8); 344 HANDLE_COMPRESS_CASE(DT_UINT16); 345 HANDLE_COMPRESS_CASE(DT_INT16); 346 HANDLE_COMPRESS_CASE(DT_UINT32); 347 HANDLE_COMPRESS_CASE(DT_INT32); 348 HANDLE_COMPRESS_CASE(DT_UINT64); 349 HANDLE_COMPRESS_CASE(DT_INT64); 350 HANDLE_COMPRESS_CASE(DT_BOOL); 351 HANDLE_COMPRESS_CASE(DT_QUINT8); 352 HANDLE_COMPRESS_CASE(DT_QINT8); 353 HANDLE_COMPRESS_CASE(DT_QUINT16); 354 HANDLE_COMPRESS_CASE(DT_QINT16); 355 HANDLE_COMPRESS_CASE(DT_QINT32); 356 HANDLE_COMPRESS_CASE(DT_HALF); 357 HANDLE_COMPRESS_CASE(DT_BFLOAT16); 358 default: 359 return false; 360 } 361 } 362 363 #undef HANDLE_COMPRESS_CASE 364 365 } // namespace tensor 366 } // namespace tensorflow 367