Home | History | Annotate | Download | only in generator
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 import com.android.jack.annotations.CalledByInvokeCustom;
     18 import com.android.jack.annotations.Constant;
     19 import com.android.jack.annotations.LinkerMethodHandle;
     20 import com.android.jack.annotations.MethodHandleKind;
     21 
     22 import java.lang.invoke.CallSite;
     23 import java.lang.invoke.ConstantCallSite;
     24 import java.lang.invoke.MethodHandle;
     25 import java.lang.invoke.MethodHandles;
     26 import java.lang.invoke.MethodType;
     27 
     28 import java.lang.Thread;
     29 import java.lang.ThreadLocal;
     30 import java.util.concurrent.atomic.AtomicInteger;
     31 import java.util.concurrent.CyclicBarrier;
     32 
     33 public class TestInvokeCustomWithConcurrentThreads extends Thread {
     34   private static final int NUMBER_OF_THREADS = 16;
     35 
     36   private static final AtomicInteger nextIndex = new AtomicInteger(0);
     37 
     38   private static final ThreadLocal<Integer> threadIndex =
     39       new ThreadLocal<Integer>() {
     40         @Override
     41         protected Integer initialValue() {
     42           return nextIndex.getAndIncrement();
     43         }
     44       };
     45 
     46   // Array of call sites instantiated, one per thread
     47   private static final CallSite[] instantiated = new CallSite[NUMBER_OF_THREADS];
     48 
     49   // Array of counters for how many times each instantiated call site is called
     50   private static final AtomicInteger[] called = new AtomicInteger[NUMBER_OF_THREADS];
     51 
     52   // Array of call site indicies of which call site a thread invoked
     53   private static final AtomicInteger[] targetted = new AtomicInteger[NUMBER_OF_THREADS];
     54 
     55   // Synchronization barrier all threads will wait on in the bootstrap method.
     56   private static final CyclicBarrier barrier = new CyclicBarrier(NUMBER_OF_THREADS);
     57 
     58   private TestInvokeCustomWithConcurrentThreads() {}
     59 
     60   private static int getThreadIndex() {
     61     return threadIndex.get().intValue();
     62   }
     63 
     64   public static int notUsed(int x) {
     65     return x;
     66   }
     67 
     68   @Override
     69   public void run() {
     70     int x = setCalled(-1 /* argument dropped */);
     71     notUsed(x);
     72   }
     73 
     74   @CalledByInvokeCustom(
     75       invokeMethodHandle = @LinkerMethodHandle(kind = MethodHandleKind.INVOKE_STATIC,
     76           enclosingType = TestInvokeCustomWithConcurrentThreads.class,
     77           name = "linkerMethod",
     78           argumentTypes = {MethodHandles.Lookup.class, String.class, MethodType.class}),
     79       name = "setCalled",
     80       returnType = int.class,
     81       argumentTypes = {int.class})
     82   private static int setCalled(int index) {
     83     called[index].getAndIncrement();
     84     targetted[getThreadIndex()].set(index);
     85     return 0;
     86   }
     87 
     88   @SuppressWarnings("unused")
     89   private static CallSite linkerMethod(MethodHandles.Lookup caller,
     90                                        String name,
     91                                        MethodType methodType) throws Throwable {
     92     int threadIndex = getThreadIndex();
     93     MethodHandle mh =
     94         caller.findStatic(TestInvokeCustomWithConcurrentThreads.class, name, methodType);
     95     assertEquals(methodType, mh.type());
     96     assertEquals(mh.type().parameterCount(), 1);
     97     mh = MethodHandles.insertArguments(mh, 0, threadIndex);
     98     mh = MethodHandles.dropArguments(mh, 0, int.class);
     99     assertEquals(mh.type().parameterCount(), 1);
    100     assertEquals(methodType, mh.type());
    101 
    102     // Wait for all threads to be in this method.
    103     // Multiple call sites should be created, but only one
    104     // invoked.
    105     barrier.await();
    106 
    107     instantiated[getThreadIndex()] = new ConstantCallSite(mh);
    108     return instantiated[getThreadIndex()];
    109   }
    110 
    111   public static void test() throws Throwable {
    112     // Initialize counters for which call site gets invoked
    113     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    114       called[i] = new AtomicInteger(0);
    115       targetted[i] = new AtomicInteger(0);
    116     }
    117 
    118     // Run threads that each invoke-custom the call site
    119     Thread [] threads = new Thread[NUMBER_OF_THREADS];
    120     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    121       threads[i] = new TestInvokeCustomWithConcurrentThreads();
    122       threads[i].start();
    123     }
    124 
    125     // Wait for all threads to complete
    126     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    127       threads[i].join();
    128     }
    129 
    130     // Check one call site instance won
    131     int winners = 0;
    132     int votes = 0;
    133     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    134       assertNotEquals(instantiated[i], null);
    135       if (called[i].get() != 0) {
    136         winners++;
    137         votes += called[i].get();
    138       }
    139     }
    140 
    141     System.out.println("Winners " + winners + " Votes " + votes);
    142 
    143     // We assert this below but output details when there's an error as
    144     // it's non-deterministic.
    145     if (winners != 1) {
    146       System.out.println("Threads did not the same call-sites:");
    147       for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    148         System.out.format(" Thread % 2d invoked call site instance #%02d\n",
    149                           i, targetted[i].get());
    150       }
    151     }
    152 
    153     // We assert this below but output details when there's an error as
    154     // it's non-deterministic.
    155     if (votes != NUMBER_OF_THREADS) {
    156       System.out.println("Call-sites invocations :");
    157       for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
    158         System.out.format(" Call site instance #%02d was invoked % 2d times\n",
    159                           i, called[i].get());
    160       }
    161     }
    162 
    163     assertEquals(winners, 1);
    164     assertEquals(votes, NUMBER_OF_THREADS);
    165   }
    166 
    167   public static void assertTrue(boolean value) {
    168     if (!value) {
    169       throw new AssertionError("assertTrue value: " + value);
    170     }
    171   }
    172 
    173   public static void assertEquals(byte b1, byte b2) {
    174     if (b1 == b2) { return; }
    175     throw new AssertionError("assertEquals b1: " + b1 + ", b2: " + b2);
    176   }
    177 
    178   public static void assertEquals(char c1, char c2) {
    179     if (c1 == c2) { return; }
    180     throw new AssertionError("assertEquals c1: " + c1 + ", c2: " + c2);
    181   }
    182 
    183   public static void assertEquals(short s1, short s2) {
    184     if (s1 == s2) { return; }
    185     throw new AssertionError("assertEquals s1: " + s1 + ", s2: " + s2);
    186   }
    187 
    188   public static void assertEquals(int i1, int i2) {
    189     if (i1 == i2) { return; }
    190     throw new AssertionError("assertEquals i1: " + i1 + ", i2: " + i2);
    191   }
    192 
    193   public static void assertEquals(long l1, long l2) {
    194     if (l1 == l2) { return; }
    195     throw new AssertionError("assertEquals l1: " + l1 + ", l2: " + l2);
    196   }
    197 
    198   public static void assertEquals(float f1, float f2) {
    199     if (f1 == f2) { return; }
    200     throw new AssertionError("assertEquals f1: " + f1 + ", f2: " + f2);
    201   }
    202 
    203   public static void assertEquals(double d1, double d2) {
    204     if (d1 == d2) { return; }
    205     throw new AssertionError("assertEquals d1: " + d1 + ", d2: " + d2);
    206   }
    207 
    208   public static void assertEquals(Object o, Object p) {
    209     if (o == p) { return; }
    210     if (o != null && p != null && o.equals(p)) { return; }
    211     throw new AssertionError("assertEquals: o1: " + o + ", o2: " + p);
    212   }
    213 
    214   public static void assertNotEquals(Object o, Object p) {
    215     if (o != p) { return; }
    216     if (o != null && p != null && !o.equals(p)) { return; }
    217     throw new AssertionError("assertNotEquals: o1: " + o + ", o2: " + p);
    218   }
    219 
    220   public static void assertEquals(String s1, String s2) {
    221     if (s1 == s2) {
    222       return;
    223     }
    224 
    225     if (s1 != null && s2 != null && s1.equals(s2)) {
    226       return;
    227     }
    228 
    229     throw new AssertionError("assertEquals s1: " + s1 + ", s2: " + s2);
    230   }
    231 }
    232