1 package com.android.server.wifi; 2 3 import static org.mockito.Matchers.anyInt; 4 import static org.mockito.Matchers.anyString; 5 import static org.mockito.Mockito.mock; 6 import static org.mockito.Mockito.when; 7 8 import android.security.KeyStore; 9 import android.util.SparseArray; 10 11 import org.mockito.Matchers; 12 import org.mockito.invocation.InvocationOnMock; 13 import org.mockito.stubbing.Answer; 14 15 import java.util.Arrays; 16 import java.util.HashMap; 17 18 class MockKeyStore { 19 20 public static class KeyBlob { 21 public byte[] blob; 22 public int flag; 23 24 public void update(byte[] blob, int flag) { 25 this.blob = Arrays.copyOf(blob, blob.length); 26 this.flag = flag; 27 } 28 } 29 private SparseArray<HashMap<String, KeyBlob>> mStore; 30 31 public MockKeyStore() { 32 mStore = new SparseArray<HashMap<String, KeyBlob>>(); 33 } 34 35 public KeyStore createMock() { 36 KeyStore mock = mock(KeyStore.class); 37 when(mock.state()).thenReturn(KeyStore.State.UNLOCKED); 38 39 when(mock.put(anyString(), Matchers.any(byte[].class), anyInt(), anyInt())) 40 .thenAnswer(new Answer<Boolean>() { 41 42 @Override 43 public Boolean answer(InvocationOnMock invocation) throws Throwable { 44 Object[] args = invocation.getArguments(); 45 return put((String) args[0], (byte[]) args[1], (Integer) args[2], 46 (Integer) args[3]); 47 } 48 }); 49 50 when(mock.importKey(anyString(), Matchers.any(byte[].class), anyInt(), anyInt())) 51 .thenAnswer(new Answer<Boolean>() { 52 53 @Override 54 public Boolean answer(InvocationOnMock invocation) throws Throwable { 55 Object[] args = invocation.getArguments(); 56 return importKey((String) args[0], (byte[]) args[1], (Integer) args[2], 57 (Integer) args[3]); 58 } 59 }); 60 61 when(mock.delete(anyString(), anyInt())).thenAnswer(new Answer<Boolean>() { 62 63 @Override 64 public Boolean answer(InvocationOnMock invocation) throws Throwable { 65 Object[] args = invocation.getArguments(); 66 return delete((String) args[0], (Integer) args[1]); 67 } 68 }); 69 70 when(mock.contains(anyString(), anyInt())).thenAnswer(new Answer<Boolean>() { 71 72 @Override 73 public Boolean answer(InvocationOnMock invocation) throws Throwable { 74 Object[] args = invocation.getArguments(); 75 return contains((String) args[0], (Integer) args[1]); 76 } 77 }); 78 79 when(mock.duplicate(anyString(), anyInt(), anyString(), anyInt())) 80 .thenAnswer(new Answer<Boolean>() { 81 @Override 82 public Boolean answer(InvocationOnMock invocation) throws Throwable { 83 Object[] args = invocation.getArguments(); 84 return duplicate((String) args[0], (Integer) args[1], (String) args[2], 85 (Integer) args[3]); 86 } 87 }); 88 return mock; 89 } 90 91 private KeyBlob access(int uid, String key, boolean createIfNotExist) { 92 if (mStore.get(uid) == null) { 93 mStore.put(uid, new HashMap<String, KeyBlob>()); 94 } 95 HashMap<String, KeyBlob> map = mStore.get(uid); 96 if (map.containsKey(key)) { 97 return map.get(key); 98 } else { 99 if (createIfNotExist) { 100 KeyBlob blob = new KeyBlob(); 101 map.put(key, blob); 102 return blob; 103 } else { 104 return null; 105 } 106 } 107 } 108 109 public KeyBlob getKeyBlob(int uid, String key) { 110 return access(uid, key, false); 111 } 112 113 private boolean put(String key, byte[] value, int uid, int flags) { 114 access(uid, key, true).update(value, flags); 115 return true; 116 } 117 118 private boolean importKey(String keyName, byte[] key, int uid, int flags) { 119 return put(keyName, key, uid, flags); 120 } 121 122 private boolean delete(String key, int uid) { 123 if (mStore.get(uid) != null) { 124 mStore.get(uid).remove(key); 125 } 126 return true; 127 } 128 129 private boolean contains(String key, int uid) { 130 return access(uid, key, false) != null; 131 } 132 133 private boolean duplicate(String srcKey, int srcUid, String destKey, int destUid) { 134 for (int i = 0; i < mStore.size(); i++) { 135 int key = mStore.keyAt(i); 136 // Cannot copy to itself 137 if (srcKey.equals(destKey) && key == destUid) { 138 continue; 139 } 140 if (srcUid == -1 || srcUid == key) { 141 HashMap<String, KeyBlob> map = mStore.get(key); 142 if (map.containsKey(srcKey)) { 143 KeyBlob blob = map.get(srcKey); 144 access(destUid, destKey, true).update(blob.blob, blob.flag); 145 break; 146 } 147 } 148 } 149 return true; 150 } 151 152 @Override 153 public String toString() { 154 StringBuilder sb = new StringBuilder(); 155 sb.append("KeyStore {"); 156 for (int i = 0; i < mStore.size(); i++) { 157 int uid = mStore.keyAt(i); 158 for (String keyName : mStore.get(uid).keySet()) { 159 sb.append(String.format("%d:%s, ", uid, keyName)); 160 } 161 } 162 sb.append("}"); 163 return sb.toString(); 164 } 165 }