Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/tensor_util.h"
     17 
     18 #include <cmath>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/type_traits.h"
     23 #include "tensorflow/core/framework/variant.h"
     24 #include "tensorflow/core/lib/core/stringpiece.h"
     25 #include "tensorflow/core/platform/protobuf.h"
     26 #include "tensorflow/core/platform/tensor_coding.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 namespace tensor {
     31 
     32 Tensor DeepCopy(const Tensor& other) {
     33   Tensor tmp = Tensor(other.dtype(), other.shape());
     34   if (DataTypeCanUseMemcpy(other.dtype())) {
     35     if (other.NumElements() > 0) {
     36       StringPiece other_data = other.tensor_data();
     37 
     38       // We use StringPiece as a convenient map over the tensor buffer,
     39       // but we cast the type to get to the underlying buffer to do the
     40       // copy.
     41       StringPiece tmp_data = tmp.tensor_data();
     42       memcpy(const_cast<char*>(tmp_data.data()), other_data.data(),
     43              other_data.size());
     44     }
     45   } else if (other.dtype() == DT_STRING) {
     46     tmp.unaligned_flat<string>() = other.unaligned_flat<string>();
     47   } else {
     48     CHECK_EQ(DT_VARIANT, other.dtype());
     49     tmp.unaligned_flat<Variant>() = other.unaligned_flat<Variant>();
     50   }
     51   return tmp;
     52 }
     53 
     54 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
     55   if (tensors.empty()) {
     56     return errors::InvalidArgument("Cannot concatenate zero tensors");
     57   }
     58   int64 total_dim0_size = 0;
     59   for (const Tensor& tensor : tensors) {
     60     if (tensor.dims() == 0) {
     61       return errors::InvalidArgument(
     62           "Cannot concatenate a zero-dimensional tensor");
     63     }
     64     total_dim0_size += tensor.dim_size(0);
     65   }
     66   TensorShape shape = tensors[0].shape();
     67   shape.set_dim(0, total_dim0_size);
     68 
     69   const DataType dtype = tensors[0].dtype();
     70   for (int i = 1; i < tensors.size(); ++i) {
     71     if (tensors[i].dtype() != dtype) {
     72       return errors::InvalidArgument(
     73           "Cannot concatenate tensors that have different data types");
     74     }
     75   }
     76   *result = Tensor(dtype, shape);
     77 
     78   // We use StringPiece as a convenient map over the tensor buffer,
     79   // but we cast the type to get to the underlying buffer to do the
     80   // copy.
     81   StringPiece to_data = result->tensor_data();
     82 
     83   if (DataTypeCanUseMemcpy(dtype)) {
     84     int64 offset = 0;
     85     for (const Tensor& tensor : tensors) {
     86       StringPiece from_data = tensor.tensor_data();
     87       CHECK_LE(offset + from_data.size(), to_data.size());
     88       memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
     89              from_data.size());
     90 
     91       offset += from_data.size();
     92     }
     93   } else {
     94     if (dtype != DT_STRING) {
     95       return errors::Internal("Unexpected data type");
     96     }
     97     string* to_strings =
     98         reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
     99 
    100     int64 offset = 0;
    101     for (const Tensor& tensor : tensors) {
    102       auto from_strings = tensor.flat<string>();
    103       CHECK_LE(offset + tensor.NumElements(), result->NumElements());
    104       for (int i = 0; i < tensor.NumElements(); ++i) {
    105         to_strings[offset + i] = from_strings(i);
    106       }
    107 
    108       offset += tensor.NumElements();
    109     }
    110   }
    111 
    112   return Status::OK();
    113 }
    114 
    115 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
    116              std::vector<Tensor>* result) {
    117   if (tensor.dims() == 0) {
    118     return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
    119   }
    120   int64 total_size = 0;
    121   for (int64 size : sizes) {
    122     total_size += size;
    123   }
    124   if (total_size != tensor.dim_size(0)) {
    125     return errors::InvalidArgument(
    126         "The values in 'sizes' do not sum to the zeroth-dimension size of "
    127         "'tensor'");
    128   }
    129 
    130   StringPiece from_data = tensor.tensor_data();
    131 
    132   if (DataTypeCanUseMemcpy(tensor.dtype())) {
    133     int64 offset = 0;
    134     for (int64 size : sizes) {
    135       TensorShape shape = tensor.shape();
    136       shape.set_dim(0, size);
    137       result->emplace_back(tensor.dtype(), shape);
    138       Tensor* split = &(*result)[result->size() - 1];
    139 
    140       // We use StringPiece as a convenient map over the tensor buffer,
    141       // but we cast the type to get to the underlying buffer to do the
    142       // copy.
    143       StringPiece to_data = split->tensor_data();
    144       CHECK_LE(offset + to_data.size(), from_data.size());
    145       memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
    146              to_data.size());
    147 
    148       offset += to_data.size();
    149     }
    150   } else {
    151     if (tensor.dtype() != DT_STRING) {
    152       return errors::Internal("Unexpected data type");
    153     }
    154     auto from_strings = tensor.flat<string>();
    155 
    156     int64 offset = 0;
    157     for (int64 size : sizes) {
    158       TensorShape shape = tensor.shape();
    159       shape.set_dim(0, size);
    160       result->emplace_back(tensor.dtype(), shape);
    161       Tensor& split = (*result)[result->size() - 1];
    162       string* to_strings = reinterpret_cast<string*>(
    163           const_cast<char*>(split.tensor_data().data()));
    164 
    165       CHECK_LE(offset + split.NumElements(), tensor.NumElements());
    166       for (int i = 0; i < split.NumElements(); ++i) {
    167         to_strings[i] = from_strings(offset + i);
    168       }
    169 
    170       offset += split.NumElements();
    171     }
    172   }
    173 
    174   return Status::OK();
    175 }
    176 
    177 namespace internal {
    178 void SetTensorProtoShape(std::vector<size_t> shape,
    179                          TensorShapeProto* shape_proto) {
    180   for (auto dim : shape) {
    181     shape_proto->mutable_dim()->Add()->set_size(dim);
    182   }
    183 }
    184 
    185 template <typename T>
    186 bool CompressTensorContent(float min_compression_ratio,
    187                            const TensorShape& shape, TensorProto* tensor) {
    188   using TypeHelper = internal::TensorProtoHelper<T>;
    189   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
    190   const int64 num_tensor_values = shape.num_elements();
    191   const int64 num_bytes = tensor->tensor_content().size();
    192   const int64 num_raw_values = num_bytes / sizeof(T);
    193   if (num_raw_values != num_tensor_values) {
    194     // Invalid or too small.
    195     return false;
    196   }
    197   int64 last_offset = num_bytes - 1;
    198   int64 prev_offset = last_offset - sizeof(T);
    199   // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
    200   // starting from the end, to find the last pair of elements that are not
    201   // identical.
    202   while (prev_offset >= 0) {
    203     if (tensor->tensor_content()[prev_offset] !=
    204         tensor->tensor_content()[last_offset]) {
    205       break;
    206     }
    207     --last_offset;
    208     --prev_offset;
    209   }
    210   // Round up to the next whole number of element of type T.
    211   const int64 new_num_values = last_offset / sizeof(T) + 1;
    212   if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
    213       static_cast<int64>(num_bytes / min_compression_ratio)) {
    214     return false;
    215   }
    216   // Copy values to truncated repeated field.
    217   if (sizeof(FieldType) == sizeof(T)) {
    218     FieldType* dst_ptr =
    219         TypeHelper::AppendUninitialized(new_num_values, tensor);
    220     port::CopySubrangeToArray(tensor->tensor_content(), 0,
    221                               new_num_values * sizeof(T),
    222                               reinterpret_cast<char*>(dst_ptr));
    223     tensor->clear_tensor_content();
    224   } else if (sizeof(T) > 1) {
    225     // Copy raw bytes to temp array first, then cast.
    226     gtl::InlinedVector<T, 64> tmp(new_num_values);
    227     port::CopySubrangeToArray(tensor->tensor_content(), 0,
    228                               new_num_values * sizeof(T),
    229                               reinterpret_cast<char*>(tmp.data()));
    230     tensor->clear_tensor_content();
    231     const T* begin = tmp.begin();
    232     const T* end = tmp.end();
    233     TypeHelper::AddValues(begin, end, tensor);
    234   } else {
    235     // Copy and cast, one byte at a time.
    236     for (int64 i = 0; i < new_num_values; ++i) {
    237       char c = tensor->tensor_content()[i];
    238       TypeHelper::AddValue(static_cast<T>(c), tensor);
    239     }
    240     tensor->clear_tensor_content();
    241   }
    242   return true;
    243 }
    244 
    245 template <typename T>
    246 inline bool PackedValuesNotEqual(T a, T b) {
    247   return a != b;
    248 }
    249 template <>
    250 inline bool PackedValuesNotEqual(float a, float b) {
    251   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
    252 }
    253 template <>
    254 inline bool PackedValuesNotEqual(double a, double b) {
    255   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
    256 }
    257 template <typename RealType>
    258 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
    259                                  const std::complex<RealType>& b) {
    260   return PackedValuesNotEqual(a.real(), b.real()) ||
    261          PackedValuesNotEqual(a.imag(), b.imag());
    262 }
    263 
    264 template <typename T>
    265 bool CompressRepeatedField(float min_compression_ratio,
    266                            const TensorShape& shape, TensorProto* tensor) {
    267   using TypeHelper = internal::TensorProtoHelper<T>;
    268   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
    269   const int64 num_tensor_values = shape.num_elements();
    270   // Notice that for complex types the tensor is stored as an array of up to
    271   // 2 * num_tensor_values real values (real and imaginary parts), possibly
    272   // truncated.
    273   const int64 num_proto_values = TypeHelper::NumValues(*tensor);
    274   if (num_proto_values != num_tensor_values) {
    275     // Already compressed or invalid.
    276     return false;
    277   }
    278   const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
    279   int64 last_index = 0;
    280   for (int64 i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
    281     const T cur_value = TypeHelper::GetValue(i, *tensor);
    282     if (PackedValuesNotEqual(cur_value, last_value)) {
    283       last_index = i + 1;
    284     }
    285   }
    286   const int64 num_truncated_proto_values = last_index + 1;
    287   const int64 num_bytes_as_field =
    288       num_truncated_proto_values * sizeof(FieldType);
    289   const int64 num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
    290   const int64 num_bytes_before = num_proto_values * sizeof(FieldType);
    291   if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
    292       static_cast<int64>(num_bytes_before / min_compression_ratio)) {
    293     return false;
    294   }
    295   if (num_bytes_as_field <= num_bytes_as_tensor_content) {
    296     TypeHelper::Truncate(num_truncated_proto_values, tensor);
    297   } else {
    298     gtl::InlinedVector<T, 64> tmp(num_tensor_values);
    299     TypeHelper::CopyValues(tmp.begin(), *tensor);
    300     TypeHelper::Truncate(0, tensor);
    301     port::CopyFromArray(tensor->mutable_tensor_content(),
    302                         reinterpret_cast<const char*>(tmp.data()),
    303                         num_bytes_as_tensor_content);
    304   }
    305   return true;
    306 }
    307 
    308 template <typename T>
    309 bool CompressTensorProtoInPlaceImpl(int64 min_num_elements,
    310                                     float min_compression_ratio,
    311                                     TensorProto* tensor) {
    312   const TensorShape shape(tensor->tensor_shape());
    313   const int64 num_tensor_values = shape.num_elements();
    314   if (num_tensor_values < min_num_elements) {
    315     return false;
    316   }
    317   if (tensor->tensor_content().empty()) {
    318     return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
    319   } else {
    320     return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
    321   }
    322   return true;
    323 }
    324 
    325 }  // namespace internal
    326 
    327 #define HANDLE_COMPRESS_CASE(TF_TYPE)                                  \
    328   case TF_TYPE:                                                        \
    329     return internal::CompressTensorProtoInPlaceImpl<                   \
    330         EnumToDataType<TF_TYPE>::Type>(min_num_elements,               \
    331                                        min_compression_ratio, tensor); \
    332     break
    333 
    334 bool CompressTensorProtoInPlace(int64 min_num_elements,
    335                                 float min_compression_ratio,
    336                                 TensorProto* tensor) {
    337   switch (tensor->dtype()) {
    338     HANDLE_COMPRESS_CASE(DT_FLOAT);
    339     HANDLE_COMPRESS_CASE(DT_DOUBLE);
    340     HANDLE_COMPRESS_CASE(DT_COMPLEX64);
    341     HANDLE_COMPRESS_CASE(DT_COMPLEX128);
    342     HANDLE_COMPRESS_CASE(DT_UINT8);
    343     HANDLE_COMPRESS_CASE(DT_INT8);
    344     HANDLE_COMPRESS_CASE(DT_UINT16);
    345     HANDLE_COMPRESS_CASE(DT_INT16);
    346     HANDLE_COMPRESS_CASE(DT_UINT32);
    347     HANDLE_COMPRESS_CASE(DT_INT32);
    348     HANDLE_COMPRESS_CASE(DT_UINT64);
    349     HANDLE_COMPRESS_CASE(DT_INT64);
    350     HANDLE_COMPRESS_CASE(DT_BOOL);
    351     HANDLE_COMPRESS_CASE(DT_QUINT8);
    352     HANDLE_COMPRESS_CASE(DT_QINT8);
    353     HANDLE_COMPRESS_CASE(DT_QUINT16);
    354     HANDLE_COMPRESS_CASE(DT_QINT16);
    355     HANDLE_COMPRESS_CASE(DT_QINT32);
    356     HANDLE_COMPRESS_CASE(DT_HALF);
    357     HANDLE_COMPRESS_CASE(DT_BFLOAT16);
    358     default:
    359       return false;
    360   }
    361 }
    362 
    363 #undef HANDLE_COMPRESS_CASE
    364 
    365 }  // namespace tensor
    366 }  // namespace tensorflow
    367