Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
     17 #define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
     18 
     19 #include <string>
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/framework/tensor_slice.pb.h"
     23 #include "tensorflow/core/lib/core/status.h"
     24 #include "tensorflow/core/lib/core/stringpiece.h"
     25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 
     28 namespace tensorflow {
     29 
     30 // A tensor slice represents a slice of a given tensor. It is represented by a
     31 // list of (start, length) pairs, where the size of the list is the rank of the
     32 // tensor.
     33 
     34 class TensorSlice {
     35  public:
     36   // Construct a tensor slice: you have a number of ways:
     37   // -- creating an empty slice
     38   // -- from just a dimension (in this case it will create a full slice)
     39   // -- from an array of pairs of integers.
     40   // -- from a TensorSliceProto protocol buffer
     41   // -- from a string format of "start,length:start,length..." where each
     42   //    "start,length" pair represents the slice on one dimension. We allow a
     43   //    special "-" that means "everything for this dimension". One such example
     44   //    is:  0,10:-:14,1:-:-
     45   TensorSlice() {}
     46   explicit TensorSlice(int dim);
     47   explicit TensorSlice(const TensorSliceProto& proto);
     48   explicit TensorSlice(std::initializer_list<std::pair<int64, int64>> extents);
     49 
     50   static Status Parse(const string& str, TensorSlice* output);
     51   static TensorSlice ParseOrDie(const string& str) {
     52     TensorSlice ret;
     53     Status s = Parse(str, &ret);
     54     if (!s.ok()) {
     55       LOG(FATAL) << "Could not parse TensorSlice";
     56     }
     57     return ret;
     58   }
     59 
     60   void Clear();
     61 
     62   // Accessors
     63   int dims() const { return starts_.size(); }
     64 
     65   int64 start(int d) const {
     66     DCHECK_GE(d, 0);
     67     DCHECK_LT(d, dims());
     68     return starts_[d];
     69   }
     70 
     71   int64 length(int d) const {
     72     DCHECK_GE(d, 0);
     73     DCHECK_LT(d, dims());
     74     return lengths_[d];
     75   }
     76 
     77   int64 end(int d) const {
     78     DCHECK_GE(d, 0);
     79     DCHECK_LT(d, dims());
     80     return start(d) + length(d);
     81   }
     82 
     83   void set_start(int d, int64 x) {
     84     DCHECK_GE(d, 0);
     85     DCHECK_LT(d, dims());
     86     DCHECK_GE(x, 0);
     87     starts_[d] = x;
     88   }
     89 
     90   void set_length(int d, int64 x) {
     91     DCHECK_GE(d, 0);
     92     DCHECK_LT(d, dims());
     93     lengths_[d] = x;
     94   }
     95 
     96   // If we have a full slice along dimension "d".
     97   bool IsFullAt(int d) const {
     98     return lengths_[d] == kFullExtent && starts_[d] == 0;
     99   }
    100 
    101   // If this is a full slice, i.e. IsFullAt(d) for every d.
    102   bool IsFull() const;
    103 
    104   // Set the slice to be a full slice of "dim" dimensions
    105   void SetFullSlice(int dim);
    106 
    107   // Extend a slice to "dim" dimensions: all the added dimensions are full.
    108   // Requires: dim >= dims().
    109   void Extend(int dim);
    110 
    111   // Conversion of a TensorSlice to other formats
    112   void AsProto(TensorSliceProto* proto) const;
    113   string DebugString() const;
    114 
    115   // Fill *indices and *sizes from *this (so that we can use the slice()
    116   // function in eigen tensor). We need a tensor shape in case some of the
    117   // slices are full slices.
    118   // We allow NDIMS to be greater than dims(), in which case we will pad the
    119   // higher dimensions with trivial dimensions.
    120   template <int NDIMS>
    121   void FillIndicesAndSizes(
    122       const TensorShape& shape,
    123       Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
    124       Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;
    125 
    126   // Interaction with other TensorSlices.
    127 
    128   // Compute the intersection with another slice and if "result" is not
    129   // nullptr, store the results in *result; returns true if there is any real
    130   // intersection.
    131   bool Intersect(const TensorSlice& other, TensorSlice* result) const;
    132   // A short hand.
    133   bool Overlaps(const TensorSlice& other) const {
    134     return Intersect(other, nullptr);
    135   }
    136 
    137   // Equals iff "*this" and "other" are logically equivalent.
    138   bool operator==(const TensorSlice& other) const;
    139   bool operator!=(const TensorSlice& other) const { return !(*this == other); }
    140 
    141   // Interaction with TensorShape.
    142 
    143   // Slices a shape and stores the result into *result_shape.
    144   // Requires that the shape and *this have the same rank.
    145   // For example, given a tensor shape of {3, 4, 5}, and a slice of
    146   // 1,2:-:0,2, the result shape is {2, 4, 2}.
    147   Status SliceTensorShape(const TensorShape& shape,
    148                           TensorShape* result_shape) const;
    149 
    150   // Given slice "sub" where "sub" is fully contained in *this,
    151   // (meaning that the intersection of "sub" and *this equals "sub"), computes
    152   // the "relative" slice of "sub" with respect to *this.
    153   //
    154   // In other words, if we use A>S to denote slicing a shape S with a slice A,
    155   // then the function is computing a slice X such that:
    156   //   X > (this > S) = sub > S
    157   // for any shape S.
    158   //
    159   // In general, along every dimension, the start of the relative slice is the
    160   // start of the "sub" slice minus the start of *this; the length of the
    161   // relative slice is the length of the "sub" slice.
    162   //
    163   // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and
    164   // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2.
    165   //
    166   // The caller needs to make sure that "sub" is indeed a sub-slice of *this;
    167   // otherwise the result is undefined.
    168   void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const;
    169 
    170   // Updates the slice in such a way that it fully covers "other" slice.
    171   // Note, "other" slice should refer to the same tensor shape.
    172   // Example:
    173   //   given a slice [2:4, :, 3:] and "other" slice [:, 1:4, 2:4] the
    174   //   updated slice would be [:, :, 2:]. Here is why:
    175   //   dim 0: "2:4"  U  ":"    ->  ":"
    176   //   dim 1: ":"    U  "1-4"  ->  ":"
    177   //   dim 2: "3:"   U  "2:4"  ->  "2:"
    178   void UpdateToCover(const TensorSlice& other);
    179 
    180   // Returns true if the length field was specified in an Extent.
    181   static bool HasExtentLength(const TensorSliceProto::Extent& extent);
    182 
    183   // Returns the value of the length field in an Extent, or -1 if it
    184   // is not present.
    185   static int64 GetExtentLength(const TensorSliceProto::Extent& extent);
    186 
    187  private:
    188   // a length value of kFullExtent (-1) means we have a full slice at this
    189   // dimension. It's defined in tensor_slice.cc.
    190   static const int64 kFullExtent;
    191 
    192   // TODO(yangke): switch to Eigen once it supports variable size arrays.
    193   // A value of
    194   gtl::InlinedVector<int64, 4> starts_;
    195   gtl::InlinedVector<int64, 4> lengths_;
    196 };
    197 
    198 template <int NDIMS>
    199 void TensorSlice::FillIndicesAndSizes(
    200     const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
    201     Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
    202   CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
    203                                  << "slices: shape = " << shape.DebugString()
    204                                  << ", slice = " << DebugString();
    205   CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from "
    206                           << "a slice of dimension " << dims();
    207   for (int d = 0; d < dims(); ++d) {
    208     if (IsFullAt(d)) {
    209       (*indices)[d] = 0;
    210       (*sizes)[d] = shape.dim_size(d);
    211     } else {
    212       (*indices)[d] = starts_[d];
    213       (*sizes)[d] = lengths_[d];
    214     }
    215   }
    216   for (int d = dims(); d < NDIMS; ++d) {
    217     (*indices)[d] = 0;
    218     (*sizes)[d] = 1;
    219   }
    220 }
    221 
    222 }  // namespace tensorflow
    223 
    224 #endif  // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
    225