Home | History | Annotate | Download | only in lockedregioncodeinjection
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
      5  * in compliance with the License. You may obtain a copy of the License at
      6  *
      7  * http://www.apache.org/licenses/LICENSE-2.0
      8  *
      9  * Unless required by applicable law or agreed to in writing, software distributed under the License
     10  * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
     11  * or implied. See the License for the specific language governing permissions and limitations under
     12  * the License.
     13  */
     14 package lockedregioncodeinjection;
     15 
     16 import java.util.ArrayList;
     17 import java.util.Arrays;
     18 import java.util.LinkedList;
     19 import java.util.List;
     20 import org.objectweb.asm.ClassVisitor;
     21 import org.objectweb.asm.MethodVisitor;
     22 import org.objectweb.asm.Opcodes;
     23 import org.objectweb.asm.commons.TryCatchBlockSorter;
     24 import org.objectweb.asm.tree.AbstractInsnNode;
     25 import org.objectweb.asm.tree.InsnList;
     26 import org.objectweb.asm.tree.LabelNode;
     27 import org.objectweb.asm.tree.MethodInsnNode;
     28 import org.objectweb.asm.tree.MethodNode;
     29 import org.objectweb.asm.tree.TryCatchBlockNode;
     30 import org.objectweb.asm.tree.analysis.Analyzer;
     31 import org.objectweb.asm.tree.analysis.AnalyzerException;
     32 import org.objectweb.asm.tree.analysis.BasicValue;
     33 import org.objectweb.asm.tree.analysis.Frame;
     34 
     35 /**
     36  * This visitor does two things:
     37  *
     38  * 1. Finds all the MONITOR_ENTER / MONITOR_EXIT in the byte code and insert the corresponding pre
     39  * and post methods calls should it matches one of the given target type in the Configuration.
     40  *
     41  * 2. Find all methods that are synchronized and insert pre method calls in the beginning and post
     42  * method calls just before all return instructions.
     43  */
     44 class LockFindingClassVisitor extends ClassVisitor {
     45     private String className = null;
     46     private final List<LockTarget> targets;
     47 
     48     public LockFindingClassVisitor(List<LockTarget> targets, ClassVisitor chain) {
     49         super(Utils.ASM_VERSION, chain);
     50         this.targets = targets;
     51     }
     52 
     53     @Override
     54     public MethodVisitor visitMethod(int access, String name, String desc, String signature,
     55             String[] exceptions) {
     56         assert this.className != null;
     57         MethodNode mn = new TryCatchBlockSorter(null, access, name, desc, signature, exceptions);
     58         MethodVisitor chain = super.visitMethod(access, name, desc, signature, exceptions);
     59         return new LockFindingMethodVisitor(this.className, mn, chain);
     60     }
     61 
     62     @Override
     63     public void visit(int version, int access, String name, String signature, String superName,
     64             String[] interfaces) {
     65         this.className = name;
     66         super.visit(version, access, name, signature, superName, interfaces);
     67     }
     68 
     69     class LockFindingMethodVisitor extends MethodVisitor {
     70         private String owner;
     71         private MethodVisitor chain;
     72 
     73         public LockFindingMethodVisitor(String owner, MethodNode mn, MethodVisitor chain) {
     74             super(Opcodes.ASM5, mn);
     75             assert owner != null;
     76             this.owner = owner;
     77             this.chain = chain;
     78         }
     79 
     80         @SuppressWarnings("unchecked")
     81         @Override
     82         public void visitEnd() {
     83             MethodNode mn = (MethodNode) mv;
     84 
     85             Analyzer a = new Analyzer(new LockTargetStateAnalysis(targets));
     86 
     87             LockTarget ownerMonitor = null;
     88             if ((mn.access & Opcodes.ACC_SYNCHRONIZED) != 0) {
     89                 for (LockTarget t : targets) {
     90                     if (t.getTargetDesc().equals("L" + owner + ";")) {
     91                         ownerMonitor = t;
     92                     }
     93                 }
     94             }
     95 
     96             try {
     97                 a.analyze(owner, mn);
     98             } catch (AnalyzerException e) {
     99                 e.printStackTrace();
    100             }
    101             InsnList instructions = mn.instructions;
    102 
    103             Frame[] frames = a.getFrames();
    104             List<Frame> frameMap = new LinkedList<>();
    105             frameMap.addAll(Arrays.asList(frames));
    106 
    107             List<List<TryCatchBlockNode>> handlersMap = new LinkedList<>();
    108 
    109             for (int i = 0; i < instructions.size(); i++) {
    110                 handlersMap.add(a.getHandlers(i));
    111             }
    112 
    113             if (ownerMonitor != null) {
    114                 AbstractInsnNode s = instructions.getFirst();
    115                 MethodInsnNode call = new MethodInsnNode(Opcodes.INVOKESTATIC,
    116                         ownerMonitor.getPreOwner(), ownerMonitor.getPreMethod(), "()V", false);
    117                 insertMethodCallBefore(mn, frameMap, handlersMap, s, 0, call);
    118             }
    119 
    120             for (int i = 0; i < instructions.size(); i++) {
    121                 AbstractInsnNode s = instructions.get(i);
    122 
    123                 if (s.getOpcode() == Opcodes.MONITORENTER) {
    124                     Frame f = frameMap.get(i);
    125                     BasicValue operand = (BasicValue) f.getStack(f.getStackSize() - 1);
    126                     if (operand instanceof LockTargetState) {
    127                         LockTargetState state = (LockTargetState) operand;
    128                         for (int j = 0; j < state.getTargets().size(); j++) {
    129                             LockTarget target = state.getTargets().get(j);
    130                             MethodInsnNode call = new MethodInsnNode(Opcodes.INVOKESTATIC,
    131                                     target.getPreOwner(), target.getPreMethod(), "()V", false);
    132                             insertMethodCallAfter(mn, frameMap, handlersMap, s, i, call);
    133                         }
    134                     }
    135                 }
    136 
    137                 if (s.getOpcode() == Opcodes.MONITOREXIT) {
    138                     Frame f = frameMap.get(i);
    139                     BasicValue operand = (BasicValue) f.getStack(f.getStackSize() - 1);
    140                     if (operand instanceof LockTargetState) {
    141                         LockTargetState state = (LockTargetState) operand;
    142                         for (int j = 0; j < state.getTargets().size(); j++) {
    143                             LockTarget target = state.getTargets().get(j);
    144                             MethodInsnNode call = new MethodInsnNode(Opcodes.INVOKESTATIC,
    145                                     target.getPostOwner(), target.getPostMethod(), "()V", false);
    146                             insertMethodCallAfter(mn, frameMap, handlersMap, s, i, call);
    147                         }
    148                     }
    149                 }
    150 
    151                 if (ownerMonitor != null && (s.getOpcode() == Opcodes.RETURN
    152                         || s.getOpcode() == Opcodes.ARETURN || s.getOpcode() == Opcodes.DRETURN
    153                         || s.getOpcode() == Opcodes.FRETURN || s.getOpcode() == Opcodes.IRETURN)) {
    154                     MethodInsnNode call =
    155                             new MethodInsnNode(Opcodes.INVOKESTATIC, ownerMonitor.getPostOwner(),
    156                                     ownerMonitor.getPostMethod(), "()V", false);
    157                     insertMethodCallBefore(mn, frameMap, handlersMap, s, i, call);
    158                     i++; // Skip ahead. Otherwise, we will revisit this instruction again.
    159                 }
    160             }
    161             super.visitEnd();
    162             mn.accept(chain);
    163         }
    164     }
    165 
    166     public static void insertMethodCallBefore(MethodNode mn, List<Frame> frameMap,
    167             List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
    168             MethodInsnNode call) {
    169         List<TryCatchBlockNode> handlers = handlersMap.get(index);
    170         InsnList instructions = mn.instructions;
    171         LabelNode end = new LabelNode();
    172         instructions.insert(node, end);
    173         frameMap.add(index, null);
    174         handlersMap.add(index, null);
    175         instructions.insertBefore(node, call);
    176         frameMap.add(index, null);
    177         handlersMap.add(index, null);
    178 
    179         LabelNode start = new LabelNode();
    180         instructions.insert(node, start);
    181         frameMap.add(index, null);
    182         handlersMap.add(index, null);
    183         updateCatchHandler(mn, handlers, start, end, handlersMap);
    184     }
    185 
    186     public static void insertMethodCallAfter(MethodNode mn, List<Frame> frameMap,
    187             List<List<TryCatchBlockNode>> handlersMap, AbstractInsnNode node, int index,
    188             MethodInsnNode call) {
    189         List<TryCatchBlockNode> handlers = handlersMap.get(index + 1);
    190         InsnList instructions = mn.instructions;
    191 
    192         LabelNode end = new LabelNode();
    193         instructions.insert(node, end);
    194         frameMap.add(index + 1, null);
    195         handlersMap.add(index + 1, null);
    196 
    197         instructions.insert(node, call);
    198         frameMap.add(index + 1, null);
    199         handlersMap.add(index + 1, null);
    200 
    201         LabelNode start = new LabelNode();
    202         instructions.insert(node, start);
    203         frameMap.add(index + 1, null);
    204         handlersMap.add(index + 1, null);
    205 
    206         updateCatchHandler(mn, handlers, start, end, handlersMap);
    207     }
    208 
    209     @SuppressWarnings("unchecked")
    210     public static void updateCatchHandler(MethodNode mn, List<TryCatchBlockNode> handlers,
    211             LabelNode start, LabelNode end, List<List<TryCatchBlockNode>> handlersMap) {
    212         if (handlers == null || handlers.size() == 0) {
    213             return;
    214         }
    215 
    216         InsnList instructions = mn.instructions;
    217         List<TryCatchBlockNode> newNodes = new ArrayList<>(handlers.size());
    218         for (TryCatchBlockNode handler : handlers) {
    219             if (!(instructions.indexOf(handler.start) <= instructions.indexOf(start)
    220                     && instructions.indexOf(end) <= instructions.indexOf(handler.end))) {
    221                 TryCatchBlockNode newNode =
    222                         new TryCatchBlockNode(start, end, handler.handler, handler.type);
    223                 newNodes.add(newNode);
    224                 for (int i = instructions.indexOf(start); i <= instructions.indexOf(end); i++) {
    225                     if (handlersMap.get(i) == null) {
    226                         handlersMap.set(i, new ArrayList<>());
    227                     }
    228                     handlersMap.get(i).add(newNode);
    229                 }
    230             } else {
    231                 for (int i = instructions.indexOf(start); i <= instructions.indexOf(end); i++) {
    232                     if (handlersMap.get(i) == null) {
    233                         handlersMap.set(i, new ArrayList<>());
    234                     }
    235                     handlersMap.get(i).add(handler);
    236                 }
    237             }
    238         }
    239         mn.tryCatchBlocks.addAll(0, newNodes);
    240     }
    241 }
    242