Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/tensor_shape.h"
     17 
     18 #include "tensorflow/core/framework/tensor_shape.pb.h"
     19 #include "tensorflow/core/kernels/bounds_check.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 #include "tensorflow/core/lib/strings/str_util.h"
     22 #include "tensorflow/core/lib/strings/strcat.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 #include "tensorflow/core/util/overflow.h"
     25 
     26 namespace tensorflow {
     27 
     28 // TensorShape and PartialTensorShape should have no fields beyond
     29 // TensorShapeRep.  In particular, their sizes should be the same.
     30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
     31               "TensorShape must have no fields beyond TensorShapeRep");
     32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
     33               "PartialTensorShape must have no fields beyond TensorShapeRep");
     34 
     35 template <class Shape>
     36 static void AppendTo(const TensorShapeBase<Shape>& s,
     37                      gtl::InlinedVector<int64, 8>* vals) {
     38   for (auto dim : s) {
     39     vals->push_back(dim.size);
     40   }
     41 }
     42 
     43 void TensorShape::CheckDimsEqual(int NDIMS) const {
     44   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
     45                           << " from a tensor of " << dims() << " dimensions";
     46 }
     47 
     48 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
     49   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
     50                           << " dimensions from a tensor of " << dims()
     51                           << " dimensions";
     52 }
     53 
     54 template <class Shape>
     55 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
     56   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
     57   // unknown_shape() set, and it seems hard to remove this without backwards
     58   // compatibility issues.
     59   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
     60   int64 num_elements = 1;
     61   if (proto.dim().size() > MaxDimensions()) return false;
     62   for (const auto& d : proto.dim()) {
     63     if (d.size() < (kIsPartial ? -1 : 0)) return false;
     64     if (d.size() == -1) {
     65       num_elements = -1;
     66     } else if (!kIsPartial || num_elements >= 0) {
     67       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
     68       if (num_elements < 0) return false;
     69     }
     70   }
     71   return true;
     72 }
     73 
     74 template <class Shape>
     75 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
     76   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
     77   // unknown_shape() set, and it seems hard to remove this without backwards
     78   // compatibility issues.
     79   if (kIsPartial && proto.unknown_rank()) {
     80     if (proto.dim_size() > 0) {
     81       return errors::InvalidArgument(
     82           "An unknown shape must not have any dimensions set.");
     83     }
     84     return Status::OK();
     85   }
     86   int64 num_elements = 1;
     87   if (proto.dim().size() > MaxDimensions()) {
     88     return errors::InvalidArgument("Shape ", DebugString(proto),
     89                                    " has too many dimensions");
     90   }
     91   for (const auto& d : proto.dim()) {
     92     if (d.size() < (kIsPartial ? -1 : 0)) {
     93       if (kIsPartial) {
     94         return errors::InvalidArgument(
     95             "Shape ", DebugString(proto),
     96             " has dimensions with values below -1 (where -1 means unknown)");
     97       } else {
     98         return errors::InvalidArgument("Shape ", DebugString(proto),
     99                                        " is not fully defined");
    100       }
    101     }
    102     if (d.size() == -1) {
    103       num_elements = -1;
    104     } else if (!kIsPartial || num_elements >= 0) {
    105       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
    106       if (num_elements < 0) {
    107         return errors::InvalidArgument(
    108             "Shape ", DebugString(proto),
    109             " is too large (more than 2**63 - 1 entries)");
    110       }
    111     }
    112   }
    113   return Status::OK();
    114 }
    115 
    116 template <class Shape>
    117 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
    118   set_tag(REP16);
    119   set_data_type(DT_INVALID);
    120   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
    121   // unknown_shape() set, and it seems hard to remove this without backwards
    122   // compatibility issues.
    123   if (kIsPartial && proto.unknown_rank()) {
    124     set_ndims_byte(kUnknownRank);
    125     set_num_elements(-1);
    126   } else {
    127     set_ndims_byte(0);
    128     set_num_elements(1);
    129     for (const auto& d : proto.dim()) {
    130       AddDim(d.size());
    131     }
    132   }
    133 }
    134 
    135 template <class Shape>
    136 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
    137   set_tag(REP16);
    138   set_data_type(DT_INVALID);
    139   set_ndims_byte(0);
    140   set_num_elements(1);
    141   for (int64 s : dim_sizes) {
    142     AddDim(internal::SubtleMustCopy(s));
    143   }
    144 }
    145 
    146 template <class Shape>
    147 TensorShapeBase<Shape>::TensorShapeBase() {
    148   set_tag(REP16);
    149   set_data_type(DT_INVALID);
    150   if (kIsPartial) {
    151     set_ndims_byte(kUnknownRank);
    152     set_num_elements(-1);
    153   } else {
    154     set_ndims_byte(0);
    155     set_num_elements(1);
    156   }
    157 }
    158 
    159 void TensorShapeRep::DestructorOutOfLine() {
    160   DCHECK(tag() == REP_OUT_OF_LINE);
    161   delete as64()->dims_;
    162 }
    163 
    164 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
    165   if (b.tag() != REP_OUT_OF_LINE) {
    166     if (tag() == REP_OUT_OF_LINE) {
    167       delete as64()->dims_;
    168     }
    169     memcpy(buf(), b.buf(), sizeof(u_.buf));
    170     // memcpy above implicitly also does:
    171     //   set_tag(b.tag());
    172     //   set_ndims_byte(b.ndims_byte());
    173     //   set_data_type(b.data_type());
    174   } else {
    175     DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
    176     set_ndims_byte(b.ndims_byte());
    177     set_data_type(b.data_type());
    178     if (tag() == REP_OUT_OF_LINE) {
    179       // vector already allocated
    180       *(as64()->dims_) = *(b.as64()->dims_);
    181     } else {
    182       set_tag(REP_OUT_OF_LINE);
    183       as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
    184     }
    185   }
    186 }
    187 
    188 template <class Shape>
    189 int64 TensorShapeBase<Shape>::dim_size(int d) const {
    190   if (unknown_rank()) return -1;
    191   DCHECK_GE(d, 0);
    192   DCHECK_LT(d, dims());
    193   if (tag() == REP16) {
    194     uint16 dim = as16()->dims_[d];
    195     if (kIsPartial && dim == kUnknownRep16) return -1;
    196     return dim;
    197   } else if (tag() == REP32) {
    198     uint32 dim = as32()->dims_[d];
    199     if (kIsPartial && dim == kUnknownRep32) return -1;
    200     return dim;
    201   } else {
    202     return (*as64()->dims_)[d];
    203   }
    204 }
    205 
    206 void TensorShapeRep::Clear() {
    207   ClearAllButDataType();
    208   set_data_type(DT_INVALID);
    209 }
    210 
    211 void TensorShapeRep::ClearAllButDataType() {
    212   if (tag() == REP_OUT_OF_LINE) {
    213     delete as64()->dims_;
    214   }
    215   set_tag(REP16);
    216   set_ndims_byte(0);
    217   // Leaves data_type alone
    218   set_num_elements(1);
    219 }
    220 
    221 template <class Shape>
    222 void TensorShapeBase<Shape>::RecomputeNumElements() {
    223   if (unknown_rank()) {
    224     set_num_elements(-1);
    225     return;
    226   }
    227   int64 n = 1;
    228   for (auto dim : *this) {
    229     if (kIsPartial && dim.size < 0) {
    230       n = -1;
    231       break;
    232     }
    233     n = MultiplyWithoutOverflow(n, dim.size);
    234     CHECK_LE(0, n);
    235   }
    236   set_num_elements(n);
    237 }
    238 
    239 template <class Shape>
    240 void TensorShapeBase<Shape>::AddDim(int64 size) {
    241   if (!kIsPartial) CHECK_GE(size, 0);
    242   if (unknown_rank()) return;
    243   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
    244   int64 new_num_elements;
    245   if (kIsPartial && (num_elements() < 0 || size < 0)) {
    246     new_num_elements = -1;
    247   } else {
    248     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
    249     CHECK_LE(0, new_num_elements);
    250   }
    251   UnsafeAddDim(size, new_num_elements);
    252 }
    253 
    254 template <class Shape>
    255 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
    256   const int nd = ndims_byte();
    257   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
    258     as16()->dims_[nd] =
    259         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
    260   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
    261     as32()->dims_[nd] =
    262         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
    263   } else if (tag() == REP_OUT_OF_LINE) {
    264     as64()->dims_->push_back(size);
    265   } else {
    266     // Need to change representation
    267     gtl::InlinedVector<int64, 8> vals;
    268     AppendTo(*this, &vals);
    269     vals.push_back(size);
    270     // We know we can't be REP16.  See if we have a small enough
    271     // number of dimensions and each dimension's size is small enough
    272     // to allow REP32.
    273     bool can_be_rep32 = (vals.size() <= 3);
    274     if (can_be_rep32) {
    275       for (size_t i = 0; i < vals.size(); i++) {
    276         if (vals[i] >= kMaxRep32) {
    277           can_be_rep32 = false;
    278           break;
    279         }
    280       }
    281     }
    282     if (can_be_rep32) {
    283       set_tag(REP32);
    284       for (size_t d = 0; d < vals.size(); d++) {
    285         as32()->dims_[d] = kIsPartial && vals[d] < 0
    286                                ? kUnknownRep32
    287                                : static_cast<uint32>(vals[d]);
    288       }
    289     } else {
    290       set_tag(REP_OUT_OF_LINE);
    291       as64()->dims_ =
    292           new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
    293     }
    294   }
    295   set_ndims_byte(nd + 1);
    296   set_num_elements(new_num_elements);
    297 }
    298 
    299 template <class Shape>
    300 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
    301   for (auto d : shape) AddDim(d.size);
    302 }
    303 
    304 template <class Shape>
    305 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
    306   CHECK_GE(d, 0);
    307   CHECK_LE(d, dims());
    308   if (!kIsPartial) CHECK_GE(size, 0);
    309   CHECK_LT(dims(), MaxDimensions());
    310   gtl::InlinedVector<int64, 8> vals;
    311   AppendTo(*this, &vals);
    312   vals.insert(vals.begin() + d, size);
    313   ClearAllButDataType();
    314   for (auto dval : vals) {
    315     AddDim(dval);
    316   }
    317 }
    318 
    319 template <class Shape>
    320 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
    321   gtl::InlinedVector<int64, 4> result;
    322   for (auto dim : *this) {
    323     result.push_back(dim.size);
    324   }
    325   return result;
    326 }
    327 
    328 template <class Shape>
    329 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
    330   CHECK_GE(d, 0);
    331   CHECK_LT(d, dims());
    332   CHECK_GE(size, 0);
    333   if (tag() == REP16 && size < kMaxRep16) {
    334     as16()->dims_[d] =
    335         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
    336   } else if (tag() == REP32 && size < kMaxRep32) {
    337     as32()->dims_[d] =
    338         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
    339   } else if (tag() == REP_OUT_OF_LINE) {
    340     (*as64()->dims_)[d] = size;
    341   } else {
    342     // Must upgrade
    343     gtl::InlinedVector<int64, 8> vals;
    344     AppendTo(*this, &vals);
    345     vals[d] = size;
    346     ClearAllButDataType();
    347     for (auto dval : vals) {
    348       AddDim(dval);
    349     }
    350   }
    351   RecomputeNumElements();
    352 }
    353 
    354 template <class Shape>
    355 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
    356   if (unknown_rank()) return;
    357   begin = begin < 0 ? dims() + begin + 1 : begin;
    358   end = end < 0 ? dims() + end + 1 : end;
    359   CHECK_GE(begin, 0);
    360   CHECK_LE(begin, dims());
    361   CHECK_GE(end, 0);
    362   CHECK_LE(end, dims());
    363   if (begin >= end) return;
    364   gtl::InlinedVector<int64, 8> vals;
    365   AppendTo(*this, &vals);
    366   vals.erase(vals.begin() + begin, vals.begin() + end);
    367   ClearAllButDataType();
    368   for (auto dval : vals) {
    369     AddDim(dval);
    370   }
    371   RecomputeNumElements();
    372 }
    373 
    374 bool TensorShape::IsSameSize(const TensorShape& b) const {
    375   if (b.dims() != dims()) return false;
    376   for (int d = 0; d < dims(); d++) {
    377     if (dim_size(d) != b.dim_size(d)) return false;
    378   }
    379   return true;
    380 }
    381 
    382 template <class Shape>
    383 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
    384   proto->Clear();
    385   if (unknown_rank()) {
    386     proto->set_unknown_rank(true);
    387   } else {
    388     for (int i = 0; i < dims(); i++) {
    389       proto->add_dim()->set_size(dim_size(i));
    390     }
    391   }
    392 }
    393 
    394 void TensorShapeRep::DumpRep() const {
    395 #if 0
    396   fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
    397   if (tag() == REP16) {
    398     fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
    399     for (int i = 0; i < ndims_byte(); i++) {
    400       fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
    401     }
    402   } else if (tag_ == REP32) {
    403     fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
    404     for (int i = 0; i < ndims_byte(); i++) {
    405       fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
    406     }
    407   } else if (tag_ == REP_OUT_OF_LINE) {
    408     fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
    409     for (int i = 0; i < ndims_byte(); i++) {
    410       fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
    411     }
    412   }
    413 #endif
    414 }
    415 
    416 template <class Shape>
    417 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
    418   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
    419 }
    420 
    421 template <class Shape>
    422 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
    423   CHECK(!unknown_rank());
    424   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
    425 }
    426 
    427 string TensorShapeRep::DebugString() const {
    428   const auto& shape = *static_cast<const PartialTensorShape*>(this);
    429   if (shape.unknown_rank()) return "<unknown>";
    430   string s = "[";
    431   for (int i = 0; i < shape.dims(); i++) {
    432     if (i > 0) strings::StrAppend(&s, ",");
    433     int64 dim = shape.dim_size(i);
    434     if (dim < 0) {
    435       strings::StrAppend(&s, "?");
    436     } else {
    437       strings::StrAppend(&s, dim);
    438     }
    439   }
    440   strings::StrAppend(&s, "]");
    441   return s;
    442 }
    443 
    444 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
    445   string s;
    446   if (proto.unknown_rank()) {
    447     strings::StrAppend(&s, "<unknown>");
    448     if (proto.dim_size() == 0) return s;
    449   }
    450   strings::StrAppend(&s, "[");
    451   bool first = true;
    452   for (const auto& d : proto.dim()) {
    453     if (!first) strings::StrAppend(&s, ",");
    454     if (d.size() == -1) {
    455       strings::StrAppend(&s, "?");
    456     } else {
    457       strings::StrAppend(&s, d.size());
    458     }
    459     first = false;
    460   }
    461   strings::StrAppend(&s, "]");
    462   return s;
    463 }
    464 
    465 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
    466                                   const TensorShape& prefix) {
    467   if (shape.dims() < prefix.dims()) return false;
    468   for (int i = 0; i < prefix.dims(); ++i) {
    469     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
    470   }
    471   return true;
    472 }
    473 
    474 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
    475                                 const TensorShape& suffix) {
    476   const int suffix_size = suffix.dims();
    477   if (shape.dims() < suffix_size) return false;
    478   for (int i = 0; i < suffix_size; ++i) {
    479     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
    480       return false;
    481     }
    482   }
    483   return true;
    484 }
    485 
    486 template <typename T, class Shape>
    487 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
    488   out->Clear();
    489   if (n > TensorShape::MaxDimensions()) {
    490     return errors::InvalidArgument("Too many dimensions");
    491   }
    492   if (n < 0) {
    493     return errors::InvalidArgument("Negative number of dimensions ", n);
    494   }
    495   for (int64 i = 0; i < n; ++i) {
    496     T dim = internal::SubtleMustCopy(dims[i]);
    497     int64 new_num_elements;
    498     if (dim < 0) {
    499       if (!out->kIsPartial) {
    500         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
    501       }
    502       if (dim < -1) {
    503         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
    504       }
    505       dim = -1;
    506       new_num_elements = -1;
    507     } else if (out->num_elements() < 0) {
    508       new_num_elements = -1;
    509     } else {
    510       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
    511       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
    512         TensorShapeProto proto;
    513         for (int64 j = 0; j < n; ++j) {
    514           proto.add_dim()->set_size(dim);
    515         }
    516         return errors::InvalidArgument(
    517             "Shape ", TensorShape::DebugString(proto),
    518             " would have more than 2**63 - 1 elements");
    519       }
    520     }
    521     out->UnsafeAddDim(dim, new_num_elements);
    522   }
    523   return Status::OK();
    524 }
    525 
    526 #define MAKE_SHAPE(T, Shape)                                                 \
    527   Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) {   \
    528     return MakeShapeHelper(dims, n, out);                                    \
    529   }                                                                          \
    530   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
    531     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
    532   }
    533 MAKE_SHAPE(int32, TensorShape)
    534 MAKE_SHAPE(int64, TensorShape)
    535 MAKE_SHAPE(int32, PartialTensorShape)
    536 MAKE_SHAPE(int64, PartialTensorShape)
    537 #undef MAKE_SHAPE
    538 
    539 string TensorShapeUtils::ShapeListString(
    540     const gtl::ArraySlice<TensorShape>& shapes) {
    541   string result = "[";
    542   bool first = true;
    543   for (const TensorShape& shape : shapes) {
    544     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
    545     first = false;
    546   }
    547   strings::StrAppend(&result, "]");
    548   return result;
    549 }
    550 
    551 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
    552   PartialTensorShape out = *this;
    553   out.AddDim(size);
    554   return out;
    555 }
    556 
    557 PartialTensorShape PartialTensorShape::Concatenate(
    558     const PartialTensorShape& shape) const {
    559   if (unknown_rank() || shape.unknown_rank()) {
    560     return PartialTensorShape();
    561   }
    562   PartialTensorShape out = *this;
    563   for (auto dim : shape) out.AddDim(dim.size);
    564   return out;
    565 }
    566 
    567 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
    568                                      PartialTensorShape* result) const {
    569   if (unknown_rank()) {
    570     *result = shape;
    571     return Status::OK();
    572   }
    573   if (shape.unknown_rank()) {
    574     *result = *this;
    575     return Status::OK();
    576   }
    577   const int dims_ = dims();
    578   if (dims_ != shape.dims()) {
    579     return errors::InvalidArgument(
    580         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
    581         shape.dims());
    582   }
    583   CHECK(result != this);
    584   result->Clear();
    585   for (int i = 0; i < dims_; ++i) {
    586     const int64 dim0 = dim_size(i);
    587     const int64 dim1 = shape.dim_size(i);
    588     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
    589       return errors::InvalidArgument(
    590           "PartialTensorShape: Incompatible shapes during merge: ",
    591           DebugString(), " vs. ", shape.DebugString());
    592     }
    593     result->AddDim(dim0 >= 0 ? dim0 : dim1);
    594   }
    595   return Status::OK();
    596 }
    597 
    598 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
    599   if (IsFullyDefined()) {
    600     const TensorShapeRep* rep = this;
    601     *shape = *static_cast<const TensorShape*>(rep);
    602     return true;
    603   }
    604   return false;
    605 }
    606 
    607 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
    608   if (unknown_rank() || shape.unknown_rank()) {
    609     return unknown_rank() == shape.unknown_rank();
    610   }
    611   if (dims() != shape.dims()) return false;
    612   for (int i = 0; i < dims(); i++) {
    613     if (dim_size(i) != shape.dim_size(i)) return false;
    614   }
    615   return true;
    616 }
    617 
    618 bool PartialTensorShape::IsCompatibleWith(
    619     const PartialTensorShape& shape) const {
    620   if (unknown_rank() || shape.unknown_rank()) return true;
    621   if (dims() != shape.dims()) return false;
    622   for (int i = 0; i < dims(); i++) {
    623     const int64 dim0 = dim_size(i);
    624     const int64 dim1 = shape.dim_size(i);
    625     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
    626   }
    627   return true;
    628 }
    629 
    630 string PartialTensorShapeUtils::PartialShapeListString(
    631     const gtl::ArraySlice<PartialTensorShape>& shapes) {
    632   string result = "[";
    633   bool first = true;
    634   for (const PartialTensorShape& shape : shapes) {
    635     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
    636     first = false;
    637   }
    638   strings::StrAppend(&result, "]");
    639   return result;
    640 }
    641 
    642 bool PartialTensorShapeUtils::AreCompatible(
    643     const gtl::ArraySlice<PartialTensorShape>& shapes0,
    644     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
    645   if (shapes0.size() == shapes1.size()) {
    646     for (size_t i = 0; i < shapes0.size(); ++i) {
    647       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
    648         return false;
    649       }
    650     }
    651     return true;
    652   } else {
    653     return false;
    654   }
    655 }
    656 
    657 bool PartialTensorShapeUtils::AreIdentical(
    658     const gtl::ArraySlice<PartialTensorShape>& shapes0,
    659     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
    660   if (shapes0.size() == shapes1.size()) {
    661     for (size_t i = 0; i < shapes0.size(); ++i) {
    662       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
    663         return false;
    664       }
    665     }
    666     return true;
    667   } else {
    668     return false;
    669   }
    670 }
    671 
    672 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
    673                                      int64* num_elements) {
    674   int64 n = 1;
    675   for (auto dim : shape) {
    676     n = MultiplyWithoutOverflow(n, dim);
    677     if (n < 0) {
    678       return errors::InvalidArgument("Can't compute total size of shape [",
    679                                      str_util::Join(shape, ","),
    680                                      "]; product would overflow int64");
    681     }
    682   }
    683   *num_elements = n;
    684   return Status::OK();
    685 }
    686 
    687 template class TensorShapeBase<TensorShape>;
    688 template class TensorShapeBase<PartialTensorShape>;
    689 
    690 }  // namespace tensorflow
    691