Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
     18 
     19 #include <string>
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/types.pb.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace tensorflow {
     32 
     33 // START_SKIP_DOXYGEN
     34 template <class Shape>
     35 class TensorShapeIter;
     36 class TensorShape;
     37 class TensorShapeProto;
     38 class PartialTensorShape;
     39 // END_SKIP_DOXYGEN
     40 
     41 /// Internal representation for both TensorShape and PartialTensorShape.
     42 class TensorShapeRep {
     43  public:
     44   ~TensorShapeRep();
     45 
     46   /// Copy the specified shape
     47   TensorShapeRep(const TensorShapeRep& b);
     48   void operator=(const TensorShapeRep& b);
     49 
     50   /// Move the specified shape.  After moving, <b> is safe for destruction and
     51   // can be reassigned into, but its dimensions and number of elements can be
     52   // nonsensical (e.g., negative dimension sizes, or number of elements not
     53   // properly recomputed).
     54   TensorShapeRep(TensorShapeRep&& b);
     55   void operator=(TensorShapeRep&& b);
     56 
     57   /// Clear a tensor shape, producing the scalar shape.
     58   void Clear();
     59 
     60   // Maximum number of dimensions in a tensor.
     61   // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
     62   static constexpr int MaxDimensions() { return 254; }
     63 
     64   /// \brief Returns the number of elements in the tensor.
     65   ///
     66   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
     67   /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
     68   /// defined.
     69   int64 num_elements() const { return num_elements_; }
     70 
     71   /// For error messages.
     72   string DebugString() const;
     73   static string DebugString(const TensorShapeProto& proto);
     74 
     75   void DumpRep() const;  // XXX
     76 
     77  protected:
     78   // Constructable only via TensorShapeBase
     79   TensorShapeRep() = default;
     80 
     81   void ClearAllButDataType();
     82 
     83   // We use 16 bytes to represent a TensorShape.  Because we need to
     84   // be able to support full 64-bit dimension sizes and an arbitrary
     85   // number of dimensions for a Tensor, but most tensor dimensions are
     86   // significantly smaller than 64 bits and most tensors are 1, 2, or 3
     87   // dimensions, we have several representations.
     88   // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
     89   // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
     90   // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
     91   //        an out of line vector.
     92   // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
     93   // This value is not allowed in TensorShape either for format compatibility.
     94   struct Rep16 {
     95     uint16 dims_[6];
     96   };
     97   struct Rep32 {
     98     uint32 dims_[3];
     99   };
    100   struct Rep64 {
    101     gtl::InlinedVector<int64, 4>* dims_;
    102   };
    103 
    104   // We use the max value of uint16 or uint32 to represent unknown shapes, so
    105   // the maximum representable valid shape in these representations is one less.
    106   static const int64 kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
    107   static const int64 kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
    108   static const uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
    109   static const uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
    110 
    111   Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
    112   Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
    113   Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
    114 
    115   const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
    116   const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
    117   const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
    118 
    119   enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
    120 
    121   // Since we have a convenient extra byte available, we allow the
    122   // Tensor class to store an 8-bit value in this extra storage.  This
    123   // allows it to store the Tensor's datatype enum value here and avoid
    124   // an extra word of storage.
    125   friend class Tensor;
    126   friend class TensorShapeTestHelper;
    127   DataType data_type() const { return static_cast<DataType>(buf()[13]); }
    128   void set_data_type(DataType dt) {
    129     // We only have 8 bits available to store DataType, so make sure it fits
    130     DCHECK_LT(static_cast<uint32>(dt), 256u);
    131     buf()[13] = static_cast<uint8>(dt);
    132   }
    133 
    134   // We store the number of dimensions in byte 14, and the RepTag in byte 15.
    135   // Bytes [0..13] vary depending on the representation.
    136   // A value of 255 indicates unknown rank in the PartialTensorShape case.
    137   static const uint8 kUnknownRank = 255;
    138   uint8 ndims_byte() const { return buf()[14]; }
    139   void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
    140 
    141   RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
    142   void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
    143 
    144   void set_num_elements(int64 n) { num_elements_ = n; }
    145 
    146  private:
    147   void DestructorOutOfLine();
    148   void SlowCopyFrom(const TensorShapeRep& b);
    149 
    150   uint8* buf() { return &u_.buf[0]; }
    151   const uint8* buf() const { return &u_.buf[0]; }
    152 
    153   union {
    154     uint8 buf[16];
    155     // Force data to be aligned enough for a pointer.
    156     Rep64* unused_aligner;
    157   } u_;
    158   int64 num_elements_;
    159 };
    160 
    161 /// Base class for TensorShape and PartialTensorShape.
    162 /// The class is templatized by either TensorShape or PartialTensorShape to
    163 /// allow skipping known/unknown checks in the TensorShape case, but the
    164 /// representation is shared exactly for fast conversion.
    165 template <class Shape>
    166 class TensorShapeBase : public TensorShapeRep {
    167  public:
    168   /// \brief Construct a `TensorShapeBase` from the provided sizes.
    169   /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
    170   explicit TensorShapeBase(gtl::ArraySlice<int64> dim_sizes);
    171   TensorShapeBase(std::initializer_list<int64> dim_sizes)
    172       : TensorShapeBase(gtl::ArraySlice<int64>(dim_sizes)) {}
    173 
    174   /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
    175   TensorShapeBase();
    176 
    177   TensorShapeBase(const TensorShapeProto& proto);
    178 
    179   /// Returns `true` iff `proto` is a valid tensor shape.
    180   // For TensorShape, the proto shape must be fully defined.
    181   static bool IsValid(const TensorShapeProto& proto);
    182 
    183   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
    184   /// status otherwise.
    185   static Status IsValidShape(const TensorShapeProto& proto);
    186 
    187   /// \brief Add a dimension to the end ("inner-most").
    188   /// REQUIRES: `size >= 0`
    189   void AddDim(int64 size);
    190 
    191   /// Appends all the dimensions from `shape`.
    192   void AppendShape(const TensorShapeBase& shape);
    193 
    194   /// \brief Insert a dimension somewhere in the `TensorShape`.
    195   /// REQUIRES: `0 <= d <= dims()`
    196   /// REQUIRES: `size >= 0`
    197   void InsertDim(int d, int64 size);
    198 
    199   /// \brief Modifies the size of the dimension `d` to be `size`
    200   /// REQUIRES: `0 <= d < dims()`
    201   /// REQUIRES: `size >= 0`
    202   void set_dim(int d, int64 size);
    203 
    204   /// \brief Removes dimension `d` from the `TensorShape`.
    205   /// REQUIRES: `0 <= d < dims()`
    206   void RemoveDim(int d) {
    207     CHECK_GE(d, 0);
    208     RemoveDimRange(d, d + 1);
    209   }
    210 
    211   /// \brief Removes last `n` dimensions from the `TensorShape`.
    212   /// REQUIRES: `0 <= n <= dims()`
    213   void RemoveLastDims(int n) {
    214     CHECK_LE(n, dims());
    215     RemoveDimRange(dims() - n, dims());
    216   }
    217 
    218   /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
    219   /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
    220   /// Python). The same is true for negative values of `begin`. REQUIRES:
    221   /// `-(dims()+1) <= begin <= dims()` REQUIRES: `-(dims()+1) <= end <= dims()`
    222   void RemoveDimRange(int begin, int end);
    223 
    224   /// Return whether the rank is unknown
    225   bool unknown_rank() const {
    226     return kIsPartial && ndims_byte() == kUnknownRank;
    227   }
    228 
    229   /// Return the number of dimensions in the tensor.
    230   /// Can be -1 meaning unknown rank for PartialTensorShape.
    231   int dims() const {
    232     uint8 dims = ndims_byte();
    233     return kIsPartial && dims == kUnknownRank ? -1 : dims;
    234   }
    235 
    236   /// \brief Returns the number of elements in dimension `d`.
    237   /// REQUIRES: `0 <= d < dims()`
    238   // TODO(touts): Rename to `dimension()` to match
    239   // `Eigen::Tensor::dimension()`?
    240   int64 dim_size(int d) const;
    241 
    242   /// Returns sizes of all dimensions.
    243   // Returns an empty list for unknown rank PartialTensorShape.
    244   gtl::InlinedVector<int64, 4> dim_sizes() const;
    245 
    246   /// Return true iff the rank and all of the dimensions are well defined
    247   // TODO(irving): Rename to is_fully_defined now that it's fast.
    248   bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
    249 
    250   /// Fill `*proto` from `*this`.
    251   void AsProto(TensorShapeProto* proto) const;
    252 
    253   /// For iterating through the dimensions.
    254   TensorShapeIter<Shape> begin() const;
    255   TensorShapeIter<Shape> end() const;
    256 
    257  private:
    258   void RecomputeNumElements();
    259   void InitDims(gtl::ArraySlice<int64> dim_sizes);
    260 
    261   // True for PartialTensorShape, false for TensorShape
    262   static constexpr bool kIsPartial =
    263       std::is_same<Shape, PartialTensorShape>::value;
    264   static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
    265                 "Shape is neither TensorShape nor PartialTensorShape");
    266 
    267   // Used by AddDim and MakeShapeHelper.  Does no error checking.
    268   void UnsafeAddDim(int64 size, int64 new_num_elements);
    269 
    270   // For use by TensorShapeUtils::MakeShape
    271   template <class T, class S>
    272   friend Status MakeShapeHelper(const T*, int64, S*);
    273 };
    274 
    275 /// Outputs `TensorShapeBase` to `std::ostream`.
    276 template <typename Shape>
    277 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
    278   return os << tsb.DebugString();
    279 }
    280 
    281 /// Represents the shape of a Tensor.
    282 ///
    283 /// A tensor's shape is denoted by its number of dimensions and a size for each
    284 /// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
    285 /// a shape of 2-D, [3,4].
    286 ///
    287 /// If you know the exact shape of your Tensor when you create the TensorShape
    288 /// object, you can specify it then, or you can create a TensorShape with
    289 /// zero dimensions and one element, and call AddDim() to add dimensions later.
    290 class TensorShape : public TensorShapeBase<TensorShape> {
    291  public:
    292   using TensorShapeBase<TensorShape>::TensorShapeBase;
    293 
    294   /// Allow a TensorShape to be used as a PartialTensorShape without copying
    295   operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
    296 
    297   /// Returns true if `*this` and `b` have the same sizes. Ignores
    298   /// dimension names.
    299   bool IsSameSize(const TensorShape& b) const;
    300   bool operator==(const TensorShape& b) const { return IsSameSize(b); }
    301   bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
    302 
    303   /// Fill `*dsizes` from `*this`.
    304   /// Notice: Using IndexType=int32 in combination with To32Bit() can
    305   /// significantly improve performance on GPU.
    306   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
    307   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
    308 
    309   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
    310   /// which case we pad the rest of the sizes with 1.
    311   /// Notice: Using IndexType=int32 in combination with To32Bit() can
    312   /// significantly improve performance on GPU.
    313   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
    314   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
    315 
    316  private:
    317   // These CHECK fail to ease debugging.
    318   // REQUIRES: dims() == NDIMS
    319   void CheckDimsEqual(int NDIMS) const;
    320   // REQUIRES: dims() >= NDIMS
    321   void CheckDimsAtLeast(int NDIMS) const;
    322 };
    323 
    324 /// Represents the value of one dimension in a TensorShape.
    325 struct TensorShapeDim {
    326   explicit TensorShapeDim(int64 s) : size(s) {}
    327   int64 size;
    328 };
    329 
    330 // START_SKIP_DOXYGEN
    331 template <class Shape>
    332 class TensorShapeIter {
    333  public:
    334   TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
    335   bool operator==(const TensorShapeIter& rhs) {
    336     DCHECK(shape_ == rhs.shape_);
    337     return d_ == rhs.d_;
    338   }
    339   bool operator!=(const TensorShapeIter& rhs) {
    340     DCHECK(shape_ == rhs.shape_);
    341     return d_ != rhs.d_;
    342   }
    343   void operator++() { ++d_; }
    344   TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
    345 
    346  private:
    347   const Shape* shape_;
    348   int d_;
    349 };
    350 // END_SKIP_DOXYGEN
    351 
    352 /// \brief Static helper routines for `TensorShape`. Includes a few common
    353 /// predicates on a tensor shape.
    354 class TensorShapeUtils {
    355  public:
    356   static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
    357 
    358   static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
    359 
    360   static bool IsVectorOrHigher(const TensorShape& shape) {
    361     return shape.dims() >= 1;
    362   }
    363 
    364   static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
    365 
    366   static bool IsSquareMatrix(const TensorShape& shape) {
    367     return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
    368   }
    369 
    370   static bool IsMatrixOrHigher(const TensorShape& shape) {
    371     return shape.dims() >= 2;
    372   }
    373 
    374   /// \brief Returns a `TensorShape` whose dimensions are
    375   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
    376   static Status MakeShape(const int32* dims, int64 n, TensorShape* out);
    377   static Status MakeShape(const int64* dims, int64 n, TensorShape* out);
    378   static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
    379   static Status MakeShape(gtl::ArraySlice<int64> shape, TensorShape* out);
    380   static Status MakeShape(const int32* dims, int64 n, PartialTensorShape* out);
    381   static Status MakeShape(const int64* dims, int64 n, PartialTensorShape* out);
    382   static Status MakeShape(gtl::ArraySlice<int32> shape,
    383                           PartialTensorShape* out);
    384   static Status MakeShape(gtl::ArraySlice<int64> shape,
    385                           PartialTensorShape* out);
    386 
    387   static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
    388 
    389   /// \brief Returns true iff `shape` starts with `prefix`.
    390   static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
    391 
    392   /// \brief Returns true iff `shape` ends with `suffix`.
    393   static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
    394 
    395   /// \brief Returns the product of values in an int64 array,
    396   /// or a failing Status if the array represents a value larger than
    397   /// a `TensorShape` can hold.
    398   static Status NumElements(gtl::ArraySlice<int64> shape, int64* num_elements);
    399 };
    400 
    401 /// Manages the partially known dimensions of a Tensor and their sizes.
    402 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
    403  public:
    404   PartialTensorShape() {}
    405   using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
    406 
    407   /// Add a dimension to the end ("inner-most"), returns a new
    408   /// PartialTensorShape.
    409   /// REQUIRES: `size >= -1`, where -1 means unknown
    410   PartialTensorShape Concatenate(int64 size) const;
    411 
    412   /// Appends all the dimensions from `shape`.  Returns a new
    413   /// PartialTensorShape.
    414   PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
    415 
    416   /// Merges all the dimensions from `shape`.  Returns
    417   /// `InvalidArgument` error if either `shape` has a different rank
    418   /// or if any of the dimensions are incompatible.
    419   Status MergeWith(const PartialTensorShape& shape,
    420                    PartialTensorShape* result) const;
    421 
    422   /// Exact equality test. Returns true iff the ranks match (i.e., both are
    423   /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
    424   /// both dimensions are known, or both are known and equal). This is a
    425   /// stronger condition that IsCompatibleWith.
    426   bool IsIdenticalTo(const PartialTensorShape& shape) const;
    427 
    428   /// Return true iff the ranks match, and if the
    429   /// dimensions all either match or one is unknown.
    430   bool IsCompatibleWith(const PartialTensorShape& shape) const;
    431 
    432   // Fill `*shape` from `*this`.
    433   // If `*this` is not fully defined, returns false and
    434   // `*shape` is left in an intermediate state.  Otherwise
    435   // returns true.
    436   bool AsTensorShape(TensorShape* shape) const;
    437 
    438   /// \brief Returns a `PartialTensorShape` whose dimensions are
    439   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
    440   /// considered "unknown".
    441   template <class T>
    442   static Status MakePartialShape(const T* dims, int n,
    443                                  PartialTensorShape* out) {
    444     return TensorShapeUtils::MakeShape(dims, n, out);
    445   }
    446 };
    447 
    448 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
    449 /// common predicates on a partially known tensor shape.
    450 class PartialTensorShapeUtils {
    451  public:
    452   static string PartialShapeListString(
    453       const gtl::ArraySlice<PartialTensorShape>& shapes);
    454 
    455   static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
    456                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
    457 
    458   static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
    459                             const gtl::ArraySlice<PartialTensorShape>& shapes1);
    460 };
    461 
    462 // ----------------------------------------------------------------------------
    463 // Template method implementation details below
    464 // ----------------------------------------------------------------------------
    465 
    466 template <int NDIMS, typename IndexType>
    467 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
    468   CheckDimsEqual(NDIMS);
    469   return AsEigenDSizesWithPadding<NDIMS, IndexType>();
    470 }
    471 
    472 template <int NDIMS, typename IndexType>
    473 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
    474   CheckDimsAtLeast(NDIMS);
    475   static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
    476   Eigen::DSizes<IndexType, NDIMS> dsizes;
    477   for (int d = 0; d < dims(); d++) {
    478     dsizes[d] = static_cast<IndexType>(dim_size(d));
    479   }
    480   for (int d = dims(); d < NDIMS; d++) {
    481     dsizes[d] = 1;
    482   }
    483   return dsizes;
    484 }
    485 
    486 // ----------------------------------------------------------------------------
    487 // Inlining of some performance critical routines
    488 // ----------------------------------------------------------------------------
    489 
    490 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
    491   num_elements_ = b.num_elements_;
    492   if (b.tag() != REP_OUT_OF_LINE) {
    493     memcpy(buf(), b.buf(), sizeof(u_.buf));
    494     // memcpy above Implicitly does:
    495     //   set_ndims_byte(b.ndims_byte());
    496     //   set_tag(b.tag());
    497   } else {
    498     set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
    499     SlowCopyFrom(b);
    500   }
    501 }
    502 
    503 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
    504   num_elements_ = b.num_elements_;
    505   memcpy(buf(), b.buf(), sizeof(u_.buf));
    506   // memcpy above Implicitly does:
    507   //   set_ndims_byte(b.ndims_byte());
    508   //   set_tag(b.tag());
    509   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
    510 }
    511 
    512 inline TensorShapeRep::~TensorShapeRep() {
    513   if (tag() == REP_OUT_OF_LINE) {
    514     DestructorOutOfLine();
    515   }
    516 }
    517 
    518 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
    519   num_elements_ = b.num_elements_;
    520   if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
    521     memcpy(buf(), b.buf(), sizeof(u_.buf));
    522     // memcpy above implicitly also does:
    523     //   set_tag(b.tag());
    524     //   set_ndims_byte(b.ndims_byte());
    525   } else {
    526     SlowCopyFrom(b);
    527   }
    528 }
    529 
    530 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
    531   if (tag() == REP_OUT_OF_LINE) {
    532     DestructorOutOfLine();
    533   }
    534   num_elements_ = b.num_elements_;
    535   memcpy(buf(), b.buf(), sizeof(u_.buf));
    536   // memcpy above Implicitly does:
    537   //   set_ndims_byte(b.ndims_byte());
    538   //   set_tag(b.tag());
    539   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
    540 }
    541 
    542 inline TensorShape::operator const PartialTensorShape&() const {
    543   // Downcast to the shared representation and upcast to PartialTensorShape
    544   const TensorShapeRep* rep = this;
    545   return *static_cast<const PartialTensorShape*>(rep);
    546 }
    547 
    548 // Declare explicit instantiations in .cc file
    549 extern template class TensorShapeBase<TensorShape>;
    550 extern template class TensorShapeBase<PartialTensorShape>;
    551 
    552 }  // namespace tensorflow
    553 
    554 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
    555