1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ 18 19 #include <cstdint> 20 #include <type_traits> 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/allocator.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/lib/core/refcount.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/core/stringpiece.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/mem.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 // Forward declarations. In particular, we forward declare protos so that their 39 // symbols can be removed from .so exports. 40 class AllocationDescription; 41 class Allocator; 42 class OpKernelContext; 43 class Tensor; 44 class TensorBuffer; 45 class TensorCApi; 46 class TensorDescription; 47 class TensorProto; 48 class Var; 49 50 namespace batch_util { 51 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); 52 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); 53 } // namespace batch_util 54 55 /// @ingroup core 56 /// Represents an n-dimensional array of values. 57 class Tensor { 58 public: 59 /// \brief Creates a 1-dimensional, 0-element float tensor. 60 /// 61 /// The returned Tensor is not a scalar (shape {}), but is instead 62 /// an empty one-dimensional Tensor (shape {0}, NumElements() == 63 /// 0). Since it has no elements, it does not need to be assigned a 64 /// value and is initialized by default (IsInitialized() is 65 /// true). If this is undesirable, consider creating a one-element 66 /// scalar which does require initialization: 67 /// 68 /// ```c++ 69 /// 70 /// Tensor(DT_FLOAT, TensorShape({})) 71 /// 72 /// ``` 73 Tensor(); 74 75 /// \brief Creates a Tensor of the given `type` and `shape`. If 76 /// LogMemory::IsEnabled() the allocation is logged as coming from 77 /// an unknown kernel and step. Calling the Tensor constructor 78 /// directly from within an Op is deprecated: use the 79 /// OpKernelConstruction/OpKernelContext allocate_* methods to 80 /// allocate a new tensor, which record the kernel and step. 81 /// 82 /// The underlying buffer is allocated using a `CPUAllocator`. 83 Tensor(DataType type, const TensorShape& shape); 84 85 /// \brief Creates a tensor with the input `type` and `shape`, using 86 /// the allocator `a` to allocate the underlying buffer. If 87 /// LogMemory::IsEnabled() the allocation is logged as coming from 88 /// an unknown kernel and step. Calling the Tensor constructor 89 /// directly from within an Op is deprecated: use the 90 /// OpKernelConstruction/OpKernelContext allocate_* methods to 91 /// allocate a new tensor, which record the kernel and step. 92 /// 93 /// `a` must outlive the lifetime of this Tensor. 94 Tensor(Allocator* a, DataType type, const TensorShape& shape); 95 96 /// \brief Creates a tensor with the input `type` and `shape`, using 97 /// the allocator `a` and the specified "allocation_attr" to 98 /// allocate the underlying buffer. If the kernel and step are known 99 /// allocation_attr.allocation_will_be_logged should be set to true 100 /// and LogMemory::RecordTensorAllocation should be called after the 101 /// tensor is constructed. Calling the Tensor constructor directly 102 /// from within an Op is deprecated: use the 103 /// OpKernelConstruction/OpKernelContext allocate_* methods to 104 /// allocate a new tensor, which record the kernel and step. 105 /// 106 /// `a` must outlive the lifetime of this Tensor. 107 Tensor(Allocator* a, DataType type, const TensorShape& shape, 108 const AllocationAttributes& allocation_attr); 109 110 /// \brief Creates an empty Tensor of the given data type. 111 /// 112 /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with 113 /// IsInitialized() returning True. See the Tensor() documentation 114 /// for details. 115 explicit Tensor(DataType type); 116 117 private: 118 // A tag type for selecting the `Tensor` constructor overload that creates a 119 // scalar tensor in host memory. 120 struct host_scalar_tag {}; 121 122 class HostScalarTensorBufferBase; 123 template <typename T> 124 struct ValueAndTensorBuffer; 125 126 // Creates a tensor with the given scalar `value` in CPU memory. 127 template <typename T> 128 Tensor(T value, host_scalar_tag tag); 129 130 public: 131 // A series of specialized constructors for scalar tensors in host memory. 132 // 133 // NOTE: The `Variant` host-scalar constructor is not defined, because Variant 134 // is implicitly constructible from many different types, and this causes 135 // ambiguities with some compilers. 136 explicit Tensor(float scalar_value) 137 : Tensor(scalar_value, host_scalar_tag{}) {} 138 explicit Tensor(double scalar_value) 139 : Tensor(scalar_value, host_scalar_tag{}) {} 140 explicit Tensor(int32 scalar_value) 141 : Tensor(scalar_value, host_scalar_tag{}) {} 142 explicit Tensor(uint32 scalar_value) 143 : Tensor(scalar_value, host_scalar_tag{}) {} 144 explicit Tensor(uint16 scalar_value) 145 : Tensor(scalar_value, host_scalar_tag{}) {} 146 explicit Tensor(uint8 scalar_value) 147 : Tensor(scalar_value, host_scalar_tag{}) {} 148 explicit Tensor(int16 scalar_value) 149 : Tensor(scalar_value, host_scalar_tag{}) {} 150 explicit Tensor(int8 scalar_value) 151 : Tensor(scalar_value, host_scalar_tag{}) {} 152 explicit Tensor(string scalar_value) 153 : Tensor(std::move(scalar_value), host_scalar_tag{}) {} 154 explicit Tensor(complex64 scalar_value) 155 : Tensor(scalar_value, host_scalar_tag{}) {} 156 explicit Tensor(complex128 scalar_value) 157 : Tensor(scalar_value, host_scalar_tag{}) {} 158 explicit Tensor(int64 scalar_value) 159 : Tensor(scalar_value, host_scalar_tag{}) {} 160 explicit Tensor(uint64 scalar_value) 161 : Tensor(scalar_value, host_scalar_tag{}) {} 162 explicit Tensor(bool scalar_value) 163 : Tensor(scalar_value, host_scalar_tag{}) {} 164 explicit Tensor(qint8 scalar_value) 165 : Tensor(scalar_value, host_scalar_tag{}) {} 166 explicit Tensor(quint8 scalar_value) 167 : Tensor(scalar_value, host_scalar_tag{}) {} 168 explicit Tensor(qint16 scalar_value) 169 : Tensor(scalar_value, host_scalar_tag{}) {} 170 explicit Tensor(quint16 scalar_value) 171 : Tensor(scalar_value, host_scalar_tag{}) {} 172 explicit Tensor(qint32 scalar_value) 173 : Tensor(scalar_value, host_scalar_tag{}) {} 174 explicit Tensor(bfloat16 scalar_value) 175 : Tensor(scalar_value, host_scalar_tag{}) {} 176 explicit Tensor(Eigen::half scalar_value) 177 : Tensor(scalar_value, host_scalar_tag{}) {} 178 explicit Tensor(ResourceHandle scalar_value) 179 : Tensor(std::move(scalar_value), host_scalar_tag{}) {} 180 181 // NOTE: The `const char*` host-scalar constructor is provided as a 182 // convenience because otherwise passing a string literal would surprisingly 183 // construct a DT_BOOL tensor. 184 explicit Tensor(const char* scalar_value) 185 : Tensor(string(scalar_value), host_scalar_tag{}) {} 186 187 /// Copy constructor. 188 Tensor(const Tensor& other); 189 190 /// \brief Move constructor. After this call, <other> is safely destructible 191 /// and can be assigned to, but other calls on it (e.g. shape manipulation) 192 /// are not valid. 193 Tensor(Tensor&& other); 194 195 ~Tensor(); 196 197 /// Returns the data type. 198 DataType dtype() const { return shape_.data_type(); } 199 200 /// Returns the shape of the tensor. 201 const TensorShape& shape() const { return shape_; } 202 203 /// \brief Convenience accessor for the tensor shape. 204 /// 205 /// For all shape accessors, see comments for relevant methods of 206 /// `TensorShape` in `tensor_shape.h`. 207 int dims() const { return shape().dims(); } 208 209 /// Convenience accessor for the tensor shape. 210 int64 dim_size(int d) const { return shape().dim_size(d); } 211 212 /// Convenience accessor for the tensor shape. 213 int64 NumElements() const { return shape().num_elements(); } 214 215 bool IsSameSize(const Tensor& b) const { 216 return shape().IsSameSize(b.shape()); 217 } 218 219 // True iff the two tensors use the same underlying refcounted storage 220 bool SharesBufferWith(const Tensor& b) const; 221 222 /// \brief If necessary, has this Tensor been initialized? 223 /// 224 /// Zero-element Tensors are always considered initialized, even if they 225 /// have never been assigned to and do not have any memory allocated. 226 bool IsInitialized() const; 227 228 /// Returns the estimated memory usage of this tensor. 229 size_t TotalBytes() const; 230 231 // Returns the size of allocated memory for this tensor. 232 size_t AllocatedBytes() const; 233 234 /// Returns true iff this tensor is aligned. 235 bool IsAligned() const { 236 #if EIGEN_MAX_ALIGN_BYTES == 0 237 return true; 238 #else 239 void* ptr = base<void>(); 240 return reinterpret_cast<intptr_t>(ptr) % EIGEN_MAX_ALIGN_BYTES == 0; 241 #endif 242 } 243 244 /// Assign operator. This tensor shares other's underlying storage. 245 Tensor& operator=(const Tensor& other) { 246 CopyFromInternal(other, other.shape()); 247 return *this; 248 } 249 250 /// Move operator. See move constructor for details. 251 Tensor& operator=(Tensor&& other); 252 253 /// \brief Copy the other tensor into this tensor and reshape it. 254 /// 255 /// This tensor shares other's underlying storage. Returns `true` 256 /// iff `other.shape()` has the same number of elements of the given 257 /// `shape`. 258 bool CopyFrom(const Tensor& other, 259 const TensorShape& shape) TF_MUST_USE_RESULT { 260 if (other.NumElements() != shape.num_elements()) return false; 261 CopyFromInternal(other, shape); 262 return true; 263 } 264 265 /// \brief Slice this tensor along the 1st dimension. 266 267 /// I.e., the returned tensor satisfies 268 /// returned[i, ...] == this[dim0_start + i, ...]. 269 /// The returned tensor shares the underlying tensor buffer with this 270 /// tensor. 271 /// 272 /// NOTE: The returned tensor may not satisfy the same alignment 273 /// requirement as this tensor depending on the shape. The caller 274 /// must check the returned tensor's alignment before calling certain 275 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). 276 /// 277 /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor 278 /// also with N dimensions. If you want to select a sub tensor, see SubSlice. 279 /// 280 /// REQUIRES: `dims()` >= 1 281 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)` 282 Tensor Slice(int64 dim0_start, int64 dim0_limit) const; 283 284 /// \brief Select a subslice from this tensor along the 1st dimension. 285 /// 286 /// When fed with an N-dimensional tensor, this method returns a tensor with 287 /// N-1 dimensions, where the returned tensor is a subslice of the input 288 /// tensor along the first dimension. The N-1 dimensions of the returned 289 /// tensor are the last N-1 dimensions of the input tensor. 290 /// 291 /// NOTE: The returned tensor may not satisfy the same alignment 292 /// requirement as this tensor depending on the shape. The caller 293 /// must check the returned tensor's alignment before calling certain 294 /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). 295 /// 296 /// REQUIRES: `dims()` >= 1 297 /// REQUIRES: `0 <= dim0_start < dim_size(0)` 298 Tensor SubSlice(int64 index) const; 299 300 /// \brief Parse `other` and construct the tensor. 301 302 /// Returns `true` iff the parsing succeeds. If the parsing fails, 303 /// the state of `*this` is unchanged. 304 bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT; 305 bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT; 306 307 /// \brief Fills in `proto` with `*this` tensor's content. 308 /// 309 /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while 310 /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()` 311 /// in a compact form. 312 void AsProtoField(TensorProto* proto) const; 313 void AsProtoTensorContent(TensorProto* proto) const; 314 315 /// \brief Return the tensor data as an `Eigen::Tensor` with the type and 316 /// sizes of this `Tensor`. 317 /// 318 /// Use these methods when you know the data type and the number of 319 /// dimensions of the Tensor and you want an `Eigen::Tensor` 320 /// automatically sized to the `Tensor` sizes. The implementation check 321 /// fails if either type or sizes mismatch. 322 /// 323 /// Example: 324 /// 325 /// ```c++ 326 /// 327 /// typedef float T; 328 /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...); 329 /// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5. 330 /// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5. 331 /// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D. 332 /// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D. 333 /// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch. 334 /// 335 /// ``` 336 template <typename T> 337 typename TTypes<T>::Vec vec() { 338 return tensor<T, 1>(); 339 } 340 341 template <typename T> 342 typename TTypes<T>::Matrix matrix() { 343 return tensor<T, 2>(); 344 } 345 346 template <typename T, size_t NDIMS> 347 typename TTypes<T, NDIMS>::Tensor tensor(); 348 349 /// \brief Return the tensor data to an `Eigen::Tensor` with the 350 /// same size but a bitwise cast to the specified dtype `T`. 351 /// 352 /// Using a bitcast is useful for move and copy operations. 353 /// NOTE: this is the same as `tensor()` except a bitcast is allowed. 354 template <typename T, size_t NDIMS> 355 typename TTypes<T, NDIMS>::Tensor bit_casted_tensor(); 356 357 /// \brief Return the tensor data to an `Eigen::Tensor` with the 358 /// last dimension elements converted into single elements of a larger type. 359 /// 360 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 361 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of 362 /// the original element type * num elements in the original last dimension. 363 /// NDIMS should be 1 less than the original number of dimensions. 364 template <typename T, size_t NDIMS> 365 typename TTypes<T, NDIMS>::Tensor reinterpret_last_dimension(); 366 367 /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a 368 /// specified shape. 369 /// 370 /// These methods allow you to access the data with the dimensions 371 /// and sizes of your choice. You do not need to know the number of 372 /// dimensions of the Tensor to call them. However, they `CHECK` that 373 /// the type matches and the dimensions requested creates an 374 /// `Eigen::Tensor` with the same number of elements as the tensor. 375 /// 376 /// Example: 377 /// 378 /// ```c++ 379 /// 380 /// typedef float T; 381 /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...); 382 /// // 1D Eigen::Tensor, size 60: 383 /// auto flat = my_ten.flat<T>(); 384 /// // 2D Eigen::Tensor 12 x 5: 385 /// auto inner = my_ten.flat_inner_dims<T>(); 386 /// // 2D Eigen::Tensor 4 x 15: 387 /// auto outer = my_ten.shaped<T, 2>({4, 15}); 388 /// // CHECK fails, bad num elements: 389 /// auto outer = my_ten.shaped<T, 2>({4, 8}); 390 /// // 3D Eigen::Tensor 6 x 5 x 2: 391 /// auto weird = my_ten.shaped<T, 3>({6, 5, 2}); 392 /// // CHECK fails, type mismatch: 393 /// auto bad = my_ten.flat<int32>(); 394 /// 395 /// ``` 396 template <typename T> 397 typename TTypes<T>::Flat flat() { 398 return shaped<T, 1>({NumElements()}); 399 } 400 401 template <typename T> 402 typename TTypes<T>::UnalignedFlat unaligned_flat() { 403 return unaligned_shaped<T, 1>({NumElements()}); 404 } 405 406 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all 407 /// Tensor dimensions but the last NDIMS-1 into the first dimension of the 408 /// result. If NDIMS > dims() then leading dimensions of size 1 will be 409 /// added to make the output rank NDIMS. 410 template <typename T, size_t NDIMS = 2> 411 typename TTypes<T, NDIMS>::Tensor flat_inner_dims(); 412 413 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all 414 /// Tensor dimensions but the first NDIMS-1 into the last dimension of the 415 /// result. If NDIMS > dims() then trailing dimensions of size 1 will be 416 /// added to make the output rank NDIMS. 417 template <typename T, size_t NDIMS = 2> 418 typename TTypes<T, NDIMS>::Tensor flat_outer_dims(); 419 420 /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the 421 /// first 'begin' Tensor dimensions into the first dimension of the result and 422 /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last 423 /// dimension of the result. If 'begin' < 0 then the |'begin'| leading 424 /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then 425 /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added. 426 template <typename T, size_t NDIMS = 3> 427 typename TTypes<T, NDIMS>::Tensor flat_inner_outer_dims(int64 begin); 428 429 template <typename T, size_t NDIMS> 430 typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes); 431 432 /// \brief Return the tensor data to an `Eigen::Tensor` with the new 433 /// shape specified in `new_sizes` and cast to a new dtype `T`. 434 /// 435 /// Using a bitcast is useful for move and copy operations. 436 /// The allowed bitcast is the only difference from `shaped()`. 437 template <typename T, size_t NDIMS> 438 typename TTypes<T, NDIMS>::Tensor bit_casted_shaped( 439 gtl::ArraySlice<int64> new_sizes); 440 441 template <typename T, size_t NDIMS> 442 typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped( 443 gtl::ArraySlice<int64> new_sizes); 444 445 /// \brief Return the Tensor data as a `TensorMap` of fixed size 1: 446 /// `TensorMap<TensorFixedSize<T, 1>>`. 447 448 /// Using `scalar()` allows the compiler to perform optimizations as 449 /// the size of the tensor is known at compile time. 450 template <typename T> 451 typename TTypes<T>::Scalar scalar(); 452 453 /// Const versions of all the methods above. 454 template <typename T> 455 typename TTypes<T>::ConstVec vec() const { 456 return tensor<T, 1>(); 457 } 458 459 template <typename T> 460 typename TTypes<T>::ConstMatrix matrix() const { 461 return tensor<T, 2>(); 462 } 463 464 template <typename T, size_t NDIMS> 465 typename TTypes<T, NDIMS>::ConstTensor tensor() const; 466 467 /// \brief Return the tensor data to an `Eigen::Tensor` with the 468 /// same size but a bitwise cast to the specified dtype `T`. 469 /// 470 /// Using a bitcast is useful for move and copy operations. 471 /// NOTE: this is the same as `tensor()` except a bitcast is allowed. 472 template <typename T, size_t NDIMS> 473 typename TTypes<T, NDIMS>::ConstTensor bit_casted_tensor() const; 474 475 /// \brief Return the tensor data to an `Eigen::Tensor` with the 476 /// last dimension elements converted into single elements of a larger type. 477 /// 478 /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 479 /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of 480 /// the original element type * num elements in the original last dimension. 481 /// NDIMS should be 1 less than the original number of dimensions. 482 template <typename T, size_t NDIMS> 483 typename TTypes<T, NDIMS>::ConstTensor reinterpret_last_dimension() const; 484 485 template <typename T> 486 typename TTypes<T>::ConstFlat flat() const { 487 return shaped<T, 1>({NumElements()}); 488 } 489 490 template <typename T> 491 typename TTypes<T>::UnalignedConstFlat unaligned_flat() const { 492 return unaligned_shaped<T, 1>({NumElements()}); 493 } 494 495 template <typename T, size_t NDIMS> 496 typename TTypes<T, NDIMS>::ConstTensor shaped( 497 gtl::ArraySlice<int64> new_sizes) const; 498 499 /// \brief Return the tensor data to an `Eigen::Tensor` with the new 500 /// shape specified in `new_sizes` and cast to a new dtype `T`. 501 /// 502 /// Using a bitcast is useful for move and copy operations. 503 /// The allowed bitcast is the only difference from `shaped()`. 504 template <typename T, size_t NDIMS> 505 typename TTypes<T, NDIMS>::ConstTensor bit_casted_shaped( 506 gtl::ArraySlice<int64> new_sizes) const; 507 508 template <typename T, size_t NDIMS> 509 typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped( 510 gtl::ArraySlice<int64> new_sizes) const; 511 512 template <typename T> 513 typename TTypes<T>::ConstScalar scalar() const; 514 515 template <typename T, size_t NDIMS = 2> 516 typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const; 517 518 template <typename T, size_t NDIMS = 2> 519 typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const; 520 521 template <typename T, size_t NDIMS = 3> 522 typename TTypes<T, NDIMS>::ConstTensor flat_inner_outer_dims( 523 int64 begin) const; 524 525 /// Render the first `max_entries` values in `*this` into a string. 526 string SummarizeValue(int64 max_entries, bool print_v2 = false) const; 527 528 /// A human-readable summary of the tensor suitable for debugging. 529 // `num_values` is the number of actual data values in the tensor 530 // included in the message. If the tensor might be resident in 531 // GPU/TPU memory use DeviceSafeDebugString instead. 532 string DebugString(int num_values) const; 533 string DebugString() const { return DebugString(3); } 534 535 // Variant of DebugString() that should be used for possibly non-CPU tensors. 536 // If the tensor is not resident on CPU, we can't read its values as 537 // DebugString() does. 538 string DeviceSafeDebugString() const; 539 540 /// Fill in the `TensorDescription` proto with metadata about the 541 /// tensor that is useful for monitoring and debugging. 542 void FillDescription(TensorDescription* description) const; 543 544 /// \brief Returns a `StringPiece` mapping the current tensor's buffer. 545 /// 546 /// The returned `StringPiece` may point to memory location on devices 547 /// that the CPU cannot address directly. 548 /// 549 /// NOTE: The underlying tensor buffer is refcounted, so the lifetime 550 /// of the contents mapped by the `StringPiece` matches the lifetime of 551 /// the buffer; callers should arrange to make sure the buffer does 552 /// not get destroyed while the `StringPiece` is still used. 553 /// 554 /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`. 555 StringPiece tensor_data() const; 556 557 /// Copy the other tensor into this tensor, reshape it and reinterpret the 558 /// buffer's datatype. If Status::OK() is returned, the two tensors now share 559 /// the same underlying storage. 560 /// 561 /// This call requires that the `other` tensor and the given type and shape 562 /// are "compatible" (i.e. they occupy the same number of bytes). 563 /// 564 /// Specifically: 565 /// 566 /// shape.num_elements() * DataTypeSize(type) 567 /// 568 /// must equal 569 /// 570 /// other.num_elements() * DataTypeSize(other.dtype()) 571 /// 572 /// In addition, this function requires: 573 /// * DataTypeSize(other.dtype()) != 0 574 /// * DataTypeSize(type) != 0 575 /// 576 /// If any of the requirements are not met, errors::InvalidArgument is 577 /// returned. 578 Status BitcastFrom(const Tensor& other, DataType dtype, 579 const TensorShape& shape); 580 581 /// Like BitcastFrom, but CHECK fails if any preconditions are not met. 582 /// 583 /// Deprecated. Use BitcastFrom instead and check the returned Status. 584 void UnsafeCopyFromInternal(const Tensor& other, DataType dtype, 585 const TensorShape& shape) { 586 TF_CHECK_OK(BitcastFrom(other, dtype, shape)); 587 } 588 589 private: 590 // Returns true if the refcount on buf_ and any possible underlying root 591 // buffer is one. 592 bool RefCountIsOne() const; 593 void CheckType(DataType expected_dtype) const; 594 void CheckTypeAndIsAligned(DataType expected_dtype) const; 595 void CheckIsAlignedAndSingleElement() const; 596 void set_dtype(DataType t) { shape_.set_data_type(t); } 597 598 // TensorShape's InlineVector. 599 static gtl::InlinedVector<int64, 4> ComputeFlatInnerDims( 600 gtl::ArraySlice<int64> orig, int64 num_out_dims); 601 static gtl::InlinedVector<int64, 4> ComputeFlatOuterDims( 602 gtl::ArraySlice<int64> orig, int64 num_out_dims); 603 604 TensorShape shape_; 605 TensorBuffer* buf_; 606 607 friend class DMAHelper; 608 friend class TensorCApi; 609 friend class TensorReference; // For access to buf_ 610 friend class VariableOp; // For access to set_shape 611 friend class AutoReloadVariableOp; // For access to set_shape 612 friend class TensorTestHelper; // For access to set_shape 613 friend class CastOpBase; // For access to set_dtype; 614 friend class OpKernelContext; // For access to RefCountIsOne(). 615 friend class ScopedAllocator; // For access to buf_. 616 friend class XlaTensor; // For access to RefCountIsOne(). 617 friend class XlaTensorBuffer; // For access to the private constructor taking 618 // the buffer 619 friend class Var; 620 template <typename Device, typename T> 621 friend class AssignVariableOp; // For access to RefCountIsOne(). 622 template <typename Device, typename T> 623 friend Status PrepareToUpdateVariable( 624 OpKernelContext* ctx, Tensor* tensor, 625 bool copy_on_read_mode); // For access to RefCountIsOne(). 626 template <typename Device, typename T> 627 friend Status EnsureSparseVariableAccess( 628 OpKernelContext* ctx, Var* var); // For access to RefCountIsOne(). 629 friend Status batch_util::CopyElementToSlice( 630 Tensor element, Tensor* parent, 631 int64 index); // For access to RefCountIsOne(). 632 friend Status batch_util::MaybeMoveSliceToElement( 633 Tensor* parent, Tensor* element, 634 int64 index); // For access to RefCountIsOne(). 635 636 friend class NumpyTensorBuffer; // For access to the private constructor 637 // taking the buffer. 638 639 // Creates a tensor with the input datatype, shape and buf. 640 // 641 // Acquires a ref on buf that belongs to this Tensor. 642 Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); 643 644 bool CanUseDMA() const; 645 646 // Only needed by variable op to set the shape of an uninitialized 647 // Tensor. 648 // TODO: Remove this when we have a better story for detecting 649 // uninitialized tensors. 650 void set_shape(const TensorShape& shape) { 651 DataType dt = dtype(); 652 shape_ = shape; 653 set_dtype(dt); 654 } 655 656 void CopyFromInternal(const Tensor& other, const TensorShape& shape); 657 658 template <typename T> 659 T* base() const; 660 661 template <size_t NDIMS> 662 void FillDimsAndValidateCompatibleShape( 663 gtl::ArraySlice<int64> new_sizes, 664 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; 665 666 template <typename T, size_t NDIMS> 667 void FillDimsAndValidateCompatibleShape( 668 gtl::ArraySlice<int64> new_sizes, 669 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; 670 }; 671 672 // Implementation details 673 674 // START_SKIP_DOXYGEN 675 676 // Interface to access the raw ref-counted data buffer. 677 class TensorBuffer : public core::RefCounted { 678 public: 679 explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {} 680 ~TensorBuffer() override {} 681 682 // data() points to a memory region of size() bytes. 683 // 684 // NOTE(mrry): The `data()` method is not virtual for performance reasons. 685 // It can be called multiple times when the contents of a `Tensor` are 686 // accessed, and so making it non-virtual allows the body to be inlined. 687 void* data() const { return data_; } 688 virtual size_t size() const = 0; 689 690 // If this TensorBuffer is sub-buffer of another TensorBuffer, 691 // returns that TensorBuffer. Otherwise, returns this. 692 virtual TensorBuffer* root_buffer() = 0; 693 694 // Fill metadata about the allocation into the proto. 695 virtual void FillAllocationDescription( 696 AllocationDescription* proto) const = 0; 697 698 template <typename T> 699 T* base() const { 700 return reinterpret_cast<T*>(data()); 701 } 702 703 // Whether this TensorBuffer owns the underlying memory. 704 virtual bool OwnsMemory() const { return true; } 705 706 private: 707 void* const data_; 708 }; 709 710 template <typename T> 711 T* Tensor::base() const { 712 return buf_ == nullptr ? nullptr : buf_->base<T>(); 713 } 714 715 template <typename T, size_t NDIMS> 716 typename TTypes<T, NDIMS>::Tensor Tensor::tensor() { 717 CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); 718 return typename TTypes<T, NDIMS>::Tensor(base<T>(), 719 shape().AsEigenDSizes<NDIMS>()); 720 } 721 722 template <typename T, size_t NDIMS> 723 typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const { 724 CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); 725 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), 726 shape().AsEigenDSizes<NDIMS>()); 727 } 728 729 template <typename T, size_t NDIMS> 730 typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_tensor() { 731 CHECK(IsAligned()); 732 return typename TTypes<T, NDIMS>::Tensor(base<T>(), 733 shape().AsEigenDSizes<NDIMS>()); 734 } 735 736 template <typename T, size_t NDIMS> 737 typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_tensor() const { 738 CHECK(IsAligned()); 739 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), 740 shape().AsEigenDSizes<NDIMS>()); 741 } 742 743 template <typename T, size_t NDIMS> 744 typename TTypes<T, NDIMS>::Tensor Tensor::reinterpret_last_dimension() { 745 if (NDIMS == dims()) { 746 return tensor<T, NDIMS>(); 747 } 748 CHECK(IsAligned()); 749 CHECK_EQ(NDIMS, dims() - 1); 750 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); 751 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 752 for (int d = 0; d < NDIMS; ++d) { 753 dims[d] = shape_.dim_sizes()[d]; 754 } 755 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims); 756 } 757 758 template <typename T, size_t NDIMS> 759 typename TTypes<T, NDIMS>::ConstTensor Tensor::reinterpret_last_dimension() 760 const { 761 if (NDIMS == dims()) { 762 return tensor<T, NDIMS>(); 763 } 764 CHECK(IsAligned()); 765 CHECK_EQ(NDIMS, dims() - 1); 766 CHECK_EQ(sizeof(T), shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); 767 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 768 for (int d = 0; d < NDIMS; ++d) { 769 dims[d] = shape_.dim_sizes()[d]; 770 } 771 return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), dims); 772 } 773 774 template <size_t NDIMS> 775 void Tensor::FillDimsAndValidateCompatibleShape( 776 gtl::ArraySlice<int64> new_sizes, 777 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const { 778 CHECK_EQ(NDIMS, new_sizes.size()); 779 int64 new_num_elements = 1; 780 for (size_t d = 0; d < NDIMS; d++) { 781 new_num_elements *= new_sizes[d]; 782 (*dims)[d] = new_sizes[d]; 783 } 784 CHECK_EQ(new_num_elements, NumElements()); 785 } 786 787 template <typename T, size_t NDIMS> 788 void Tensor::FillDimsAndValidateCompatibleShape( 789 gtl::ArraySlice<int64> new_sizes, 790 Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const { 791 CHECK_EQ(NDIMS, new_sizes.size()); 792 int64 new_num_elements = 1; 793 for (size_t d = 0; d < NDIMS; d++) { 794 new_num_elements *= new_sizes[d]; 795 (*dims)[d] = new_sizes[d]; 796 } 797 const int element_size = DataTypeSize(BaseType(dtype())); 798 if (element_size > 0) { 799 CHECK_EQ(new_num_elements * sizeof(T), NumElements() * element_size); 800 } else { 801 // DataTypeSize() returns 0 for some data types. In this case, assume that T 802 // has the same size as the buffer type. 803 // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD 804 // types, then we should check DataTypeToEnum<T>::v() == dtype(). Or simply 805 // check if `element_size > 0` to err when bit cast is attempted on Tensor 806 // of unknown data type size. 807 CHECK_EQ(new_num_elements, NumElements()); 808 } 809 } 810 811 template <typename T, size_t NDIMS> 812 typename TTypes<T, NDIMS>::Tensor Tensor::shaped( 813 gtl::ArraySlice<int64> new_sizes) { 814 CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); 815 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 816 FillDimsAndValidateCompatibleShape(new_sizes, &dims); 817 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims); 818 } 819 820 template <typename T, size_t NDIMS> 821 typename TTypes<T, NDIMS>::Tensor Tensor::bit_casted_shaped( 822 gtl::ArraySlice<int64> new_sizes) { 823 CHECK(IsAligned()); 824 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 825 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims); 826 return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims); 827 } 828 829 template <typename T, size_t NDIMS> 830 typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped( 831 gtl::ArraySlice<int64> new_sizes) { 832 CheckType(DataTypeToEnum<T>::v()); 833 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 834 FillDimsAndValidateCompatibleShape(new_sizes, &dims); 835 return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims); 836 } 837 838 template <typename T, size_t NDIMS> 839 typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped( 840 gtl::ArraySlice<int64> new_sizes) const { 841 CheckType(DataTypeToEnum<T>::v()); 842 CHECK(IsAligned()); 843 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 844 FillDimsAndValidateCompatibleShape(new_sizes, &dims); 845 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims); 846 } 847 848 template <typename T, size_t NDIMS> 849 typename TTypes<T, NDIMS>::ConstTensor Tensor::bit_casted_shaped( 850 gtl::ArraySlice<int64> new_sizes) const { 851 CHECK(IsAligned()); 852 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 853 FillDimsAndValidateCompatibleShape<T>(new_sizes, &dims); 854 return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims); 855 } 856 857 template <typename T, size_t NDIMS> 858 typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped( 859 gtl::ArraySlice<int64> new_sizes) const { 860 CheckType(DataTypeToEnum<T>::v()); 861 Eigen::array<Eigen::DenseIndex, NDIMS> dims; 862 FillDimsAndValidateCompatibleShape(new_sizes, &dims); 863 return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims); 864 } 865 866 template <typename T> 867 typename TTypes<T>::Scalar Tensor::scalar() { 868 CheckIsAlignedAndSingleElement(); 869 return typename TTypes<T>::Scalar(base<T>()); 870 } 871 872 template <typename T> 873 typename TTypes<T>::ConstScalar Tensor::scalar() const { 874 CheckIsAlignedAndSingleElement(); 875 return typename TTypes<T>::ConstScalar(base<T>()); 876 } 877 878 template <typename T, size_t NDIMS> 879 typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() { 880 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); 881 } 882 883 template <typename T, size_t NDIMS> 884 typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() { 885 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); 886 } 887 888 template <typename T, size_t NDIMS> 889 typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_outer_dims(int64 begin) { 890 gtl::InlinedVector<int64, 4> flat_outer = 891 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); 892 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS)); 893 } 894 895 template <typename T, size_t NDIMS> 896 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const { 897 return shaped<T, NDIMS>(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); 898 } 899 900 template <typename T, size_t NDIMS> 901 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const { 902 return shaped<T, NDIMS>(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); 903 } 904 905 template <typename T, size_t NDIMS> 906 typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_outer_dims( 907 int64 begin) const { 908 gtl::InlinedVector<int64, 4> flat_outer = 909 ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); 910 return shaped<T, NDIMS>(ComputeFlatInnerDims(flat_outer, NDIMS)); 911 } 912 913 inline Tensor::Tensor(const Tensor& other) 914 : shape_(other.shape()), buf_(other.buf_) { 915 if (buf_) buf_->Ref(); 916 } 917 918 inline Tensor::Tensor(Tensor&& other) 919 : shape_(std::move(other.shape())), buf_(other.buf_) { 920 other.buf_ = nullptr; 921 } 922 923 class Tensor::HostScalarTensorBufferBase : public TensorBuffer { 924 public: 925 using TensorBuffer::TensorBuffer; 926 void FillAllocationDescription(AllocationDescription* proto) const final; 927 }; 928 929 // A packed representation for a single scalar value of type `T`, and a 930 // `TensorBuffer` implementation that describes (and manages the lifetime of) 931 // that value. 932 template <typename T> 933 struct Tensor::ValueAndTensorBuffer { 934 class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase { 935 public: 936 HostScalarTensorBuffer(void* data) : HostScalarTensorBufferBase(data) {} 937 size_t size() const final { return sizeof(T); } 938 TensorBuffer* root_buffer() final { return this; } 939 940 // Override `operator delete` so that calling `delete this` in 941 // `core::Refcounted::Unref()` for an object of this type will free 942 // the enclosing `ValueAndTensorBuffer` for the tensor buffer. 943 // 944 // NOTE(mrry): The definition of this method must be outside the class 945 // definition in order to satisfy some compilers. 946 static void operator delete(void* ptr); 947 948 static void operator delete(void*, void*) { 949 // Some compilers require an overridden class-specific deallocation 950 // function, which will be called if placement `new` throws an 951 // exception. 952 } 953 954 private: 955 ~HostScalarTensorBuffer() override { static_cast<T*>(data())->~T(); } 956 }; 957 958 T value; 959 HostScalarTensorBuffer tensor_buffer; 960 }; 961 962 /* static */ 963 template <typename T> 964 void Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer::operator delete( 965 void* ptr) { 966 // Use a dummy object to compute to offset of 967 // `ValueAndTensorBuffer::tensor_buffer`, because `offsetof()` is not 968 // necessarily defined on this non-POD type (until C++17). 969 // 970 // NOTE(mrry): Using `sizeof(Tensor::ValueAndTensorBuffer<T>)` here requires 971 // us to define this method outside the class definition, so that it is not 972 // considered an incomplete type. 973 typename std::aligned_storage<sizeof(Tensor::ValueAndTensorBuffer<T>), 974 alignof(Tensor::ValueAndTensorBuffer<T>)>::type 975 dummy_storage_; 976 Tensor::ValueAndTensorBuffer<T>* dummy_object = 977 reinterpret_cast<Tensor::ValueAndTensorBuffer<T>*>(&dummy_storage_); 978 intptr_t offset = reinterpret_cast<intptr_t>(&dummy_object->tensor_buffer) - 979 reinterpret_cast<intptr_t>(dummy_object); 980 981 port::AlignedFree(static_cast<char*>(ptr) - offset); 982 } 983 984 template <typename T> 985 Tensor::Tensor(T value, host_scalar_tag tag) { 986 auto* value_and_buf = static_cast<Tensor::ValueAndTensorBuffer<T>*>( 987 port::AlignedMalloc(sizeof(typename Tensor::ValueAndTensorBuffer<T>), 988 EIGEN_MAX_ALIGN_BYTES)); 989 new (&value_and_buf->value) T(std::move(value)); 990 new (&value_and_buf->tensor_buffer) 991 typename Tensor::ValueAndTensorBuffer<T>::HostScalarTensorBuffer( 992 value_and_buf); 993 buf_ = &value_and_buf->tensor_buffer; 994 set_dtype(DataTypeToEnum<T>::value); 995 } 996 997 inline Tensor& Tensor::operator=(Tensor&& other) { 998 // Avoid self-assignment, since we might destroy our underlying buffer. 999 if (&other != this) { 1000 shape_ = std::move(other.shape_); 1001 if (buf_) buf_->Unref(); 1002 buf_ = other.buf_; 1003 other.buf_ = nullptr; 1004 } 1005 return *this; 1006 } 1007 1008 // END_SKIP_DOXYGEN 1009 1010 } // namespace tensorflow 1011 1012 #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ 1013