1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/tensor_slice.h" 17 #include <vector> 18 #include "tensorflow/core/lib/core/errors.h" 19 #include "tensorflow/core/lib/strings/numbers.h" 20 #include "tensorflow/core/lib/strings/str_util.h" 21 #include "tensorflow/core/lib/strings/strcat.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace tensorflow { 25 26 TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } 27 28 TensorSlice::TensorSlice(const TensorSliceProto& proto) { 29 starts_.reserve(proto.extent_size()); 30 lengths_.reserve(proto.extent_size()); 31 for (const auto& e : proto.extent()) { 32 starts_.push_back(e.start()); 33 lengths_.push_back(GetExtentLength(e)); 34 } 35 } 36 37 TensorSlice::TensorSlice( 38 std::initializer_list<std::pair<int64, int64>> extents) { 39 starts_.reserve(extents.size()); 40 lengths_.reserve(extents.size()); 41 for (const auto& e : extents) { 42 starts_.push_back(e.first); 43 lengths_.push_back(e.second); 44 } 45 } 46 47 Status TensorSlice::Parse(const string& str, TensorSlice* slice) { 48 std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty()); 49 slice->starts_.reserve(items.size()); 50 slice->lengths_.reserve(items.size()); 51 for (const string& x : items) { 52 int64 s, l; 53 if (x == "-") { 54 // "everything" 55 s = 0; 56 l = kFullExtent; 57 } else { 58 std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty()); 59 if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) || 60 !strings::safe_strto64(sl[1], &l)) { 61 return errors::InvalidArgument( 62 "Expected a pair of numbers or '-' " 63 "but got '", 64 x, "': string = ", str); 65 } 66 if (s < 0 || l <= 0) { 67 return errors::InvalidArgument( 68 "Expected non-negative start and " 69 "positive length but got start = ", 70 s, ", length = ", l, ": string = ", str); 71 } 72 } 73 slice->starts_.push_back(s); 74 slice->lengths_.push_back(l); 75 } 76 77 return Status::OK(); 78 } 79 80 void TensorSlice::Clear() { 81 starts_.clear(); 82 lengths_.clear(); 83 } 84 85 bool TensorSlice::IsFull() const { 86 for (int d = 0; d < dims(); ++d) { 87 if (!IsFullAt(d)) return false; 88 } 89 return true; 90 } 91 92 void TensorSlice::SetFullSlice(int dim) { 93 Clear(); 94 starts_.reserve(dim); 95 lengths_.reserve(dim); 96 for (int d = 0; d < dim; ++d) { 97 starts_.push_back(0); 98 lengths_.push_back(kFullExtent); 99 } 100 } 101 102 void TensorSlice::Extend(int dim) { 103 int old_dim = dims(); 104 DCHECK_LE(old_dim, dim); 105 starts_.resize(dim); 106 lengths_.resize(dim); 107 for (int d = old_dim; d < dim; ++d) { 108 starts_[d] = 0; 109 lengths_[d] = kFullExtent; 110 } 111 } 112 113 void TensorSlice::AsProto(TensorSliceProto* proto) const { 114 for (int d = 0; d < dims(); ++d) { 115 TensorSliceProto::Extent* e = proto->add_extent(); 116 // We only need to record the explicit slice for non-full slices 117 if (!IsFullAt(d)) { 118 e->set_start(starts_[d]); 119 e->set_length(lengths_[d]); 120 } 121 } 122 } 123 124 string TensorSlice::DebugString() const { 125 string buffer; 126 bool first = true; 127 for (int d = 0; d < dims(); ++d) { 128 if (!first) { 129 buffer.append(":"); 130 } 131 string s; 132 if (IsFullAt(d)) { 133 buffer.append("-"); 134 } else { 135 strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]); 136 } 137 first = false; 138 } 139 return buffer; 140 } 141 142 bool TensorSlice::Intersect(const TensorSlice& other, 143 TensorSlice* result) const { 144 // First, if two slices have different ranks, they obviously don't overlap 145 // -- in fact they are not compatible. 146 if (dims() != other.dims()) { 147 return false; 148 } 149 150 // Setting the result to the right dimension 151 if (result) { 152 result->SetFullSlice(dims()); 153 } 154 // The two slices overlap if they overlap in all dimensions. 155 for (int d = 0; d < dims(); ++d) { 156 if (IsFullAt(d)) { 157 if (result) { 158 result->set_start(d, other.start(d)); 159 result->set_length(d, other.length(d)); 160 } 161 } else if (other.IsFullAt(d)) { 162 if (result) { 163 result->set_start(d, start(d)); 164 result->set_length(d, length(d)); 165 } 166 } else { 167 // If we have an intersection here, it should have a start that is the 168 // max of the two starts and an end that is the min of the two ends. 169 int64 s = std::max(start(d), other.start(d)); 170 int64 l = std::min(end(d), other.end(d)) - s; 171 if (l > 0) { 172 // We have a real intersection 173 if (result) { 174 result->set_start(d, s); 175 result->set_length(d, l); 176 } 177 } else { 178 // We don't have an intersection for this dimension -- thus we don't 179 // have any intersection at all. 180 if (result) { 181 result->Clear(); 182 } 183 return false; 184 } 185 } 186 } 187 // If we are here, we know there is overlap in every dimension. 188 return true; 189 } 190 191 bool TensorSlice::operator==(const TensorSlice& other) const { 192 return dims() == other.dims() && starts_ == other.starts_ && 193 lengths_ == other.lengths_; 194 } 195 196 void TensorSlice::ComputeRelative(const TensorSlice& sub, 197 TensorSlice* relative) const { 198 DCHECK_EQ(dims(), sub.dims()); 199 relative->SetFullSlice(dims()); 200 for (int d = 0; d < dims(); ++d) { 201 if (IsFullAt(d)) { 202 relative->set_start(d, sub.start(d)); 203 relative->set_length(d, sub.length(d)); 204 } else { 205 // Otherwise the relative start is the difference between the start of 206 // sub and the start of base 207 relative->set_start(d, sub.start(d) - start(d)); 208 relative->set_length(d, sub.length(d)); 209 } 210 } 211 } 212 213 void TensorSlice::UpdateToCover(const TensorSlice& other) { 214 DCHECK_EQ(dims(), other.dims()); 215 for (int d = 0; d < dims(); ++d) { 216 if (!IsFullAt(d)) { 217 if (other.IsFullAt(d)) { 218 starts_[d] = 0; 219 lengths_[d] = kFullExtent; 220 } else { 221 const auto new_end = std::max(end(d), other.end(d)); 222 set_start(d, std::min(start(d), other.start(d))); 223 set_length(d, new_end - start(d)); 224 } 225 } 226 } 227 } 228 229 // static 230 bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { 231 return extent.has_length_case() == TensorSliceProto::Extent::kLength; 232 } 233 234 // static 235 int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { 236 if (!HasExtentLength(extent)) return -1; 237 return extent.length(); 238 } 239 240 Status TensorSlice::SliceTensorShape(const TensorShape& shape, 241 TensorShape* result_shape) const { 242 result_shape->Clear(); 243 // Mismatching ranks: we can't apply the slice at all. 244 if (shape.dims() != dims()) { 245 return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(), 246 ", slice = ", DebugString()); 247 } 248 for (int d = 0; d < dims(); ++d) { 249 if (IsFullAt(d)) { 250 result_shape->AddDim(shape.dim_size(d)); 251 } else { 252 // Check if the extent applies to the dimension 253 if (end(d) <= shape.dim_size(d)) { 254 // Yes: the end is within the range of the dim -- we adjust the result 255 // shape so that its size along this dimension is the length of the 256 // slice. 257 result_shape->AddDim(length(d)); 258 } else { 259 // The extent doesn't apply to the dimension 260 result_shape->Clear(); 261 return errors::Internal("Extent in dimension ", d, 262 " out of bounds: shape = ", shape.DebugString(), 263 ", slice = ", DebugString()); 264 } 265 } 266 } 267 // If we are here, we have successfully applied the shape. 268 return Status::OK(); 269 } 270 271 const int64 TensorSlice::kFullExtent = -1; 272 273 } // namespace tensorflow 274