Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
     18 
     19 #include <functional>
     20 #include <iterator>
     21 #include <memory>
     22 #include <vector>
     23 
     24 #include "tensorflow/compiler/xla/layout_util.h"
     25 #include "tensorflow/compiler/xla/ptr_util.h"
     26 #include "tensorflow/compiler/xla/shape_util.h"
     27 #include "tensorflow/compiler/xla/status_macros.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/lib/gtl/array_slice.h"
     32 #include "tensorflow/core/lib/gtl/iterator_range.h"
     33 #include "tensorflow/core/lib/gtl/optional.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace xla {
     38 
     39 namespace internal {
     40 
     41 // Internal representation of each node in a ShapeTree.
     42 template <typename T>
     43 struct ShapeTreeNode {
     44   // Data corresponding to this node.
     45   T data;
     46 
     47   // Children of this node.
     48   std::vector<std::unique_ptr<ShapeTreeNode>> children;
     49 
     50   ShapeTreeNode() = default;
     51   explicit ShapeTreeNode(const T& data) : data(data) {}
     52 
     53   ShapeTreeNode(const ShapeTreeNode& other)
     54       : data(other.data), children(other.children.size()) {
     55     for (size_t i = 0; i < children.size(); ++i) {
     56       children[i] = MakeUnique<ShapeTreeNode>(*other.children[i]);
     57     }
     58   }
     59 
     60   ShapeTreeNode& operator=(const ShapeTreeNode& other) {
     61     if (this != &other) {
     62       data = other.data;
     63       children.resize(other.children.size());
     64       for (size_t i = 0; i < children.size(); ++i) {
     65         children[i] = MakeUnique<ShapeTreeNode>(*other.children[i]);
     66       }
     67     }
     68     return *this;
     69   }
     70 };
     71 
     72 }  // namespace internal
     73 
     74 template <typename T, bool is_const>
     75 class ShapeTreeIterator;
     76 
     77 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
     78 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
     79 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
     80 // type T.
     81 //
     82 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
     83 // ShapeTree is an identically structured tree with data elements of type T at
     84 // every node. I.e. the root is a tuple by definition, all interior nodes are
     85 // also tuples, and all leaves are arrays.
     86 //
     87 // Like the Shape data structure, this is a tree and tuple elements cannot be
     88 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
     89 // object.
     90 //
     91 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
     92 // it's helpful not to copy a Shape just to make a ShapeTree.  In these cases,
     93 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor.  It's
     94 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
     95 // before its ShapeTree goes away.
     96 template <typename T>
     97 class ShapeTree {
     98   friend class ShapeTreeIterator<T, /*is_const=*/true>;
     99   friend class ShapeTreeIterator<T, /*is_const=*/false>;
    100 
    101  public:
    102   // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
    103   ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
    104 
    105   // Create ShapeTree with the given shape, and default-constructed T values for
    106   // all nodes.
    107   //
    108   // The version that takes a pointer may be cheaper because it doesn't require
    109   // any Shape copies, but then it's up to you to ensure that the pointer stays
    110   // alive longer than this ShapeTree.
    111   explicit ShapeTree(Shape shape);
    112   explicit ShapeTree(const Shape* shape);
    113 
    114   // Create ShapeTree with the given shape, and init_value for all nodes.
    115   ShapeTree(Shape shape, const T& init_value);
    116   ShapeTree(const Shape* shape, const T& init_value);
    117 
    118   ShapeTree(const ShapeTree& other) { *this = other; }
    119   ShapeTree(ShapeTree&&) = default;
    120 
    121   ShapeTree& operator=(const ShapeTree& other) {
    122     root_ = other.root_;
    123 
    124     // Fix up internal pointer if necessary.
    125     if (other.shape_storage_) {
    126       CHECK_EQ(other.shape_, other.shape_storage_.get());
    127       shape_storage_.reset(new Shape(*other.shape_));
    128       shape_ = shape_storage_.get();
    129     } else {
    130       shape_ = other.shape_;
    131     }
    132 
    133     return *this;
    134   }
    135 
    136   ShapeTree& operator=(ShapeTree&& other) = default;
    137 
    138   // Returns the data element associated with the array in the shape at the
    139   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
    140   const T& element(const ShapeIndex& index) const;
    141   T* mutable_element(const ShapeIndex& index);
    142 
    143   // Return the shape represented with this ShapeTree.
    144   const Shape& shape() const { return *shape_; }
    145 
    146   // Replaces *only* the underlying shape of this ShapeTree. The caller must own
    147   // the Shape object and hence shape_storage_ is not updated.
    148   //
    149   // Only safe to use this if the ShapeTree was constructed with 'explicit
    150   // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The
    151   // caller must ensure that the input shape is consistent with the underlying
    152   // tree.
    153   void replace_shape_ptr(const Shape* shape) {
    154     CHECK(shape_storage_.get() == nullptr);
    155     shape_ = shape;
    156   }
    157 
    158   // Returns true if the node at the given index is a leaf node (an array
    159   // shape).
    160   bool IsLeaf(const ShapeIndex& index) const {
    161     return Lookup(index)->children.empty();
    162   }
    163 
    164   // iterator implements a forward_iterator with value_type =
    165   // std::pair<ShapeIndex, T&>
    166   using iterator = ShapeTreeIterator<T, /*is_const=*/false>;
    167   using const_iterator = ShapeTreeIterator<T, /*is_const=*/true>;
    168 
    169   // begin/end for iterating over all nodes.
    170   iterator begin() {
    171     return iterator(&root_, /*iterate_leaves_only=*/false,
    172                     /*reverse=*/false);
    173   }
    174   iterator end() {
    175     return iterator(nullptr, /*iterate_leaves_only=*/false,
    176                     /*reverse=*/false);
    177   }
    178   const_iterator begin() const {
    179     return const_iterator(&root_, /*iterate_leaves_only=*/false,
    180                           /*reverse=*/false);
    181   }
    182   const_iterator end() const {
    183     return const_iterator(nullptr, /*iterate_leaves_only=*/false,
    184                           /*reverse=*/false);
    185   }
    186 
    187   // rbegin/rend for iterating over all nodes in reverse.
    188   iterator rbegin() {
    189     return iterator(&root_, /*iterate_leaves_only=*/false,
    190                     /*reverse=*/true);
    191   }
    192   iterator rend() {
    193     return iterator(nullptr, /*iterate_leaves_only=*/false,
    194                     /*reverse=*/true);
    195   }
    196   const_iterator rbegin() const {
    197     return const_iterator(&root_, /*iterate_leaves_only=*/false,
    198                           /*reverse=*/true);
    199   }
    200   const_iterator rend() const {
    201     return const_iterator(nullptr, /*iterate_leaves_only=*/false,
    202                           /*reverse=*/true);
    203   }
    204 
    205   // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
    206   // children).
    207   iterator leaf_begin() {
    208     return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false);
    209   }
    210   iterator leaf_end() {
    211     return iterator(nullptr, /*iterate_leaves_only=*/true,
    212                     /*reverse=*/false);
    213   }
    214   const_iterator leaf_begin() const {
    215     return const_iterator(&root_, /*iterate_leaves_only=*/true,
    216                           /*reverse=*/false);
    217   }
    218   const_iterator leaf_end() const {
    219     return const_iterator(nullptr, /*iterate_leaves_only=*/true,
    220                           /*reverse=*/false);
    221   }
    222   // range-based iterator for leaf_begin()/leaf_end().
    223   tensorflow::gtl::iterator_range<iterator> leaves() {
    224     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
    225   }
    226   tensorflow::gtl::iterator_range<const_iterator> leaves() const {
    227     return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
    228   }
    229 
    230   iterator leaf_rbegin() {
    231     return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true);
    232   }
    233   iterator leaf_rend() {
    234     return iterator(nullptr, /*iterate_leaves_only=*/true,
    235                     /*reverse=*/true);
    236   }
    237   const_iterator leaf_rbegin() const {
    238     return const_iterator(&root_, /*iterate_leaves_only=*/true,
    239                           /*reverse=*/true);
    240   }
    241   const_iterator leaf_rend() const {
    242     return const_iterator(nullptr, /*iterate_leaves_only=*/true,
    243                           /*reverse=*/true);
    244   }
    245 
    246   // Recursively traverses the shape and calls the given function at each
    247   // element. The function has the following arguments:
    248   //
    249   //   Fn :    A callable of type void(const ShapeIndex& index, const T& data)
    250   //           (or compatible).
    251   //   index : the index of the element in the shape. See ShapeUtil::GetSubshape
    252   //           for definition of index.
    253   //   data : The data value at this element.
    254   template <typename Fn>
    255   void ForEachElement(const Fn& func) const;
    256 
    257   // Like ForEachElement, but the callable has type
    258   //
    259   //   void (const ShapeIndex& index, T* data).
    260   //
    261   template <typename Fn>
    262   void ForEachMutableElement(const Fn& func);
    263 
    264   // Like ForEach(Mutable)Element, but the callable returns a Status instead of
    265   // void.  The first non-OK return value is returned by the ForEach* function.
    266   template <typename Fn>
    267   Status ForEachElementWithStatus(const Fn& func) const;
    268   template <typename Fn>
    269   Status ForEachMutableElementWithStatus(const Fn& func);
    270 
    271   // Copy the subtree of values from 'other' rooted at ShapeIndex
    272   // 'source_base_index' into the subtree of value in this ShapeTree rooted at
    273   // 'target_base_index'.
    274   //
    275   // Precondition: The subshape of other.shape() at index source_base_index must
    276   // be compatible with the subshape of shape() at index target_base_index.
    277   void CopySubtreeFrom(const ShapeTree<T>& other,
    278                        const ShapeIndex& source_base_index,
    279                        const ShapeIndex& target_base_index);
    280 
    281   bool operator==(const ShapeTree<T>& other) const;
    282   bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
    283 
    284  private:
    285   using Node = internal::ShapeTreeNode<T>;
    286 
    287   // Initialize node->children based on 'shape'. All children are assigned the
    288   // the given 'init_value'.
    289   void InitChildren(const Shape& shape, const T& init_value, Node* node);
    290 
    291   // Initialize node->children based on 'shape'. All children have
    292   // default-constructed data values.
    293   void InitChildren(const Shape& shape, Node* node);
    294 
    295   // Helpers for traversing the shape via ForEachElement. The helpers
    296   // recursively traverse the subtree rooted at "index" (defined as in
    297   // ShapeUtil::GetSubshape).
    298   template <typename Fn>
    299   static Status ForEachHelper(const Fn& func, const Node& node,
    300                               ShapeIndex* index);
    301   template <typename Fn>
    302   static Status ForEachMutableHelper(const Fn& func, Node* node,
    303                                      ShapeIndex* index);
    304 
    305   // Return the tree node at the given index.
    306   Node* Lookup(const ShapeIndex& index);
    307   const Node* Lookup(const ShapeIndex& index) const;
    308 
    309   // The root node, which contains all other nodes.
    310   Node root_;
    311 
    312   // If we own our Shape, this field contains it, and shape_ is a pointer into
    313   // here.  Otherwise if we don't own our shape, this is nullptr.
    314   std::unique_ptr<Shape> shape_storage_;
    315 
    316   // The XLA shape mirrored in this ShapeTree.  This is either
    317   // shape_storage_.get() or the Shape pointer passed to our constructor.
    318   const Shape* shape_;
    319 };
    320 
    321 // Internal iterator that performs a pre-order walk. This is copyable, but
    322 // contains a vector so isn't cheap to copy. This also means post-increment is
    323 // expensive. The iterator value_type is equivalent to a std::pair<ShapeIndex,
    324 // T&>, similar to std::map. The non-const iterator's T& type can be mutated
    325 // in-place.
    326 template <typename T, bool is_const>
    327 class ShapeTreeIterator : public std::iterator<std::forward_iterator_tag,
    328                                                std::pair<ShapeIndex, T&>> {
    329  public:
    330   using value_type =
    331       typename std::conditional<is_const, std::pair<ShapeIndex, const T&>,
    332                                 std::pair<ShapeIndex, T&>>::type;
    333   using NodeType =
    334       typename std::conditional<is_const, const typename ShapeTree<T>::Node,
    335                                 typename ShapeTree<T>::Node>::type;
    336 
    337   // Construct an iterator pointing at node. Node must either be the tree root
    338   // or nullptr (which is equivalent to end() and should not be dereferenced or
    339   // incremented). If iterate_leaves_only is true, the iterator will not include
    340   // interior tree nodes, only leaves. If reverse is true, the iterator will
    341   // visit nodes in the reverse of pre-order traversal.
    342   ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse)
    343       : node_(node),
    344         iterate_leaves_only_(iterate_leaves_only),
    345         reverse_(reverse) {
    346     if (node_) {
    347       if (reverse_) {
    348         while (!node_->children.empty()) {
    349           const int child_index = node_->children.size() - 1;
    350           stack_.push_back({node_, child_index});
    351           node_ = node_->children[child_index].get();
    352         }
    353       } else {
    354         if (!node_->children.empty() && iterate_leaves_only) {
    355           ++*this;
    356         }
    357       }
    358     }
    359   }
    360   ShapeTreeIterator(const ShapeTreeIterator& other)
    361       : node_(other.node_),
    362         stack_(other.stack_),
    363         iterate_leaves_only_(other.iterate_leaves_only_),
    364         reverse_(other.reverse_) {}
    365 
    366   ShapeTreeIterator& operator++() {
    367     CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!";
    368     if (reverse_) {
    369       while (!stack_.empty()) {
    370         node_ = stack_.back().first;
    371         int64 next_child_index = stack_.back().second - 1;
    372         stack_.pop_back();
    373         if (next_child_index < 0) {
    374           if (!iterate_leaves_only_) {
    375             // All children are visited, yield <node_>.
    376             return *this;
    377           }
    378         } else {
    379           stack_.push_back({node_, next_child_index});
    380           node_ = node_->children[next_child_index].get();
    381           while (!node_->children.empty()) {
    382             const int child_index = node_->children.size() - 1;
    383             stack_.push_back({node_, child_index});
    384             node_ = node_->children[child_index].get();
    385           }
    386           return *this;
    387         }
    388       }
    389     } else {
    390       // We're doing a pre-order walk, so if our current node has children take
    391       // the first child.
    392       if (!node_->children.empty()) {
    393         stack_.push_back({node_, /*child-index=*/0});
    394         node_ = node_->children[0].get();
    395         if (node_->children.empty() || !iterate_leaves_only_) {
    396           return *this;
    397         } else {
    398           // This is a non-leaf; tail-recurse.
    399           return ++(*this);
    400         }
    401       }
    402       // Otherwise we are currently at a leaf. Walk back up until a node
    403       // contains a child we haven't visited yet.
    404       while (!stack_.empty()) {
    405         node_ = stack_.back().first;
    406         int64 next_child_index = stack_.back().second + 1;
    407         stack_.pop_back();
    408         if (node_->children.size() > next_child_index) {
    409           stack_.push_back({node_, next_child_index});
    410           node_ = node_->children[next_child_index].get();
    411 
    412           if (node_->children.empty() || !iterate_leaves_only_) {
    413             return *this;
    414           } else {
    415             // This is a non-leaf; tail-recurse.
    416             return ++(*this);
    417           }
    418         }
    419       }
    420     }
    421     // We've walked off the end of the tree. Set node_ to nullptr to signify
    422     // end().
    423     node_ = nullptr;
    424     current_.reset();
    425     return *this;
    426   }
    427   ShapeTreeIterator operator++(int) {
    428     auto i = *this;
    429     ++(*this);
    430     return i;
    431   }
    432   bool operator==(const ShapeTreeIterator& other) const {
    433     return node_ == other.node_;
    434   }
    435   bool operator!=(const ShapeTreeIterator& other) const {
    436     return node_ != other.node_;
    437   }
    438   value_type& operator*() { return UpdateCurrent(); }
    439   value_type* operator->() { return &UpdateCurrent(); }
    440 
    441  private:
    442   // Updates the current_ member to reflect the current state.
    443   value_type& UpdateCurrent() {
    444     ShapeIndex index;
    445     for (auto& node_and_index : stack_) {
    446       index.push_back(node_and_index.second);
    447     }
    448     current_ = MakeUnique<value_type>(index, node_->data);
    449     return *current_;
    450   }
    451 
    452   // The node to which this iterator is pointing. This is the source of truth in
    453   // the iterator - the stack only exists to facilitate walking back from
    454   // children to parents.
    455   NodeType* node_;
    456   // Stack of {node, child-index} pairs of the path taken from the root to get
    457   // to node_. This allows us to backtrack and know where to go next.
    458   std::vector<std::pair<NodeType*, int64>> stack_;
    459   // True if we should not include interior nodes in our walk.
    460   bool iterate_leaves_only_;
    461   // True if we should yield the reverse of the pre-order traversal.
    462   bool reverse_;
    463   // Placeholder for the current value. Ideally this wouldn't exist and would
    464   // just be an rvalue, but operator -> needs to return a pointer to something.
    465   // We cannot just use a plain old value_type as it contains a reference so
    466   // cannot be default-constructed.
    467   std::unique_ptr<value_type> current_;
    468 };
    469 
    470 template <typename T>
    471 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
    472                                 Node* node) {
    473   if (ShapeUtil::IsTuple(shape)) {
    474     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
    475       node->children.emplace_back(new Node(init_value));
    476       InitChildren(shape.tuple_shapes(i), init_value,
    477                    node->children.back().get());
    478     }
    479   }
    480 }
    481 
    482 template <typename T>
    483 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
    484   if (ShapeUtil::IsTuple(shape)) {
    485     for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
    486       node->children.emplace_back(new Node());
    487       InitChildren(shape.tuple_shapes(i), node->children.back().get());
    488     }
    489   }
    490 }
    491 
    492 template <typename T>
    493 ShapeTree<T>::ShapeTree(Shape shape)
    494     : root_(),
    495       shape_storage_(MakeUnique<Shape>(std::move(shape))),
    496       shape_(shape_storage_.get()) {
    497   // The shape_ field is just used to hold the structure of the shape.
    498   // It should not be relied upon to store layout information.
    499   LayoutUtil::ClearLayout(shape_storage_.get());
    500   InitChildren(*shape_, &root_);
    501 }
    502 
    503 template <typename T>
    504 ShapeTree<T>::ShapeTree(const Shape* shape) : root_(), shape_(shape) {
    505   InitChildren(*shape_, &root_);
    506 }
    507 
    508 template <typename T>
    509 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
    510     : root_(init_value),
    511       shape_storage_(MakeUnique<Shape>(std::move(shape))),
    512       shape_(shape_storage_.get()) {
    513   // The shape_ field is just used to hold the structure of the shape.
    514   // It should not be relied upon to store layout information.
    515   LayoutUtil::ClearLayout(shape_storage_.get());
    516   InitChildren(*shape_, init_value, &root_);
    517 }
    518 
    519 template <typename T>
    520 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
    521     : root_(init_value), shape_(shape) {
    522   InitChildren(*shape_, init_value, &root_);
    523 }
    524 
    525 template <typename T>
    526 const T& ShapeTree<T>::element(const ShapeIndex& index) const {
    527   return Lookup(index)->data;
    528 }
    529 
    530 template <typename T>
    531 T* ShapeTree<T>::mutable_element(const ShapeIndex& index) {
    532   return &Lookup(index)->data;
    533 }
    534 
    535 template <typename T>
    536 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
    537   Node* node = &root_;
    538   for (const int64 i : index) {
    539     CHECK_GE(i, 0);
    540     CHECK_LT(i, node->children.size());
    541     node = node->children[i].get();
    542   }
    543   return node;
    544 }
    545 
    546 template <typename T>
    547 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
    548     const ShapeIndex& index) const {
    549   return const_cast<ShapeTree*>(this)->Lookup(index);
    550 }
    551 
    552 /* static */
    553 template <typename T>
    554 template <typename Fn>
    555 Status ShapeTree<T>::ForEachHelper(const Fn& func, const Node& node,
    556                                    ShapeIndex* index) {
    557   TF_RETURN_IF_ERROR(func(*index, node.data));
    558   for (int64 i = 0; i < node.children.size(); ++i) {
    559     index->push_back(i);
    560     TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index));
    561     index->pop_back();
    562   }
    563   return Status::OK();
    564 }
    565 
    566 /* static */
    567 template <typename T>
    568 template <typename Fn>
    569 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func, Node* node,
    570                                           ShapeIndex* index) {
    571   TF_RETURN_IF_ERROR(func(*index, &node->data));
    572   for (int64 i = 0; i < node->children.size(); ++i) {
    573     index->push_back(i);
    574     TF_RETURN_IF_ERROR(
    575         ForEachMutableHelper(func, node->children[i].get(), index));
    576     index->pop_back();
    577   }
    578   return Status::OK();
    579 }
    580 
    581 template <typename T>
    582 template <typename Fn>
    583 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
    584   ShapeIndex index;
    585   return ForEachHelper(func, root_, &index);
    586 }
    587 
    588 template <typename T>
    589 template <typename Fn>
    590 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
    591   ShapeIndex index;
    592   return ForEachMutableHelper(func, &root_, &index);
    593 }
    594 
    595 template <typename T>
    596 template <typename Fn>
    597 void ShapeTree<T>::ForEachElement(const Fn& func) const {
    598   ShapeIndex index;
    599   return ForEachHelper(
    600              [&func](const ShapeIndex& index, const T& data) {
    601                func(index, data);
    602                return Status::OK();
    603              },
    604              root_, &index)
    605       .IgnoreError();
    606 }
    607 
    608 template <typename T>
    609 template <typename Fn>
    610 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
    611   ShapeIndex index;
    612   return ForEachMutableHelper(
    613              [&func](const ShapeIndex& index, T* data) {
    614                func(index, data);
    615                return Status::OK();
    616              },
    617              &root_, &index)
    618       .IgnoreError();
    619 }
    620 
    621 template <typename T>
    622 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
    623                                    const ShapeIndex& source_base_index,
    624                                    const ShapeIndex& target_base_index) {
    625   CHECK(ShapeUtil::Compatible(
    626       ShapeUtil::GetSubshape(shape(), target_base_index),
    627       ShapeUtil::GetSubshape(other.shape(), source_base_index)));
    628   ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
    629                             const ShapeIndex& index, T* data) {
    630     // Copy the data element only if index is in the
    631     // subtree rooted at target_base_index.
    632     for (int i = 0; i < target_base_index.size(); ++i) {
    633       if (i >= index.size() || index[i] != target_base_index[i]) {
    634         return;
    635       }
    636     }
    637     // Construct source element index to copy from.
    638     ShapeIndex source_index = source_base_index;
    639     for (int i = target_base_index.size(); i < index.size(); ++i) {
    640       source_index.push_back(index[i]);
    641     }
    642     *data = other.element(source_index);
    643   });
    644 }
    645 
    646 template <typename T>
    647 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
    648   bool equal = true;
    649   ForEachElement(
    650       [this, &other, &equal](const ShapeIndex& index, const T& data) {
    651         if (data != other.element(index)) {
    652           equal = false;
    653         }
    654       });
    655   return equal;
    656 }
    657 
    658 }  // namespace xla
    659 
    660 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
    661