1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 18 19 #include <functional> 20 #include <iterator> 21 #include <memory> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/layout_util.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/status_macros.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/gtl/array_slice.h" 32 #include "tensorflow/core/lib/gtl/iterator_range.h" 33 #include "tensorflow/core/lib/gtl/optional.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 namespace internal { 40 41 // Internal representation of each node in a ShapeTree. 42 template <typename T> 43 struct ShapeTreeNode { 44 // Data corresponding to this node. 45 T data; 46 47 // Children of this node. 48 std::vector<std::unique_ptr<ShapeTreeNode>> children; 49 50 ShapeTreeNode() = default; 51 explicit ShapeTreeNode(const T& data) : data(data) {} 52 53 ShapeTreeNode(const ShapeTreeNode& other) 54 : data(other.data), children(other.children.size()) { 55 for (size_t i = 0; i < children.size(); ++i) { 56 children[i] = MakeUnique<ShapeTreeNode>(*other.children[i]); 57 } 58 } 59 60 ShapeTreeNode& operator=(const ShapeTreeNode& other) { 61 if (this != &other) { 62 data = other.data; 63 children.resize(other.children.size()); 64 for (size_t i = 0; i < children.size(); ++i) { 65 children[i] = MakeUnique<ShapeTreeNode>(*other.children[i]); 66 } 67 } 68 return *this; 69 } 70 }; 71 72 } // namespace internal 73 74 template <typename T, bool is_const> 75 class ShapeTreeIterator; 76 77 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a 78 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array) 79 // in the shape. For array shapes, a ShapeTree trivially holds a single value of 80 // type T. 81 // 82 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a 83 // ShapeTree is an identically structured tree with data elements of type T at 84 // every node. I.e. the root is a tuple by definition, all interior nodes are 85 // also tuples, and all leaves are arrays. 86 // 87 // Like the Shape data structure, this is a tree and tuple elements cannot be 88 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T 89 // object. 90 // 91 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes 92 // it's helpful not to copy a Shape just to make a ShapeTree. In these cases, 93 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's 94 // then up to you to ensure that the pointed-to Shape doesn't die or mutate 95 // before its ShapeTree goes away. 96 template <typename T> 97 class ShapeTree { 98 friend class ShapeTreeIterator<T, /*is_const=*/true>; 99 friend class ShapeTreeIterator<T, /*is_const=*/false>; 100 101 public: 102 // Default constructor creates a tree with a nil shape (i.e. an empty tuple). 103 ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} 104 105 // Create ShapeTree with the given shape, and default-constructed T values for 106 // all nodes. 107 // 108 // The version that takes a pointer may be cheaper because it doesn't require 109 // any Shape copies, but then it's up to you to ensure that the pointer stays 110 // alive longer than this ShapeTree. 111 explicit ShapeTree(Shape shape); 112 explicit ShapeTree(const Shape* shape); 113 114 // Create ShapeTree with the given shape, and init_value for all nodes. 115 ShapeTree(Shape shape, const T& init_value); 116 ShapeTree(const Shape* shape, const T& init_value); 117 118 ShapeTree(const ShapeTree& other) { *this = other; } 119 ShapeTree(ShapeTree&&) = default; 120 121 ShapeTree& operator=(const ShapeTree& other) { 122 root_ = other.root_; 123 124 // Fix up internal pointer if necessary. 125 if (other.shape_storage_) { 126 CHECK_EQ(other.shape_, other.shape_storage_.get()); 127 shape_storage_.reset(new Shape(*other.shape_)); 128 shape_ = shape_storage_.get(); 129 } else { 130 shape_ = other.shape_; 131 } 132 133 return *this; 134 } 135 136 ShapeTree& operator=(ShapeTree&& other) = default; 137 138 // Returns the data element associated with the array in the shape at the 139 // given index (see ShapeUtil::GetSubshape for how indexes are defined). 140 const T& element(const ShapeIndex& index) const; 141 T* mutable_element(const ShapeIndex& index); 142 143 // Return the shape represented with this ShapeTree. 144 const Shape& shape() const { return *shape_; } 145 146 // Replaces *only* the underlying shape of this ShapeTree. The caller must own 147 // the Shape object and hence shape_storage_ is not updated. 148 // 149 // Only safe to use this if the ShapeTree was constructed with 'explicit 150 // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The 151 // caller must ensure that the input shape is consistent with the underlying 152 // tree. 153 void replace_shape_ptr(const Shape* shape) { 154 CHECK(shape_storage_.get() == nullptr); 155 shape_ = shape; 156 } 157 158 // Returns true if the node at the given index is a leaf node (an array 159 // shape). 160 bool IsLeaf(const ShapeIndex& index) const { 161 return Lookup(index)->children.empty(); 162 } 163 164 // iterator implements a forward_iterator with value_type = 165 // std::pair<ShapeIndex, T&> 166 using iterator = ShapeTreeIterator<T, /*is_const=*/false>; 167 using const_iterator = ShapeTreeIterator<T, /*is_const=*/true>; 168 169 // begin/end for iterating over all nodes. 170 iterator begin() { 171 return iterator(&root_, /*iterate_leaves_only=*/false, 172 /*reverse=*/false); 173 } 174 iterator end() { 175 return iterator(nullptr, /*iterate_leaves_only=*/false, 176 /*reverse=*/false); 177 } 178 const_iterator begin() const { 179 return const_iterator(&root_, /*iterate_leaves_only=*/false, 180 /*reverse=*/false); 181 } 182 const_iterator end() const { 183 return const_iterator(nullptr, /*iterate_leaves_only=*/false, 184 /*reverse=*/false); 185 } 186 187 // rbegin/rend for iterating over all nodes in reverse. 188 iterator rbegin() { 189 return iterator(&root_, /*iterate_leaves_only=*/false, 190 /*reverse=*/true); 191 } 192 iterator rend() { 193 return iterator(nullptr, /*iterate_leaves_only=*/false, 194 /*reverse=*/true); 195 } 196 const_iterator rbegin() const { 197 return const_iterator(&root_, /*iterate_leaves_only=*/false, 198 /*reverse=*/true); 199 } 200 const_iterator rend() const { 201 return const_iterator(nullptr, /*iterate_leaves_only=*/false, 202 /*reverse=*/true); 203 } 204 205 // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no 206 // children). 207 iterator leaf_begin() { 208 return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); 209 } 210 iterator leaf_end() { 211 return iterator(nullptr, /*iterate_leaves_only=*/true, 212 /*reverse=*/false); 213 } 214 const_iterator leaf_begin() const { 215 return const_iterator(&root_, /*iterate_leaves_only=*/true, 216 /*reverse=*/false); 217 } 218 const_iterator leaf_end() const { 219 return const_iterator(nullptr, /*iterate_leaves_only=*/true, 220 /*reverse=*/false); 221 } 222 // range-based iterator for leaf_begin()/leaf_end(). 223 tensorflow::gtl::iterator_range<iterator> leaves() { 224 return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); 225 } 226 tensorflow::gtl::iterator_range<const_iterator> leaves() const { 227 return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); 228 } 229 230 iterator leaf_rbegin() { 231 return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); 232 } 233 iterator leaf_rend() { 234 return iterator(nullptr, /*iterate_leaves_only=*/true, 235 /*reverse=*/true); 236 } 237 const_iterator leaf_rbegin() const { 238 return const_iterator(&root_, /*iterate_leaves_only=*/true, 239 /*reverse=*/true); 240 } 241 const_iterator leaf_rend() const { 242 return const_iterator(nullptr, /*iterate_leaves_only=*/true, 243 /*reverse=*/true); 244 } 245 246 // Recursively traverses the shape and calls the given function at each 247 // element. The function has the following arguments: 248 // 249 // Fn : A callable of type void(const ShapeIndex& index, const T& data) 250 // (or compatible). 251 // index : the index of the element in the shape. See ShapeUtil::GetSubshape 252 // for definition of index. 253 // data : The data value at this element. 254 template <typename Fn> 255 void ForEachElement(const Fn& func) const; 256 257 // Like ForEachElement, but the callable has type 258 // 259 // void (const ShapeIndex& index, T* data). 260 // 261 template <typename Fn> 262 void ForEachMutableElement(const Fn& func); 263 264 // Like ForEach(Mutable)Element, but the callable returns a Status instead of 265 // void. The first non-OK return value is returned by the ForEach* function. 266 template <typename Fn> 267 Status ForEachElementWithStatus(const Fn& func) const; 268 template <typename Fn> 269 Status ForEachMutableElementWithStatus(const Fn& func); 270 271 // Copy the subtree of values from 'other' rooted at ShapeIndex 272 // 'source_base_index' into the subtree of value in this ShapeTree rooted at 273 // 'target_base_index'. 274 // 275 // Precondition: The subshape of other.shape() at index source_base_index must 276 // be compatible with the subshape of shape() at index target_base_index. 277 void CopySubtreeFrom(const ShapeTree<T>& other, 278 const ShapeIndex& source_base_index, 279 const ShapeIndex& target_base_index); 280 281 bool operator==(const ShapeTree<T>& other) const; 282 bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); } 283 284 private: 285 using Node = internal::ShapeTreeNode<T>; 286 287 // Initialize node->children based on 'shape'. All children are assigned the 288 // the given 'init_value'. 289 void InitChildren(const Shape& shape, const T& init_value, Node* node); 290 291 // Initialize node->children based on 'shape'. All children have 292 // default-constructed data values. 293 void InitChildren(const Shape& shape, Node* node); 294 295 // Helpers for traversing the shape via ForEachElement. The helpers 296 // recursively traverse the subtree rooted at "index" (defined as in 297 // ShapeUtil::GetSubshape). 298 template <typename Fn> 299 static Status ForEachHelper(const Fn& func, const Node& node, 300 ShapeIndex* index); 301 template <typename Fn> 302 static Status ForEachMutableHelper(const Fn& func, Node* node, 303 ShapeIndex* index); 304 305 // Return the tree node at the given index. 306 Node* Lookup(const ShapeIndex& index); 307 const Node* Lookup(const ShapeIndex& index) const; 308 309 // The root node, which contains all other nodes. 310 Node root_; 311 312 // If we own our Shape, this field contains it, and shape_ is a pointer into 313 // here. Otherwise if we don't own our shape, this is nullptr. 314 std::unique_ptr<Shape> shape_storage_; 315 316 // The XLA shape mirrored in this ShapeTree. This is either 317 // shape_storage_.get() or the Shape pointer passed to our constructor. 318 const Shape* shape_; 319 }; 320 321 // Internal iterator that performs a pre-order walk. This is copyable, but 322 // contains a vector so isn't cheap to copy. This also means post-increment is 323 // expensive. The iterator value_type is equivalent to a std::pair<ShapeIndex, 324 // T&>, similar to std::map. The non-const iterator's T& type can be mutated 325 // in-place. 326 template <typename T, bool is_const> 327 class ShapeTreeIterator : public std::iterator<std::forward_iterator_tag, 328 std::pair<ShapeIndex, T&>> { 329 public: 330 using value_type = 331 typename std::conditional<is_const, std::pair<ShapeIndex, const T&>, 332 std::pair<ShapeIndex, T&>>::type; 333 using NodeType = 334 typename std::conditional<is_const, const typename ShapeTree<T>::Node, 335 typename ShapeTree<T>::Node>::type; 336 337 // Construct an iterator pointing at node. Node must either be the tree root 338 // or nullptr (which is equivalent to end() and should not be dereferenced or 339 // incremented). If iterate_leaves_only is true, the iterator will not include 340 // interior tree nodes, only leaves. If reverse is true, the iterator will 341 // visit nodes in the reverse of pre-order traversal. 342 ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) 343 : node_(node), 344 iterate_leaves_only_(iterate_leaves_only), 345 reverse_(reverse) { 346 if (node_) { 347 if (reverse_) { 348 while (!node_->children.empty()) { 349 const int child_index = node_->children.size() - 1; 350 stack_.push_back({node_, child_index}); 351 node_ = node_->children[child_index].get(); 352 } 353 } else { 354 if (!node_->children.empty() && iterate_leaves_only) { 355 ++*this; 356 } 357 } 358 } 359 } 360 ShapeTreeIterator(const ShapeTreeIterator& other) 361 : node_(other.node_), 362 stack_(other.stack_), 363 iterate_leaves_only_(other.iterate_leaves_only_), 364 reverse_(other.reverse_) {} 365 366 ShapeTreeIterator& operator++() { 367 CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; 368 if (reverse_) { 369 while (!stack_.empty()) { 370 node_ = stack_.back().first; 371 int64 next_child_index = stack_.back().second - 1; 372 stack_.pop_back(); 373 if (next_child_index < 0) { 374 if (!iterate_leaves_only_) { 375 // All children are visited, yield <node_>. 376 return *this; 377 } 378 } else { 379 stack_.push_back({node_, next_child_index}); 380 node_ = node_->children[next_child_index].get(); 381 while (!node_->children.empty()) { 382 const int child_index = node_->children.size() - 1; 383 stack_.push_back({node_, child_index}); 384 node_ = node_->children[child_index].get(); 385 } 386 return *this; 387 } 388 } 389 } else { 390 // We're doing a pre-order walk, so if our current node has children take 391 // the first child. 392 if (!node_->children.empty()) { 393 stack_.push_back({node_, /*child-index=*/0}); 394 node_ = node_->children[0].get(); 395 if (node_->children.empty() || !iterate_leaves_only_) { 396 return *this; 397 } else { 398 // This is a non-leaf; tail-recurse. 399 return ++(*this); 400 } 401 } 402 // Otherwise we are currently at a leaf. Walk back up until a node 403 // contains a child we haven't visited yet. 404 while (!stack_.empty()) { 405 node_ = stack_.back().first; 406 int64 next_child_index = stack_.back().second + 1; 407 stack_.pop_back(); 408 if (node_->children.size() > next_child_index) { 409 stack_.push_back({node_, next_child_index}); 410 node_ = node_->children[next_child_index].get(); 411 412 if (node_->children.empty() || !iterate_leaves_only_) { 413 return *this; 414 } else { 415 // This is a non-leaf; tail-recurse. 416 return ++(*this); 417 } 418 } 419 } 420 } 421 // We've walked off the end of the tree. Set node_ to nullptr to signify 422 // end(). 423 node_ = nullptr; 424 current_.reset(); 425 return *this; 426 } 427 ShapeTreeIterator operator++(int) { 428 auto i = *this; 429 ++(*this); 430 return i; 431 } 432 bool operator==(const ShapeTreeIterator& other) const { 433 return node_ == other.node_; 434 } 435 bool operator!=(const ShapeTreeIterator& other) const { 436 return node_ != other.node_; 437 } 438 value_type& operator*() { return UpdateCurrent(); } 439 value_type* operator->() { return &UpdateCurrent(); } 440 441 private: 442 // Updates the current_ member to reflect the current state. 443 value_type& UpdateCurrent() { 444 ShapeIndex index; 445 for (auto& node_and_index : stack_) { 446 index.push_back(node_and_index.second); 447 } 448 current_ = MakeUnique<value_type>(index, node_->data); 449 return *current_; 450 } 451 452 // The node to which this iterator is pointing. This is the source of truth in 453 // the iterator - the stack only exists to facilitate walking back from 454 // children to parents. 455 NodeType* node_; 456 // Stack of {node, child-index} pairs of the path taken from the root to get 457 // to node_. This allows us to backtrack and know where to go next. 458 std::vector<std::pair<NodeType*, int64>> stack_; 459 // True if we should not include interior nodes in our walk. 460 bool iterate_leaves_only_; 461 // True if we should yield the reverse of the pre-order traversal. 462 bool reverse_; 463 // Placeholder for the current value. Ideally this wouldn't exist and would 464 // just be an rvalue, but operator -> needs to return a pointer to something. 465 // We cannot just use a plain old value_type as it contains a reference so 466 // cannot be default-constructed. 467 std::unique_ptr<value_type> current_; 468 }; 469 470 template <typename T> 471 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value, 472 Node* node) { 473 if (ShapeUtil::IsTuple(shape)) { 474 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { 475 node->children.emplace_back(new Node(init_value)); 476 InitChildren(shape.tuple_shapes(i), init_value, 477 node->children.back().get()); 478 } 479 } 480 } 481 482 template <typename T> 483 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) { 484 if (ShapeUtil::IsTuple(shape)) { 485 for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { 486 node->children.emplace_back(new Node()); 487 InitChildren(shape.tuple_shapes(i), node->children.back().get()); 488 } 489 } 490 } 491 492 template <typename T> 493 ShapeTree<T>::ShapeTree(Shape shape) 494 : root_(), 495 shape_storage_(MakeUnique<Shape>(std::move(shape))), 496 shape_(shape_storage_.get()) { 497 // The shape_ field is just used to hold the structure of the shape. 498 // It should not be relied upon to store layout information. 499 LayoutUtil::ClearLayout(shape_storage_.get()); 500 InitChildren(*shape_, &root_); 501 } 502 503 template <typename T> 504 ShapeTree<T>::ShapeTree(const Shape* shape) : root_(), shape_(shape) { 505 InitChildren(*shape_, &root_); 506 } 507 508 template <typename T> 509 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value) 510 : root_(init_value), 511 shape_storage_(MakeUnique<Shape>(std::move(shape))), 512 shape_(shape_storage_.get()) { 513 // The shape_ field is just used to hold the structure of the shape. 514 // It should not be relied upon to store layout information. 515 LayoutUtil::ClearLayout(shape_storage_.get()); 516 InitChildren(*shape_, init_value, &root_); 517 } 518 519 template <typename T> 520 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value) 521 : root_(init_value), shape_(shape) { 522 InitChildren(*shape_, init_value, &root_); 523 } 524 525 template <typename T> 526 const T& ShapeTree<T>::element(const ShapeIndex& index) const { 527 return Lookup(index)->data; 528 } 529 530 template <typename T> 531 T* ShapeTree<T>::mutable_element(const ShapeIndex& index) { 532 return &Lookup(index)->data; 533 } 534 535 template <typename T> 536 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) { 537 Node* node = &root_; 538 for (const int64 i : index) { 539 CHECK_GE(i, 0); 540 CHECK_LT(i, node->children.size()); 541 node = node->children[i].get(); 542 } 543 return node; 544 } 545 546 template <typename T> 547 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup( 548 const ShapeIndex& index) const { 549 return const_cast<ShapeTree*>(this)->Lookup(index); 550 } 551 552 /* static */ 553 template <typename T> 554 template <typename Fn> 555 Status ShapeTree<T>::ForEachHelper(const Fn& func, const Node& node, 556 ShapeIndex* index) { 557 TF_RETURN_IF_ERROR(func(*index, node.data)); 558 for (int64 i = 0; i < node.children.size(); ++i) { 559 index->push_back(i); 560 TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); 561 index->pop_back(); 562 } 563 return Status::OK(); 564 } 565 566 /* static */ 567 template <typename T> 568 template <typename Fn> 569 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func, Node* node, 570 ShapeIndex* index) { 571 TF_RETURN_IF_ERROR(func(*index, &node->data)); 572 for (int64 i = 0; i < node->children.size(); ++i) { 573 index->push_back(i); 574 TF_RETURN_IF_ERROR( 575 ForEachMutableHelper(func, node->children[i].get(), index)); 576 index->pop_back(); 577 } 578 return Status::OK(); 579 } 580 581 template <typename T> 582 template <typename Fn> 583 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const { 584 ShapeIndex index; 585 return ForEachHelper(func, root_, &index); 586 } 587 588 template <typename T> 589 template <typename Fn> 590 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) { 591 ShapeIndex index; 592 return ForEachMutableHelper(func, &root_, &index); 593 } 594 595 template <typename T> 596 template <typename Fn> 597 void ShapeTree<T>::ForEachElement(const Fn& func) const { 598 ShapeIndex index; 599 return ForEachHelper( 600 [&func](const ShapeIndex& index, const T& data) { 601 func(index, data); 602 return Status::OK(); 603 }, 604 root_, &index) 605 .IgnoreError(); 606 } 607 608 template <typename T> 609 template <typename Fn> 610 void ShapeTree<T>::ForEachMutableElement(const Fn& func) { 611 ShapeIndex index; 612 return ForEachMutableHelper( 613 [&func](const ShapeIndex& index, T* data) { 614 func(index, data); 615 return Status::OK(); 616 }, 617 &root_, &index) 618 .IgnoreError(); 619 } 620 621 template <typename T> 622 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other, 623 const ShapeIndex& source_base_index, 624 const ShapeIndex& target_base_index) { 625 CHECK(ShapeUtil::Compatible( 626 ShapeUtil::GetSubshape(shape(), target_base_index), 627 ShapeUtil::GetSubshape(other.shape(), source_base_index))); 628 ForEachMutableElement([this, &other, &source_base_index, &target_base_index]( 629 const ShapeIndex& index, T* data) { 630 // Copy the data element only if index is in the 631 // subtree rooted at target_base_index. 632 for (int i = 0; i < target_base_index.size(); ++i) { 633 if (i >= index.size() || index[i] != target_base_index[i]) { 634 return; 635 } 636 } 637 // Construct source element index to copy from. 638 ShapeIndex source_index = source_base_index; 639 for (int i = target_base_index.size(); i < index.size(); ++i) { 640 source_index.push_back(index[i]); 641 } 642 *data = other.element(source_index); 643 }); 644 } 645 646 template <typename T> 647 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const { 648 bool equal = true; 649 ForEachElement( 650 [this, &other, &equal](const ShapeIndex& index, const T& data) { 651 if (data != other.element(index)) { 652 equal = false; 653 } 654 }); 655 return equal; 656 } 657 658 } // namespace xla 659 660 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_ 661