1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/tensor_shape.h" 17 18 #include "tensorflow/core/framework/tensor_shape.pb.h" 19 #include "tensorflow/core/lib/core/status_test_util.h" 20 #include "tensorflow/core/lib/random/simple_philox.h" 21 #include "tensorflow/core/lib/strings/str_util.h" 22 #include "tensorflow/core/lib/strings/strcat.h" 23 #include "tensorflow/core/platform/test.h" 24 #include "tensorflow/core/platform/test_benchmark.h" 25 26 namespace tensorflow { 27 class TensorShapeTestHelper { 28 public: 29 static void set_data_type(TensorShape* s, DataType t) { s->set_data_type(t); } 30 static uint8 data_type(const TensorShape* s) { return s->data_type(); } 31 }; 32 33 namespace { 34 35 TEST(TensorShapeTest, Default) { 36 // The default TensorShape constructor constructs a shape of 0-dim 37 // and 1-element. 38 TensorShape s; 39 EXPECT_EQ(s.dims(), 0); 40 EXPECT_EQ(s.num_elements(), 1); 41 } 42 43 TEST(TensorShapeTest, set_dim) { 44 TensorShape s({10, 5}); 45 46 s.set_dim(0, 20); 47 ASSERT_EQ(2, s.dims()); 48 EXPECT_EQ(20, s.dim_size(0)); 49 EXPECT_EQ(100, s.num_elements()); 50 51 s.set_dim(1, 2); 52 ASSERT_EQ(2, s.dims()); 53 EXPECT_EQ(2, s.dim_size(1)); 54 EXPECT_EQ(40, s.num_elements()); 55 } 56 57 TEST(TensorShapeTest, RemoveDim) { 58 TensorShape s({10, 5}); 59 s.RemoveDim(0); 60 EXPECT_EQ(5, s.num_elements()); 61 ASSERT_EQ(1, s.dims()); 62 } 63 64 TEST(TensorShapeTest, RemoveAndAddDim) { 65 TensorShape s({10, 5, 20}); 66 s.RemoveDim(1); 67 s.AddDim(100); 68 69 EXPECT_EQ(20000, s.num_elements()); 70 ASSERT_EQ(3, s.dims()); 71 } 72 73 TEST(TensorShapeTest, RemoveLastDims) { 74 TensorShape s({2, 3, 5, 7}); 75 s.RemoveLastDims(1); 76 77 ASSERT_EQ(3, s.dims()); 78 EXPECT_EQ(30, s.num_elements()); 79 80 s.RemoveLastDims(2); 81 ASSERT_EQ(1, s.dims()); 82 EXPECT_EQ(2, s.dim_size(0)); 83 } 84 85 TEST(TensorShapeTest, RemoveDimRange) { 86 TensorShape s0({2, 3, 5, 7}); 87 // Empty interval => noop. 88 for (int i = -4; i <= 4; ++i) { 89 s0.RemoveDimRange(i, i); 90 ASSERT_EQ(4, s0.dims()); 91 ASSERT_EQ(210, s0.num_elements()); 92 } 93 94 // Positive begin and end. 95 s0.RemoveDimRange(3, 1); // Empty interval. 96 ASSERT_EQ(4, s0.dims()); 97 ASSERT_EQ(210, s0.num_elements()); 98 s0.RemoveDimRange(0, 3); 99 ASSERT_EQ(1, s0.dims()); 100 EXPECT_EQ(7, s0.dim_size(0)); 101 TensorShape s1({2, 3, 5, 7}); 102 s1.RemoveDimRange(2, 3); 103 ASSERT_EQ(3, s1.dims()); 104 ASSERT_EQ(42, s1.num_elements()); 105 106 // Negative begin or end. 107 TensorShape s2({2, 3, 5, 7}); 108 s2.RemoveDimRange(-2, -3); // Empty interval. 109 ASSERT_EQ(4, s2.dims()); 110 ASSERT_EQ(210, s2.num_elements()); 111 s2.RemoveDimRange(0, -2); 112 ASSERT_EQ(1, s2.dims()); 113 ASSERT_EQ(7, s2.dim_size(0)); 114 TensorShape s3({2, 3, 5, 7}); 115 s3.RemoveDimRange(-3, -2); 116 ASSERT_EQ(3, s3.dims()); 117 ASSERT_EQ(42, s3.num_elements()); 118 } 119 120 TEST(TensorShapeTest, InvalidShapeProto) { 121 TensorShapeProto proto; 122 EXPECT_TRUE(TensorShape::IsValid(proto)); 123 124 proto.add_dim()->set_size(357); 125 proto.add_dim()->set_size(982); 126 EXPECT_TRUE(TensorShape::IsValid(proto)); 127 128 proto.Clear(); 129 proto.add_dim()->set_size(-357); 130 proto.add_dim()->set_size(-982); 131 EXPECT_FALSE(TensorShape::IsValid(proto)); 132 133 proto.Clear(); 134 proto.add_dim()->set_size(1LL << 35); 135 proto.add_dim()->set_size((1LL << 35) + 1); 136 EXPECT_FALSE(TensorShape::IsValid(proto)); 137 } 138 139 TEST(TensorShapeTest, TooManyDimsProto) { 140 TensorShapeProto proto; 141 // Deliberate redundancy to ensure that both paths work. 142 EXPECT_TRUE(TensorShape::IsValid(proto)); 143 TF_EXPECT_OK(TensorShape::IsValidShape(proto)); 144 for (int i = 0; i < TensorShape::MaxDimensions(); i++) { 145 proto.add_dim()->set_size(1); 146 } 147 EXPECT_TRUE(TensorShape::IsValid(proto)); 148 TF_EXPECT_OK(TensorShape::IsValidShape(proto)); 149 proto.add_dim()->set_size(1); 150 EXPECT_FALSE(TensorShape::IsValid(proto)); 151 EXPECT_FALSE(TensorShape::IsValidShape(proto).ok()); 152 } 153 154 TEST(TensorShapeTest, SetDimForEmptyTensor) { 155 TensorShape s({10, 5, 20}); 156 EXPECT_EQ(1000, s.num_elements()); 157 s.set_dim(1, 0); 158 EXPECT_EQ(0, s.num_elements()); 159 s.set_dim(1, 7); 160 EXPECT_EQ(1400, s.num_elements()); 161 } 162 163 TEST(TensorShapeTest, AppendShape64BitIndices) { 164 TensorShape s({10, 2147483648}); 165 166 EXPECT_EQ(10, s.dim_size(0)); 167 EXPECT_EQ(2147483648, s.dim_size(1)); 168 169 TensorShape s2; 170 s2.AppendShape(s); 171 EXPECT_EQ(10, s2.dim_size(0)); 172 EXPECT_EQ(2147483648, s2.dim_size(1)); 173 } 174 175 TEST(TensorShapeTest, DataType) { 176 TensorShape s({}); 177 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INVALID); 178 TensorShapeTestHelper::set_data_type(&s, DT_INT32); 179 s.AddDim(1); 180 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32); 181 s.AddDim(100000); 182 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32); 183 TensorShapeTestHelper::set_data_type(&s, DT_UINT16_REF); 184 s.AddDim(2); 185 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); 186 s.AddDim(4); 187 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); 188 s.AddDim(3); 189 EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF); 190 191 TensorShape s2 = s; 192 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF); 193 s2.RemoveDim(2); 194 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF); 195 TensorShapeTestHelper::set_data_type(&s2, DT_FLOAT); 196 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_FLOAT); 197 s2.Clear(); 198 EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_INVALID); 199 } 200 201 TEST(TensorShapeTest, ostream) { 202 TensorShape s({10, 5, 4}); 203 std::stringstream ss; 204 ss << s; 205 EXPECT_EQ(ss.str(), "[10,5,4]"); 206 } 207 208 // ----------------------------------------------------------------------- 209 // An old implementation of TensorShape using a different representation, 210 // preserved here in the unittest to allow us to have a randomized unittest 211 // that makes sure the behavior of TensorShape and TensorShapeOld are 212 // the same. 213 class TensorShapeIterOld; // Declared below 214 215 /// Manages the dimensions of a Tensor and their sizes. 216 class TensorShapeOld { 217 public: 218 /// \brief Construct a `TensorShape` from the provided sizes. 219 /// REQUIRES: `dim_sizes[i] >= 0` 220 explicit TensorShapeOld(gtl::ArraySlice<int64> dim_sizes); 221 TensorShapeOld(std::initializer_list<int64> dim_sizes) 222 : TensorShapeOld(gtl::ArraySlice<int64>(dim_sizes)) {} 223 224 /// REQUIRES: `IsValid(proto)` 225 explicit TensorShapeOld(const TensorShapeProto& proto); 226 227 /// Create a tensor shape with no dimensions and one element, which you can 228 /// then call `AddDim()` on. 229 TensorShapeOld(); 230 231 /// Returns `true` iff `proto` is a valid tensor shape. 232 static bool IsValid(const TensorShapeProto& proto); 233 234 /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error 235 /// status otherwise. 236 static Status IsValidShape(const TensorShapeProto& proto); 237 238 /// Clear a tensor shape 239 void Clear(); 240 241 /// \brief Add a dimension to the end ("inner-most"). 242 /// REQUIRES: `size >= 0` 243 void AddDim(int64 size); 244 245 /// Appends all the dimensions from `shape`. 246 void AppendShape(const TensorShapeOld& shape); 247 248 /// \brief Insert a dimension somewhere in the `TensorShape`. 249 /// REQUIRES: `0 <= d <= dims()` 250 /// REQUIRES: `size >= 0` 251 void InsertDim(int d, int64 size); 252 253 /// \brief Modifies the size of the dimension `d` to be `size` 254 /// REQUIRES: `0 <= d < dims()` 255 /// REQUIRES: `size >= 0` 256 void set_dim(int d, int64 size); 257 258 /// \brief Removes dimension `d` from the `TensorShape`. 259 /// REQUIRES: `0 <= d < dims()` 260 void RemoveDim(int d); 261 262 /// Return the number of dimensions in the tensor. 263 int dims() const { return dim_sizes_.size(); } 264 265 /// \brief Returns the number of elements in dimension `d`. 266 /// REQUIRES: `0 <= d < dims()` 267 // TODO(touts): Rename to `dimension()` to match 268 // `Eigen::Tensor::dimension()`? 269 int64 dim_size(int d) const { 270 DCHECK_GE(d, 0); 271 DCHECK_LT(d, dims()); 272 return dim_sizes_[d]; 273 } 274 275 /// Returns sizes of all dimensions. 276 gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; } 277 278 /// \brief Returns the number of elements in the tensor. 279 /// 280 /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` 281 /// which uses `ptrdiff_t`. 282 int64 num_elements() const { return num_elements_; } 283 284 /// Returns true if `*this` and `b` have the same sizes. Ignores 285 /// dimension names. 286 bool IsSameSize(const TensorShapeOld& b) const; 287 bool operator==(const TensorShapeOld& b) const { return IsSameSize(b); } 288 289 /// Fill `*proto` from `*this`. 290 void AsProto(TensorShapeProto* proto) const; 291 292 /// Fill `*dsizes` from `*this`. 293 template <int NDIMS> 294 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const; 295 296 /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in 297 /// which case we pad the rest of the sizes with 1. 298 template <int NDIMS> 299 Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const; 300 301 /// For iterating through the dimensions. 302 TensorShapeIterOld begin() const; 303 TensorShapeIterOld end() const; 304 305 /// For error messages. 306 string DebugString() const; 307 308 /// Same as `TensorShape(proto).DebugString()` but doesn't crash for 309 /// invalid protos. 310 static string DebugString(const TensorShapeProto& proto); 311 312 private: 313 // Recalculates the dimensions of this tensor after they are modified. 314 void recompute_dims(); 315 316 // TODO(josh11b): Maybe use something from the Eigen Tensor library 317 // for the sizes. 318 gtl::InlinedVector<int64, 4> dim_sizes_; 319 320 // total number of elements (avoids recomputing it each time). 321 int64 num_elements_; 322 }; 323 324 struct TensorShapeDimOld { 325 explicit TensorShapeDimOld(int64 s) : size(s) {} 326 int64 size; 327 }; 328 329 class TensorShapeIterOld { 330 public: 331 TensorShapeIterOld(const TensorShapeOld* shape, int d) 332 : shape_(shape), d_(d) {} 333 bool operator==(const TensorShapeIterOld& rhs) { 334 DCHECK(shape_ == rhs.shape_); 335 return d_ == rhs.d_; 336 } 337 bool operator!=(const TensorShapeIterOld& rhs) { 338 DCHECK(shape_ == rhs.shape_); 339 return d_ != rhs.d_; 340 } 341 void operator++() { ++d_; } 342 TensorShapeDimOld operator*() { 343 return TensorShapeDimOld(shape_->dim_size(d_)); 344 } 345 346 private: 347 const TensorShapeOld* shape_; 348 int d_; 349 }; 350 351 // An upper limit of the total number of elements in a tensor. 352 static const int64 kMaxElements = (1LL << 40); 353 354 bool TensorShapeOld::IsValid(const TensorShapeProto& proto) { 355 int64 num_elements = 1; 356 for (const auto& d : proto.dim()) { 357 if (d.size() < 0) return false; 358 num_elements *= d.size(); 359 if (num_elements > kMaxElements) return false; 360 } 361 return true; 362 } 363 364 Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) { 365 int64 num_elements = 1; 366 for (const auto& d : proto.dim()) { 367 if (d.size() < 0) { 368 return errors::InvalidArgument("Shape ", DebugString(proto), 369 " has negative dimensions; ", 370 "perhaps an un-fed placeholder?"); 371 } 372 num_elements *= d.size(); 373 if (num_elements > kMaxElements) { 374 return errors::InvalidArgument("Shape ", DebugString(proto), 375 " is too large (more than ", kMaxElements, 376 " entries)"); 377 } 378 } 379 return Status::OK(); 380 } 381 382 TensorShapeOld::TensorShapeOld(const TensorShapeProto& proto) { 383 dim_sizes_.reserve(proto.dim_size()); 384 num_elements_ = 1; 385 for (const auto& d : proto.dim()) { 386 AddDim(d.size()); 387 } 388 } 389 390 TensorShapeOld::TensorShapeOld(gtl::ArraySlice<int64> dim_sizes) { 391 dim_sizes_.reserve(dim_sizes.size()); 392 num_elements_ = 1; 393 for (auto s : dim_sizes) { 394 AddDim(s); 395 } 396 } 397 398 TensorShapeOld::TensorShapeOld() : num_elements_(1) {} 399 400 void TensorShapeOld::Clear() { 401 dim_sizes_.clear(); 402 num_elements_ = 1; 403 } 404 405 void TensorShapeOld::AddDim(int64 size) { 406 CHECK_GE(size, 0); 407 dim_sizes_.push_back(size); 408 num_elements_ *= size; 409 CHECK_LE(0, num_elements_); 410 CHECK_LE(num_elements_, kMaxElements); 411 } 412 413 void TensorShapeOld::AppendShape(const TensorShapeOld& shape) { 414 for (auto d : shape) AddDim(d.size); 415 } 416 417 void TensorShapeOld::InsertDim(int d, int64 size) { 418 CHECK_GE(d, 0); 419 CHECK_LE(d, dims()); 420 CHECK_GE(size, 0); 421 dim_sizes_.insert(dim_sizes_.begin() + d, size); 422 num_elements_ *= size; 423 CHECK_LE(0, num_elements_); 424 CHECK_LE(num_elements_, kMaxElements); 425 } 426 427 void TensorShapeOld::set_dim(int d, int64 size) { 428 CHECK_GE(d, 0); 429 CHECK_LT(d, dims()); 430 CHECK_GE(size, 0); 431 432 // Update the number of elements. num_elements_ is int64. 433 dim_sizes_[d] = size; 434 recompute_dims(); 435 } 436 437 void TensorShapeOld::RemoveDim(int d) { 438 CHECK_GE(d, 0); 439 CHECK_LT(d, dims()); 440 441 // Update the number of elements and remove the dimension from the 442 // sizes. 443 dim_sizes_.erase(dim_sizes_.begin() + d); 444 recompute_dims(); 445 } 446 447 void TensorShapeOld::recompute_dims() { 448 num_elements_ = 1; 449 for (auto s : dim_sizes_) { 450 num_elements_ *= s; 451 CHECK_LE(0, num_elements_); 452 CHECK_LE(num_elements_, kMaxElements); 453 } 454 } 455 456 bool TensorShapeOld::IsSameSize(const TensorShapeOld& b) const { 457 if (b.dims() != dims()) return false; 458 for (int d = 0; d < dims(); d++) { 459 if (dim_size(d) != b.dim_size(d)) return false; 460 } 461 return true; 462 } 463 464 void TensorShapeOld::AsProto(TensorShapeProto* proto) const { 465 proto->Clear(); 466 for (size_t d = 0; d < dim_sizes_.size(); ++d) { 467 auto* dim = proto->add_dim(); 468 dim->set_size(dim_sizes_[d]); 469 } 470 } 471 472 TensorShapeIterOld TensorShapeOld::begin() const { 473 return TensorShapeIterOld(this, 0); 474 } 475 476 TensorShapeIterOld TensorShapeOld::end() const { 477 return TensorShapeIterOld(this, dims()); 478 } 479 480 string TensorShapeOld::DebugString() const { 481 return strings::StrCat( 482 "[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]"); 483 } 484 485 string TensorShapeOld::DebugString(const TensorShapeProto& proto) { 486 string s = "["; 487 bool first = true; 488 for (const auto& d : proto.dim()) { 489 strings::StrAppend(&s, first ? "" : ",", d.size()); 490 first = false; 491 } 492 strings::StrAppend(&s, "]"); 493 return s; 494 } 495 // End of old implementation 496 // ------------------------------------------------------------------------ 497 498 static int64 SkewedSize(random::SimplePhilox* gen, int64 current_elements) { 499 int64 result = 0; 500 do { 501 if (current_elements < 100) { 502 result = gen->Uniform(100000); 503 } else { 504 result = gen->Uniform(2); 505 } 506 } while ((result * current_elements >= 1LL << 34) || 507 (result * current_elements < 0)); 508 return result; 509 } 510 511 TEST(TensorShapeTest, Randomized) { 512 // We do a randomized test to verify that the behavior of the 513 // TensorShape implementation (which changes representations depending 514 // on the values) is identical to our older, more straightforward (but 515 // more memory hungry) implementation (TensorShapeOld). 516 random::PhiloxRandom philox(7, 7); 517 random::SimplePhilox gen(&philox); 518 TensorShape s; 519 TensorShapeOld sold; 520 TensorShapeProto sp; 521 TensorShapeProto spold; 522 LOG(INFO) << "Sizes: " << sizeof(TensorShape) << " vs " 523 << sizeof(TensorShapeOld); 524 for (int i = 0; i < 100000; i++) { 525 s.AsProto(&sp); 526 sold.AsProto(&spold); 527 EXPECT_EQ(sp.DebugString(), spold.DebugString()); 528 if ((i % 1000) == 0) { 529 fprintf(stderr, "ITERATION %d: %s\n", i, sp.DebugString().c_str()); 530 } 531 EXPECT_EQ(s.num_elements(), sold.num_elements()); 532 533 // Test moves. 534 TensorShape copy = s; 535 TensorShape moved(std::move(copy)); 536 EXPECT_EQ(s, moved); 537 copy = s; 538 moved = std::move(copy); 539 EXPECT_EQ(s, moved); 540 541 int64 ne = sold.num_elements(); 542 int r = gen.Uniform(100); 543 if (r < 10) { 544 int64 sz = SkewedSize(&gen, sold.num_elements()); 545 s.AddDim(sz); 546 sold.AddDim(sz); 547 } else if (r < 15) { 548 s.Clear(); 549 sold.Clear(); 550 } else if (r < 35 && s.dims() > 0 && ne > 0 && ne < 100000000) { 551 int dim = gen.Uniform(s.dims()); 552 s.RemoveDim(dim); 553 sold.RemoveDim(dim); 554 } else if (r < 50 && ne > 0 && ne < 100000000) { 555 int dim = gen.Uniform(s.dims() + 1); 556 int64 sz = SkewedSize(&gen, sold.num_elements()); 557 s.InsertDim(dim, sz); 558 sold.InsertDim(dim, sz); 559 } else { 560 std::vector<int64> sizes; 561 const int N = (gen.Uniform(4) == 0) ? gen.Uniform(10) : gen.Uniform(3); 562 int64 num_elements = 1; 563 for (int i = 0; i < N; i++) { 564 int64 sz = SkewedSize(&gen, num_elements); 565 sizes.push_back(sz); 566 num_elements *= std::max<int64>(1, sz); 567 } 568 569 s = TensorShape(sizes); 570 sold = TensorShapeOld(sizes); 571 } 572 } 573 } 574 575 TEST(TensorShapeTest, Large) { 576 // We used to cap shapes at 2**40 elements. Ensure the 577 // bound is now higher. 578 int64 one = 1; 579 int64 max = std::numeric_limits<int64>::max(); 580 EXPECT_EQ(TensorShape({max}).num_elements(), max); 581 EXPECT_EQ(TensorShape({1, max}).num_elements(), max); 582 EXPECT_EQ(TensorShape({max, 1}).num_elements(), max); 583 EXPECT_EQ(TensorShape({one << 62}).num_elements(), one << 62); 584 EXPECT_EQ(TensorShape({one << 20, one << 41}).num_elements(), one << 61); 585 EXPECT_EQ(TensorShape({1000, 1000, 1000, 1000, 1000, 1000}).num_elements(), 586 1e18); 587 } 588 589 TEST(TensorShapeTest, Overflow) { 590 int64 one = 1; 591 std::vector<std::vector<int64>> overflows = { 592 {1 << 30, 1 << 30, 1 << 30}, 593 {1 << 5, (one << 60) + 1}, 594 }; 595 for (const auto& overflow : overflows) { 596 TensorShapeProto proto; 597 for (auto dim : overflow) { 598 proto.add_dim()->set_size(dim); 599 } 600 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, 601 TensorShape::IsValidShape(proto).code()); 602 TensorShape shape; 603 EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, 604 TensorShapeUtils::MakeShape(overflow, &shape).code()); 605 } 606 } 607 608 TEST(TensorShapeTest, UnknownRank) { 609 // NOTE(irving): Unfortunately, for historical reasons we have to allow an 610 // TensorShapeProto with unknown_rank() set to be parsed as a TensorShape. 611 // Would be nice to tighten this, but it's tricky given backwards 612 // compatibility requirements. 613 TensorShapeProto proto; 614 proto.set_unknown_rank(true); 615 EXPECT_TRUE(TensorShape::IsValid(proto)); 616 TF_EXPECT_OK(TensorShape::IsValidShape(proto)); 617 EXPECT_EQ(TensorShape(), TensorShape(proto)); 618 619 proto.add_dim()->set_size(7); 620 EXPECT_TRUE(TensorShape::IsValid(proto)); 621 TF_EXPECT_OK(TensorShape::IsValidShape(proto)); 622 EXPECT_EQ(TensorShape({7}), TensorShape(proto)); 623 } 624 625 TEST(TensorShapeUtilsTest, StartsWith) { 626 EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({}), TensorShape({}))); 627 EXPECT_TRUE( 628 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({}))); 629 EXPECT_TRUE( 630 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2}))); 631 EXPECT_TRUE( 632 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 3}))); 633 EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}), 634 TensorShape({2, 3}))); 635 EXPECT_FALSE( 636 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({3}))); 637 EXPECT_FALSE( 638 TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 4}))); 639 EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3}), 640 TensorShape({2, 3, 4}))); 641 EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}), 642 TensorShape({3, 4}))); 643 } 644 645 TEST(TensorShapeUtilsTest, EndsWith) { 646 EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({}), TensorShape({}))); 647 EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({}))); 648 EXPECT_TRUE( 649 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({3}))); 650 EXPECT_TRUE( 651 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3}))); 652 EXPECT_TRUE( 653 TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({3, 4}))); 654 EXPECT_FALSE( 655 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2}))); 656 EXPECT_FALSE( 657 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 4}))); 658 EXPECT_FALSE( 659 TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3, 4}))); 660 EXPECT_FALSE( 661 TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({2, 3}))); 662 } 663 664 // A few different test cases for tensor sizes for benchmarks 665 static std::vector<int64> MakeSizes(int arg) { 666 std::vector<int64> sizes; 667 switch (arg) { 668 case 0: 669 sizes = {100}; 670 break; 671 case 1: 672 sizes = {100, 1000}; 673 break; 674 case 2: 675 sizes = {100, 1000000}; 676 break; 677 case 3: 678 sizes = {100, 256, 192, 3}; 679 break; 680 case 4: 681 sizes = {1, 2, 1ll << 34, 1, 1, 1}; 682 break; 683 } 684 return sizes; 685 } 686 687 static void BM_TensorShape_Init(int iters, int arg) { 688 auto sizes = MakeSizes(arg); 689 while (--iters > 0) { 690 TensorShape shape(sizes); 691 tensorflow::testing::DoNotOptimize(shape.num_elements()); 692 } 693 } 694 BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); 695 696 static void BM_TensorShape_Assign(int iters, int arg) { 697 TensorShape s(MakeSizes(arg)); 698 while (--iters > 0) { 699 TensorShape s2 = s; 700 } 701 } 702 BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4); 703 704 } // namespace 705 } // namespace tensorflow 706