Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/tensor_slice.h"
     17 #include <vector>
     18 #include "tensorflow/core/lib/core/errors.h"
     19 #include "tensorflow/core/lib/strings/numbers.h"
     20 #include "tensorflow/core/lib/strings/str_util.h"
     21 #include "tensorflow/core/lib/strings/strcat.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace tensorflow {
     25 
     26 TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
     27 
     28 TensorSlice::TensorSlice(const TensorSliceProto& proto) {
     29   starts_.reserve(proto.extent_size());
     30   lengths_.reserve(proto.extent_size());
     31   for (const auto& e : proto.extent()) {
     32     starts_.push_back(e.start());
     33     lengths_.push_back(GetExtentLength(e));
     34   }
     35 }
     36 
     37 TensorSlice::TensorSlice(
     38     std::initializer_list<std::pair<int64, int64>> extents) {
     39   starts_.reserve(extents.size());
     40   lengths_.reserve(extents.size());
     41   for (const auto& e : extents) {
     42     starts_.push_back(e.first);
     43     lengths_.push_back(e.second);
     44   }
     45 }
     46 
     47 Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
     48   std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
     49   slice->starts_.reserve(items.size());
     50   slice->lengths_.reserve(items.size());
     51   for (const string& x : items) {
     52     int64 s, l;
     53     if (x == "-") {
     54       // "everything"
     55       s = 0;
     56       l = kFullExtent;
     57     } else {
     58       std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
     59       if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
     60           !strings::safe_strto64(sl[1], &l)) {
     61         return errors::InvalidArgument(
     62             "Expected a pair of numbers or '-' "
     63             "but got '",
     64             x, "': string = ", str);
     65       }
     66       if (s < 0 || l <= 0) {
     67         return errors::InvalidArgument(
     68             "Expected non-negative start and "
     69             "positive length but got start = ",
     70             s, ", length = ", l, ": string = ", str);
     71       }
     72     }
     73     slice->starts_.push_back(s);
     74     slice->lengths_.push_back(l);
     75   }
     76 
     77   return Status::OK();
     78 }
     79 
     80 void TensorSlice::Clear() {
     81   starts_.clear();
     82   lengths_.clear();
     83 }
     84 
     85 bool TensorSlice::IsFull() const {
     86   for (int d = 0; d < dims(); ++d) {
     87     if (!IsFullAt(d)) return false;
     88   }
     89   return true;
     90 }
     91 
     92 void TensorSlice::SetFullSlice(int dim) {
     93   Clear();
     94   starts_.reserve(dim);
     95   lengths_.reserve(dim);
     96   for (int d = 0; d < dim; ++d) {
     97     starts_.push_back(0);
     98     lengths_.push_back(kFullExtent);
     99   }
    100 }
    101 
    102 void TensorSlice::Extend(int dim) {
    103   int old_dim = dims();
    104   DCHECK_LE(old_dim, dim);
    105   starts_.resize(dim);
    106   lengths_.resize(dim);
    107   for (int d = old_dim; d < dim; ++d) {
    108     starts_[d] = 0;
    109     lengths_[d] = kFullExtent;
    110   }
    111 }
    112 
    113 void TensorSlice::AsProto(TensorSliceProto* proto) const {
    114   for (int d = 0; d < dims(); ++d) {
    115     TensorSliceProto::Extent* e = proto->add_extent();
    116     // We only need to record the explicit slice for non-full slices
    117     if (!IsFullAt(d)) {
    118       e->set_start(starts_[d]);
    119       e->set_length(lengths_[d]);
    120     }
    121   }
    122 }
    123 
    124 string TensorSlice::DebugString() const {
    125   string buffer;
    126   bool first = true;
    127   for (int d = 0; d < dims(); ++d) {
    128     if (!first) {
    129       buffer.append(":");
    130     }
    131     string s;
    132     if (IsFullAt(d)) {
    133       buffer.append("-");
    134     } else {
    135       strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
    136     }
    137     first = false;
    138   }
    139   return buffer;
    140 }
    141 
    142 bool TensorSlice::Intersect(const TensorSlice& other,
    143                             TensorSlice* result) const {
    144   // First, if two slices have different ranks, they obviously don't overlap
    145   // -- in fact they are not compatible.
    146   if (dims() != other.dims()) {
    147     return false;
    148   }
    149 
    150   // Setting the result to the right dimension
    151   if (result) {
    152     result->SetFullSlice(dims());
    153   }
    154   // The two slices overlap if they overlap in all dimensions.
    155   for (int d = 0; d < dims(); ++d) {
    156     if (IsFullAt(d)) {
    157       if (result) {
    158         result->set_start(d, other.start(d));
    159         result->set_length(d, other.length(d));
    160       }
    161     } else if (other.IsFullAt(d)) {
    162       if (result) {
    163         result->set_start(d, start(d));
    164         result->set_length(d, length(d));
    165       }
    166     } else {
    167       // If we have an intersection here, it should have a start that is the
    168       // max of the two starts and an end that is the min of the two ends.
    169       int64 s = std::max(start(d), other.start(d));
    170       int64 l = std::min(end(d), other.end(d)) - s;
    171       if (l > 0) {
    172         // We have a real intersection
    173         if (result) {
    174           result->set_start(d, s);
    175           result->set_length(d, l);
    176         }
    177       } else {
    178         // We don't have an intersection for this dimension -- thus we don't
    179         // have any intersection at all.
    180         if (result) {
    181           result->Clear();
    182         }
    183         return false;
    184       }
    185     }
    186   }
    187   // If we are here, we know there is overlap in every dimension.
    188   return true;
    189 }
    190 
    191 bool TensorSlice::operator==(const TensorSlice& other) const {
    192   return dims() == other.dims() && starts_ == other.starts_ &&
    193          lengths_ == other.lengths_;
    194 }
    195 
    196 void TensorSlice::ComputeRelative(const TensorSlice& sub,
    197                                   TensorSlice* relative) const {
    198   DCHECK_EQ(dims(), sub.dims());
    199   relative->SetFullSlice(dims());
    200   for (int d = 0; d < dims(); ++d) {
    201     if (IsFullAt(d)) {
    202       relative->set_start(d, sub.start(d));
    203       relative->set_length(d, sub.length(d));
    204     } else {
    205       // Otherwise the relative start is the difference between the start of
    206       // sub and the start of base
    207       relative->set_start(d, sub.start(d) - start(d));
    208       relative->set_length(d, sub.length(d));
    209     }
    210   }
    211 }
    212 
    213 void TensorSlice::UpdateToCover(const TensorSlice& other) {
    214   DCHECK_EQ(dims(), other.dims());
    215   for (int d = 0; d < dims(); ++d) {
    216     if (!IsFullAt(d)) {
    217       if (other.IsFullAt(d)) {
    218         starts_[d] = 0;
    219         lengths_[d] = kFullExtent;
    220       } else {
    221         const auto new_end = std::max(end(d), other.end(d));
    222         set_start(d, std::min(start(d), other.start(d)));
    223         set_length(d, new_end - start(d));
    224       }
    225     }
    226   }
    227 }
    228 
    229 // static
    230 bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
    231   return extent.has_length_case() == TensorSliceProto::Extent::kLength;
    232 }
    233 
    234 // static
    235 int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
    236   if (!HasExtentLength(extent)) return -1;
    237   return extent.length();
    238 }
    239 
    240 Status TensorSlice::SliceTensorShape(const TensorShape& shape,
    241                                      TensorShape* result_shape) const {
    242   result_shape->Clear();
    243   // Mismatching ranks: we can't apply the slice at all.
    244   if (shape.dims() != dims()) {
    245     return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
    246                             ", slice = ", DebugString());
    247   }
    248   for (int d = 0; d < dims(); ++d) {
    249     if (IsFullAt(d)) {
    250       result_shape->AddDim(shape.dim_size(d));
    251     } else {
    252       // Check if the extent applies to the dimension
    253       if (end(d) <= shape.dim_size(d)) {
    254         // Yes: the end is within the range of the dim -- we adjust the result
    255         // shape so that its size along this dimension is the length of the
    256         // slice.
    257         result_shape->AddDim(length(d));
    258       } else {
    259         // The extent doesn't apply to the dimension
    260         result_shape->Clear();
    261         return errors::Internal("Extent in dimension ", d,
    262                                 " out of bounds: shape = ", shape.DebugString(),
    263                                 ", slice = ", DebugString());
    264       }
    265     }
    266   }
    267   // If we are here, we have successfully applied the shape.
    268   return Status::OK();
    269 }
    270 
    271 const int64 TensorSlice::kFullExtent = -1;
    272 
    273 }  // namespace tensorflow
    274