Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/tensor_shape.h"
     17 
     18 #include "tensorflow/core/framework/tensor_shape.pb.h"
     19 #include "tensorflow/core/lib/core/status_test_util.h"
     20 #include "tensorflow/core/lib/random/simple_philox.h"
     21 #include "tensorflow/core/lib/strings/str_util.h"
     22 #include "tensorflow/core/lib/strings/strcat.h"
     23 #include "tensorflow/core/platform/test.h"
     24 #include "tensorflow/core/platform/test_benchmark.h"
     25 
     26 namespace tensorflow {
     27 class TensorShapeTestHelper {
     28  public:
     29   static void set_data_type(TensorShape* s, DataType t) { s->set_data_type(t); }
     30   static uint8 data_type(const TensorShape* s) { return s->data_type(); }
     31 };
     32 
     33 namespace {
     34 
     35 TEST(TensorShapeTest, Default) {
     36   // The default TensorShape constructor constructs a shape of 0-dim
     37   // and 1-element.
     38   TensorShape s;
     39   EXPECT_EQ(s.dims(), 0);
     40   EXPECT_EQ(s.num_elements(), 1);
     41 }
     42 
     43 TEST(TensorShapeTest, set_dim) {
     44   TensorShape s({10, 5});
     45 
     46   s.set_dim(0, 20);
     47   ASSERT_EQ(2, s.dims());
     48   EXPECT_EQ(20, s.dim_size(0));
     49   EXPECT_EQ(100, s.num_elements());
     50 
     51   s.set_dim(1, 2);
     52   ASSERT_EQ(2, s.dims());
     53   EXPECT_EQ(2, s.dim_size(1));
     54   EXPECT_EQ(40, s.num_elements());
     55 }
     56 
     57 TEST(TensorShapeTest, RemoveDim) {
     58   TensorShape s({10, 5});
     59   s.RemoveDim(0);
     60   EXPECT_EQ(5, s.num_elements());
     61   ASSERT_EQ(1, s.dims());
     62 }
     63 
     64 TEST(TensorShapeTest, RemoveAndAddDim) {
     65   TensorShape s({10, 5, 20});
     66   s.RemoveDim(1);
     67   s.AddDim(100);
     68 
     69   EXPECT_EQ(20000, s.num_elements());
     70   ASSERT_EQ(3, s.dims());
     71 }
     72 
     73 TEST(TensorShapeTest, RemoveLastDims) {
     74   TensorShape s({2, 3, 5, 7});
     75   s.RemoveLastDims(1);
     76 
     77   ASSERT_EQ(3, s.dims());
     78   EXPECT_EQ(30, s.num_elements());
     79 
     80   s.RemoveLastDims(2);
     81   ASSERT_EQ(1, s.dims());
     82   EXPECT_EQ(2, s.dim_size(0));
     83 }
     84 
     85 TEST(TensorShapeTest, RemoveDimRange) {
     86   TensorShape s0({2, 3, 5, 7});
     87   // Empty interval => noop.
     88   for (int i = -4; i <= 4; ++i) {
     89     s0.RemoveDimRange(i, i);
     90     ASSERT_EQ(4, s0.dims());
     91     ASSERT_EQ(210, s0.num_elements());
     92   }
     93 
     94   // Positive begin and end.
     95   s0.RemoveDimRange(3, 1);  // Empty interval.
     96   ASSERT_EQ(4, s0.dims());
     97   ASSERT_EQ(210, s0.num_elements());
     98   s0.RemoveDimRange(0, 3);
     99   ASSERT_EQ(1, s0.dims());
    100   EXPECT_EQ(7, s0.dim_size(0));
    101   TensorShape s1({2, 3, 5, 7});
    102   s1.RemoveDimRange(2, 3);
    103   ASSERT_EQ(3, s1.dims());
    104   ASSERT_EQ(42, s1.num_elements());
    105 
    106   // Negative begin or end.
    107   TensorShape s2({2, 3, 5, 7});
    108   s2.RemoveDimRange(-2, -3);  // Empty interval.
    109   ASSERT_EQ(4, s2.dims());
    110   ASSERT_EQ(210, s2.num_elements());
    111   s2.RemoveDimRange(0, -2);
    112   ASSERT_EQ(1, s2.dims());
    113   ASSERT_EQ(7, s2.dim_size(0));
    114   TensorShape s3({2, 3, 5, 7});
    115   s3.RemoveDimRange(-3, -2);
    116   ASSERT_EQ(3, s3.dims());
    117   ASSERT_EQ(42, s3.num_elements());
    118 }
    119 
    120 TEST(TensorShapeTest, InvalidShapeProto) {
    121   TensorShapeProto proto;
    122   EXPECT_TRUE(TensorShape::IsValid(proto));
    123 
    124   proto.add_dim()->set_size(357);
    125   proto.add_dim()->set_size(982);
    126   EXPECT_TRUE(TensorShape::IsValid(proto));
    127 
    128   proto.Clear();
    129   proto.add_dim()->set_size(-357);
    130   proto.add_dim()->set_size(-982);
    131   EXPECT_FALSE(TensorShape::IsValid(proto));
    132 
    133   proto.Clear();
    134   proto.add_dim()->set_size(1LL << 35);
    135   proto.add_dim()->set_size((1LL << 35) + 1);
    136   EXPECT_FALSE(TensorShape::IsValid(proto));
    137 }
    138 
    139 TEST(TensorShapeTest, TooManyDimsProto) {
    140   TensorShapeProto proto;
    141   // Deliberate redundancy to ensure that both paths work.
    142   EXPECT_TRUE(TensorShape::IsValid(proto));
    143   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
    144   for (int i = 0; i < TensorShape::MaxDimensions(); i++) {
    145     proto.add_dim()->set_size(1);
    146   }
    147   EXPECT_TRUE(TensorShape::IsValid(proto));
    148   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
    149   proto.add_dim()->set_size(1);
    150   EXPECT_FALSE(TensorShape::IsValid(proto));
    151   EXPECT_FALSE(TensorShape::IsValidShape(proto).ok());
    152 }
    153 
    154 TEST(TensorShapeTest, SetDimForEmptyTensor) {
    155   TensorShape s({10, 5, 20});
    156   EXPECT_EQ(1000, s.num_elements());
    157   s.set_dim(1, 0);
    158   EXPECT_EQ(0, s.num_elements());
    159   s.set_dim(1, 7);
    160   EXPECT_EQ(1400, s.num_elements());
    161 }
    162 
    163 TEST(TensorShapeTest, AppendShape64BitIndices) {
    164   TensorShape s({10, 2147483648});
    165 
    166   EXPECT_EQ(10, s.dim_size(0));
    167   EXPECT_EQ(2147483648, s.dim_size(1));
    168 
    169   TensorShape s2;
    170   s2.AppendShape(s);
    171   EXPECT_EQ(10, s2.dim_size(0));
    172   EXPECT_EQ(2147483648, s2.dim_size(1));
    173 }
    174 
    175 TEST(TensorShapeTest, DataType) {
    176   TensorShape s({});
    177   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INVALID);
    178   TensorShapeTestHelper::set_data_type(&s, DT_INT32);
    179   s.AddDim(1);
    180   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
    181   s.AddDim(100000);
    182   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_INT32);
    183   TensorShapeTestHelper::set_data_type(&s, DT_UINT16_REF);
    184   s.AddDim(2);
    185   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
    186   s.AddDim(4);
    187   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
    188   s.AddDim(3);
    189   EXPECT_EQ(TensorShapeTestHelper::data_type(&s), DT_UINT16_REF);
    190 
    191   TensorShape s2 = s;
    192   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
    193   s2.RemoveDim(2);
    194   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_UINT16_REF);
    195   TensorShapeTestHelper::set_data_type(&s2, DT_FLOAT);
    196   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_FLOAT);
    197   s2.Clear();
    198   EXPECT_EQ(TensorShapeTestHelper::data_type(&s2), DT_INVALID);
    199 }
    200 
    201 TEST(TensorShapeTest, ostream) {
    202   TensorShape s({10, 5, 4});
    203   std::stringstream ss;
    204   ss << s;
    205   EXPECT_EQ(ss.str(), "[10,5,4]");
    206 }
    207 
    208 // -----------------------------------------------------------------------
    209 // An old implementation of TensorShape using a different representation,
    210 // preserved here in the unittest to allow us to have a randomized unittest
    211 // that makes sure the behavior of TensorShape and TensorShapeOld are
    212 // the same.
    213 class TensorShapeIterOld;  // Declared below
    214 
    215 /// Manages the dimensions of a Tensor and their sizes.
    216 class TensorShapeOld {
    217  public:
    218   /// \brief Construct a `TensorShape` from the provided sizes.
    219   /// REQUIRES: `dim_sizes[i] >= 0`
    220   explicit TensorShapeOld(gtl::ArraySlice<int64> dim_sizes);
    221   TensorShapeOld(std::initializer_list<int64> dim_sizes)
    222       : TensorShapeOld(gtl::ArraySlice<int64>(dim_sizes)) {}
    223 
    224   /// REQUIRES: `IsValid(proto)`
    225   explicit TensorShapeOld(const TensorShapeProto& proto);
    226 
    227   /// Create a tensor shape with no dimensions and one element, which you can
    228   /// then call `AddDim()` on.
    229   TensorShapeOld();
    230 
    231   /// Returns `true` iff `proto` is a valid tensor shape.
    232   static bool IsValid(const TensorShapeProto& proto);
    233 
    234   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
    235   /// status otherwise.
    236   static Status IsValidShape(const TensorShapeProto& proto);
    237 
    238   /// Clear a tensor shape
    239   void Clear();
    240 
    241   /// \brief Add a dimension to the end ("inner-most").
    242   /// REQUIRES: `size >= 0`
    243   void AddDim(int64 size);
    244 
    245   /// Appends all the dimensions from `shape`.
    246   void AppendShape(const TensorShapeOld& shape);
    247 
    248   /// \brief Insert a dimension somewhere in the `TensorShape`.
    249   /// REQUIRES: `0 <= d <= dims()`
    250   /// REQUIRES: `size >= 0`
    251   void InsertDim(int d, int64 size);
    252 
    253   /// \brief Modifies the size of the dimension `d` to be `size`
    254   /// REQUIRES: `0 <= d < dims()`
    255   /// REQUIRES: `size >= 0`
    256   void set_dim(int d, int64 size);
    257 
    258   /// \brief Removes dimension `d` from the `TensorShape`.
    259   /// REQUIRES: `0 <= d < dims()`
    260   void RemoveDim(int d);
    261 
    262   /// Return the number of dimensions in the tensor.
    263   int dims() const { return dim_sizes_.size(); }
    264 
    265   /// \brief Returns the number of elements in dimension `d`.
    266   /// REQUIRES: `0 <= d < dims()`
    267   // TODO(touts): Rename to `dimension()` to match
    268   // `Eigen::Tensor::dimension()`?
    269   int64 dim_size(int d) const {
    270     DCHECK_GE(d, 0);
    271     DCHECK_LT(d, dims());
    272     return dim_sizes_[d];
    273   }
    274 
    275   /// Returns sizes of all dimensions.
    276   gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
    277 
    278   /// \brief Returns the number of elements in the tensor.
    279   ///
    280   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
    281   /// which uses `ptrdiff_t`.
    282   int64 num_elements() const { return num_elements_; }
    283 
    284   /// Returns true if `*this` and `b` have the same sizes. Ignores
    285   /// dimension names.
    286   bool IsSameSize(const TensorShapeOld& b) const;
    287   bool operator==(const TensorShapeOld& b) const { return IsSameSize(b); }
    288 
    289   /// Fill `*proto` from `*this`.
    290   void AsProto(TensorShapeProto* proto) const;
    291 
    292   /// Fill `*dsizes` from `*this`.
    293   template <int NDIMS>
    294   Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
    295 
    296   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
    297   /// which case we pad the rest of the sizes with 1.
    298   template <int NDIMS>
    299   Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
    300 
    301   /// For iterating through the dimensions.
    302   TensorShapeIterOld begin() const;
    303   TensorShapeIterOld end() const;
    304 
    305   /// For error messages.
    306   string DebugString() const;
    307 
    308   /// Same as `TensorShape(proto).DebugString()` but doesn't crash for
    309   /// invalid protos.
    310   static string DebugString(const TensorShapeProto& proto);
    311 
    312  private:
    313   // Recalculates the dimensions of this tensor after they are modified.
    314   void recompute_dims();
    315 
    316   // TODO(josh11b): Maybe use something from the Eigen Tensor library
    317   // for the sizes.
    318   gtl::InlinedVector<int64, 4> dim_sizes_;
    319 
    320   // total number of elements (avoids recomputing it each time).
    321   int64 num_elements_;
    322 };
    323 
    324 struct TensorShapeDimOld {
    325   explicit TensorShapeDimOld(int64 s) : size(s) {}
    326   int64 size;
    327 };
    328 
    329 class TensorShapeIterOld {
    330  public:
    331   TensorShapeIterOld(const TensorShapeOld* shape, int d)
    332       : shape_(shape), d_(d) {}
    333   bool operator==(const TensorShapeIterOld& rhs) {
    334     DCHECK(shape_ == rhs.shape_);
    335     return d_ == rhs.d_;
    336   }
    337   bool operator!=(const TensorShapeIterOld& rhs) {
    338     DCHECK(shape_ == rhs.shape_);
    339     return d_ != rhs.d_;
    340   }
    341   void operator++() { ++d_; }
    342   TensorShapeDimOld operator*() {
    343     return TensorShapeDimOld(shape_->dim_size(d_));
    344   }
    345 
    346  private:
    347   const TensorShapeOld* shape_;
    348   int d_;
    349 };
    350 
    351 // An upper limit of the total number of elements in a tensor.
    352 static const int64 kMaxElements = (1LL << 40);
    353 
    354 bool TensorShapeOld::IsValid(const TensorShapeProto& proto) {
    355   int64 num_elements = 1;
    356   for (const auto& d : proto.dim()) {
    357     if (d.size() < 0) return false;
    358     num_elements *= d.size();
    359     if (num_elements > kMaxElements) return false;
    360   }
    361   return true;
    362 }
    363 
    364 Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) {
    365   int64 num_elements = 1;
    366   for (const auto& d : proto.dim()) {
    367     if (d.size() < 0) {
    368       return errors::InvalidArgument("Shape ", DebugString(proto),
    369                                      " has negative dimensions; ",
    370                                      "perhaps an un-fed placeholder?");
    371     }
    372     num_elements *= d.size();
    373     if (num_elements > kMaxElements) {
    374       return errors::InvalidArgument("Shape ", DebugString(proto),
    375                                      " is too large (more than ", kMaxElements,
    376                                      " entries)");
    377     }
    378   }
    379   return Status::OK();
    380 }
    381 
    382 TensorShapeOld::TensorShapeOld(const TensorShapeProto& proto) {
    383   dim_sizes_.reserve(proto.dim_size());
    384   num_elements_ = 1;
    385   for (const auto& d : proto.dim()) {
    386     AddDim(d.size());
    387   }
    388 }
    389 
    390 TensorShapeOld::TensorShapeOld(gtl::ArraySlice<int64> dim_sizes) {
    391   dim_sizes_.reserve(dim_sizes.size());
    392   num_elements_ = 1;
    393   for (auto s : dim_sizes) {
    394     AddDim(s);
    395   }
    396 }
    397 
    398 TensorShapeOld::TensorShapeOld() : num_elements_(1) {}
    399 
    400 void TensorShapeOld::Clear() {
    401   dim_sizes_.clear();
    402   num_elements_ = 1;
    403 }
    404 
    405 void TensorShapeOld::AddDim(int64 size) {
    406   CHECK_GE(size, 0);
    407   dim_sizes_.push_back(size);
    408   num_elements_ *= size;
    409   CHECK_LE(0, num_elements_);
    410   CHECK_LE(num_elements_, kMaxElements);
    411 }
    412 
    413 void TensorShapeOld::AppendShape(const TensorShapeOld& shape) {
    414   for (auto d : shape) AddDim(d.size);
    415 }
    416 
    417 void TensorShapeOld::InsertDim(int d, int64 size) {
    418   CHECK_GE(d, 0);
    419   CHECK_LE(d, dims());
    420   CHECK_GE(size, 0);
    421   dim_sizes_.insert(dim_sizes_.begin() + d, size);
    422   num_elements_ *= size;
    423   CHECK_LE(0, num_elements_);
    424   CHECK_LE(num_elements_, kMaxElements);
    425 }
    426 
    427 void TensorShapeOld::set_dim(int d, int64 size) {
    428   CHECK_GE(d, 0);
    429   CHECK_LT(d, dims());
    430   CHECK_GE(size, 0);
    431 
    432   // Update the number of elements. num_elements_ is int64.
    433   dim_sizes_[d] = size;
    434   recompute_dims();
    435 }
    436 
    437 void TensorShapeOld::RemoveDim(int d) {
    438   CHECK_GE(d, 0);
    439   CHECK_LT(d, dims());
    440 
    441   // Update the number of elements and remove the dimension from the
    442   // sizes.
    443   dim_sizes_.erase(dim_sizes_.begin() + d);
    444   recompute_dims();
    445 }
    446 
    447 void TensorShapeOld::recompute_dims() {
    448   num_elements_ = 1;
    449   for (auto s : dim_sizes_) {
    450     num_elements_ *= s;
    451     CHECK_LE(0, num_elements_);
    452     CHECK_LE(num_elements_, kMaxElements);
    453   }
    454 }
    455 
    456 bool TensorShapeOld::IsSameSize(const TensorShapeOld& b) const {
    457   if (b.dims() != dims()) return false;
    458   for (int d = 0; d < dims(); d++) {
    459     if (dim_size(d) != b.dim_size(d)) return false;
    460   }
    461   return true;
    462 }
    463 
    464 void TensorShapeOld::AsProto(TensorShapeProto* proto) const {
    465   proto->Clear();
    466   for (size_t d = 0; d < dim_sizes_.size(); ++d) {
    467     auto* dim = proto->add_dim();
    468     dim->set_size(dim_sizes_[d]);
    469   }
    470 }
    471 
    472 TensorShapeIterOld TensorShapeOld::begin() const {
    473   return TensorShapeIterOld(this, 0);
    474 }
    475 
    476 TensorShapeIterOld TensorShapeOld::end() const {
    477   return TensorShapeIterOld(this, dims());
    478 }
    479 
    480 string TensorShapeOld::DebugString() const {
    481   return strings::StrCat(
    482       "[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]");
    483 }
    484 
    485 string TensorShapeOld::DebugString(const TensorShapeProto& proto) {
    486   string s = "[";
    487   bool first = true;
    488   for (const auto& d : proto.dim()) {
    489     strings::StrAppend(&s, first ? "" : ",", d.size());
    490     first = false;
    491   }
    492   strings::StrAppend(&s, "]");
    493   return s;
    494 }
    495 // End of old implementation
    496 // ------------------------------------------------------------------------
    497 
    498 static int64 SkewedSize(random::SimplePhilox* gen, int64 current_elements) {
    499   int64 result = 0;
    500   do {
    501     if (current_elements < 100) {
    502       result = gen->Uniform(100000);
    503     } else {
    504       result = gen->Uniform(2);
    505     }
    506   } while ((result * current_elements >= 1LL << 34) ||
    507            (result * current_elements < 0));
    508   return result;
    509 }
    510 
    511 TEST(TensorShapeTest, Randomized) {
    512   // We do a randomized test to verify that the behavior of the
    513   // TensorShape implementation (which changes representations depending
    514   // on the values) is identical to our older, more straightforward (but
    515   // more memory hungry) implementation (TensorShapeOld).
    516   random::PhiloxRandom philox(7, 7);
    517   random::SimplePhilox gen(&philox);
    518   TensorShape s;
    519   TensorShapeOld sold;
    520   TensorShapeProto sp;
    521   TensorShapeProto spold;
    522   LOG(INFO) << "Sizes: " << sizeof(TensorShape) << " vs "
    523             << sizeof(TensorShapeOld);
    524   for (int i = 0; i < 100000; i++) {
    525     s.AsProto(&sp);
    526     sold.AsProto(&spold);
    527     EXPECT_EQ(sp.DebugString(), spold.DebugString());
    528     if ((i % 1000) == 0) {
    529       fprintf(stderr, "ITERATION %d: %s\n", i, sp.DebugString().c_str());
    530     }
    531     EXPECT_EQ(s.num_elements(), sold.num_elements());
    532 
    533     // Test moves.
    534     TensorShape copy = s;
    535     TensorShape moved(std::move(copy));
    536     EXPECT_EQ(s, moved);
    537     copy = s;
    538     moved = std::move(copy);
    539     EXPECT_EQ(s, moved);
    540 
    541     int64 ne = sold.num_elements();
    542     int r = gen.Uniform(100);
    543     if (r < 10) {
    544       int64 sz = SkewedSize(&gen, sold.num_elements());
    545       s.AddDim(sz);
    546       sold.AddDim(sz);
    547     } else if (r < 15) {
    548       s.Clear();
    549       sold.Clear();
    550     } else if (r < 35 && s.dims() > 0 && ne > 0 && ne < 100000000) {
    551       int dim = gen.Uniform(s.dims());
    552       s.RemoveDim(dim);
    553       sold.RemoveDim(dim);
    554     } else if (r < 50 && ne > 0 && ne < 100000000) {
    555       int dim = gen.Uniform(s.dims() + 1);
    556       int64 sz = SkewedSize(&gen, sold.num_elements());
    557       s.InsertDim(dim, sz);
    558       sold.InsertDim(dim, sz);
    559     } else {
    560       std::vector<int64> sizes;
    561       const int N = (gen.Uniform(4) == 0) ? gen.Uniform(10) : gen.Uniform(3);
    562       int64 num_elements = 1;
    563       for (int i = 0; i < N; i++) {
    564         int64 sz = SkewedSize(&gen, num_elements);
    565         sizes.push_back(sz);
    566         num_elements *= std::max<int64>(1, sz);
    567       }
    568 
    569       s = TensorShape(sizes);
    570       sold = TensorShapeOld(sizes);
    571     }
    572   }
    573 }
    574 
    575 TEST(TensorShapeTest, Large) {
    576   // We used to cap shapes at 2**40 elements.  Ensure the
    577   // bound is now higher.
    578   int64 one = 1;
    579   int64 max = std::numeric_limits<int64>::max();
    580   EXPECT_EQ(TensorShape({max}).num_elements(), max);
    581   EXPECT_EQ(TensorShape({1, max}).num_elements(), max);
    582   EXPECT_EQ(TensorShape({max, 1}).num_elements(), max);
    583   EXPECT_EQ(TensorShape({one << 62}).num_elements(), one << 62);
    584   EXPECT_EQ(TensorShape({one << 20, one << 41}).num_elements(), one << 61);
    585   EXPECT_EQ(TensorShape({1000, 1000, 1000, 1000, 1000, 1000}).num_elements(),
    586             1e18);
    587 }
    588 
    589 TEST(TensorShapeTest, Overflow) {
    590   int64 one = 1;
    591   std::vector<std::vector<int64>> overflows = {
    592       {1 << 30, 1 << 30, 1 << 30},
    593       {1 << 5, (one << 60) + 1},
    594   };
    595   for (const auto& overflow : overflows) {
    596     TensorShapeProto proto;
    597     for (auto dim : overflow) {
    598       proto.add_dim()->set_size(dim);
    599     }
    600     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
    601               TensorShape::IsValidShape(proto).code());
    602     TensorShape shape;
    603     EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT,
    604               TensorShapeUtils::MakeShape(overflow, &shape).code());
    605   }
    606 }
    607 
    608 TEST(TensorShapeTest, UnknownRank) {
    609   // NOTE(irving): Unfortunately, for historical reasons we have to allow an
    610   // TensorShapeProto with unknown_rank() set to be parsed as a TensorShape.
    611   // Would be nice to tighten this, but it's tricky given backwards
    612   // compatibility requirements.
    613   TensorShapeProto proto;
    614   proto.set_unknown_rank(true);
    615   EXPECT_TRUE(TensorShape::IsValid(proto));
    616   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
    617   EXPECT_EQ(TensorShape(), TensorShape(proto));
    618 
    619   proto.add_dim()->set_size(7);
    620   EXPECT_TRUE(TensorShape::IsValid(proto));
    621   TF_EXPECT_OK(TensorShape::IsValidShape(proto));
    622   EXPECT_EQ(TensorShape({7}), TensorShape(proto));
    623 }
    624 
    625 TEST(TensorShapeUtilsTest, StartsWith) {
    626   EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({}), TensorShape({})));
    627   EXPECT_TRUE(
    628       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({})));
    629   EXPECT_TRUE(
    630       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2})));
    631   EXPECT_TRUE(
    632       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 3})));
    633   EXPECT_TRUE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
    634                                            TensorShape({2, 3})));
    635   EXPECT_FALSE(
    636       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({3})));
    637   EXPECT_FALSE(
    638       TensorShapeUtils::StartsWith(TensorShape({2, 3}), TensorShape({2, 4})));
    639   EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3}),
    640                                             TensorShape({2, 3, 4})));
    641   EXPECT_FALSE(TensorShapeUtils::StartsWith(TensorShape({2, 3, 4}),
    642                                             TensorShape({3, 4})));
    643 }
    644 
    645 TEST(TensorShapeUtilsTest, EndsWith) {
    646   EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({}), TensorShape({})));
    647   EXPECT_TRUE(TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({})));
    648   EXPECT_TRUE(
    649       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({3})));
    650   EXPECT_TRUE(
    651       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3})));
    652   EXPECT_TRUE(
    653       TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({3, 4})));
    654   EXPECT_FALSE(
    655       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2})));
    656   EXPECT_FALSE(
    657       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 4})));
    658   EXPECT_FALSE(
    659       TensorShapeUtils::EndsWith(TensorShape({2, 3}), TensorShape({2, 3, 4})));
    660   EXPECT_FALSE(
    661       TensorShapeUtils::EndsWith(TensorShape({2, 3, 4}), TensorShape({2, 3})));
    662 }
    663 
    664 // A few different test cases for tensor sizes for benchmarks
    665 static std::vector<int64> MakeSizes(int arg) {
    666   std::vector<int64> sizes;
    667   switch (arg) {
    668     case 0:
    669       sizes = {100};
    670       break;
    671     case 1:
    672       sizes = {100, 1000};
    673       break;
    674     case 2:
    675       sizes = {100, 1000000};
    676       break;
    677     case 3:
    678       sizes = {100, 256, 192, 3};
    679       break;
    680     case 4:
    681       sizes = {1, 2, 1ll << 34, 1, 1, 1};
    682       break;
    683   }
    684   return sizes;
    685 }
    686 
    687 static void BM_TensorShape_Init(int iters, int arg) {
    688   auto sizes = MakeSizes(arg);
    689   while (--iters > 0) {
    690     TensorShape shape(sizes);
    691     tensorflow::testing::DoNotOptimize(shape.num_elements());
    692   }
    693 }
    694 BENCHMARK(BM_TensorShape_Init)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
    695 
    696 static void BM_TensorShape_Assign(int iters, int arg) {
    697   TensorShape s(MakeSizes(arg));
    698   while (--iters > 0) {
    699     TensorShape s2 = s;
    700   }
    701 }
    702 BENCHMARK(BM_TensorShape_Assign)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
    703 
    704 }  // namespace
    705 }  // namespace tensorflow
    706