Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/batch_util.h"
     17 
     18 #include "tensorflow/core/framework/register_types.h"
     19 #include "tensorflow/core/framework/types.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 
     22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
     23 
     24 namespace tensorflow {
     25 namespace batch_util {
     26 
     27 namespace {
     28 
     29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) {
     30   DCHECK_NE(parent.dim_size(0), 0);
     31   DCHECK_GE(index, 0);
     32   if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
     33     TensorShape chip_shape = parent.shape();
     34     chip_shape.RemoveDim(0);
     35     return errors::Internal(
     36         "ValidateInput Cannot perform copy: number of elements does not match. "
     37         " Shapes are: [element]: ",
     38         element.shape().DebugString(),
     39         ", [parent slice]: ", chip_shape.DebugString());
     40   }
     41   return Status::OK();
     42 }
     43 
     44 template <typename T>
     45 Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index,
     46                             bool /* can_move */) {
     47   parent->flat_outer_dims<T>().chip(index, 0) = element.flat<T>();
     48   return Status::OK();
     49 }
     50 
     51 template <>
     52 Status HandleElementToSlice<string>(Tensor element, Tensor* parent, int64 index,
     53                                     bool can_move) {
     54   auto parent_as_matrix = parent->flat_outer_dims<string>();
     55   auto element_flat = element.flat<string>();
     56   if (can_move) {
     57     for (int64 i = 0; i < element.NumElements(); ++i) {
     58       parent_as_matrix(index, i) = std::move(element_flat(i));
     59     }
     60   } else {
     61     parent_as_matrix.chip(index, 0) = element_flat;
     62   }
     63   return Status::OK();
     64 }
     65 
     66 template <>
     67 Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent,
     68                                      int64 index, bool can_move) {
     69   auto parent_as_matrix = parent->flat_outer_dims<Variant>();
     70   auto element_flat = element.flat<Variant>();
     71   if (can_move) {
     72     for (int64 i = 0; i < element.NumElements(); ++i) {
     73       parent_as_matrix(index, i) = std::move(element_flat(i));
     74     }
     75   } else {
     76     parent_as_matrix.chip(index, 0) = element_flat;
     77   }
     78   return Status::OK();
     79 }
     80 
     81 // TODO(jsimsa): Add HandleElementToSlice<variant> specialization that moves
     82 // the data when possible.
     83 
     84 template <typename T>
     85 static Status HandleSliceToElement(const Tensor& parent, Tensor* element,
     86                                    int64 index) {
     87   element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0);
     88   return Status::OK();
     89 }
     90 
     91 }  // namespace
     92 
     93 // Copies element into the index^th slice of parent (in the 0th dimension).
     94 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
     95   TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
     96 
     97   bool can_move = element.RefCountIsOne();
     98 #define HANDLE_TYPE(T)                                                \
     99   case DataTypeToEnum<T>::value: {                                    \
    100     return HandleElementToSlice<T>(std::move(element), parent, index, \
    101                                    can_move);                         \
    102   }
    103 
    104   switch (element.dtype()) {
    105     TF_CALL_ALL_TYPES(HANDLE_TYPE);
    106     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
    107 #undef HANDLE_TYPE
    108     default:
    109       return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
    110                                    element.dtype());
    111   }
    112 }
    113 
    114 // Copies the index^th slice of parent (in the 0th dimension) into element.
    115 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
    116   TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
    117 
    118 #define HANDLE_TYPE(T)                                      \
    119   case DataTypeToEnum<T>::value: {                          \
    120     return HandleSliceToElement<T>(parent, element, index); \
    121   }
    122 
    123   switch (parent.dtype()) {
    124     TF_CALL_ALL_TYPES(HANDLE_TYPE);
    125     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
    126 #undef HANDLE_TYPE
    127     default:
    128       return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
    129                                    element->dtype());
    130   }
    131 }
    132 
    133 // The following five functions are copied from padding_fifo_queue.cc.
    134 // TODO(mrry): Reconcile these functions with the similar methods in the
    135 // queue implementation.
    136 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
    137   DCHECK_NE(parent->dim_size(0), 0);
    138   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
    139     TensorShape chip_shape = parent->shape();
    140     chip_shape.RemoveDim(0);
    141     return errors::Internal(
    142         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
    143         "element is greater than number of elements in parent slice.  ",
    144         "Shapes are: [element]: ", element.shape().DebugString(),
    145         ", [parent slice]: ", chip_shape.DebugString());
    146   }
    147   return Status::OK();
    148 }
    149 
    150 template <typename T, int NDIMS>
    151 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
    152                                   int index) {
    153   TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
    154   if (element.NumElements() == 0) {
    155     return Status::OK();
    156   }
    157   auto element_t = element.tensor<T, NDIMS>();
    158   auto parent_t = parent->tensor<T, NDIMS + 1>();
    159   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
    160   slice_indices[0] = index;
    161   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
    162   slice_size[0] = 1;
    163   for (size_t i = 1; i < slice_size.size(); ++i) {
    164     slice_size[i] = element_t.dimension(i - 1);
    165   }
    166   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
    167   return Status::OK();
    168 }
    169 
    170 template <int NDIMS>
    171 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
    172                                           int index) {
    173 #define HANDLE_TYPE(T)                                                   \
    174   case DataTypeToEnum<T>::value: {                                       \
    175     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
    176   }
    177 
    178   switch (element.dtype()) {
    179     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
    180 #undef HANDLE_TYPE
    181     default:
    182       return errors::Unimplemented(
    183           "HandleElementToLargerSliceWithRank Unhandled data type: ",
    184           element.dtype());
    185   }
    186 }
    187 
    188 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
    189                                 int index) {
    190   if (parent->dims() != element.dims() + 1) {
    191     return errors::Internal(
    192         "Mismatched ranks.  Element's rank is: ", element.dims(),
    193         " but element is meant to be a slice in output Tensor having rank: ",
    194         parent->dims(), " (should be: ", element.dims() + 1, ")");
    195   }
    196 
    197 #define HANDLE_DIMS(NDIMS)                                                  \
    198   case NDIMS: {                                                             \
    199     TF_RETURN_IF_ERROR(                                                     \
    200         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
    201     return Status::OK();                                                    \
    202   }
    203 
    204   switch (element.dims()) {
    205     HANDLE_DIMS(0);
    206     HANDLE_DIMS(1);
    207     HANDLE_DIMS(2);
    208     HANDLE_DIMS(3);
    209     HANDLE_DIMS(4);
    210 #undef HANDLE_DIMS
    211     default:
    212       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
    213                                    element.dims());
    214   }
    215 }
    216 
    217 Status SetElementZero(Tensor* element, const Tensor& padding) {
    218 #define HANDLE_TYPE(T)                                     \
    219   if (element->dtype() == DataTypeToEnum<T>::value) {      \
    220     element->flat<T>().setConstant(padding.scalar<T>()()); \
    221     return Status::OK();                                   \
    222   }
    223   TF_CALL_DATASET_TYPES(HANDLE_TYPE);
    224 #undef HANDLE_TYPE
    225   return errors::Unimplemented("SetElementZero Unhandled data type: ",
    226                                element->dtype());
    227 }
    228 
    229 }  // namespace batch_util
    230 }  // namespace tensorflow
    231