1 /* 2 * Copyright 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.conscrypt; 18 19 import java.io.EOFException; 20 import java.io.IOException; 21 import java.io.InputStream; 22 import java.io.OutputStream; 23 import java.net.InetAddress; 24 import java.net.SocketException; 25 import java.nio.channels.ClosedChannelException; 26 import javax.net.ssl.SSLException; 27 import javax.net.ssl.SSLSocket; 28 import javax.net.ssl.SSLSocketFactory; 29 30 /** 31 * Client-side endpoint. Provides basic services for sending/receiving messages from the client 32 * socket. 33 */ 34 final class ClientEndpoint { 35 private final SSLSocket socket; 36 private InputStream input; 37 private OutputStream output; 38 39 ClientEndpoint(SSLSocketFactory socketFactory, ChannelType channelType, int port, 40 String[] protocols, String[] ciphers) throws IOException { 41 socket = channelType.newClientSocket(socketFactory, InetAddress.getLoopbackAddress(), port); 42 socket.setEnabledProtocols(protocols); 43 socket.setEnabledCipherSuites(ciphers); 44 } 45 46 void start() { 47 try { 48 socket.startHandshake(); 49 input = socket.getInputStream(); 50 output = socket.getOutputStream(); 51 } catch (IOException e) { 52 e.printStackTrace(); 53 throw new RuntimeException(e); 54 } 55 } 56 57 void stop() { 58 try { 59 socket.close(); 60 } catch (IOException e) { 61 throw new RuntimeException(e); 62 } 63 } 64 65 int readMessage(byte[] buffer) { 66 try { 67 int totalBytesRead = 0; 68 while (totalBytesRead < buffer.length) { 69 int remaining = buffer.length - totalBytesRead; 70 int bytesRead = input.read(buffer, totalBytesRead, remaining); 71 if (bytesRead == -1) { 72 break; 73 } 74 totalBytesRead += bytesRead; 75 } 76 return totalBytesRead; 77 } catch (SSLException e) { 78 if (e.getCause() instanceof EOFException) { 79 return -1; 80 } 81 throw new RuntimeException(e); 82 } catch (ClosedChannelException e) { 83 // Thrown for channel-based sockets. Just treat like EOF. 84 return -1; 85 } catch (SocketException e) { 86 // The socket was broken. Just treat like EOF. 87 return -1; 88 } catch (IOException e) { 89 throw new RuntimeException(e); 90 } 91 } 92 93 void sendMessage(byte[] data) { 94 try { 95 output.write(data); 96 } catch (IOException e) { 97 throw new RuntimeException(e); 98 } 99 } 100 101 void flush() { 102 try { 103 output.flush(); 104 } catch (IOException e) { 105 throw new RuntimeException(e); 106 } 107 } 108 } 109