Home | History | Annotate | Download | only in DNS
      1 // Copyright (c) 2005 Brian Wellington (bwelling (at) xbill.org)
      2 
      3 package org.xbill.DNS;
      4 
      5 import java.io.*;
      6 import java.net.*;
      7 import java.nio.*;
      8 import java.nio.channels.*;
      9 
     10 final class TCPClient extends Client {
     11 
     12 public
     13 TCPClient(long endTime) throws IOException {
     14 	super(SocketChannel.open(), endTime);
     15 }
     16 
     17 void
     18 bind(SocketAddress addr) throws IOException {
     19 	SocketChannel channel = (SocketChannel) key.channel();
     20 	channel.socket().bind(addr);
     21 }
     22 
     23 void
     24 connect(SocketAddress addr) throws IOException {
     25 	SocketChannel channel = (SocketChannel) key.channel();
     26 	if (channel.connect(addr))
     27 		return;
     28 	key.interestOps(SelectionKey.OP_CONNECT);
     29 	try {
     30 		while (!channel.finishConnect()) {
     31 			if (!key.isConnectable())
     32 				blockUntil(key, endTime);
     33 		}
     34 	}
     35 	finally {
     36 		if (key.isValid())
     37 			key.interestOps(0);
     38 	}
     39 }
     40 
     41 void
     42 send(byte [] data) throws IOException {
     43 	SocketChannel channel = (SocketChannel) key.channel();
     44 	verboseLog("TCP write", data);
     45 	byte [] lengthArray = new byte[2];
     46 	lengthArray[0] = (byte)(data.length >>> 8);
     47 	lengthArray[1] = (byte)(data.length & 0xFF);
     48 	ByteBuffer [] buffers = new ByteBuffer[2];
     49 	buffers[0] = ByteBuffer.wrap(lengthArray);
     50 	buffers[1] = ByteBuffer.wrap(data);
     51 	int nsent = 0;
     52 	key.interestOps(SelectionKey.OP_WRITE);
     53 	try {
     54 		while (nsent < data.length + 2) {
     55 			if (key.isWritable()) {
     56 				long n = channel.write(buffers);
     57 				if (n < 0)
     58 					throw new EOFException();
     59 				nsent += (int) n;
     60 				if (nsent < data.length + 2 &&
     61 				    System.currentTimeMillis() > endTime)
     62 					throw new SocketTimeoutException();
     63 			} else
     64 				blockUntil(key, endTime);
     65 		}
     66 	}
     67 	finally {
     68 		if (key.isValid())
     69 			key.interestOps(0);
     70 	}
     71 }
     72 
     73 private byte []
     74 _recv(int length) throws IOException {
     75 	SocketChannel channel = (SocketChannel) key.channel();
     76 	int nrecvd = 0;
     77 	byte [] data = new byte[length];
     78 	ByteBuffer buffer = ByteBuffer.wrap(data);
     79 	key.interestOps(SelectionKey.OP_READ);
     80 	try {
     81 		while (nrecvd < length) {
     82 			if (key.isReadable()) {
     83 				long n = channel.read(buffer);
     84 				if (n < 0)
     85 					throw new EOFException();
     86 				nrecvd += (int) n;
     87 				if (nrecvd < length &&
     88 				    System.currentTimeMillis() > endTime)
     89 					throw new SocketTimeoutException();
     90 			} else
     91 				blockUntil(key, endTime);
     92 		}
     93 	}
     94 	finally {
     95 		if (key.isValid())
     96 			key.interestOps(0);
     97 	}
     98 	return data;
     99 }
    100 
    101 byte []
    102 recv() throws IOException {
    103 	byte [] buf = _recv(2);
    104 	int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF);
    105 	byte [] data = _recv(length);
    106 	verboseLog("TCP read", data);
    107 	return data;
    108 }
    109 
    110 static byte []
    111 sendrecv(SocketAddress local, SocketAddress remote, byte [] data, long endTime)
    112 throws IOException
    113 {
    114 	TCPClient client = new TCPClient(endTime);
    115 	try {
    116 		if (local != null)
    117 			client.bind(local);
    118 		client.connect(remote);
    119 		client.send(data);
    120 		return client.recv();
    121 	}
    122 	finally {
    123 		client.cleanup();
    124 	}
    125 }
    126 
    127 static byte []
    128 sendrecv(SocketAddress addr, byte [] data, long endTime) throws IOException {
    129 	return sendrecv(null, addr, data, endTime);
    130 }
    131 
    132 }
    133