1 // Copyright (c) 2005 Brian Wellington (bwelling (at) xbill.org) 2 3 package org.xbill.DNS; 4 5 import java.io.*; 6 import java.net.*; 7 import java.nio.*; 8 import java.nio.channels.*; 9 10 final class TCPClient extends Client { 11 12 public 13 TCPClient(long endTime) throws IOException { 14 super(SocketChannel.open(), endTime); 15 } 16 17 void 18 bind(SocketAddress addr) throws IOException { 19 SocketChannel channel = (SocketChannel) key.channel(); 20 channel.socket().bind(addr); 21 } 22 23 void 24 connect(SocketAddress addr) throws IOException { 25 SocketChannel channel = (SocketChannel) key.channel(); 26 if (channel.connect(addr)) 27 return; 28 key.interestOps(SelectionKey.OP_CONNECT); 29 try { 30 while (!channel.finishConnect()) { 31 if (!key.isConnectable()) 32 blockUntil(key, endTime); 33 } 34 } 35 finally { 36 if (key.isValid()) 37 key.interestOps(0); 38 } 39 } 40 41 void 42 send(byte [] data) throws IOException { 43 SocketChannel channel = (SocketChannel) key.channel(); 44 verboseLog("TCP write", data); 45 byte [] lengthArray = new byte[2]; 46 lengthArray[0] = (byte)(data.length >>> 8); 47 lengthArray[1] = (byte)(data.length & 0xFF); 48 ByteBuffer [] buffers = new ByteBuffer[2]; 49 buffers[0] = ByteBuffer.wrap(lengthArray); 50 buffers[1] = ByteBuffer.wrap(data); 51 int nsent = 0; 52 key.interestOps(SelectionKey.OP_WRITE); 53 try { 54 while (nsent < data.length + 2) { 55 if (key.isWritable()) { 56 long n = channel.write(buffers); 57 if (n < 0) 58 throw new EOFException(); 59 nsent += (int) n; 60 if (nsent < data.length + 2 && 61 System.currentTimeMillis() > endTime) 62 throw new SocketTimeoutException(); 63 } else 64 blockUntil(key, endTime); 65 } 66 } 67 finally { 68 if (key.isValid()) 69 key.interestOps(0); 70 } 71 } 72 73 private byte [] 74 _recv(int length) throws IOException { 75 SocketChannel channel = (SocketChannel) key.channel(); 76 int nrecvd = 0; 77 byte [] data = new byte[length]; 78 ByteBuffer buffer = ByteBuffer.wrap(data); 79 key.interestOps(SelectionKey.OP_READ); 80 try { 81 while (nrecvd < length) { 82 if (key.isReadable()) { 83 long n = channel.read(buffer); 84 if (n < 0) 85 throw new EOFException(); 86 nrecvd += (int) n; 87 if (nrecvd < length && 88 System.currentTimeMillis() > endTime) 89 throw new SocketTimeoutException(); 90 } else 91 blockUntil(key, endTime); 92 } 93 } 94 finally { 95 if (key.isValid()) 96 key.interestOps(0); 97 } 98 return data; 99 } 100 101 byte [] 102 recv() throws IOException { 103 byte [] buf = _recv(2); 104 int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF); 105 byte [] data = _recv(length); 106 verboseLog("TCP read", data); 107 return data; 108 } 109 110 static byte [] 111 sendrecv(SocketAddress local, SocketAddress remote, byte [] data, long endTime) 112 throws IOException 113 { 114 TCPClient client = new TCPClient(endTime); 115 try { 116 if (local != null) 117 client.bind(local); 118 client.connect(remote); 119 client.send(data); 120 return client.recv(); 121 } 122 finally { 123 client.cleanup(); 124 } 125 } 126 127 static byte [] 128 sendrecv(SocketAddress addr, byte [] data, long endTime) throws IOException { 129 return sendrecv(null, addr, data, endTime); 130 } 131 132 } 133