Home | History | Annotate | Download | only in collect
      1 /*
      2  * Copyright (C) 2007 The Guava Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.google.common.collect;
     18 
     19 import static com.google.common.base.Preconditions.checkArgument;
     20 import static com.google.common.base.Preconditions.checkState;
     21 import static com.google.common.collect.CollectPreconditions.checkNonnegative;
     22 import static com.google.common.collect.CollectPreconditions.checkRemove;
     23 
     24 import com.google.common.annotations.GwtCompatible;
     25 import com.google.common.annotations.GwtIncompatible;
     26 import com.google.common.base.MoreObjects;
     27 import com.google.common.primitives.Ints;
     28 
     29 import java.io.IOException;
     30 import java.io.ObjectInputStream;
     31 import java.io.ObjectOutputStream;
     32 import java.io.Serializable;
     33 import java.util.Comparator;
     34 import java.util.ConcurrentModificationException;
     35 import java.util.Iterator;
     36 import java.util.NoSuchElementException;
     37 
     38 import javax.annotation.Nullable;
     39 
     40 /**
     41  * A multiset which maintains the ordering of its elements, according to either their natural order
     42  * or an explicit {@link Comparator}. In all cases, this implementation uses
     43  * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
     44  * determine equivalence of instances.
     45  *
     46  * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
     47  * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
     48  * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
     49  *
     50  * <p>See the Guava User Guide article on <a href=
     51  * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset">
     52  * {@code Multiset}</a>.
     53  *
     54  * @author Louis Wasserman
     55  * @author Jared Levy
     56  * @since 2.0 (imported from Google Collections Library)
     57  */
     58 @GwtCompatible(emulated = true)
     59 public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
     60 
     61   /**
     62    * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
     63    * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
     64    * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
     65    * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
     66    * user attempts to add an element to the multiset that violates this constraint (for example,
     67    * the user attempts to add a string element to a set whose elements are integers), the
     68    * {@code add(Object)} call will throw a {@code ClassCastException}.
     69    *
     70    * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
     71    * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
     72    */
     73   public static <E extends Comparable> TreeMultiset<E> create() {
     74     return new TreeMultiset<E>(Ordering.natural());
     75   }
     76 
     77   /**
     78    * Creates a new, empty multiset, sorted according to the specified comparator. All elements
     79    * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
     80    * {@code comparator.compare(e1,
     81    * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
     82    * the multiset. If the user attempts to add an element to the multiset that violates this
     83    * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
     84    *
     85    * @param comparator
     86    *          the comparator that will be used to sort this multiset. A null value indicates that
     87    *          the elements' <i>natural ordering</i> should be used.
     88    */
     89   @SuppressWarnings("unchecked")
     90   public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
     91     return (comparator == null)
     92         ? new TreeMultiset<E>((Comparator) Ordering.natural())
     93         : new TreeMultiset<E>(comparator);
     94   }
     95 
     96   /**
     97    * Creates an empty multiset containing the given initial elements, sorted according to the
     98    * elements' natural order.
     99    *
    100    * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
    101    *
    102    * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
    103    * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
    104    */
    105   public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
    106     TreeMultiset<E> multiset = create();
    107     Iterables.addAll(multiset, elements);
    108     return multiset;
    109   }
    110 
    111   private final transient Reference<AvlNode<E>> rootReference;
    112   private final transient GeneralRange<E> range;
    113   private final transient AvlNode<E> header;
    114 
    115   TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
    116     super(range.comparator());
    117     this.rootReference = rootReference;
    118     this.range = range;
    119     this.header = endLink;
    120   }
    121 
    122   TreeMultiset(Comparator<? super E> comparator) {
    123     super(comparator);
    124     this.range = GeneralRange.all(comparator);
    125     this.header = new AvlNode<E>(null, 1);
    126     successor(header, header);
    127     this.rootReference = new Reference<AvlNode<E>>();
    128   }
    129 
    130   /**
    131    * A function which can be summed across a subtree.
    132    */
    133   private enum Aggregate {
    134     SIZE {
    135       @Override
    136       int nodeAggregate(AvlNode<?> node) {
    137         return node.elemCount;
    138       }
    139 
    140       @Override
    141       long treeAggregate(@Nullable AvlNode<?> root) {
    142         return (root == null) ? 0 : root.totalCount;
    143       }
    144     },
    145     DISTINCT {
    146       @Override
    147       int nodeAggregate(AvlNode<?> node) {
    148         return 1;
    149       }
    150 
    151       @Override
    152       long treeAggregate(@Nullable AvlNode<?> root) {
    153         return (root == null) ? 0 : root.distinctElements;
    154       }
    155     };
    156     abstract int nodeAggregate(AvlNode<?> node);
    157 
    158     abstract long treeAggregate(@Nullable AvlNode<?> root);
    159   }
    160 
    161   private long aggregateForEntries(Aggregate aggr) {
    162     AvlNode<E> root = rootReference.get();
    163     long total = aggr.treeAggregate(root);
    164     if (range.hasLowerBound()) {
    165       total -= aggregateBelowRange(aggr, root);
    166     }
    167     if (range.hasUpperBound()) {
    168       total -= aggregateAboveRange(aggr, root);
    169     }
    170     return total;
    171   }
    172 
    173   private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
    174     if (node == null) {
    175       return 0;
    176     }
    177     int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
    178     if (cmp < 0) {
    179       return aggregateBelowRange(aggr, node.left);
    180     } else if (cmp == 0) {
    181       switch (range.getLowerBoundType()) {
    182         case OPEN:
    183           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
    184         case CLOSED:
    185           return aggr.treeAggregate(node.left);
    186         default:
    187           throw new AssertionError();
    188       }
    189     } else {
    190       return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node)
    191           + aggregateBelowRange(aggr, node.right);
    192     }
    193   }
    194 
    195   private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
    196     if (node == null) {
    197       return 0;
    198     }
    199     int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
    200     if (cmp > 0) {
    201       return aggregateAboveRange(aggr, node.right);
    202     } else if (cmp == 0) {
    203       switch (range.getUpperBoundType()) {
    204         case OPEN:
    205           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
    206         case CLOSED:
    207           return aggr.treeAggregate(node.right);
    208         default:
    209           throw new AssertionError();
    210       }
    211     } else {
    212       return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node)
    213           + aggregateAboveRange(aggr, node.left);
    214     }
    215   }
    216 
    217   @Override
    218   public int size() {
    219     return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
    220   }
    221 
    222   @Override
    223   int distinctElements() {
    224     return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
    225   }
    226 
    227   @Override
    228   public int count(@Nullable Object element) {
    229     try {
    230       @SuppressWarnings("unchecked")
    231       E e = (E) element;
    232       AvlNode<E> root = rootReference.get();
    233       if (!range.contains(e) || root == null) {
    234         return 0;
    235       }
    236       return root.count(comparator(), e);
    237     } catch (ClassCastException e) {
    238       return 0;
    239     } catch (NullPointerException e) {
    240       return 0;
    241     }
    242   }
    243 
    244   @Override
    245   public int add(@Nullable E element, int occurrences) {
    246     checkNonnegative(occurrences, "occurrences");
    247     if (occurrences == 0) {
    248       return count(element);
    249     }
    250     checkArgument(range.contains(element));
    251     AvlNode<E> root = rootReference.get();
    252     if (root == null) {
    253       comparator().compare(element, element);
    254       AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
    255       successor(header, newRoot, header);
    256       rootReference.checkAndSet(root, newRoot);
    257       return 0;
    258     }
    259     int[] result = new int[1]; // used as a mutable int reference to hold result
    260     AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
    261     rootReference.checkAndSet(root, newRoot);
    262     return result[0];
    263   }
    264 
    265   @Override
    266   public int remove(@Nullable Object element, int occurrences) {
    267     checkNonnegative(occurrences, "occurrences");
    268     if (occurrences == 0) {
    269       return count(element);
    270     }
    271     AvlNode<E> root = rootReference.get();
    272     int[] result = new int[1]; // used as a mutable int reference to hold result
    273     AvlNode<E> newRoot;
    274     try {
    275       @SuppressWarnings("unchecked")
    276       E e = (E) element;
    277       if (!range.contains(e) || root == null) {
    278         return 0;
    279       }
    280       newRoot = root.remove(comparator(), e, occurrences, result);
    281     } catch (ClassCastException e) {
    282       return 0;
    283     } catch (NullPointerException e) {
    284       return 0;
    285     }
    286     rootReference.checkAndSet(root, newRoot);
    287     return result[0];
    288   }
    289 
    290   @Override
    291   public int setCount(@Nullable E element, int count) {
    292     checkNonnegative(count, "count");
    293     if (!range.contains(element)) {
    294       checkArgument(count == 0);
    295       return 0;
    296     }
    297 
    298     AvlNode<E> root = rootReference.get();
    299     if (root == null) {
    300       if (count > 0) {
    301         add(element, count);
    302       }
    303       return 0;
    304     }
    305     int[] result = new int[1]; // used as a mutable int reference to hold result
    306     AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
    307     rootReference.checkAndSet(root, newRoot);
    308     return result[0];
    309   }
    310 
    311   @Override
    312   public boolean setCount(@Nullable E element, int oldCount, int newCount) {
    313     checkNonnegative(newCount, "newCount");
    314     checkNonnegative(oldCount, "oldCount");
    315     checkArgument(range.contains(element));
    316 
    317     AvlNode<E> root = rootReference.get();
    318     if (root == null) {
    319       if (oldCount == 0) {
    320         if (newCount > 0) {
    321           add(element, newCount);
    322         }
    323         return true;
    324       } else {
    325         return false;
    326       }
    327     }
    328     int[] result = new int[1]; // used as a mutable int reference to hold result
    329     AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
    330     rootReference.checkAndSet(root, newRoot);
    331     return result[0] == oldCount;
    332   }
    333 
    334   private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
    335     return new Multisets.AbstractEntry<E>() {
    336       @Override
    337       public E getElement() {
    338         return baseEntry.getElement();
    339       }
    340 
    341       @Override
    342       public int getCount() {
    343         int result = baseEntry.getCount();
    344         if (result == 0) {
    345           return count(getElement());
    346         } else {
    347           return result;
    348         }
    349       }
    350     };
    351   }
    352 
    353   /**
    354    * Returns the first node in the tree that is in range.
    355    */
    356   @Nullable private AvlNode<E> firstNode() {
    357     AvlNode<E> root = rootReference.get();
    358     if (root == null) {
    359       return null;
    360     }
    361     AvlNode<E> node;
    362     if (range.hasLowerBound()) {
    363       E endpoint = range.getLowerEndpoint();
    364       node = rootReference.get().ceiling(comparator(), endpoint);
    365       if (node == null) {
    366         return null;
    367       }
    368       if (range.getLowerBoundType() == BoundType.OPEN
    369           && comparator().compare(endpoint, node.getElement()) == 0) {
    370         node = node.succ;
    371       }
    372     } else {
    373       node = header.succ;
    374     }
    375     return (node == header || !range.contains(node.getElement())) ? null : node;
    376   }
    377 
    378   @Nullable private AvlNode<E> lastNode() {
    379     AvlNode<E> root = rootReference.get();
    380     if (root == null) {
    381       return null;
    382     }
    383     AvlNode<E> node;
    384     if (range.hasUpperBound()) {
    385       E endpoint = range.getUpperEndpoint();
    386       node = rootReference.get().floor(comparator(), endpoint);
    387       if (node == null) {
    388         return null;
    389       }
    390       if (range.getUpperBoundType() == BoundType.OPEN
    391           && comparator().compare(endpoint, node.getElement()) == 0) {
    392         node = node.pred;
    393       }
    394     } else {
    395       node = header.pred;
    396     }
    397     return (node == header || !range.contains(node.getElement())) ? null : node;
    398   }
    399 
    400   @Override
    401   Iterator<Entry<E>> entryIterator() {
    402     return new Iterator<Entry<E>>() {
    403       AvlNode<E> current = firstNode();
    404       Entry<E> prevEntry;
    405 
    406       @Override
    407       public boolean hasNext() {
    408         if (current == null) {
    409           return false;
    410         } else if (range.tooHigh(current.getElement())) {
    411           current = null;
    412           return false;
    413         } else {
    414           return true;
    415         }
    416       }
    417 
    418       @Override
    419       public Entry<E> next() {
    420         if (!hasNext()) {
    421           throw new NoSuchElementException();
    422         }
    423         Entry<E> result = wrapEntry(current);
    424         prevEntry = result;
    425         if (current.succ == header) {
    426           current = null;
    427         } else {
    428           current = current.succ;
    429         }
    430         return result;
    431       }
    432 
    433       @Override
    434       public void remove() {
    435         checkRemove(prevEntry != null);
    436         setCount(prevEntry.getElement(), 0);
    437         prevEntry = null;
    438       }
    439     };
    440   }
    441 
    442   @Override
    443   Iterator<Entry<E>> descendingEntryIterator() {
    444     return new Iterator<Entry<E>>() {
    445       AvlNode<E> current = lastNode();
    446       Entry<E> prevEntry = null;
    447 
    448       @Override
    449       public boolean hasNext() {
    450         if (current == null) {
    451           return false;
    452         } else if (range.tooLow(current.getElement())) {
    453           current = null;
    454           return false;
    455         } else {
    456           return true;
    457         }
    458       }
    459 
    460       @Override
    461       public Entry<E> next() {
    462         if (!hasNext()) {
    463           throw new NoSuchElementException();
    464         }
    465         Entry<E> result = wrapEntry(current);
    466         prevEntry = result;
    467         if (current.pred == header) {
    468           current = null;
    469         } else {
    470           current = current.pred;
    471         }
    472         return result;
    473       }
    474 
    475       @Override
    476       public void remove() {
    477         checkRemove(prevEntry != null);
    478         setCount(prevEntry.getElement(), 0);
    479         prevEntry = null;
    480       }
    481     };
    482   }
    483 
    484   @Override
    485   public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
    486     return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo(
    487         comparator(),
    488         upperBound,
    489         boundType)), header);
    490   }
    491 
    492   @Override
    493   public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
    494     return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo(
    495         comparator(),
    496         lowerBound,
    497         boundType)), header);
    498   }
    499 
    500   static int distinctElements(@Nullable AvlNode<?> node) {
    501     return (node == null) ? 0 : node.distinctElements;
    502   }
    503 
    504   private static final class Reference<T> {
    505     @Nullable private T value;
    506 
    507     @Nullable public T get() {
    508       return value;
    509     }
    510 
    511     public void checkAndSet(@Nullable T expected, T newValue) {
    512       if (value != expected) {
    513         throw new ConcurrentModificationException();
    514       }
    515       value = newValue;
    516     }
    517   }
    518 
    519   private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
    520     @Nullable private final E elem;
    521 
    522     // elemCount is 0 iff this node has been deleted.
    523     private int elemCount;
    524 
    525     private int distinctElements;
    526     private long totalCount;
    527     private int height;
    528     private AvlNode<E> left;
    529     private AvlNode<E> right;
    530     private AvlNode<E> pred;
    531     private AvlNode<E> succ;
    532 
    533     AvlNode(@Nullable E elem, int elemCount) {
    534       checkArgument(elemCount > 0);
    535       this.elem = elem;
    536       this.elemCount = elemCount;
    537       this.totalCount = elemCount;
    538       this.distinctElements = 1;
    539       this.height = 1;
    540       this.left = null;
    541       this.right = null;
    542     }
    543 
    544     public int count(Comparator<? super E> comparator, E e) {
    545       int cmp = comparator.compare(e, elem);
    546       if (cmp < 0) {
    547         return (left == null) ? 0 : left.count(comparator, e);
    548       } else if (cmp > 0) {
    549         return (right == null) ? 0 : right.count(comparator, e);
    550       } else {
    551         return elemCount;
    552       }
    553     }
    554 
    555     private AvlNode<E> addRightChild(E e, int count) {
    556       right = new AvlNode<E>(e, count);
    557       successor(this, right, succ);
    558       height = Math.max(2, height);
    559       distinctElements++;
    560       totalCount += count;
    561       return this;
    562     }
    563 
    564     private AvlNode<E> addLeftChild(E e, int count) {
    565       left = new AvlNode<E>(e, count);
    566       successor(pred, left, this);
    567       height = Math.max(2, height);
    568       distinctElements++;
    569       totalCount += count;
    570       return this;
    571     }
    572 
    573     AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
    574       /*
    575        * It speeds things up considerably to unconditionally add count to totalCount here,
    576        * but that destroys failure atomicity in the case of count overflow. =(
    577        */
    578       int cmp = comparator.compare(e, elem);
    579       if (cmp < 0) {
    580         AvlNode<E> initLeft = left;
    581         if (initLeft == null) {
    582           result[0] = 0;
    583           return addLeftChild(e, count);
    584         }
    585         int initHeight = initLeft.height;
    586 
    587         left = initLeft.add(comparator, e, count, result);
    588         if (result[0] == 0) {
    589           distinctElements++;
    590         }
    591         this.totalCount += count;
    592         return (left.height == initHeight) ? this : rebalance();
    593       } else if (cmp > 0) {
    594         AvlNode<E> initRight = right;
    595         if (initRight == null) {
    596           result[0] = 0;
    597           return addRightChild(e, count);
    598         }
    599         int initHeight = initRight.height;
    600 
    601         right = initRight.add(comparator, e, count, result);
    602         if (result[0] == 0) {
    603           distinctElements++;
    604         }
    605         this.totalCount += count;
    606         return (right.height == initHeight) ? this : rebalance();
    607       }
    608 
    609       // adding count to me!  No rebalance possible.
    610       result[0] = elemCount;
    611       long resultCount = (long) elemCount + count;
    612       checkArgument(resultCount <= Integer.MAX_VALUE);
    613       this.elemCount += count;
    614       this.totalCount += count;
    615       return this;
    616     }
    617 
    618     AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
    619       int cmp = comparator.compare(e, elem);
    620       if (cmp < 0) {
    621         AvlNode<E> initLeft = left;
    622         if (initLeft == null) {
    623           result[0] = 0;
    624           return this;
    625         }
    626 
    627         left = initLeft.remove(comparator, e, count, result);
    628 
    629         if (result[0] > 0) {
    630           if (count >= result[0]) {
    631             this.distinctElements--;
    632             this.totalCount -= result[0];
    633           } else {
    634             this.totalCount -= count;
    635           }
    636         }
    637         return (result[0] == 0) ? this : rebalance();
    638       } else if (cmp > 0) {
    639         AvlNode<E> initRight = right;
    640         if (initRight == null) {
    641           result[0] = 0;
    642           return this;
    643         }
    644 
    645         right = initRight.remove(comparator, e, count, result);
    646 
    647         if (result[0] > 0) {
    648           if (count >= result[0]) {
    649             this.distinctElements--;
    650             this.totalCount -= result[0];
    651           } else {
    652             this.totalCount -= count;
    653           }
    654         }
    655         return rebalance();
    656       }
    657 
    658       // removing count from me!
    659       result[0] = elemCount;
    660       if (count >= elemCount) {
    661         return deleteMe();
    662       } else {
    663         this.elemCount -= count;
    664         this.totalCount -= count;
    665         return this;
    666       }
    667     }
    668 
    669     AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
    670       int cmp = comparator.compare(e, elem);
    671       if (cmp < 0) {
    672         AvlNode<E> initLeft = left;
    673         if (initLeft == null) {
    674           result[0] = 0;
    675           return (count > 0) ? addLeftChild(e, count) : this;
    676         }
    677 
    678         left = initLeft.setCount(comparator, e, count, result);
    679 
    680         if (count == 0 && result[0] != 0) {
    681           this.distinctElements--;
    682         } else if (count > 0 && result[0] == 0) {
    683           this.distinctElements++;
    684         }
    685 
    686         this.totalCount += count - result[0];
    687         return rebalance();
    688       } else if (cmp > 0) {
    689         AvlNode<E> initRight = right;
    690         if (initRight == null) {
    691           result[0] = 0;
    692           return (count > 0) ? addRightChild(e, count) : this;
    693         }
    694 
    695         right = initRight.setCount(comparator, e, count, result);
    696 
    697         if (count == 0 && result[0] != 0) {
    698           this.distinctElements--;
    699         } else if (count > 0 && result[0] == 0) {
    700           this.distinctElements++;
    701         }
    702 
    703         this.totalCount += count - result[0];
    704         return rebalance();
    705       }
    706 
    707       // setting my count
    708       result[0] = elemCount;
    709       if (count == 0) {
    710         return deleteMe();
    711       }
    712       this.totalCount += count - elemCount;
    713       this.elemCount = count;
    714       return this;
    715     }
    716 
    717     AvlNode<E> setCount(
    718         Comparator<? super E> comparator,
    719         @Nullable E e,
    720         int expectedCount,
    721         int newCount,
    722         int[] result) {
    723       int cmp = comparator.compare(e, elem);
    724       if (cmp < 0) {
    725         AvlNode<E> initLeft = left;
    726         if (initLeft == null) {
    727           result[0] = 0;
    728           if (expectedCount == 0 && newCount > 0) {
    729             return addLeftChild(e, newCount);
    730           }
    731           return this;
    732         }
    733 
    734         left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
    735 
    736         if (result[0] == expectedCount) {
    737           if (newCount == 0 && result[0] != 0) {
    738             this.distinctElements--;
    739           } else if (newCount > 0 && result[0] == 0) {
    740             this.distinctElements++;
    741           }
    742           this.totalCount += newCount - result[0];
    743         }
    744         return rebalance();
    745       } else if (cmp > 0) {
    746         AvlNode<E> initRight = right;
    747         if (initRight == null) {
    748           result[0] = 0;
    749           if (expectedCount == 0 && newCount > 0) {
    750             return addRightChild(e, newCount);
    751           }
    752           return this;
    753         }
    754 
    755         right = initRight.setCount(comparator, e, expectedCount, newCount, result);
    756 
    757         if (result[0] == expectedCount) {
    758           if (newCount == 0 && result[0] != 0) {
    759             this.distinctElements--;
    760           } else if (newCount > 0 && result[0] == 0) {
    761             this.distinctElements++;
    762           }
    763           this.totalCount += newCount - result[0];
    764         }
    765         return rebalance();
    766       }
    767 
    768       // setting my count
    769       result[0] = elemCount;
    770       if (expectedCount == elemCount) {
    771         if (newCount == 0) {
    772           return deleteMe();
    773         }
    774         this.totalCount += newCount - elemCount;
    775         this.elemCount = newCount;
    776       }
    777       return this;
    778     }
    779 
    780     private AvlNode<E> deleteMe() {
    781       int oldElemCount = this.elemCount;
    782       this.elemCount = 0;
    783       successor(pred, succ);
    784       if (left == null) {
    785         return right;
    786       } else if (right == null) {
    787         return left;
    788       } else if (left.height >= right.height) {
    789         AvlNode<E> newTop = pred;
    790         // newTop is the maximum node in my left subtree
    791         newTop.left = left.removeMax(newTop);
    792         newTop.right = right;
    793         newTop.distinctElements = distinctElements - 1;
    794         newTop.totalCount = totalCount - oldElemCount;
    795         return newTop.rebalance();
    796       } else {
    797         AvlNode<E> newTop = succ;
    798         newTop.right = right.removeMin(newTop);
    799         newTop.left = left;
    800         newTop.distinctElements = distinctElements - 1;
    801         newTop.totalCount = totalCount - oldElemCount;
    802         return newTop.rebalance();
    803       }
    804     }
    805 
    806     // Removes the minimum node from this subtree to be reused elsewhere
    807     private AvlNode<E> removeMin(AvlNode<E> node) {
    808       if (left == null) {
    809         return right;
    810       } else {
    811         left = left.removeMin(node);
    812         distinctElements--;
    813         totalCount -= node.elemCount;
    814         return rebalance();
    815       }
    816     }
    817 
    818     // Removes the maximum node from this subtree to be reused elsewhere
    819     private AvlNode<E> removeMax(AvlNode<E> node) {
    820       if (right == null) {
    821         return left;
    822       } else {
    823         right = right.removeMax(node);
    824         distinctElements--;
    825         totalCount -= node.elemCount;
    826         return rebalance();
    827       }
    828     }
    829 
    830     private void recomputeMultiset() {
    831       this.distinctElements = 1 + TreeMultiset.distinctElements(left)
    832           + TreeMultiset.distinctElements(right);
    833       this.totalCount = elemCount + totalCount(left) + totalCount(right);
    834     }
    835 
    836     private void recomputeHeight() {
    837       this.height = 1 + Math.max(height(left), height(right));
    838     }
    839 
    840     private void recompute() {
    841       recomputeMultiset();
    842       recomputeHeight();
    843     }
    844 
    845     private AvlNode<E> rebalance() {
    846       switch (balanceFactor()) {
    847         case -2:
    848           if (right.balanceFactor() > 0) {
    849             right = right.rotateRight();
    850           }
    851           return rotateLeft();
    852         case 2:
    853           if (left.balanceFactor() < 0) {
    854             left = left.rotateLeft();
    855           }
    856           return rotateRight();
    857         default:
    858           recomputeHeight();
    859           return this;
    860       }
    861     }
    862 
    863     private int balanceFactor() {
    864       return height(left) - height(right);
    865     }
    866 
    867     private AvlNode<E> rotateLeft() {
    868       checkState(right != null);
    869       AvlNode<E> newTop = right;
    870       this.right = newTop.left;
    871       newTop.left = this;
    872       newTop.totalCount = this.totalCount;
    873       newTop.distinctElements = this.distinctElements;
    874       this.recompute();
    875       newTop.recomputeHeight();
    876       return newTop;
    877     }
    878 
    879     private AvlNode<E> rotateRight() {
    880       checkState(left != null);
    881       AvlNode<E> newTop = left;
    882       this.left = newTop.right;
    883       newTop.right = this;
    884       newTop.totalCount = this.totalCount;
    885       newTop.distinctElements = this.distinctElements;
    886       this.recompute();
    887       newTop.recomputeHeight();
    888       return newTop;
    889     }
    890 
    891     private static long totalCount(@Nullable AvlNode<?> node) {
    892       return (node == null) ? 0 : node.totalCount;
    893     }
    894 
    895     private static int height(@Nullable AvlNode<?> node) {
    896       return (node == null) ? 0 : node.height;
    897     }
    898 
    899     @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
    900       int cmp = comparator.compare(e, elem);
    901       if (cmp < 0) {
    902         return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
    903       } else if (cmp == 0) {
    904         return this;
    905       } else {
    906         return (right == null) ? null : right.ceiling(comparator, e);
    907       }
    908     }
    909 
    910     @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
    911       int cmp = comparator.compare(e, elem);
    912       if (cmp > 0) {
    913         return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
    914       } else if (cmp == 0) {
    915         return this;
    916       } else {
    917         return (left == null) ? null : left.floor(comparator, e);
    918       }
    919     }
    920 
    921     @Override
    922     public E getElement() {
    923       return elem;
    924     }
    925 
    926     @Override
    927     public int getCount() {
    928       return elemCount;
    929     }
    930 
    931     @Override
    932     public String toString() {
    933       return Multisets.immutableEntry(getElement(), getCount()).toString();
    934     }
    935   }
    936 
    937   private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
    938     a.succ = b;
    939     b.pred = a;
    940   }
    941 
    942   private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
    943     successor(a, b);
    944     successor(b, c);
    945   }
    946 
    947   /*
    948    * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
    949    * calls the comparator to compare the two keys. If that change is made,
    950    * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
    951    */
    952 
    953   /**
    954    * @serialData the comparator, the number of distinct elements, the first element, its count, the
    955    *             second element, its count, and so on
    956    */
    957   @GwtIncompatible("java.io.ObjectOutputStream")
    958   private void writeObject(ObjectOutputStream stream) throws IOException {
    959     stream.defaultWriteObject();
    960     stream.writeObject(elementSet().comparator());
    961     Serialization.writeMultiset(this, stream);
    962   }
    963 
    964   @GwtIncompatible("java.io.ObjectInputStream")
    965   private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
    966     stream.defaultReadObject();
    967     @SuppressWarnings("unchecked")
    968     // reading data stored by writeObject
    969     Comparator<? super E> comparator = (Comparator<? super E>) stream.readObject();
    970     Serialization.getFieldSetter(AbstractSortedMultiset.class, "comparator").set(this, comparator);
    971     Serialization.getFieldSetter(TreeMultiset.class, "range").set(
    972         this,
    973         GeneralRange.all(comparator));
    974     Serialization.getFieldSetter(TreeMultiset.class, "rootReference").set(
    975         this,
    976         new Reference<AvlNode<E>>());
    977     AvlNode<E> header = new AvlNode<E>(null, 1);
    978     Serialization.getFieldSetter(TreeMultiset.class, "header").set(this, header);
    979     successor(header, header);
    980     Serialization.populateMultiset(this, stream);
    981   }
    982 
    983   @GwtIncompatible("not needed in emulated source") private static final long serialVersionUID = 1;
    984 }
    985