1 /* 2 * Copyright (C) 2014 Square, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.squareup.okhttp; 17 18 import com.squareup.okhttp.internal.NamedRunnable; 19 import com.squareup.okhttp.internal.Util; 20 import java.io.IOException; 21 import java.net.InetAddress; 22 import java.net.InetSocketAddress; 23 import java.net.ProtocolException; 24 import java.net.Proxy; 25 import java.net.ServerSocket; 26 import java.net.Socket; 27 import java.net.SocketException; 28 import java.util.concurrent.ExecutorService; 29 import java.util.concurrent.Executors; 30 import java.util.concurrent.TimeUnit; 31 import java.util.concurrent.atomic.AtomicInteger; 32 import java.util.logging.Level; 33 import java.util.logging.Logger; 34 import okio.Buffer; 35 import okio.BufferedSink; 36 import okio.BufferedSource; 37 import okio.Okio; 38 39 /** 40 * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer. 41 * See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>. 42 */ 43 public final class SocksProxy { 44 private static final int VERSION_5 = 5; 45 private static final int METHOD_NONE = 0xff; 46 private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0; 47 private static final int ADDRESS_TYPE_IPV4 = 1; 48 private static final int ADDRESS_TYPE_DOMAIN_NAME = 3; 49 private static final int COMMAND_CONNECT = 1; 50 private static final int REPLY_SUCCEEDED = 0; 51 52 private static final Logger logger = Logger.getLogger(SocksProxy.class.getName()); 53 54 private final ExecutorService executor = Executors.newCachedThreadPool( 55 Util.threadFactory("SocksProxy", false)); 56 57 private ServerSocket serverSocket; 58 private AtomicInteger connectionCount = new AtomicInteger(); 59 60 public void play() throws IOException { 61 serverSocket = new ServerSocket(0); 62 executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) { 63 @Override protected void execute() { 64 try { 65 while (true) { 66 Socket socket = serverSocket.accept(); 67 connectionCount.incrementAndGet(); 68 service(socket); 69 } 70 } catch (SocketException e) { 71 logger.info(name + " done accepting connections: " + e.getMessage()); 72 } catch (IOException e) { 73 logger.log(Level.WARNING, name + " failed unexpectedly", e); 74 } 75 } 76 }); 77 } 78 79 public Proxy proxy() { 80 return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved( 81 "localhost", serverSocket.getLocalPort())); 82 } 83 84 public int connectionCount() { 85 return connectionCount.get(); 86 } 87 88 public void shutdown() throws Exception { 89 serverSocket.close(); 90 executor.shutdown(); 91 if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { 92 throw new IOException("Gave up waiting for executor to shut down"); 93 } 94 } 95 96 private void service(final Socket from) { 97 executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) { 98 @Override protected void execute() { 99 try { 100 BufferedSource fromSource = Okio.buffer(Okio.source(from)); 101 BufferedSink fromSink = Okio.buffer(Okio.sink(from)); 102 hello(fromSource, fromSink); 103 acceptCommand(from.getInetAddress(), fromSource, fromSink); 104 } catch (IOException e) { 105 logger.log(Level.WARNING, name + " failed", e); 106 Util.closeQuietly(from); 107 } 108 } 109 }); 110 } 111 112 private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException { 113 int version = fromSource.readByte() & 0xff; 114 int methodCount = fromSource.readByte() & 0xff; 115 int selectedMethod = METHOD_NONE; 116 117 if (version != VERSION_5) { 118 throw new ProtocolException("unsupported version: " + version); 119 } 120 121 for (int i = 0; i < methodCount; i++) { 122 int candidateMethod = fromSource.readByte() & 0xff; 123 if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) { 124 selectedMethod = candidateMethod; 125 } 126 } 127 128 switch (selectedMethod) { 129 case METHOD_NO_AUTHENTICATION_REQUIRED: 130 fromSink.writeByte(VERSION_5); 131 fromSink.writeByte(selectedMethod); 132 fromSink.emit(); 133 break; 134 135 default: 136 throw new ProtocolException("unsupported method: " + selectedMethod); 137 } 138 } 139 140 private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource, 141 BufferedSink fromSink) throws IOException { 142 // Read the command. 143 int version = fromSource.readByte() & 0xff; 144 if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version); 145 int command = fromSource.readByte() & 0xff; 146 int reserved = fromSource.readByte() & 0xff; 147 if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved); 148 149 int addressType = fromSource.readByte() & 0xff; 150 InetAddress toAddress; 151 switch (addressType) { 152 case ADDRESS_TYPE_IPV4: 153 toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L)); 154 break; 155 156 case ADDRESS_TYPE_DOMAIN_NAME: 157 int domainNameLength = fromSource.readByte() & 0xff; 158 String domainName = fromSource.readUtf8(domainNameLength); 159 toAddress = InetAddress.getByName(domainName); 160 break; 161 162 default: 163 throw new ProtocolException("unsupported address type: " + addressType); 164 } 165 166 int port = fromSource.readShort() & 0xffff; 167 168 switch (command) { 169 case COMMAND_CONNECT: 170 Socket toSocket = new Socket(toAddress, port); 171 byte[] localAddress = toSocket.getLocalAddress().getAddress(); 172 if (localAddress.length != 4) { 173 throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress()); 174 } 175 176 // Write the reply. 177 fromSink.writeByte(VERSION_5); 178 fromSink.writeByte(REPLY_SUCCEEDED); 179 fromSink.writeByte(0); 180 fromSink.writeByte(ADDRESS_TYPE_IPV4); 181 fromSink.write(localAddress); 182 fromSink.writeShort(toSocket.getLocalPort()); 183 fromSink.emit(); 184 185 logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress); 186 187 // Copy sources to sinks in both directions. 188 BufferedSource toSource = Okio.buffer(Okio.source(toSocket)); 189 BufferedSink toSink = Okio.buffer(Okio.sink(toSocket)); 190 transfer(fromAddress, toAddress, fromSource, toSink); 191 transfer(fromAddress, toAddress, toSource, fromSink); 192 break; 193 194 default: 195 throw new ProtocolException("unexpected command: " + command); 196 } 197 } 198 199 private void transfer(final InetAddress fromAddress, final InetAddress toAddress, 200 final BufferedSource source, final BufferedSink sink) { 201 executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) { 202 @Override protected void execute() { 203 Buffer buffer = new Buffer(); 204 try { 205 while (true) { 206 long byteCount = source.read(buffer, 2048L); 207 if (byteCount == -1L) break; 208 sink.write(buffer, byteCount); 209 sink.emit(); 210 } 211 } catch (SocketException e) { 212 logger.info(name + " done: " + e.getMessage()); 213 } catch (IOException e) { 214 logger.log(Level.WARNING, name + " failed", e); 215 } 216 217 try { 218 source.close(); 219 } catch (IOException e) { 220 logger.log(Level.WARNING, name + " failed", e); 221 } 222 223 try { 224 sink.close(); 225 } catch (IOException e) { 226 logger.log(Level.WARNING, name + " failed", e); 227 } 228 } 229 }); 230 } 231 } 232