1 # Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc. 2 # 3 # Permission to use, copy, modify, and distribute this software and its 4 # documentation for any purpose with or without fee is hereby granted, 5 # provided that the above copyright notice and this permission notice 6 # appear in all copies. 7 # 8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES 9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR 11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT 14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 15 16 """Help for building DNS wire format messages""" 17 18 import cStringIO 19 import struct 20 import random 21 import time 22 23 import dns.exception 24 import dns.tsig 25 26 QUESTION = 0 27 ANSWER = 1 28 AUTHORITY = 2 29 ADDITIONAL = 3 30 31 class Renderer(object): 32 """Helper class for building DNS wire-format messages. 33 34 Most applications can use the higher-level L{dns.message.Message} 35 class and its to_wire() method to generate wire-format messages. 36 This class is for those applications which need finer control 37 over the generation of messages. 38 39 Typical use:: 40 41 r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512) 42 r.add_question(qname, qtype, qclass) 43 r.add_rrset(dns.renderer.ANSWER, rrset_1) 44 r.add_rrset(dns.renderer.ANSWER, rrset_2) 45 r.add_rrset(dns.renderer.AUTHORITY, ns_rrset) 46 r.add_edns(0, 0, 4096) 47 r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_1) 48 r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_2) 49 r.write_header() 50 r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac) 51 wire = r.get_wire() 52 53 @ivar output: where rendering is written 54 @type output: cStringIO.StringIO object 55 @ivar id: the message id 56 @type id: int 57 @ivar flags: the message flags 58 @type flags: int 59 @ivar max_size: the maximum size of the message 60 @type max_size: int 61 @ivar origin: the origin to use when rendering relative names 62 @type origin: dns.name.Name object 63 @ivar compress: the compression table 64 @type compress: dict 65 @ivar section: the section currently being rendered 66 @type section: int (dns.renderer.QUESTION, dns.renderer.ANSWER, 67 dns.renderer.AUTHORITY, or dns.renderer.ADDITIONAL) 68 @ivar counts: list of the number of RRs in each section 69 @type counts: int list of length 4 70 @ivar mac: the MAC of the rendered message (if TSIG was used) 71 @type mac: string 72 """ 73 74 def __init__(self, id=None, flags=0, max_size=65535, origin=None): 75 """Initialize a new renderer. 76 77 @param id: the message id 78 @type id: int 79 @param flags: the DNS message flags 80 @type flags: int 81 @param max_size: the maximum message size; the default is 65535. 82 If rendering results in a message greater than I{max_size}, 83 then L{dns.exception.TooBig} will be raised. 84 @type max_size: int 85 @param origin: the origin to use when rendering relative names 86 @type origin: dns.name.Namem or None. 87 """ 88 89 self.output = cStringIO.StringIO() 90 if id is None: 91 self.id = random.randint(0, 65535) 92 else: 93 self.id = id 94 self.flags = flags 95 self.max_size = max_size 96 self.origin = origin 97 self.compress = {} 98 self.section = QUESTION 99 self.counts = [0, 0, 0, 0] 100 self.output.write('\x00' * 12) 101 self.mac = '' 102 103 def _rollback(self, where): 104 """Truncate the output buffer at offset I{where}, and remove any 105 compression table entries that pointed beyond the truncation 106 point. 107 108 @param where: the offset 109 @type where: int 110 """ 111 112 self.output.seek(where) 113 self.output.truncate() 114 keys_to_delete = [] 115 for k, v in self.compress.iteritems(): 116 if v >= where: 117 keys_to_delete.append(k) 118 for k in keys_to_delete: 119 del self.compress[k] 120 121 def _set_section(self, section): 122 """Set the renderer's current section. 123 124 Sections must be rendered order: QUESTION, ANSWER, AUTHORITY, 125 ADDITIONAL. Sections may be empty. 126 127 @param section: the section 128 @type section: int 129 @raises dns.exception.FormError: an attempt was made to set 130 a section value less than the current section. 131 """ 132 133 if self.section != section: 134 if self.section > section: 135 raise dns.exception.FormError 136 self.section = section 137 138 def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): 139 """Add a question to the message. 140 141 @param qname: the question name 142 @type qname: dns.name.Name 143 @param rdtype: the question rdata type 144 @type rdtype: int 145 @param rdclass: the question rdata class 146 @type rdclass: int 147 """ 148 149 self._set_section(QUESTION) 150 before = self.output.tell() 151 qname.to_wire(self.output, self.compress, self.origin) 152 self.output.write(struct.pack("!HH", rdtype, rdclass)) 153 after = self.output.tell() 154 if after >= self.max_size: 155 self._rollback(before) 156 raise dns.exception.TooBig 157 self.counts[QUESTION] += 1 158 159 def add_rrset(self, section, rrset, **kw): 160 """Add the rrset to the specified section. 161 162 Any keyword arguments are passed on to the rdataset's to_wire() 163 routine. 164 165 @param section: the section 166 @type section: int 167 @param rrset: the rrset 168 @type rrset: dns.rrset.RRset object 169 """ 170 171 self._set_section(section) 172 before = self.output.tell() 173 n = rrset.to_wire(self.output, self.compress, self.origin, **kw) 174 after = self.output.tell() 175 if after >= self.max_size: 176 self._rollback(before) 177 raise dns.exception.TooBig 178 self.counts[section] += n 179 180 def add_rdataset(self, section, name, rdataset, **kw): 181 """Add the rdataset to the specified section, using the specified 182 name as the owner name. 183 184 Any keyword arguments are passed on to the rdataset's to_wire() 185 routine. 186 187 @param section: the section 188 @type section: int 189 @param name: the owner name 190 @type name: dns.name.Name object 191 @param rdataset: the rdataset 192 @type rdataset: dns.rdataset.Rdataset object 193 """ 194 195 self._set_section(section) 196 before = self.output.tell() 197 n = rdataset.to_wire(name, self.output, self.compress, self.origin, 198 **kw) 199 after = self.output.tell() 200 if after >= self.max_size: 201 self._rollback(before) 202 raise dns.exception.TooBig 203 self.counts[section] += n 204 205 def add_edns(self, edns, ednsflags, payload, options=None): 206 """Add an EDNS OPT record to the message. 207 208 @param edns: The EDNS level to use. 209 @type edns: int 210 @param ednsflags: EDNS flag values. 211 @type ednsflags: int 212 @param payload: The EDNS sender's payload field, which is the maximum 213 size of UDP datagram the sender can handle. 214 @type payload: int 215 @param options: The EDNS options list 216 @type options: list of dns.edns.Option instances 217 @see: RFC 2671 218 """ 219 220 # make sure the EDNS version in ednsflags agrees with edns 221 ednsflags &= 0xFF00FFFFL 222 ednsflags |= (edns << 16) 223 self._set_section(ADDITIONAL) 224 before = self.output.tell() 225 self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload, 226 ednsflags, 0)) 227 if not options is None: 228 lstart = self.output.tell() 229 for opt in options: 230 stuff = struct.pack("!HH", opt.otype, 0) 231 self.output.write(stuff) 232 start = self.output.tell() 233 opt.to_wire(self.output) 234 end = self.output.tell() 235 assert end - start < 65536 236 self.output.seek(start - 2) 237 stuff = struct.pack("!H", end - start) 238 self.output.write(stuff) 239 self.output.seek(0, 2) 240 lend = self.output.tell() 241 assert lend - lstart < 65536 242 self.output.seek(lstart - 2) 243 stuff = struct.pack("!H", lend - lstart) 244 self.output.write(stuff) 245 self.output.seek(0, 2) 246 after = self.output.tell() 247 if after >= self.max_size: 248 self._rollback(before) 249 raise dns.exception.TooBig 250 self.counts[ADDITIONAL] += 1 251 252 def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, 253 request_mac, algorithm=dns.tsig.default_algorithm): 254 """Add a TSIG signature to the message. 255 256 @param keyname: the TSIG key name 257 @type keyname: dns.name.Name object 258 @param secret: the secret to use 259 @type secret: string 260 @param fudge: TSIG time fudge 261 @type fudge: int 262 @param id: the message id to encode in the tsig signature 263 @type id: int 264 @param tsig_error: TSIG error code; default is 0. 265 @type tsig_error: int 266 @param other_data: TSIG other data. 267 @type other_data: string 268 @param request_mac: This message is a response to the request which 269 had the specified MAC. 270 @param algorithm: the TSIG algorithm to use 271 @type request_mac: string 272 """ 273 274 self._set_section(ADDITIONAL) 275 before = self.output.tell() 276 s = self.output.getvalue() 277 (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, 278 keyname, 279 secret, 280 int(time.time()), 281 fudge, 282 id, 283 tsig_error, 284 other_data, 285 request_mac, 286 algorithm=algorithm) 287 keyname.to_wire(self.output, self.compress, self.origin) 288 self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, 289 dns.rdataclass.ANY, 0, 0)) 290 rdata_start = self.output.tell() 291 self.output.write(tsig_rdata) 292 after = self.output.tell() 293 assert after - rdata_start < 65536 294 if after >= self.max_size: 295 self._rollback(before) 296 raise dns.exception.TooBig 297 self.output.seek(rdata_start - 2) 298 self.output.write(struct.pack('!H', after - rdata_start)) 299 self.counts[ADDITIONAL] += 1 300 self.output.seek(10) 301 self.output.write(struct.pack('!H', self.counts[ADDITIONAL])) 302 self.output.seek(0, 2) 303 304 def write_header(self): 305 """Write the DNS message header. 306 307 Writing the DNS message header is done asfter all sections 308 have been rendered, but before the optional TSIG signature 309 is added. 310 """ 311 312 self.output.seek(0) 313 self.output.write(struct.pack('!HHHHHH', self.id, self.flags, 314 self.counts[0], self.counts[1], 315 self.counts[2], self.counts[3])) 316 self.output.seek(0, 2) 317 318 def get_wire(self): 319 """Return the wire format message. 320 321 @rtype: string 322 """ 323 324 return self.output.getvalue() 325