Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implementation notes:
     17 //
     18 // Tensor.cc uses a few templated classes and structs to facilitate
     19 // implementation of the Tensor class.
     20 //
     21 // * Buffer<T>: provides the implementation for a typed array T[n].
     22 //   The array is allocated by the given allocator. It runs T's
     23 //   default constructors and destructors when T is not a simple type
     24 //   (e.g., string.), and skips them otherwise.
     25 //
     26 // * Helper<T>: provides various routines given type T.  The routines
     27 //   includes running the constructor and destructor of T[], encoding
     28 //   an decoding T[] into/from a Cord, etc.
     29 
     30 #include "tensorflow/core/framework/tensor.h"
     31 
     32 #include "tensorflow/core/framework/allocation_description.pb.h"
     33 #include "tensorflow/core/framework/log_memory.h"
     34 #include "tensorflow/core/framework/resource_handle.pb.h"
     35 #include "tensorflow/core/framework/tensor.pb.h"
     36 #include "tensorflow/core/framework/tensor_description.pb.h"
     37 #include "tensorflow/core/framework/type_traits.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/framework/variant.h"
     40 #include "tensorflow/core/framework/variant_encode_decode.h"
     41 #include "tensorflow/core/framework/variant_op_registry.h"
     42 #include "tensorflow/core/framework/variant_tensor_data.h"
     43 #include "tensorflow/core/lib/core/coding.h"
     44 #include "tensorflow/core/lib/core/errors.h"
     45 #include "tensorflow/core/lib/core/status.h"
     46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     47 #include "tensorflow/core/lib/gtl/stl_util.h"
     48 #include "tensorflow/core/lib/strings/str_util.h"
     49 #include "tensorflow/core/lib/strings/strcat.h"
     50 #include "tensorflow/core/platform/logging.h"
     51 #include "tensorflow/core/platform/macros.h"
     52 #include "tensorflow/core/platform/protobuf.h"
     53 #include "tensorflow/core/platform/tensor_coding.h"
     54 #include "tensorflow/core/platform/types.h"
     55 
     56 namespace tensorflow {
     57 
     58 // Allow Tensors to be stored inside Variants with automatic
     59 // encoding/decoding when those Variants are themselves being decoded
     60 // in a Tensor's FromProto.
     61 //
     62 // NOTE(mrry): The corresponding "copy function" registrations can be found in
     63 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
     64 // code).
     65 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
     66 
     67 namespace {
     68 
     69 // An un-templated base class for Buffer.
     70 class BufferBase : public TensorBuffer {
     71  public:
     72   explicit BufferBase(Allocator* alloc, void* data_ptr)
     73       : TensorBuffer(data_ptr), alloc_(alloc) {}
     74 
     75   TensorBuffer* root_buffer() override { return this; }
     76   void FillAllocationDescription(AllocationDescription* proto) const override {
     77     void* data_ptr = data();
     78     int64 rb = size();
     79     proto->set_requested_bytes(rb);
     80     proto->set_allocator_name(alloc_->Name());
     81     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
     82     if (alloc_->TracksAllocationSizes()) {
     83       int64 ab = alloc_->AllocatedSize(data_ptr);
     84       proto->set_allocated_bytes(ab);
     85       int64 id = alloc_->AllocationId(data_ptr);
     86       if (id > 0) {
     87         proto->set_allocation_id(id);
     88       }
     89       if (RefCountIsOne()) {
     90         proto->set_has_single_reference(true);
     91       }
     92     }
     93   }
     94 
     95  protected:
     96   void RecordDeallocation() {
     97     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
     98                                         alloc_->Name());
     99   }
    100 
    101   Allocator* const alloc_;
    102 };
    103 
    104 // Typed ref-counted buffer: T[n].
    105 template <typename T>
    106 class Buffer : public BufferBase {
    107  public:
    108   Buffer(Allocator* a, int64 n);
    109   Buffer(Allocator* a, int64 n, const AllocationAttributes& allocation_attr);
    110 
    111   size_t size() const override { return sizeof(T) * elem_; }
    112 
    113  private:
    114   T* data_;
    115   int64 elem_;
    116 
    117   ~Buffer() override;
    118 
    119   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
    120 };
    121 
    122 void LogUnexpectedSize(int64 actual, int64 expected) {
    123   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
    124 }
    125 
    126 // A set of helper functions depending on T.
    127 template <typename T>
    128 struct Helper {
    129   // By default, we assume T is a simple type (float, int32, etc.)
    130   static_assert(is_simple_type<T>::value, "T is not a simple type.");
    131   typedef protobuf::RepeatedField<T> RepeatedFieldType;
    132 
    133   // Encoder of simple type T to a string.  We do a copy.
    134   template <typename Destination>
    135   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
    136     DCHECK_EQ(in->size(), sizeof(T) * n);
    137     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
    138                            out);
    139   }
    140 
    141   // Decoder of simple type T. Copy the bytes from "in" into the
    142   // tensor buffer.
    143   template <typename Source>
    144   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
    145     if (in.size() != sizeof(T) * n) {
    146       LogUnexpectedSize(in.size(), sizeof(T) * n);
    147       return nullptr;
    148     }
    149     Buffer<T>* buf = new Buffer<T>(a, n);
    150     char* data = buf->template base<char>();
    151     if (data == nullptr) {
    152       buf->Unref();
    153       return nullptr;
    154     }
    155     port::CopyToArray(in, data);
    156     return buf;
    157   }
    158 
    159   // Memory usage.
    160   static int64 TotalBytes(TensorBuffer* in, int64 n) {
    161     DCHECK_EQ(in->size(), sizeof(T) * n);
    162     return in->size();
    163   }
    164 };
    165 
    166 // Helper specialization for string (the only non-simple type we
    167 // support).
    168 template <>
    169 struct Helper<string> {
    170   // Proto message uses RepeatedFieldType to hold repeated T.
    171   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
    172 
    173   // Encodes "n" elements of type string stored in "in" into Cord
    174   // "out", which is usually the TensorProto::tensor_content.
    175   template <typename Destination>
    176   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
    177     port::EncodeStringList(in->base<const string>(), n, out);
    178   }
    179 
    180   // Decodes "n" elements of type string from "in" and constructs a
    181   // buffer out of it. Returns nullptr if the decoding fails. "in" is
    182   // usually the TensorProto::tensor_content.
    183   template <typename Source>
    184   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
    185     Buffer<string>* buf = new Buffer<string>(a, n);
    186     string* strings = buf->template base<string>();
    187     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
    188       buf->Unref();
    189       return nullptr;
    190     }
    191     return buf;
    192   }
    193 
    194   // Returns the estimated memory usage of "n" elements of type T
    195   // stored in buffer "in".
    196   static int64 TotalBytes(TensorBuffer* in, int n) {
    197     int64 tot = in->size();
    198     DCHECK_EQ(tot, sizeof(string) * n);
    199     const string* p = in->base<const string>();
    200     for (int i = 0; i < n; ++i, ++p) tot += p->size();
    201     return tot;
    202   }
    203 };
    204 
    205 template <>
    206 struct Helper<ResourceHandle> {
    207   // Proto message uses RepeatedFieldType to hold repeated T.
    208   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
    209 
    210   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
    211   // "out", which is usually the TensorProto::tensor_content.
    212   template <typename Destination>
    213   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
    214     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
    215                              port::NewStringListEncoder(out));
    216   }
    217 
    218   // Decodes "n" elements of type string from "in" and constructs a
    219   // buffer out of it. Returns nullptr if the decoding fails. "in" is
    220   // usually the TensorProto::tensor_content.
    221   template <typename Source>
    222   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
    223     auto* buf = new Buffer<ResourceHandle>(a, n);
    224     ResourceHandle* ps = buf->template base<ResourceHandle>();
    225     if (ps == nullptr ||
    226         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
    227       buf->Unref();
    228       return nullptr;
    229     }
    230     return buf;
    231   }
    232 
    233   // Returns the estimated memory usage of "n" elements of type T
    234   // stored in buffer "in".
    235   static int64 TotalBytes(TensorBuffer* in, int n) {
    236     return n * sizeof(ResourceHandle);
    237   }
    238 };
    239 
    240 template <>
    241 struct Helper<Variant> {
    242   // Encodes "n" elements of type Variant stored in "in" into destination
    243   // "out", which is usually the TensorProto::tensor_content.
    244   template <typename Destination>
    245   static void Encode(TensorBuffer* in, int64 n, Destination* out) {
    246     EncodeVariantList(in->base<const Variant>(), n,
    247                       port::NewStringListEncoder(out));
    248   }
    249 
    250   // Decodes "n" elements of type Variant from "in" and constructs a
    251   // buffer out of it. Returns nullptr if the decoding fails. "in" is
    252   // usually the TensorProto::tensor_content.
    253   template <typename Source>
    254   static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
    255     auto* buf = new Buffer<Variant>(a, n);
    256     Variant* ps = buf->template base<Variant>();
    257     if (ps == nullptr ||
    258         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
    259       buf->Unref();
    260       return nullptr;
    261     }
    262     return buf;
    263   }
    264 
    265   // Returns the estimated memory usage of "n" elements of type T
    266   // stored in buffer "in".
    267   static int64 TotalBytes(TensorBuffer* in, int n) {
    268     return n * sizeof(Variant);
    269   }
    270 };
    271 
    272 template <typename T>
    273 struct ProtoHelper {};
    274 
    275 // For a C++ type "T" (float, double, int32, etc.), the repeated field
    276 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
    277 // int32, string, etc) in the TensorProto is used for serializing the
    278 // tensor of type "T".
    279 #define PROTO_TRAITS(T, F, N)                                          \
    280   template <>                                                          \
    281   struct ProtoHelper<T> {                                              \
    282     typedef Helper<F>::RepeatedFieldType FieldType;                    \
    283     static FieldType::const_iterator Begin(const TensorProto& proto) { \
    284       return proto.N##_val().begin();                                  \
    285     }                                                                  \
    286     static size_t NumElements(const TensorProto& proto) {              \
    287       return proto.N##_val().size();                                   \
    288     }                                                                  \
    289     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
    290       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
    291       proto->mutable_##N##_val()->Swap(&copy);                         \
    292     }                                                                  \
    293   };
    294 PROTO_TRAITS(float, float, float);
    295 PROTO_TRAITS(double, double, double);
    296 PROTO_TRAITS(int32, int32, int);
    297 PROTO_TRAITS(uint8, int32, int);
    298 PROTO_TRAITS(uint16, int32, int);
    299 PROTO_TRAITS(uint32, uint32, uint32);
    300 PROTO_TRAITS(int16, int32, int);
    301 PROTO_TRAITS(int8, int32, int);
    302 PROTO_TRAITS(bool, bool, bool);
    303 PROTO_TRAITS(string, string, string);
    304 PROTO_TRAITS(qint8, int32, int);
    305 PROTO_TRAITS(quint8, int32, int);
    306 PROTO_TRAITS(qint16, int32, int);
    307 PROTO_TRAITS(quint16, int32, int);
    308 #undef PROTO_TRAITS
    309 
    310 template <>
    311 struct ProtoHelper<int64> {
    312   static const int64* Begin(const TensorProto& proto) {
    313     return reinterpret_cast<const int64*>(proto.int64_val().begin());
    314   }
    315   static size_t NumElements(const TensorProto& proto) {
    316     return proto.int64_val().size();
    317   }
    318   static void Fill(const int64* data, size_t n, TensorProto* proto) {
    319     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
    320     proto->mutable_int64_val()->Swap(&copy);
    321   }
    322 };
    323 
    324 template <>
    325 struct ProtoHelper<uint64> {
    326   static const uint64* Begin(const TensorProto& proto) {
    327     return reinterpret_cast<const uint64*>(proto.uint64_val().begin());
    328   }
    329   static size_t NumElements(const TensorProto& proto) {
    330     return proto.uint64_val().size();
    331   }
    332   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
    333     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
    334     proto->mutable_uint64_val()->Swap(&copy);
    335   }
    336 };
    337 
    338 template <>
    339 struct ProtoHelper<ResourceHandle> {
    340   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
    341       const TensorProto& proto) {
    342     return proto.resource_handle_val().begin();
    343   }
    344   static size_t NumElements(const TensorProto& proto) {
    345     return proto.resource_handle_val().size();
    346   }
    347   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
    348     auto* handles = proto->mutable_resource_handle_val();
    349     handles->Clear();
    350     for (size_t i = 0; i < n; i++) {
    351       data[i].AsProto(handles->Add());
    352     }
    353   }
    354 };
    355 
    356 template <>
    357 struct ProtoHelper<Variant> {
    358   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
    359   Begin(const TensorProto& proto) {
    360     return proto.variant_val().begin();
    361   }
    362   static size_t NumElements(const TensorProto& proto) {
    363     return proto.variant_val().size();
    364   }
    365   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
    366     auto* variant_values = proto->mutable_variant_val();
    367     variant_values->Clear();
    368     for (size_t i = 0; i < n; ++i) {
    369       VariantTensorData tmp;
    370       data[i].Encode(&tmp);
    371       tmp.ToProto(variant_values->Add());
    372     }
    373   }
    374 };
    375 
    376 template <>
    377 struct ProtoHelper<complex64> {
    378   typedef Helper<float>::RepeatedFieldType FieldType;
    379   static const complex64* Begin(const TensorProto& proto) {
    380     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
    381   }
    382   static size_t NumElements(const TensorProto& proto) {
    383     return proto.scomplex_val().size() / 2;
    384   }
    385   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
    386     const float* p = reinterpret_cast<const float*>(data);
    387     FieldType copy(p, p + n * 2);
    388     proto->mutable_scomplex_val()->Swap(&copy);
    389   }
    390 };
    391 
    392 template <>
    393 struct ProtoHelper<complex128> {
    394   typedef Helper<double>::RepeatedFieldType FieldType;
    395   static const complex128* Begin(const TensorProto& proto) {
    396     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
    397   }
    398   static size_t NumElements(const TensorProto& proto) {
    399     return proto.dcomplex_val().size() / 2;
    400   }
    401   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
    402     const double* p = reinterpret_cast<const double*>(data);
    403     FieldType copy(p, p + n * 2);
    404     proto->mutable_dcomplex_val()->Swap(&copy);
    405   }
    406 };
    407 
    408 template <>
    409 struct ProtoHelper<qint32> {
    410   typedef Helper<int32>::RepeatedFieldType FieldType;
    411   static const qint32* Begin(const TensorProto& proto) {
    412     return reinterpret_cast<const qint32*>(proto.int_val().data());
    413   }
    414   static size_t NumElements(const TensorProto& proto) {
    415     return proto.int_val().size();
    416   }
    417   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
    418     const int32* p = reinterpret_cast<const int32*>(data);
    419     FieldType copy(p, p + n);
    420     proto->mutable_int_val()->Swap(&copy);
    421   }
    422 };
    423 
    424 template <>
    425 struct ProtoHelper<bfloat16> {
    426   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
    427     proto->mutable_half_val()->Reserve(n);
    428     for (size_t i = 0; i < n; ++i) {
    429       proto->mutable_half_val()->AddAlreadyReserved(data[i].value);
    430     }
    431   }
    432 };
    433 
    434 template <>
    435 struct ProtoHelper<Eigen::half> {
    436   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
    437     proto->mutable_half_val()->Reserve(n);
    438     for (size_t i = 0; i < n; ++i) {
    439       proto->mutable_half_val()->AddAlreadyReserved(data[i].x);
    440     }
    441   }
    442 };
    443 
    444 template <typename T>
    445 Buffer<T>::Buffer(Allocator* a, int64 n)
    446     : BufferBase(a, a->Allocate<T>(n)), elem_(n) {}
    447 
    448 template <typename T>
    449 Buffer<T>::Buffer(Allocator* a, int64 n,
    450                   const AllocationAttributes& allocation_attr)
    451     : BufferBase(a, a->Allocate<T>(n, allocation_attr)), elem_(n) {}
    452 
    453 template <typename T>
    454 Buffer<T>::~Buffer() {
    455   if (data()) {
    456     if (LogMemory::IsEnabled()) {
    457       RecordDeallocation();
    458     }
    459     alloc_->Deallocate<T>(static_cast<T*>(data()), elem_);
    460   }
    461 }
    462 
    463 // Allocates a T[n] buffer. Fills in the buffer with repeated values
    464 // in "in".  If "in" has less values than "n", fills the rest of T[n]
    465 // with the last value. If "in" has no values, fills T[n] with the
    466 // default value for T.
    467 //
    468 // This routine is using the typed fields (float_val, etc.) in the
    469 // tensor proto as opposed to the untyped binary representation
    470 // (tensor_content). This is used when we expect the TensorProto is
    471 // used by a client program which may not know how to encode a tensor
    472 // in the compact binary representation.
    473 template <typename T>
    474 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
    475   CHECK_GT(n, 0);
    476   Buffer<T>* buf = new Buffer<T>(a, n);
    477   T* data = buf->template base<T>();
    478   if (data == nullptr) {
    479     buf->Unref();
    480     return nullptr;
    481   }
    482 
    483   const int64 in_n = ProtoHelper<T>::NumElements(in);
    484   if (in_n <= 0) {
    485     std::fill_n(data, n, T());
    486   } else {
    487     auto begin = ProtoHelper<T>::Begin(in);
    488     if (n <= in_n) {
    489       std::copy_n(begin, n, data);
    490     } else {
    491       std::copy_n(begin, in_n, data);
    492       const T& last = *(data + in_n - 1);
    493       std::fill_n(data + in_n, n - in_n, last);
    494     }
    495   }
    496 
    497   return buf;
    498 }
    499 
    500 template <>
    501 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
    502                                       int64 n) {
    503   CHECK_GT(n, 0);
    504   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
    505   Variant* data = buf->template base<Variant>();
    506   if (data == nullptr) {
    507     buf->Unref();
    508     return nullptr;
    509   }
    510   const int64 in_n = ProtoHelper<Variant>::NumElements(in);
    511   if (in_n <= 0) {
    512     std::fill_n(data, n, Variant());
    513   } else {
    514     for (int64 i = 0; i < in_n; ++i) {
    515       data[i] = in.variant_val(i);
    516       if (!DecodeUnaryVariant(&data[i])) {
    517         LOG(ERROR) << "Could not decode variant with type_name: \""
    518                    << data[i].TypeName()
    519                    << "\".  Perhaps you forgot to register a "
    520                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
    521         buf->Unref();
    522         return nullptr;
    523       }
    524     }
    525     for (int64 i = in_n; i < n; ++i) {
    526       data[i] = Variant();
    527     }
    528   }
    529   return buf;
    530 }
    531 
    532 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
    533 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
    534 // we don't use ProtoHelper<uint16>).
    535 template <>
    536 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
    537                                           int64 n) {
    538   CHECK_GT(n, 0);
    539   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
    540   uint16* data = buf->template base<uint16>();
    541   if (data == nullptr) {
    542     buf->Unref();
    543     return nullptr;
    544   }
    545   const int64 in_n = in.half_val().size();
    546   auto begin = in.half_val().begin();
    547   if (n <= in_n) {
    548     std::copy_n(begin, n, data);
    549   } else if (in_n > 0) {
    550     std::copy_n(begin, in_n, data);
    551     const uint16 last = *(data + in_n - 1);
    552     std::fill_n(data + in_n, n - in_n, last);
    553   } else {
    554     std::fill_n(data, n, 0);
    555   }
    556   return buf;
    557 }
    558 
    559 template <>
    560 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
    561                                        int64 n) {
    562   CHECK_GT(n, 0);
    563   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
    564   uint16* data = buf->template base<uint16>();
    565   if (data == nullptr) {
    566     buf->Unref();
    567     return nullptr;
    568   }
    569   const int64 in_n = in.half_val().size();
    570   auto begin = in.half_val().begin();
    571   if (n <= in_n) {
    572     std::copy_n(begin, n, data);
    573   } else if (in_n > 0) {
    574     std::copy_n(begin, in_n, data);
    575     const uint16 last = *(data + in_n - 1);
    576     std::fill_n(data + in_n, n - in_n, last);
    577   } else {
    578     std::fill_n(data, n, 0);
    579   }
    580   return buf;
    581 }
    582 
    583 // Copies T[n] stored in the buffer "in" into the repeated field in
    584 // "out" corresponding to type T.
    585 template <typename T>
    586 void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
    587   const T* data = in.base<const T>();
    588   // NOTE: T may not the same as
    589   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
    590   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
    591   // critical, we can specialize T=float and do memcpy directly.
    592   ProtoHelper<T>::Fill(data, n, out);
    593 }
    594 
    595 void RefIfNonNull(core::RefCounted* buf) {
    596   if (buf) buf->Ref();
    597 }
    598 
    599 void UnrefIfNonNull(core::RefCounted* buf) {
    600   if (buf) buf->Unref();
    601 }
    602 
    603 }  // end namespace
    604 
    605 Tensor::Tensor() : Tensor(DT_FLOAT) {}
    606 
    607 Tensor::Tensor(DataType type) : shape_({0}), buf_(nullptr) { set_dtype(type); }
    608 
    609 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
    610     : shape_(shape), buf_(buf) {
    611   set_dtype(type);
    612   RefIfNonNull(buf);
    613 }
    614 
    615 bool Tensor::IsInitialized() const {
    616   return (buf_ != nullptr && buf_->data() != nullptr) ||
    617          shape_.num_elements() == 0;
    618 }
    619 
    620 void Tensor::CheckType(DataType expected_dtype) const {
    621   CHECK_EQ(dtype(), expected_dtype) << " "
    622       << DataTypeString(expected_dtype) << " expected, got "
    623       << DataTypeString(dtype());
    624 }
    625 
    626 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
    627   CHECK_EQ(dtype(), expected_dtype) << " "
    628       << DataTypeString(expected_dtype) << " expected, got "
    629       << DataTypeString(dtype());
    630   CHECK(IsAligned()) << "ptr = " << base<void>();
    631 }
    632 
    633 void Tensor::CheckIsAlignedAndSingleElement() const {
    634   CHECK(IsAligned()) << "Aligned and single element";
    635   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
    636 }
    637 
    638 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
    639 
    640 void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
    641   CHECK_EQ(shape.num_elements(), other.NumElements());
    642   // Data type will be overwritten if this == &other, since dtype is part of
    643   // shape.
    644   DataType other_dtype = other.dtype();
    645   shape_ = shape;
    646   set_dtype(other_dtype);
    647   if (buf_ != other.buf_) {
    648     UnrefIfNonNull(buf_);
    649     buf_ = other.buf_;
    650     RefIfNonNull(buf_);
    651   }
    652 }
    653 
    654 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
    655                            const TensorShape& shape) {
    656   int in_size = DataTypeSize(other.dtype());
    657   int out_size = DataTypeSize(dtype);
    658   if (in_size == 0) {
    659     return errors::InvalidArgument("other tensor has zero-sized data type");
    660   }
    661   if (out_size == 0) {
    662     return errors::InvalidArgument("specified output type is zero-sized");
    663   }
    664   if (shape.num_elements() * out_size !=
    665       other.shape().num_elements() * in_size) {
    666     return errors::InvalidArgument(
    667         "input and output shapes/data type sizes are not compatible");
    668   }
    669   shape_ = shape;
    670   shape_.set_data_type(dtype);
    671   if (buf_ != other.buf_) {
    672     UnrefIfNonNull(buf_);
    673     buf_ = other.buf_;
    674     RefIfNonNull(buf_);
    675   }
    676   return Status::OK();
    677 }
    678 
    679 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
    680 // For the latter case, we have to make sure that the refcount is
    681 // one both for the SubBuffer _and_ the underlying TensorBuffer.
    682 bool Tensor::RefCountIsOne() const {
    683   return buf_ != nullptr && buf_->RefCountIsOne() &&
    684          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
    685 }
    686 
    687 // The macro CASES() expands to a switch statement conditioned on
    688 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
    689 #define SINGLE_ARG(...) __VA_ARGS__
    690 #define CASE(TYPE, STMTS)             \
    691   case DataTypeToEnum<TYPE>::value: { \
    692     typedef TYPE T;                   \
    693     STMTS;                            \
    694     break;                            \
    695   }
    696 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
    697   switch (TYPE_ENUM) {                                         \
    698     CASE(float, SINGLE_ARG(STMTS))                             \
    699     CASE(double, SINGLE_ARG(STMTS))                            \
    700     CASE(int32, SINGLE_ARG(STMTS))                             \
    701     CASE(uint8, SINGLE_ARG(STMTS))                             \
    702     CASE(uint16, SINGLE_ARG(STMTS))                            \
    703     CASE(uint32, SINGLE_ARG(STMTS))                            \
    704     CASE(uint64, SINGLE_ARG(STMTS))                            \
    705     CASE(int16, SINGLE_ARG(STMTS))                             \
    706     CASE(int8, SINGLE_ARG(STMTS))                              \
    707     CASE(string, SINGLE_ARG(STMTS))                            \
    708     CASE(complex64, SINGLE_ARG(STMTS))                         \
    709     CASE(complex128, SINGLE_ARG(STMTS))                        \
    710     CASE(int64, SINGLE_ARG(STMTS))                             \
    711     CASE(bool, SINGLE_ARG(STMTS))                              \
    712     CASE(qint32, SINGLE_ARG(STMTS))                            \
    713     CASE(quint8, SINGLE_ARG(STMTS))                            \
    714     CASE(qint8, SINGLE_ARG(STMTS))                             \
    715     CASE(quint16, SINGLE_ARG(STMTS))                           \
    716     CASE(qint16, SINGLE_ARG(STMTS))                            \
    717     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
    718     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
    719     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
    720     CASE(Variant, SINGLE_ARG(STMTS))                           \
    721     case DT_INVALID:                                           \
    722       INVALID;                                                 \
    723       break;                                                   \
    724     default:                                                   \
    725       DEFAULT;                                                 \
    726       break;                                                   \
    727   }
    728 
    729 #define CASES(TYPE_ENUM, STMTS)                                      \
    730   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
    731                      , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
    732 
    733 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
    734     : shape_(shape), buf_(nullptr) {
    735   set_dtype(type);
    736   CHECK_NOTNULL(a);
    737   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
    738     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
    739   }
    740   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
    741     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
    742                                       *this);
    743   }
    744 }
    745 
    746 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
    747                const AllocationAttributes& allocation_attr)
    748     : shape_(shape), buf_(nullptr) {
    749   set_dtype(type);
    750   CHECK_NOTNULL(a);
    751   if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
    752     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
    753   }
    754   if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
    755       buf_->data() != nullptr && LogMemory::IsEnabled()) {
    756     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
    757                                       LogMemory::UNKNOWN_STEP_ID, *this);
    758   }
    759 }
    760 
    761 Tensor::Tensor(DataType type, const TensorShape& shape)
    762     : Tensor(cpu_allocator(), type, shape) {}
    763 
    764 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
    765     AllocationDescription* proto) const {
    766   proto->set_requested_bytes(size());
    767   proto->set_allocator_name("HostScalarTensorBuffer");
    768   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
    769 }
    770 
    771 template <typename T>
    772 class SubBuffer : public TensorBuffer {
    773  public:
    774   // This buffer is an alias to buf[delta, delta + n).
    775   SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
    776       : TensorBuffer(buf->base<T>() + delta),
    777         root_(buf->root_buffer()),
    778         elem_(n) {
    779     // Sanity check. The caller should ensure the sub buffer is valid.
    780     CHECK_LE(root_->base<T>(), this->base<T>());
    781     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
    782     CHECK_LE(this->base<T>(), root_limit);
    783     CHECK_LE(this->base<T>() + n, root_limit);
    784     // Hold a ref of the underlying root buffer.
    785     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
    786     root_->Ref();
    787   }
    788 
    789   size_t size() const override { return sizeof(T) * elem_; }
    790   TensorBuffer* root_buffer() override { return root_; }
    791   void FillAllocationDescription(AllocationDescription* proto) const override {
    792     root_->FillAllocationDescription(proto);
    793   }
    794 
    795  private:
    796   TensorBuffer* root_;
    797   T* data_;
    798   int64 elem_;
    799 
    800   ~SubBuffer() override { root_->Unref(); }
    801 
    802   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
    803 };
    804 
    805 Tensor Tensor::Slice(int64 start, int64 limit) const {
    806   CHECK_GE(dims(), 1);
    807   CHECK_LE(0, start);
    808   CHECK_LE(start, limit);
    809   int64 dim0_size = shape_.dim_size(0);
    810   CHECK_LE(limit, dim0_size);
    811   if ((start == 0) && (limit == dim0_size)) {
    812     return *this;
    813   }
    814   Tensor ret;
    815   ret.shape_ = shape_;
    816   ret.set_dtype(dtype());
    817   ret.buf_ = nullptr;
    818   if (dim0_size > 0) {
    819     const int64 elems_per_dim0 = NumElements() / dim0_size;
    820     const int64 delta = start * elems_per_dim0;
    821     dim0_size = limit - start;
    822     ret.shape_.set_dim(0, dim0_size);
    823     const int64 num_elems = dim0_size * elems_per_dim0;
    824     if (buf_) {
    825       DataType dt = dtype();
    826       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
    827     }
    828   }
    829   return ret;
    830 }
    831 
    832 Tensor Tensor::SubSlice(int64 index) const {
    833   CHECK_GE(dims(), 1);  // Crash ok.
    834   CHECK_LE(0, index);   // Crash ok.
    835   int64 dim0_size = shape_.dim_size(0);
    836   CHECK_LE(index, dim0_size);  // Crash ok.
    837   Tensor ret;
    838   ret.shape_ = shape_;
    839   ret.shape_.RemoveDim(0);
    840   ret.set_dtype(dtype());
    841   ret.buf_ = nullptr;
    842   if (dim0_size > 0) {
    843     const int64 elems_per_dim0 = NumElements() / dim0_size;
    844     const int64 delta = index * elems_per_dim0;
    845     const int64 num_elems = elems_per_dim0;
    846     if (buf_) {
    847       DataType dt = dtype();
    848       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
    849     }
    850   }
    851   return ret;
    852 }
    853 
    854 bool Tensor::FromProto(const TensorProto& proto) {
    855   return FromProto(cpu_allocator(), proto);
    856 }
    857 
    858 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
    859   CHECK_NOTNULL(a);
    860   TensorBuffer* p = nullptr;
    861   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
    862   if (proto.dtype() == DT_INVALID) return false;
    863   TensorShape shape(proto.tensor_shape());
    864   const int64 N = shape.num_elements();
    865   if (N > 0 && proto.dtype()) {
    866     bool dtype_error = false;
    867     if (!proto.tensor_content().empty()) {
    868       const auto& content = proto.tensor_content();
    869       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
    870                          dtype_error = true, dtype_error = true);
    871     } else {
    872       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
    873                          dtype_error = true, dtype_error = true);
    874     }
    875     if (dtype_error || p == nullptr) return false;
    876   }
    877   shape_ = shape;
    878   set_dtype(proto.dtype());
    879   UnrefIfNonNull(buf_);
    880   buf_ = p;
    881   // TODO(misard) add tracking of which kernels and steps are calling
    882   // FromProto.
    883   if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
    884     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
    885                                       LogMemory::UNKNOWN_STEP_ID, *this);
    886   }
    887   return true;
    888 }
    889 
    890 void Tensor::AsProtoField(TensorProto* proto) const {
    891   proto->Clear();
    892   shape_.AsProto(proto->mutable_tensor_shape());
    893   proto->set_dtype(dtype());
    894   if (buf_) {
    895     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
    896   }
    897 }
    898 
    899 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
    900   proto->Clear();
    901   proto->set_dtype(dtype());
    902   shape_.AsProto(proto->mutable_tensor_shape());
    903   if (buf_) {
    904     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
    905                                      proto->mutable_tensor_content()));
    906   }
    907 }
    908 
    909 size_t Tensor::TotalBytes() const {
    910   if (shape_.num_elements() == 0) return 0;
    911   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
    912   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
    913   return 0;  // Makes compiler happy.
    914 }
    915 
    916 size_t Tensor::AllocatedBytes() const {
    917   TensorDescription tensor_description;
    918   FillDescription(&tensor_description);
    919   if (tensor_description.has_allocation_description() &&
    920       tensor_description.allocation_description().allocated_bytes() > 0) {
    921     return tensor_description.allocation_description().allocated_bytes();
    922   } else {
    923     // Fall back to TotalBytes() if the allocator doesn't have its size.
    924     return TotalBytes();
    925   }
    926 }
    927 
    928 bool Tensor::CanUseDMA() const {
    929   CASES(dtype(), return is_simple_type<T>::value);
    930   return false;  // Makes compiler happy.
    931 }
    932 
    933 #undef CASES
    934 #undef CASE
    935 
    936 namespace {
    937 
    938 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
    939 // we would like to keep them compatible with their absl counterparts, for ease
    940 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
    941 // logic is so simple we can just replicate it here, where it is close to its
    942 // usage and easy to change later. And there's the extra benefit of not
    943 // accessing an 'internal' namespace.
    944 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
    945                                                 bool print_v2) {
    946   return a;
    947 }
    948 inline string PrintOneElement(const string& a, bool print_v2) {
    949   if (print_v2) {
    950     return "\"" + str_util::CEscape(a) + "\"";
    951   } else {
    952     return str_util::CEscape(a);
    953   }
    954 }
    955 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
    956   return static_cast<float>(h);
    957 }
    958 
    959 // Print from left dim to right dim recursively.
    960 template <typename T>
    961 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
    962                  int64 limit, int shape_size, const T* data, int64* data_index,
    963                  string* result) {
    964   if (*data_index >= limit) return;
    965   int64 element_count = shape[dim_index];
    966   // We have reached the right-most dimension of the tensor.
    967   if (dim_index == shape_size - 1) {
    968     for (int64 i = 0; i < element_count; i++) {
    969       if (*data_index >= limit) {
    970         // If not enough elements has been printed, append "...".
    971         if (dim_index != 0 && i < element_count) {
    972           strings::StrAppend(result, "...");
    973         }
    974         return;
    975       }
    976       if (i > 0) strings::StrAppend(result, " ");
    977       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
    978     }
    979     return;
    980   }
    981   // Loop every element of one dim.
    982   for (int64 i = 0; i < element_count; i++) {
    983     bool flag = false;
    984     if (*data_index < limit) {
    985       strings::StrAppend(result, "[");
    986       flag = true;
    987     }
    988     // As for each element, print the sub-dim.
    989     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
    990                 result);
    991     if (*data_index < limit || flag) {
    992       strings::StrAppend(result, "]");
    993       flag = false;
    994     }
    995   }
    996 }
    997 
    998 // Appends the spacing between elements for a given dim onto a result string
    999 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
   1000   if (dim_index == num_dims - 1) {
   1001     strings::StrAppend(result, " ");
   1002     return;
   1003   }
   1004   for (int j = 0; j < num_dims - dim_index - 1; j++) {
   1005     strings::StrAppend(result, "\n");
   1006   }
   1007   for (int j = 0; j <= dim_index; j++) {
   1008     strings::StrAppend(result, " ");
   1009   }
   1010 }
   1011 
   1012 // Print from left dim to right dim recursively.
   1013 template <typename T>
   1014 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
   1015                    int64 num_elts_at_ends, int num_dims, const T* data,
   1016                    int64 data_index, string* result) {
   1017   // We have recursed beyond all the dimensions into a single element
   1018   // of the tensor.
   1019   if (dim_index == num_dims) {
   1020     strings::StrAppend(result, PrintOneElement(data[data_index], true));
   1021     return;
   1022   }
   1023 
   1024   strings::StrAppend(result, "[");
   1025   int64 element_count = shape[dim_index];
   1026   int64 start_of_end =
   1027       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
   1028 
   1029   // Loop every element of one dim.
   1030   int64 elements_per_iter = 1;
   1031   for (int i = dim_index + 1; i < num_dims; i++) {
   1032     elements_per_iter *= shape[i];
   1033   }
   1034   for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
   1035     if (i > 0) {
   1036       PrintDimSpacing(dim_index, num_dims, result);
   1037     }
   1038 
   1039     // As for each element, print the sub-dim.
   1040     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
   1041                   data_index + elements_per_iter * i, result);
   1042   }
   1043   if (element_count > 2 * num_elts_at_ends) {
   1044     PrintDimSpacing(dim_index, num_dims, result);
   1045     strings::StrAppend(result, "...");
   1046   }
   1047   for (int64 i = start_of_end; i < element_count; i++) {
   1048     // As for each element, print the sub-dim.
   1049     PrintDimSpacing(dim_index, num_dims, result);
   1050     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
   1051                   data_index + elements_per_iter * i, result);
   1052   }
   1053 
   1054   strings::StrAppend(result, "]");
   1055 }
   1056 
   1057 template <typename T>
   1058 string SummarizeArray(int64 limit, int64 num_elts,
   1059                       const TensorShape& tensor_shape, const char* data,
   1060                       const bool print_v2) {
   1061   string ret;
   1062   const T* array = reinterpret_cast<const T*>(data);
   1063 
   1064   const gtl::InlinedVector<int64, 4> shape = tensor_shape.dim_sizes();
   1065   if (shape.empty()) {
   1066     for (int64 i = 0; i < limit; ++i) {
   1067       if (i > 0) strings::StrAppend(&ret, " ");
   1068       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
   1069     }
   1070     if (num_elts > limit) strings::StrAppend(&ret, "...");
   1071     return ret;
   1072   }
   1073   if (print_v2) {
   1074     const int num_dims = tensor_shape.dims();
   1075     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
   1076   } else {
   1077     int64 data_index = 0;
   1078     const int shape_size = tensor_shape.dims();
   1079     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
   1080 
   1081     if (num_elts > limit) strings::StrAppend(&ret, "...");
   1082   }
   1083 
   1084   return ret;
   1085 }
   1086 }  // namespace
   1087 
   1088 string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
   1089   const int64 num_elts = NumElements();
   1090   if (max_entries < 0) {
   1091     max_entries = num_elts;
   1092   }
   1093   size_t limit = std::min(max_entries, num_elts);
   1094   if ((limit > 0) && (buf_ == nullptr)) {
   1095     return strings::StrCat("uninitialized Tensor of ", num_elts,
   1096                            " elements of type ", dtype());
   1097   }
   1098   const char* data = limit > 0 ? tensor_data().data() : nullptr;
   1099   switch (dtype()) {
   1100     case DT_HALF:
   1101       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
   1102                                          print_v2);
   1103       break;
   1104     case DT_FLOAT:
   1105       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
   1106       break;
   1107     case DT_DOUBLE:
   1108       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
   1109       break;
   1110     case DT_UINT32:
   1111       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
   1112       break;
   1113     case DT_INT32:
   1114       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
   1115       break;
   1116     case DT_UINT8:
   1117     case DT_QUINT8:
   1118       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
   1119       break;
   1120     case DT_UINT16:
   1121     case DT_QUINT16:
   1122       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
   1123       break;
   1124     case DT_INT16:
   1125     case DT_QINT16:
   1126       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
   1127       break;
   1128     case DT_INT8:
   1129     case DT_QINT8:
   1130       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
   1131       break;
   1132     case DT_UINT64:
   1133       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
   1134       break;
   1135     case DT_INT64:
   1136       return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
   1137       break;
   1138     case DT_BOOL:
   1139       // TODO(tucker): Is it better to emit "True False..."?  This
   1140       // will emit "1 0..." which is more compact.
   1141       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
   1142       break;
   1143     case DT_STRING:
   1144       return SummarizeArray<string>(limit, num_elts, shape_, data, print_v2);
   1145       break;
   1146     default: {
   1147       // All irregular cases
   1148       string ret;
   1149       if (print_v2) {
   1150         strings::StrAppend(&ret, "[");
   1151       }
   1152       // TODO(irving): Don't call flat every time around this
   1153       // loop.
   1154       for (size_t i = 0; i < limit; ++i) {
   1155         if (i > 0) strings::StrAppend(&ret, " ");
   1156         switch (dtype()) {
   1157           case DT_VARIANT: {
   1158             const Variant& v = flat<Variant>()(i);
   1159             strings::StrAppend(&ret, v.DebugString());
   1160           } break;
   1161           default:
   1162             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
   1163             // complex64, quantized).
   1164             strings::StrAppend(&ret, "?");
   1165         }
   1166       }
   1167       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
   1168       if (print_v2) {
   1169         strings::StrAppend(&ret, "]");
   1170       }
   1171       return ret;
   1172     }
   1173   }
   1174 }
   1175 
   1176 StringPiece Tensor::tensor_data() const {
   1177   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
   1178   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
   1179 }
   1180 
   1181 bool Tensor::SharesBufferWith(const Tensor& b) const {
   1182   return buf_ != nullptr && b.buf_ != nullptr &&
   1183          buf_->root_buffer() == b.buf_->root_buffer();
   1184 }
   1185 
   1186 string Tensor::DebugString(int num_values) const {
   1187   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
   1188                          " shape: ", shape().DebugString(),
   1189                          " values: ", SummarizeValue(num_values), ">");
   1190 }
   1191 
   1192 string Tensor::DeviceSafeDebugString() const {
   1193   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
   1194                          " shape: ", shape().DebugString(), ">");
   1195 }
   1196 
   1197 void Tensor::FillDescription(TensorDescription* description) const {
   1198   description->set_dtype(dtype());
   1199   shape().AsProto(description->mutable_shape());
   1200   if (buf_ != nullptr && buf_->data() != nullptr) {
   1201     buf_->FillAllocationDescription(
   1202         description->mutable_allocation_description());
   1203   }
   1204 }
   1205 
   1206 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatInnerDims(
   1207     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
   1208   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
   1209   int64 offset = orig.size() - num_out_dims;
   1210   for (int64 out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
   1211     const int64 in_dim = out_dim + offset;
   1212     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
   1213   }
   1214   for (int64 in_dim = 0; in_dim < offset; ++in_dim) {
   1215     out_dims[0] *= orig[in_dim];
   1216   }
   1217   return out_dims;
   1218 }
   1219 
   1220 gtl::InlinedVector<int64, 4> Tensor::ComputeFlatOuterDims(
   1221     gtl::ArraySlice<int64> orig, int64 num_out_dims) {
   1222   gtl::InlinedVector<int64, 4> out_dims(num_out_dims, 0);
   1223   for (int64 out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
   1224     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
   1225   }
   1226   for (int64 in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
   1227     out_dims[num_out_dims - 1] *= orig[in_dim];
   1228   }
   1229   return out_dims;
   1230 }
   1231 
   1232 }  // namespace tensorflow
   1233