Home | History | Annotate | Download | only in okhttp
      1 /*
      2  * Copyright (C) 2014 Square, Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.squareup.okhttp;
     17 
     18 import com.squareup.okhttp.internal.NamedRunnable;
     19 import com.squareup.okhttp.internal.Util;
     20 import java.io.IOException;
     21 import java.net.InetAddress;
     22 import java.net.InetSocketAddress;
     23 import java.net.ProtocolException;
     24 import java.net.Proxy;
     25 import java.net.ServerSocket;
     26 import java.net.Socket;
     27 import java.net.SocketException;
     28 import java.util.concurrent.ExecutorService;
     29 import java.util.concurrent.Executors;
     30 import java.util.concurrent.TimeUnit;
     31 import java.util.concurrent.atomic.AtomicInteger;
     32 import java.util.logging.Level;
     33 import java.util.logging.Logger;
     34 import okio.Buffer;
     35 import okio.BufferedSink;
     36 import okio.BufferedSource;
     37 import okio.Okio;
     38 
     39 /**
     40  * A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer.
     41  * See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>.
     42  */
     43 public final class SocksProxy {
     44   private static final int VERSION_5 = 5;
     45   private static final int METHOD_NONE = 0xff;
     46   private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0;
     47   private static final int ADDRESS_TYPE_IPV4 = 1;
     48   private static final int ADDRESS_TYPE_DOMAIN_NAME = 3;
     49   private static final int COMMAND_CONNECT = 1;
     50   private static final int REPLY_SUCCEEDED = 0;
     51 
     52   private static final Logger logger = Logger.getLogger(SocksProxy.class.getName());
     53 
     54   private final ExecutorService executor = Executors.newCachedThreadPool(
     55       Util.threadFactory("SocksProxy", false));
     56 
     57   private ServerSocket serverSocket;
     58   private AtomicInteger connectionCount = new AtomicInteger();
     59 
     60   public void play() throws IOException {
     61     serverSocket = new ServerSocket(0);
     62     executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) {
     63       @Override protected void execute() {
     64         try {
     65           while (true) {
     66             Socket socket = serverSocket.accept();
     67             connectionCount.incrementAndGet();
     68             service(socket);
     69           }
     70         } catch (SocketException e) {
     71           logger.info(name + " done accepting connections: " + e.getMessage());
     72         } catch (IOException e) {
     73           logger.log(Level.WARNING, name + " failed unexpectedly", e);
     74         }
     75       }
     76     });
     77   }
     78 
     79   public Proxy proxy() {
     80     return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved(
     81         "localhost", serverSocket.getLocalPort()));
     82   }
     83 
     84   public int connectionCount() {
     85     return connectionCount.get();
     86   }
     87 
     88   public void shutdown() throws Exception {
     89     serverSocket.close();
     90     executor.shutdown();
     91     if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
     92       throw new IOException("Gave up waiting for executor to shut down");
     93     }
     94   }
     95 
     96   private void service(final Socket from) {
     97     executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) {
     98       @Override protected void execute() {
     99         try {
    100           BufferedSource fromSource = Okio.buffer(Okio.source(from));
    101           BufferedSink fromSink = Okio.buffer(Okio.sink(from));
    102           hello(fromSource, fromSink);
    103           acceptCommand(from.getInetAddress(), fromSource, fromSink);
    104         } catch (IOException e) {
    105           logger.log(Level.WARNING, name + " failed", e);
    106           Util.closeQuietly(from);
    107         }
    108       }
    109     });
    110   }
    111 
    112   private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException {
    113     int version = fromSource.readByte() & 0xff;
    114     int methodCount = fromSource.readByte() & 0xff;
    115     int selectedMethod = METHOD_NONE;
    116 
    117     if (version != VERSION_5) {
    118       throw new ProtocolException("unsupported version: " + version);
    119     }
    120 
    121     for (int i = 0; i < methodCount; i++) {
    122       int candidateMethod = fromSource.readByte() & 0xff;
    123       if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) {
    124         selectedMethod = candidateMethod;
    125       }
    126     }
    127 
    128     switch (selectedMethod) {
    129       case METHOD_NO_AUTHENTICATION_REQUIRED:
    130         fromSink.writeByte(VERSION_5);
    131         fromSink.writeByte(selectedMethod);
    132         fromSink.emit();
    133         break;
    134 
    135       default:
    136         throw new ProtocolException("unsupported method: " + selectedMethod);
    137     }
    138   }
    139 
    140   private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource,
    141       BufferedSink fromSink) throws IOException {
    142     // Read the command.
    143     int version = fromSource.readByte() & 0xff;
    144     if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version);
    145     int command = fromSource.readByte() & 0xff;
    146     int reserved = fromSource.readByte() & 0xff;
    147     if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved);
    148 
    149     int addressType = fromSource.readByte() & 0xff;
    150     InetAddress toAddress;
    151     switch (addressType) {
    152       case ADDRESS_TYPE_IPV4:
    153         toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L));
    154         break;
    155 
    156       case ADDRESS_TYPE_DOMAIN_NAME:
    157         int domainNameLength = fromSource.readByte() & 0xff;
    158         String domainName = fromSource.readUtf8(domainNameLength);
    159         toAddress = InetAddress.getByName(domainName);
    160         break;
    161 
    162       default:
    163         throw new ProtocolException("unsupported address type: " + addressType);
    164     }
    165 
    166     int port = fromSource.readShort() & 0xffff;
    167 
    168     switch (command) {
    169       case COMMAND_CONNECT:
    170         Socket toSocket = new Socket(toAddress, port);
    171         byte[] localAddress = toSocket.getLocalAddress().getAddress();
    172         if (localAddress.length != 4) {
    173           throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress());
    174         }
    175 
    176         // Write the reply.
    177         fromSink.writeByte(VERSION_5);
    178         fromSink.writeByte(REPLY_SUCCEEDED);
    179         fromSink.writeByte(0);
    180         fromSink.writeByte(ADDRESS_TYPE_IPV4);
    181         fromSink.write(localAddress);
    182         fromSink.writeShort(toSocket.getLocalPort());
    183         fromSink.emit();
    184 
    185         logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress);
    186 
    187         // Copy sources to sinks in both directions.
    188         BufferedSource toSource = Okio.buffer(Okio.source(toSocket));
    189         BufferedSink toSink = Okio.buffer(Okio.sink(toSocket));
    190         transfer(fromAddress, toAddress, fromSource, toSink);
    191         transfer(fromAddress, toAddress, toSource, fromSink);
    192         break;
    193 
    194       default:
    195         throw new ProtocolException("unexpected command: " + command);
    196     }
    197   }
    198 
    199   private void transfer(final InetAddress fromAddress, final InetAddress toAddress,
    200       final BufferedSource source, final BufferedSink sink) {
    201     executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) {
    202       @Override protected void execute() {
    203         Buffer buffer = new Buffer();
    204         try {
    205           while (true) {
    206             long byteCount = source.read(buffer, 2048L);
    207             if (byteCount == -1L) break;
    208             sink.write(buffer, byteCount);
    209             sink.emit();
    210           }
    211         } catch (SocketException e) {
    212           logger.info(name + " done: " + e.getMessage());
    213         } catch (IOException e) {
    214           logger.log(Level.WARNING, name + " failed", e);
    215         }
    216 
    217         try {
    218           source.close();
    219         } catch (IOException e) {
    220           logger.log(Level.WARNING, name + " failed", e);
    221         }
    222 
    223         try {
    224           sink.close();
    225         } catch (IOException e) {
    226           logger.log(Level.WARNING, name + " failed", e);
    227         }
    228       }
    229     });
    230   }
    231 }
    232