Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
     17 #define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
     18 
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 #include "tensorflow/core/framework/tensor_slice.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 
     23 namespace tensorflow {
     24 
     25 namespace {
     26 
     27 // Some hackery to invoke eigen tensor to copy over tensor slices with variable
     28 // dimension tensors.
     29 // TODO(yangke): get rid of that once the variable dimension tensor support is
     30 // in.
     31 static const int kTensorSliceMaxRank = 8;
     32 
     33 // Create a tensor map with the given shape: we support up to 8 dimensions. If
     34 // the shape has less than 8 dimensions, we pad the remaining dimension with 1.
     35 template <typename T>
     36 Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>>
     37 GetEigenTensorMapFromTensorShape(const TensorShape& shape, T* data) {
     38   Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> dsizes =
     39       shape.AsEigenDSizesWithPadding<kTensorSliceMaxRank>();
     40   Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>> eig(
     41       data, dsizes);
     42   return eig;
     43 }
     44 
     45 // For everything except string, a standard Eigen cast and assignment works
     46 template <typename DstT>
     47 struct CopyThatWorksWithStringPointer {
     48   template <typename SrcTensor, typename DstTensor, typename Shape>
     49   static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
     50                    Shape d_start) {
     51     d.slice(d_start, len) = s.slice(s_start, len).template cast<DstT>();
     52   }
     53 };
     54 
     55 // Eigen makes it extremely difficult to dereference a tensor of string* into
     56 // string, so we roll our own loop instead.
     57 template <>
     58 struct CopyThatWorksWithStringPointer<string> {
     59   template <typename SrcTensor, typename DstTensor, typename Shape>
     60   static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
     61                    Shape d_start) {
     62     typedef typename SrcTensor::Index Index;
     63     static_assert(kTensorSliceMaxRank == 8,
     64                   "If kTensorSliceMaxRank changes, modify the loop below.");
     65     for (Index i0 = 0; i0 < len[0]; i0++) {
     66       for (Index i1 = 0; i1 < len[1]; i1++) {
     67         for (Index i2 = 0; i2 < len[2]; i2++) {
     68           for (Index i3 = 0; i3 < len[3]; i3++) {
     69             for (Index i4 = 0; i4 < len[4]; i4++) {
     70               for (Index i5 = 0; i5 < len[5]; i5++) {
     71                 for (Index i6 = 0; i6 < len[6]; i6++) {
     72                   for (Index i7 = 0; i7 < len[7]; i7++) {
     73                     d(d_start[0] + i0, d_start[1] + i1, d_start[2] + i2,
     74                       d_start[3] + i3, d_start[4] + i4, d_start[5] + i5,
     75                       d_start[6] + i6, d_start[7] + i7) =
     76                         *s(s_start[0] + i0, s_start[1] + i1, s_start[2] + i2,
     77                            s_start[3] + i3, s_start[4] + i4, s_start[5] + i5,
     78                            s_start[6] + i6, s_start[7] + i7);
     79                   }
     80                 }
     81               }
     82             }
     83           }
     84         }
     85       }
     86     }
     87   }
     88 };
     89 
     90 // Checkpointing of half is done by storing the raw 16 bits as a signed 32bit
     91 // integer. To restore the checkpoint we need to do the reverse operation by
     92 // reinterpreting the integer as a 16 bit float. This prevents us from using
     93 // the default cast operation.
     94 template <>
     95 struct CopyThatWorksWithStringPointer<Eigen::half> {
     96   template <typename SrcTensor, typename DstTensor, typename Shape>
     97   static void Copy(const SrcTensor& s, Shape s_start, Shape len, DstTensor& d,
     98                    Shape d_start) {
     99     typedef typename SrcTensor::Index Index;
    100     static_assert(kTensorSliceMaxRank == 8,
    101                   "If kTensorSliceMaxRank changes, modify the loop below.");
    102     for (Index i0 = 0; i0 < len[0]; i0++) {
    103       for (Index i1 = 0; i1 < len[1]; i1++) {
    104         for (Index i2 = 0; i2 < len[2]; i2++) {
    105           for (Index i3 = 0; i3 < len[3]; i3++) {
    106             for (Index i4 = 0; i4 < len[4]; i4++) {
    107               for (Index i5 = 0; i5 < len[5]; i5++) {
    108                 for (Index i6 = 0; i6 < len[6]; i6++) {
    109                   for (Index i7 = 0; i7 < len[7]; i7++) {
    110                     d(d_start[0] + i0, d_start[1] + i1, d_start[2] + i2,
    111                       d_start[3] + i3, d_start[4] + i4, d_start[5] + i5,
    112                       d_start[6] + i6, d_start[7] + i7) =
    113                         Eigen::half_impl::raw_uint16_to_half(
    114                             s(s_start[0] + i0, s_start[1] + i1, s_start[2] + i2,
    115                               s_start[3] + i3, s_start[4] + i4, s_start[5] + i5,
    116                               s_start[6] + i6, s_start[7] + i7));
    117                   }
    118                 }
    119               }
    120             }
    121           }
    122         }
    123       }
    124     }
    125   }
    126 };
    127 
    128 // Given a tensor described by "shape", two slices "slice_s" and "slice_d",
    129 // and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of
    130 // memory that stores the data for "slice_s" and "ptr_d" points to a chunk of
    131 // memory that stores the data for "slice_d". This function copies the data
    132 // that belongs to the intersection of the two slices from slice_s to
    133 // slice_d.  Uses Tensor cast<DstT>() to convert from SrcT to DstT. Returns true
    134 // iff the two slices share any intersection (and thus some data is copied).
    135 // TODO(yangke): figure out if we can make it private.
    136 template <typename SrcT, typename DstT>
    137 static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape,
    138                                                  const TensorSlice& slice_s,
    139                                                  const TensorSlice& slice_d,
    140                                                  const SrcT* ptr_s,
    141                                                  DstT* ptr_d) {
    142   CHECK_LE(shape.dims(), kTensorSliceMaxRank)
    143       << "Only tensors of size up to " << kTensorSliceMaxRank
    144       << " are supported";
    145   // We need to compute the intersection of the two slices.
    146   TensorSlice inter;
    147   if (!slice_s.Intersect(slice_d, &inter)) {
    148     // There is no intersection: returns false.
    149     return false;
    150   } else {
    151     // We need to compute the applied shapes after applying slice_s and
    152     // slice_d.
    153     TensorShape shp_s, shp_d;
    154     Status s;
    155     s = slice_s.SliceTensorShape(shape, &shp_s);
    156     if (!s.ok()) {
    157       LOG(WARNING) << s;
    158       return false;
    159     }
    160     s = slice_d.SliceTensorShape(shape, &shp_d);
    161     if (!s.ok()) {
    162       LOG(WARNING) << s;
    163       return false;
    164     }
    165 
    166     // We need to compute the relative slice of "inter" w.r.t. both slice_s and
    167     // slice_d.
    168     TensorSlice rel_s, rel_d;
    169     slice_s.ComputeRelative(inter, &rel_s);
    170     slice_d.ComputeRelative(inter, &rel_d);
    171 
    172     // Get the eigen tensor maps to the data.
    173     auto t_s = GetEigenTensorMapFromTensorShape(shp_s, ptr_s);
    174     auto t_d = GetEigenTensorMapFromTensorShape(shp_d, ptr_d);
    175 
    176     Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> s_start, s_len,
    177         d_start, d_len;
    178 
    179     rel_s.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_s, &s_start, &s_len);
    180     rel_d.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_d, &d_start, &d_len);
    181     CopyThatWorksWithStringPointer<DstT>::Copy(t_s, s_start, s_len, t_d,
    182                                                d_start);
    183     return true;
    184   }
    185 }
    186 
    187 }  // namespace
    188 
    189 }  // namespace tensorflow
    190 
    191 #endif  // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
    192