Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
     18 
     19 #include <algorithm>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor.pb.h"
     24 #include "tensorflow/core/framework/tensor_shape.pb.h"
     25 #include "tensorflow/core/framework/type_traits.h"
     26 #include "tensorflow/core/platform/protobuf.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 namespace tensor {
     31 
     32 // DeepCopy returns a tensor whose contents are a deep copy of the
     33 // contents of 'other'.  This function is intended only for
     34 // convenience, not speed.
     35 //
     36 // REQUIRES: 'other' must point to data stored in CPU memory.
     37 // REQUIRES: 'other' must be a Tensor of a copy-able type if
     38 //           'other' is not appropriately memory-aligned.
     39 Tensor DeepCopy(const Tensor& other);
     40 
     41 // Concatenates 'tensors' into a single tensor, along their 0th dimension.
     42 //
     43 // REQUIRES: All members of 'tensors' must have the same data type parameter.
     44 // REQUIRES: Each member of 'tensors' must have at least one dimension.
     45 // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
     46 // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
     47 //           is not appropriately memory-aligned.
     48 Status Concat(const gtl::ArraySlice<Tensor>& tensors,
     49               Tensor* result) TF_MUST_USE_RESULT;
     50 
     51 // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
     52 // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
     53 //
     54 // REQUIRES: 'tensor' must have at least one dimension.
     55 // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
     56 // REQUIRES: 'tensor' must point to data stored in CPU memory.
     57 // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
     58 //           appropriately memory-aligned.
     59 //
     60 // Split() and Concat() are inverse operations.
     61 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
     62              std::vector<Tensor>* result) TF_MUST_USE_RESULT;
     63 
     64 namespace internal {
     65 void SetTensorProtoShape(std::vector<size_t> shape,
     66                          TensorShapeProto* shape_proto);
     67 
     68 template <typename Type>
     69 class TensorProtoFieldHelper : public std::false_type {};
     70 
     71 #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME)                            \
     72   template <>                                                                 \
     73   class TensorProtoFieldHelper<TYPE> : public std::true_type {                \
     74    public:                                                                    \
     75     typedef decltype(                                                         \
     76         std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType;            \
     77     typedef decltype(                                                         \
     78         std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType;     \
     79     typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \
     80         MutableRepeatedFieldType;                                             \
     81     static MutableRepeatedFieldType GetMutableField(TensorProto* proto) {     \
     82       return proto->mutable_##FIELDNAME##_val();                              \
     83     }                                                                         \
     84     static RepeatedFieldType& GetField(const TensorProto& proto) {            \
     85       return proto.FIELDNAME##_val();                                         \
     86     }                                                                         \
     87   }
     88 
     89 // The argument pairs in the following macro instantiations encode the
     90 // mapping from C++ type ($1) to repeated field name "$2_val" used for storing
     91 // values in TensorProto. See tensorflow/core/framework/tensor.proto.
     92 DEFINE_PROTO_FIELD_HELPER(float, float);
     93 DEFINE_PROTO_FIELD_HELPER(double, double);
     94 DEFINE_PROTO_FIELD_HELPER(int8, int);
     95 DEFINE_PROTO_FIELD_HELPER(uint8, int);
     96 DEFINE_PROTO_FIELD_HELPER(int16, int);
     97 DEFINE_PROTO_FIELD_HELPER(uint16, int);
     98 DEFINE_PROTO_FIELD_HELPER(int32, int);
     99 DEFINE_PROTO_FIELD_HELPER(uint32, uint32);
    100 DEFINE_PROTO_FIELD_HELPER(int64, int64);
    101 DEFINE_PROTO_FIELD_HELPER(uint64, uint64);
    102 DEFINE_PROTO_FIELD_HELPER(bool, bool);
    103 DEFINE_PROTO_FIELD_HELPER(qint8, int);
    104 DEFINE_PROTO_FIELD_HELPER(quint8, int);
    105 DEFINE_PROTO_FIELD_HELPER(qint16, int);
    106 DEFINE_PROTO_FIELD_HELPER(quint16, int);
    107 DEFINE_PROTO_FIELD_HELPER(qint32, int);
    108 DEFINE_PROTO_FIELD_HELPER(Eigen::half, half);
    109 DEFINE_PROTO_FIELD_HELPER(bfloat16, half);
    110 DEFINE_PROTO_FIELD_HELPER(complex64, scomplex);
    111 DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex);
    112 
    113 #undef DEFINE_PROTO_HELPER
    114 
    115 template <typename T>
    116 struct CopyHelper {
    117   template <typename SrcIter, typename DstIter>
    118   static void ToArray(SrcIter begin, SrcIter end, DstIter dst) {
    119     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
    120     using DstType = typename std::iterator_traits<DstIter>::value_type;
    121     std::transform(begin, end, dst, [](const SrcType& x) -> DstType {
    122       return static_cast<DstType>(x);
    123     });
    124   }
    125   template <typename SrcIter>
    126   static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) {
    127     std::copy(begin, end, dst);
    128   }
    129   template <typename SrcIter, typename DstIter>
    130   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
    131     ToArray(begin, end, dst);
    132   }
    133 };
    134 
    135 // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are
    136 // stored in an int32 field.
    137 template <>
    138 struct CopyHelper<Eigen::half> {
    139   template <typename SrcIter>
    140   static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) {
    141     std::transform(begin, end, dst, [](int x) -> Eigen::half {
    142       Eigen::half h;
    143       h.x = static_cast<uint16>(x);
    144       return h;
    145     });
    146   }
    147   template <typename SrcIter, typename DstIter>
    148   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
    149     std::transform(begin, end, dst,
    150                    [](Eigen::half h) -> int { return static_cast<int>(h.x); });
    151   }
    152 };
    153 
    154 template <>
    155 struct CopyHelper<bfloat16> {
    156   template <typename SrcIter>
    157   static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) {
    158     std::transform(begin, end, dst, [](int x) -> bfloat16 {
    159       bfloat16 bf16;
    160       bf16.value = static_cast<uint16>(x);
    161       return bf16;
    162     });
    163   }
    164   template <typename SrcIter, typename DstIter>
    165   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
    166     std::transform(begin, end, dst, [](bfloat16 bf16) -> int {
    167       return static_cast<int>(bf16.value);
    168     });
    169   }
    170 };
    171 
    172 // Overloads for complex types that store real and imaginary parts
    173 // at indices 2*i and 2*i+1 in float or double field.
    174 template <typename RealType>
    175 struct CopyHelper<std::complex<RealType>> {
    176   template <typename SrcIter>
    177   static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) {
    178     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
    179     RealType* real_dst = reinterpret_cast<RealType*>(dst);
    180     std::copy(begin, end, real_dst);
    181   }
    182 
    183   template <typename SrcIter, typename DstIter>
    184   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
    185     using DstType = typename std::iterator_traits<DstIter>::value_type;
    186     size_t n = std::distance(begin, end);
    187     const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin));
    188     std::copy_n(real_begin, 2 * n, dst);
    189   }
    190 };
    191 
    192 // Helper class to extract and insert values into TensorProto represented as
    193 // repeated fields.
    194 template <typename T>
    195 class TensorProtoHelper : public std::true_type {
    196  public:
    197   using FieldHelper = TensorProtoFieldHelper<T>;
    198   using FieldType = typename TensorProtoFieldHelper<T>::FieldType;
    199 
    200   static DataType GetDataType() { return DataTypeToEnum<T>::value; }
    201 
    202   // Returns the number of values of type T encoded in the proto.
    203   static size_t NumValues(const TensorProto& proto) {
    204     size_t raw_size = FieldHelper::GetField(proto).size();
    205     return is_complex<T>::value ? raw_size / 2 : raw_size;
    206   }
    207 
    208   static void AddValue(const T& value, TensorProto* proto) {
    209     const T* val_ptr = &value;
    210     AddValues(val_ptr, val_ptr + 1, proto);
    211   }
    212 
    213   static T GetValue(size_t index, const TensorProto& proto) {
    214     T val;
    215     if (is_complex<T>::value) index *= 2;
    216     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index,
    217                            FieldHelper::GetField(proto).begin() + index + 1,
    218                            &val);
    219     return val;
    220   }
    221 
    222   template <typename IterType>
    223   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
    224     size_t n = std::distance(begin, end);
    225     FieldType* dst = AppendUninitialized(n, proto);
    226     CopyHelper<T>::FromArray(begin, end, dst);
    227   }
    228 
    229   template <typename IterType>
    230   static void CopyValues(IterType dst, const TensorProto& proto) {
    231     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(),
    232                            FieldHelper::GetField(proto).end(), dst);
    233   }
    234 
    235   static void Truncate(size_t new_size, TensorProto* proto) {
    236     if (is_complex<T>::value) new_size *= 2;
    237     FieldHelper::GetMutableField(proto)->Truncate(new_size);
    238   }
    239 
    240   static FieldType* AppendUninitialized(size_t n, TensorProto* proto) {
    241     if (is_complex<T>::value) n *= 2;
    242     auto* field = FieldHelper::GetMutableField(proto);
    243     field->Reserve(field->size() + n);
    244     return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n));
    245   }
    246 };
    247 
    248 // Specialization for string.
    249 template <>
    250 class TensorProtoHelper<string> : public std::true_type {
    251  public:
    252   static DataType GetDataType() { return DataType::DT_STRING; }
    253   static void AddValue(const string& value, TensorProto* proto) {
    254     *proto->mutable_string_val()->Add() = value;
    255   }
    256   template <typename IterType>
    257   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
    258     for (IterType it = begin; it != end; ++it) {
    259       AddValue(*it, proto);
    260     }
    261   }
    262   template <typename IterType>
    263   static void CopyToTensorContent(IterType begin, IterType end,
    264                                   TensorProto* proto) {
    265     AddValues(begin, end, proto);
    266   }
    267 };
    268 
    269 }  // namespace internal
    270 
    271 // Creates a 'TensorProto' with specified shape and values.
    272 // The dtype and a field to represent data values of the returned 'TensorProto'
    273 // are determined based on type of the 'values' parameter.
    274 template <typename Type>
    275 typename std::enable_if<internal::TensorProtoHelper<Type>::value,
    276                         TensorProto>::type
    277 CreateTensorProto(const std::vector<Type>& values,
    278                   const std::vector<size_t>& shape) {
    279   TensorProto tensor;
    280   TensorShapeProto tensor_shape_proto;
    281   internal::SetTensorProtoShape(shape, &tensor_shape_proto);
    282   if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
    283     LOG(ERROR) << "Shape and number of values (" << values.size()
    284                << ") are incompatible.";
    285     return tensor;
    286   }
    287   using TypeHelper = internal::TensorProtoHelper<Type>;
    288   tensor.set_dtype(TypeHelper::GetDataType());
    289   tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
    290   TypeHelper::AddValues(values.begin(), values.end(), &tensor);
    291   return tensor;
    292 }
    293 
    294 // Converts values in tensor to run-length encoded compressed form.
    295 //
    296 // The elements of a tensor can be stored in a TensorProto in one of the
    297 // following two forms:
    298 // 1. As a raw byte string in the field `tensor_content` containing the
    299 //    serialized in-memory representation of the tensor.
    300 // 2. As values of a repeated field depending on the datatype, e.g. that
    301 //    values of a DT_FLOAT tensor would be stored in the repeated field
    302 //    `float_val`.
    303 // Storage scheme 2 may use a simple form of run-length encoding to compress
    304 // data: If the values contains a tail of identical values, the repeated field
    305 // will be truncated such that the number of values in the repeated field is
    306 // less than the number of elements implied by the field`tensor_shape`. The
    307 // original tensor can be recovered by repeating the final value in the repeated
    308 // field.
    309 //
    310 // The TensorProto will be compressed if a) the tensor contains at least
    311 // min_num_elements elements and b) the compressed tensor proto is would be at
    312 // most the size of the original tensor proto divided by min_compression_ratio.
    313 //
    314 // Returns true if the tensor was compressed.
    315 bool CompressTensorProtoInPlace(int64 min_num_elements,
    316                                 float min_compression_ratio,
    317                                 TensorProto* tensor);
    318 
    319 inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
    320   static const int64 kDefaultMinNumElements = 64;
    321   static const float kDefaultMinCompressionRatio = 2.0f;
    322   return CompressTensorProtoInPlace(kDefaultMinNumElements,
    323                                     kDefaultMinCompressionRatio, tensor);
    324 }
    325 
    326 }  // namespace tensor
    327 }  // namespace tensorflow
    328 
    329 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
    330