1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/tensor_shape.h" 17 18 #include "tensorflow/core/framework/tensor_shape.pb.h" 19 #include "tensorflow/core/kernels/bounds_check.h" 20 #include "tensorflow/core/lib/core/errors.h" 21 #include "tensorflow/core/lib/strings/str_util.h" 22 #include "tensorflow/core/lib/strings/strcat.h" 23 #include "tensorflow/core/platform/logging.h" 24 #include "tensorflow/core/util/overflow.h" 25 26 namespace tensorflow { 27 28 // TensorShape and PartialTensorShape should have no fields beyond 29 // TensorShapeRep. In particular, their sizes should be the same. 30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape), 31 "TensorShape must have no fields beyond TensorShapeRep"); 32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape), 33 "PartialTensorShape must have no fields beyond TensorShapeRep"); 34 35 template <class Shape> 36 static void AppendTo(const TensorShapeBase<Shape>& s, 37 gtl::InlinedVector<int64, 8>* vals) { 38 for (auto dim : s) { 39 vals->push_back(dim.size); 40 } 41 } 42 43 void TensorShape::CheckDimsEqual(int NDIMS) const { 44 CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions" 45 << " from a tensor of " << dims() << " dimensions"; 46 } 47 48 void TensorShape::CheckDimsAtLeast(int NDIMS) const { 49 CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS 50 << " dimensions from a tensor of " << dims() 51 << " dimensions"; 52 } 53 54 template <class Shape> 55 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) { 56 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with 57 // unknown_shape() set, and it seems hard to remove this without backwards 58 // compatibility issues. 59 if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0; 60 int64 num_elements = 1; 61 if (proto.dim().size() > MaxDimensions()) return false; 62 for (const auto& d : proto.dim()) { 63 if (d.size() < (kIsPartial ? -1 : 0)) return false; 64 if (d.size() == -1) { 65 num_elements = -1; 66 } else if (!kIsPartial || num_elements >= 0) { 67 num_elements = MultiplyWithoutOverflow(num_elements, d.size()); 68 if (num_elements < 0) return false; 69 } 70 } 71 return true; 72 } 73 74 template <class Shape> 75 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) { 76 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with 77 // unknown_shape() set, and it seems hard to remove this without backwards 78 // compatibility issues. 79 if (kIsPartial && proto.unknown_rank()) { 80 if (proto.dim_size() > 0) { 81 return errors::InvalidArgument( 82 "An unknown shape must not have any dimensions set."); 83 } 84 return Status::OK(); 85 } 86 int64 num_elements = 1; 87 if (proto.dim().size() > MaxDimensions()) { 88 return errors::InvalidArgument("Shape ", DebugString(proto), 89 " has too many dimensions"); 90 } 91 for (const auto& d : proto.dim()) { 92 if (d.size() < (kIsPartial ? -1 : 0)) { 93 if (kIsPartial) { 94 return errors::InvalidArgument( 95 "Shape ", DebugString(proto), 96 " has dimensions with values below -1 (where -1 means unknown)"); 97 } else { 98 return errors::InvalidArgument("Shape ", DebugString(proto), 99 " is not fully defined"); 100 } 101 } 102 if (d.size() == -1) { 103 num_elements = -1; 104 } else if (!kIsPartial || num_elements >= 0) { 105 num_elements = MultiplyWithoutOverflow(num_elements, d.size()); 106 if (num_elements < 0) { 107 return errors::InvalidArgument( 108 "Shape ", DebugString(proto), 109 " is too large (more than 2**63 - 1 entries)"); 110 } 111 } 112 } 113 return Status::OK(); 114 } 115 116 template <class Shape> 117 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) { 118 set_tag(REP16); 119 set_data_type(DT_INVALID); 120 // NOTE(irving): Unfortunately, TensorShape allows parsing protos with 121 // unknown_shape() set, and it seems hard to remove this without backwards 122 // compatibility issues. 123 if (kIsPartial && proto.unknown_rank()) { 124 set_ndims_byte(kUnknownRank); 125 set_num_elements(-1); 126 } else { 127 set_ndims_byte(0); 128 set_num_elements(1); 129 for (const auto& d : proto.dim()) { 130 AddDim(d.size()); 131 } 132 } 133 } 134 135 template <class Shape> 136 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) { 137 set_tag(REP16); 138 set_data_type(DT_INVALID); 139 set_ndims_byte(0); 140 set_num_elements(1); 141 for (int64 s : dim_sizes) { 142 AddDim(internal::SubtleMustCopy(s)); 143 } 144 } 145 146 template <class Shape> 147 TensorShapeBase<Shape>::TensorShapeBase() { 148 set_tag(REP16); 149 set_data_type(DT_INVALID); 150 if (kIsPartial) { 151 set_ndims_byte(kUnknownRank); 152 set_num_elements(-1); 153 } else { 154 set_ndims_byte(0); 155 set_num_elements(1); 156 } 157 } 158 159 void TensorShapeRep::DestructorOutOfLine() { 160 DCHECK(tag() == REP_OUT_OF_LINE); 161 delete as64()->dims_; 162 } 163 164 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) { 165 if (b.tag() != REP_OUT_OF_LINE) { 166 if (tag() == REP_OUT_OF_LINE) { 167 delete as64()->dims_; 168 } 169 memcpy(buf(), b.buf(), sizeof(u_.buf)); 170 // memcpy above implicitly also does: 171 // set_tag(b.tag()); 172 // set_ndims_byte(b.ndims_byte()); 173 // set_data_type(b.data_type()); 174 } else { 175 DCHECK_EQ(b.tag(), REP_OUT_OF_LINE); 176 set_ndims_byte(b.ndims_byte()); 177 set_data_type(b.data_type()); 178 if (tag() == REP_OUT_OF_LINE) { 179 // vector already allocated 180 *(as64()->dims_) = *(b.as64()->dims_); 181 } else { 182 set_tag(REP_OUT_OF_LINE); 183 as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_)); 184 } 185 } 186 } 187 188 template <class Shape> 189 int64 TensorShapeBase<Shape>::dim_size(int d) const { 190 if (unknown_rank()) return -1; 191 DCHECK_GE(d, 0); 192 DCHECK_LT(d, dims()); 193 if (tag() == REP16) { 194 uint16 dim = as16()->dims_[d]; 195 if (kIsPartial && dim == kUnknownRep16) return -1; 196 return dim; 197 } else if (tag() == REP32) { 198 uint32 dim = as32()->dims_[d]; 199 if (kIsPartial && dim == kUnknownRep32) return -1; 200 return dim; 201 } else { 202 return (*as64()->dims_)[d]; 203 } 204 } 205 206 void TensorShapeRep::Clear() { 207 ClearAllButDataType(); 208 set_data_type(DT_INVALID); 209 } 210 211 void TensorShapeRep::ClearAllButDataType() { 212 if (tag() == REP_OUT_OF_LINE) { 213 delete as64()->dims_; 214 } 215 set_tag(REP16); 216 set_ndims_byte(0); 217 // Leaves data_type alone 218 set_num_elements(1); 219 } 220 221 template <class Shape> 222 void TensorShapeBase<Shape>::RecomputeNumElements() { 223 if (unknown_rank()) { 224 set_num_elements(-1); 225 return; 226 } 227 int64 n = 1; 228 for (auto dim : *this) { 229 if (kIsPartial && dim.size < 0) { 230 n = -1; 231 break; 232 } 233 n = MultiplyWithoutOverflow(n, dim.size); 234 CHECK_LE(0, n); 235 } 236 set_num_elements(n); 237 } 238 239 template <class Shape> 240 void TensorShapeBase<Shape>::AddDim(int64 size) { 241 if (!kIsPartial) CHECK_GE(size, 0); 242 if (unknown_rank()) return; 243 CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor"; 244 int64 new_num_elements; 245 if (kIsPartial && (num_elements() < 0 || size < 0)) { 246 new_num_elements = -1; 247 } else { 248 new_num_elements = MultiplyWithoutOverflow(num_elements(), size); 249 CHECK_LE(0, new_num_elements); 250 } 251 UnsafeAddDim(size, new_num_elements); 252 } 253 254 template <class Shape> 255 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) { 256 const int nd = ndims_byte(); 257 if (tag() == REP16 && nd < 6 && size < kMaxRep16) { 258 as16()->dims_[nd] = 259 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); 260 } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) { 261 as32()->dims_[nd] = 262 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); 263 } else if (tag() == REP_OUT_OF_LINE) { 264 as64()->dims_->push_back(size); 265 } else { 266 // Need to change representation 267 gtl::InlinedVector<int64, 8> vals; 268 AppendTo(*this, &vals); 269 vals.push_back(size); 270 // We know we can't be REP16. See if we have a small enough 271 // number of dimensions and each dimension's size is small enough 272 // to allow REP32. 273 bool can_be_rep32 = (vals.size() <= 3); 274 if (can_be_rep32) { 275 for (size_t i = 0; i < vals.size(); i++) { 276 if (vals[i] >= kMaxRep32) { 277 can_be_rep32 = false; 278 break; 279 } 280 } 281 } 282 if (can_be_rep32) { 283 set_tag(REP32); 284 for (size_t d = 0; d < vals.size(); d++) { 285 as32()->dims_[d] = kIsPartial && vals[d] < 0 286 ? kUnknownRep32 287 : static_cast<uint32>(vals[d]); 288 } 289 } else { 290 set_tag(REP_OUT_OF_LINE); 291 as64()->dims_ = 292 new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end()); 293 } 294 } 295 set_ndims_byte(nd + 1); 296 set_num_elements(new_num_elements); 297 } 298 299 template <class Shape> 300 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) { 301 for (auto d : shape) AddDim(d.size); 302 } 303 304 template <class Shape> 305 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) { 306 CHECK_GE(d, 0); 307 CHECK_LE(d, dims()); 308 if (!kIsPartial) CHECK_GE(size, 0); 309 CHECK_LT(dims(), MaxDimensions()); 310 gtl::InlinedVector<int64, 8> vals; 311 AppendTo(*this, &vals); 312 vals.insert(vals.begin() + d, size); 313 ClearAllButDataType(); 314 for (auto dval : vals) { 315 AddDim(dval); 316 } 317 } 318 319 template <class Shape> 320 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const { 321 gtl::InlinedVector<int64, 4> result; 322 for (auto dim : *this) { 323 result.push_back(dim.size); 324 } 325 return result; 326 } 327 328 template <class Shape> 329 void TensorShapeBase<Shape>::set_dim(int d, int64 size) { 330 CHECK_GE(d, 0); 331 CHECK_LT(d, dims()); 332 CHECK_GE(size, 0); 333 if (tag() == REP16 && size < kMaxRep16) { 334 as16()->dims_[d] = 335 kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size); 336 } else if (tag() == REP32 && size < kMaxRep32) { 337 as32()->dims_[d] = 338 kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size); 339 } else if (tag() == REP_OUT_OF_LINE) { 340 (*as64()->dims_)[d] = size; 341 } else { 342 // Must upgrade 343 gtl::InlinedVector<int64, 8> vals; 344 AppendTo(*this, &vals); 345 vals[d] = size; 346 ClearAllButDataType(); 347 for (auto dval : vals) { 348 AddDim(dval); 349 } 350 } 351 RecomputeNumElements(); 352 } 353 354 template <class Shape> 355 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) { 356 if (unknown_rank()) return; 357 begin = begin < 0 ? dims() + begin + 1 : begin; 358 end = end < 0 ? dims() + end + 1 : end; 359 CHECK_GE(begin, 0); 360 CHECK_LE(begin, dims()); 361 CHECK_GE(end, 0); 362 CHECK_LE(end, dims()); 363 if (begin >= end) return; 364 gtl::InlinedVector<int64, 8> vals; 365 AppendTo(*this, &vals); 366 vals.erase(vals.begin() + begin, vals.begin() + end); 367 ClearAllButDataType(); 368 for (auto dval : vals) { 369 AddDim(dval); 370 } 371 RecomputeNumElements(); 372 } 373 374 bool TensorShape::IsSameSize(const TensorShape& b) const { 375 if (b.dims() != dims()) return false; 376 for (int d = 0; d < dims(); d++) { 377 if (dim_size(d) != b.dim_size(d)) return false; 378 } 379 return true; 380 } 381 382 template <class Shape> 383 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const { 384 proto->Clear(); 385 if (unknown_rank()) { 386 proto->set_unknown_rank(true); 387 } else { 388 for (int i = 0; i < dims(); i++) { 389 proto->add_dim()->set_size(dim_size(i)); 390 } 391 } 392 } 393 394 void TensorShapeRep::DumpRep() const { 395 #if 0 396 fprintf(stderr, "Rep: %d %d dims\n", tag(), dims()); 397 if (tag() == REP16) { 398 fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte()); 399 for (int i = 0; i < ndims_byte(); i++) { 400 fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]); 401 } 402 } else if (tag_ == REP32) { 403 fprintf(stderr, "REP32 NDIMS: %d\n", ndims_); 404 for (int i = 0; i < ndims_byte(); i++) { 405 fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]); 406 } 407 } else if (tag_ == REP_OUT_OF_LINE) { 408 fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_); 409 for (int i = 0; i < ndims_byte(); i++) { 410 fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]); 411 } 412 } 413 #endif 414 } 415 416 template <class Shape> 417 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const { 418 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0); 419 } 420 421 template <class Shape> 422 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const { 423 CHECK(!unknown_rank()); 424 return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims()); 425 } 426 427 string TensorShapeRep::DebugString() const { 428 const auto& shape = *static_cast<const PartialTensorShape*>(this); 429 if (shape.unknown_rank()) return "<unknown>"; 430 string s = "["; 431 for (int i = 0; i < shape.dims(); i++) { 432 if (i > 0) strings::StrAppend(&s, ","); 433 int64 dim = shape.dim_size(i); 434 if (dim < 0) { 435 strings::StrAppend(&s, "?"); 436 } else { 437 strings::StrAppend(&s, dim); 438 } 439 } 440 strings::StrAppend(&s, "]"); 441 return s; 442 } 443 444 string TensorShapeRep::DebugString(const TensorShapeProto& proto) { 445 string s; 446 if (proto.unknown_rank()) { 447 strings::StrAppend(&s, "<unknown>"); 448 if (proto.dim_size() == 0) return s; 449 } 450 strings::StrAppend(&s, "["); 451 bool first = true; 452 for (const auto& d : proto.dim()) { 453 if (!first) strings::StrAppend(&s, ","); 454 if (d.size() == -1) { 455 strings::StrAppend(&s, "?"); 456 } else { 457 strings::StrAppend(&s, d.size()); 458 } 459 first = false; 460 } 461 strings::StrAppend(&s, "]"); 462 return s; 463 } 464 465 bool TensorShapeUtils::StartsWith(const TensorShape& shape, 466 const TensorShape& prefix) { 467 if (shape.dims() < prefix.dims()) return false; 468 for (int i = 0; i < prefix.dims(); ++i) { 469 if (shape.dim_size(i) != prefix.dim_size(i)) return false; 470 } 471 return true; 472 } 473 474 bool TensorShapeUtils::EndsWith(const TensorShape& shape, 475 const TensorShape& suffix) { 476 const int suffix_size = suffix.dims(); 477 if (shape.dims() < suffix_size) return false; 478 for (int i = 0; i < suffix_size; ++i) { 479 if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) { 480 return false; 481 } 482 } 483 return true; 484 } 485 486 template <typename T, class Shape> 487 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) { 488 out->Clear(); 489 if (n > TensorShape::MaxDimensions()) { 490 return errors::InvalidArgument("Too many dimensions"); 491 } 492 if (n < 0) { 493 return errors::InvalidArgument("Negative number of dimensions ", n); 494 } 495 for (int64 i = 0; i < n; ++i) { 496 T dim = internal::SubtleMustCopy(dims[i]); 497 int64 new_num_elements; 498 if (dim < 0) { 499 if (!out->kIsPartial) { 500 return errors::InvalidArgument("Dimension ", dim, " must be >= 0"); 501 } 502 if (dim < -1) { 503 return errors::InvalidArgument("Dimension ", dim, " must be >= -1"); 504 } 505 dim = -1; 506 new_num_elements = -1; 507 } else if (out->num_elements() < 0) { 508 new_num_elements = -1; 509 } else { 510 new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim); 511 if (TF_PREDICT_FALSE(new_num_elements < 0)) { 512 TensorShapeProto proto; 513 for (int64 j = 0; j < n; ++j) { 514 proto.add_dim()->set_size(dim); 515 } 516 return errors::InvalidArgument( 517 "Shape ", TensorShape::DebugString(proto), 518 " would have more than 2**63 - 1 elements"); 519 } 520 } 521 out->UnsafeAddDim(dim, new_num_elements); 522 } 523 return Status::OK(); 524 } 525 526 #define MAKE_SHAPE(T, Shape) \ 527 Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) { \ 528 return MakeShapeHelper(dims, n, out); \ 529 } \ 530 Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \ 531 return MakeShapeHelper(shape.data(), shape.size(), out); \ 532 } 533 MAKE_SHAPE(int32, TensorShape) 534 MAKE_SHAPE(int64, TensorShape) 535 MAKE_SHAPE(int32, PartialTensorShape) 536 MAKE_SHAPE(int64, PartialTensorShape) 537 #undef MAKE_SHAPE 538 539 string TensorShapeUtils::ShapeListString( 540 const gtl::ArraySlice<TensorShape>& shapes) { 541 string result = "["; 542 bool first = true; 543 for (const TensorShape& shape : shapes) { 544 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); 545 first = false; 546 } 547 strings::StrAppend(&result, "]"); 548 return result; 549 } 550 551 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const { 552 PartialTensorShape out = *this; 553 out.AddDim(size); 554 return out; 555 } 556 557 PartialTensorShape PartialTensorShape::Concatenate( 558 const PartialTensorShape& shape) const { 559 if (unknown_rank() || shape.unknown_rank()) { 560 return PartialTensorShape(); 561 } 562 PartialTensorShape out = *this; 563 for (auto dim : shape) out.AddDim(dim.size); 564 return out; 565 } 566 567 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, 568 PartialTensorShape* result) const { 569 if (unknown_rank()) { 570 *result = shape; 571 return Status::OK(); 572 } 573 if (shape.unknown_rank()) { 574 *result = *this; 575 return Status::OK(); 576 } 577 const int dims_ = dims(); 578 if (dims_ != shape.dims()) { 579 return errors::InvalidArgument( 580 "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ", 581 shape.dims()); 582 } 583 CHECK(result != this); 584 result->Clear(); 585 for (int i = 0; i < dims_; ++i) { 586 const int64 dim0 = dim_size(i); 587 const int64 dim1 = shape.dim_size(i); 588 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) { 589 return errors::InvalidArgument( 590 "PartialTensorShape: Incompatible shapes during merge: ", 591 DebugString(), " vs. ", shape.DebugString()); 592 } 593 result->AddDim(dim0 >= 0 ? dim0 : dim1); 594 } 595 return Status::OK(); 596 } 597 598 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const { 599 if (IsFullyDefined()) { 600 const TensorShapeRep* rep = this; 601 *shape = *static_cast<const TensorShape*>(rep); 602 return true; 603 } 604 return false; 605 } 606 607 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const { 608 if (unknown_rank() || shape.unknown_rank()) { 609 return unknown_rank() == shape.unknown_rank(); 610 } 611 if (dims() != shape.dims()) return false; 612 for (int i = 0; i < dims(); i++) { 613 if (dim_size(i) != shape.dim_size(i)) return false; 614 } 615 return true; 616 } 617 618 bool PartialTensorShape::IsCompatibleWith( 619 const PartialTensorShape& shape) const { 620 if (unknown_rank() || shape.unknown_rank()) return true; 621 if (dims() != shape.dims()) return false; 622 for (int i = 0; i < dims(); i++) { 623 const int64 dim0 = dim_size(i); 624 const int64 dim1 = shape.dim_size(i); 625 if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false; 626 } 627 return true; 628 } 629 630 string PartialTensorShapeUtils::PartialShapeListString( 631 const gtl::ArraySlice<PartialTensorShape>& shapes) { 632 string result = "["; 633 bool first = true; 634 for (const PartialTensorShape& shape : shapes) { 635 strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString()); 636 first = false; 637 } 638 strings::StrAppend(&result, "]"); 639 return result; 640 } 641 642 bool PartialTensorShapeUtils::AreCompatible( 643 const gtl::ArraySlice<PartialTensorShape>& shapes0, 644 const gtl::ArraySlice<PartialTensorShape>& shapes1) { 645 if (shapes0.size() == shapes1.size()) { 646 for (size_t i = 0; i < shapes0.size(); ++i) { 647 if (!shapes0[i].IsCompatibleWith(shapes1[i])) { 648 return false; 649 } 650 } 651 return true; 652 } else { 653 return false; 654 } 655 } 656 657 bool PartialTensorShapeUtils::AreIdentical( 658 const gtl::ArraySlice<PartialTensorShape>& shapes0, 659 const gtl::ArraySlice<PartialTensorShape>& shapes1) { 660 if (shapes0.size() == shapes1.size()) { 661 for (size_t i = 0; i < shapes0.size(); ++i) { 662 if (!shapes0[i].IsIdenticalTo(shapes1[i])) { 663 return false; 664 } 665 } 666 return true; 667 } else { 668 return false; 669 } 670 } 671 672 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape, 673 int64* num_elements) { 674 int64 n = 1; 675 for (auto dim : shape) { 676 n = MultiplyWithoutOverflow(n, dim); 677 if (n < 0) { 678 return errors::InvalidArgument("Can't compute total size of shape [", 679 str_util::Join(shape, ","), 680 "]; product would overflow int64"); 681 } 682 } 683 *num_elements = n; 684 return Status::OK(); 685 } 686 687 template class TensorShapeBase<TensorShape>; 688 template class TensorShapeBase<PartialTensorShape>; 689 690 } // namespace tensorflow 691