mitmproxy.dns
1from __future__ import annotations 2 3import base64 4import itertools 5import random 6import struct 7import time 8from collections.abc import Iterable 9from dataclasses import dataclass 10from ipaddress import IPv4Address 11from ipaddress import IPv6Address 12from typing import ClassVar 13 14from mitmproxy import flow 15from mitmproxy.coretypes import serializable 16from mitmproxy.net.dns import classes 17from mitmproxy.net.dns import domain_names 18from mitmproxy.net.dns import https_records 19from mitmproxy.net.dns import op_codes 20from mitmproxy.net.dns import response_codes 21from mitmproxy.net.dns import types 22from mitmproxy.net.dns.https_records import HTTPSRecord 23from mitmproxy.net.dns.https_records import SVCParamKeys 24 25# DNS parameters taken from https://www.iana.org/assignments/dns-parameters/dns-parameters.xml 26 27 28@dataclass 29class Question(serializable.SerializableDataclass): 30 HEADER: ClassVar[struct.Struct] = struct.Struct("!HH") 31 32 name: str 33 type: int 34 class_: int 35 36 def __str__(self) -> str: 37 return self.name 38 39 def to_json(self) -> dict: 40 """ 41 Converts the question into json for mitmweb. 42 Sync with web/src/flow.ts. 43 """ 44 return { 45 "name": self.name, 46 "type": types.to_str(self.type), 47 "class": classes.to_str(self.class_), 48 } 49 50 51@dataclass 52class ResourceRecord(serializable.SerializableDataclass): 53 DEFAULT_TTL: ClassVar[int] = 60 54 HEADER: ClassVar[struct.Struct] = struct.Struct("!HHIH") 55 56 name: str 57 type: int 58 class_: int 59 ttl: int 60 data: bytes 61 62 def __str__(self) -> str: 63 try: 64 if self.type == types.A: 65 return str(self.ipv4_address) 66 if self.type == types.AAAA: 67 return str(self.ipv6_address) 68 if self.type in (types.NS, types.CNAME, types.PTR): 69 return self.domain_name 70 if self.type == types.TXT: 71 return self.text 72 if self.type == types.HTTPS: 73 return str(https_records.unpack(self.data)) 74 except Exception: 75 return f"0x{self.data.hex()} (invalid {types.to_str(self.type)} data)" 76 return f"0x{self.data.hex()}" 77 78 @property 79 def text(self) -> str: 80 return self.data.decode("utf-8") 81 82 @text.setter 83 def text(self, value: str) -> None: 84 self.data = value.encode("utf-8") 85 86 @property 87 def ipv4_address(self) -> IPv4Address: 88 return IPv4Address(self.data) 89 90 @ipv4_address.setter 91 def ipv4_address(self, ip: IPv4Address) -> None: 92 self.data = ip.packed 93 94 @property 95 def ipv6_address(self) -> IPv6Address: 96 return IPv6Address(self.data) 97 98 @ipv6_address.setter 99 def ipv6_address(self, ip: IPv6Address) -> None: 100 self.data = ip.packed 101 102 @property 103 def domain_name(self) -> str: 104 return domain_names.unpack(self.data) 105 106 @domain_name.setter 107 def domain_name(self, name: str) -> None: 108 self.data = domain_names.pack(name) 109 110 @property 111 def https_alpn(self) -> tuple[bytes, ...] | None: 112 record = https_records.unpack(self.data) 113 alpn_bytes = record.params.get(SVCParamKeys.ALPN.value, None) 114 if alpn_bytes is not None: 115 i = 0 116 ret = [] 117 while i < len(alpn_bytes): 118 token_len = alpn_bytes[i] 119 ret.append(alpn_bytes[i + 1 : i + 1 + token_len]) 120 i += token_len + 1 121 return tuple(ret) 122 else: 123 return None 124 125 @https_alpn.setter 126 def https_alpn(self, alpn: Iterable[bytes] | None) -> None: 127 record = https_records.unpack(self.data) 128 if alpn is None: 129 record.params.pop(SVCParamKeys.ALPN.value, None) 130 else: 131 alpn_bytes = b"".join(bytes([len(a)]) + a for a in alpn) 132 record.params[SVCParamKeys.ALPN.value] = alpn_bytes 133 self.data = https_records.pack(record) 134 135 @property 136 def https_ech(self) -> str | None: 137 record = https_records.unpack(self.data) 138 ech_bytes = record.params.get(SVCParamKeys.ECH.value, None) 139 if ech_bytes is not None: 140 return base64.b64encode(ech_bytes).decode("utf-8") 141 else: 142 return None 143 144 @https_ech.setter 145 def https_ech(self, ech: str | None) -> None: 146 record = https_records.unpack(self.data) 147 if ech is None: 148 record.params.pop(SVCParamKeys.ECH.value, None) 149 else: 150 ech_bytes = base64.b64decode(ech.encode("utf-8")) 151 record.params[SVCParamKeys.ECH.value] = ech_bytes 152 self.data = https_records.pack(record) 153 154 def to_json(self) -> dict: 155 """ 156 Converts the resource record into json for mitmweb. 157 Sync with web/src/flow.ts. 158 """ 159 return { 160 "name": self.name, 161 "type": types.to_str(self.type), 162 "class": classes.to_str(self.class_), 163 "ttl": self.ttl, 164 "data": str(self), 165 } 166 167 @classmethod 168 def A(cls, name: str, ip: IPv4Address, *, ttl: int = DEFAULT_TTL) -> ResourceRecord: 169 """Create an IPv4 resource record.""" 170 return cls(name, types.A, classes.IN, ttl, ip.packed) 171 172 @classmethod 173 def AAAA( 174 cls, name: str, ip: IPv6Address, *, ttl: int = DEFAULT_TTL 175 ) -> ResourceRecord: 176 """Create an IPv6 resource record.""" 177 return cls(name, types.AAAA, classes.IN, ttl, ip.packed) 178 179 @classmethod 180 def CNAME( 181 cls, alias: str, canonical: str, *, ttl: int = DEFAULT_TTL 182 ) -> ResourceRecord: 183 """Create a canonical internet name resource record.""" 184 return cls(alias, types.CNAME, classes.IN, ttl, domain_names.pack(canonical)) 185 186 @classmethod 187 def PTR(cls, inaddr: str, ptr: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord: 188 """Create a canonical internet name resource record.""" 189 return cls(inaddr, types.PTR, classes.IN, ttl, domain_names.pack(ptr)) 190 191 @classmethod 192 def TXT(cls, name: str, text: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord: 193 """Create a textual resource record.""" 194 return cls(name, types.TXT, classes.IN, ttl, text.encode("utf-8")) 195 196 @classmethod 197 def HTTPS( 198 cls, name: str, record: HTTPSRecord, ttl: int = DEFAULT_TTL 199 ) -> ResourceRecord: 200 """Create a HTTPS resource record""" 201 return cls(name, types.HTTPS, classes.IN, ttl, https_records.pack(record)) 202 203 204# comments are taken from rfc1035 205@dataclass 206class Message(serializable.SerializableDataclass): 207 HEADER: ClassVar[struct.Struct] = struct.Struct("!HHHHHH") 208 209 timestamp: float 210 """The time at which the message was sent or received.""" 211 id: int 212 """An identifier assigned by the program that generates any kind of query.""" 213 query: bool 214 """A field that specifies whether this message is a query.""" 215 op_code: int 216 """ 217 A field that specifies kind of query in this message. 218 This value is set by the originator of a request and copied into the response. 219 """ 220 authoritative_answer: bool 221 """ 222 This field is valid in responses, and specifies that the responding name server 223 is an authority for the domain name in question section. 224 """ 225 truncation: bool 226 """Specifies that this message was truncated due to length greater than that permitted on the transmission channel.""" 227 recursion_desired: bool 228 """ 229 This field may be set in a query and is copied into the response. 230 If set, it directs the name server to pursue the query recursively. 231 """ 232 recursion_available: bool 233 """This field is set or cleared in a response, and denotes whether recursive query support is available in the name server.""" 234 reserved: int 235 """Reserved for future use. Must be zero in all queries and responses.""" 236 response_code: int 237 """This field is set as part of responses.""" 238 questions: list[Question] 239 """ 240 The question section is used to carry the "question" in most queries, i.e. 241 the parameters that define what is being asked. 242 """ 243 answers: list[ResourceRecord] 244 """First resource record section.""" 245 authorities: list[ResourceRecord] 246 """Second resource record section.""" 247 additionals: list[ResourceRecord] 248 """Third resource record section.""" 249 250 def __str__(self) -> str: 251 return "\r\n".join( 252 map( 253 str, 254 itertools.chain( 255 self.questions, self.answers, self.authorities, self.additionals 256 ), 257 ) 258 ) 259 260 @property 261 def content(self) -> bytes: 262 """Returns the user-friendly content of all parts as encoded bytes.""" 263 return str(self).encode() 264 265 @property 266 def question(self) -> Question | None: 267 """DNS practically only supports a single question at the 268 same time, so this is a shorthand for this.""" 269 if len(self.questions) == 1: 270 return self.questions[0] 271 return None 272 273 @property 274 def size(self) -> int: 275 """Returns the cumulative data size of all resource record sections.""" 276 return sum( 277 len(x.data) 278 for x in itertools.chain.from_iterable( 279 [self.answers, self.authorities, self.additionals] 280 ) 281 ) 282 283 def fail(self, response_code: int) -> Message: 284 if response_code == response_codes.NOERROR: 285 raise ValueError("response_code must be an error code.") 286 return Message( 287 timestamp=time.time(), 288 id=self.id, 289 query=False, 290 op_code=self.op_code, 291 authoritative_answer=False, 292 truncation=False, 293 recursion_desired=self.recursion_desired, 294 recursion_available=False, 295 reserved=0, 296 response_code=response_code, 297 questions=self.questions, 298 answers=[], 299 authorities=[], 300 additionals=[], 301 ) 302 303 def succeed(self, answers: list[ResourceRecord]) -> Message: 304 return Message( 305 timestamp=time.time(), 306 id=self.id, 307 query=False, 308 op_code=self.op_code, 309 authoritative_answer=False, 310 truncation=False, 311 recursion_desired=self.recursion_desired, 312 recursion_available=True, 313 reserved=0, 314 response_code=response_codes.NOERROR, 315 questions=self.questions, 316 answers=answers, 317 authorities=[], 318 additionals=[], 319 ) 320 321 @classmethod 322 def unpack(cls, buffer: bytes) -> Message: 323 """Converts the entire given buffer into a DNS message.""" 324 length, msg = cls.unpack_from(buffer, 0) 325 if length != len(buffer): 326 raise struct.error(f"unpack requires a buffer of {length} bytes") 327 return msg 328 329 @classmethod 330 def unpack_from(cls, buffer: bytes | bytearray, offset: int) -> tuple[int, Message]: 331 """Converts the buffer from a given offset into a DNS message and also returns its length.""" 332 ( 333 id, 334 flags, 335 len_questions, 336 len_answers, 337 len_authorities, 338 len_additionals, 339 ) = Message.HEADER.unpack_from(buffer, offset) 340 msg = Message( 341 timestamp=time.time(), 342 id=id, 343 query=(flags & (1 << 15)) == 0, 344 op_code=(flags >> 11) & 0b1111, 345 authoritative_answer=(flags & (1 << 10)) != 0, 346 truncation=(flags & (1 << 9)) != 0, 347 recursion_desired=(flags & (1 << 8)) != 0, 348 recursion_available=(flags & (1 << 7)) != 0, 349 reserved=(flags >> 4) & 0b111, 350 response_code=flags & 0b1111, 351 questions=[], 352 answers=[], 353 authorities=[], 354 additionals=[], 355 ) 356 offset += Message.HEADER.size 357 cached_names = domain_names.cache() 358 359 def unpack_domain_name() -> str: 360 nonlocal buffer, offset, cached_names 361 name, length = domain_names.unpack_from_with_compression( 362 buffer, offset, cached_names 363 ) 364 offset += length 365 return name 366 367 for i in range(0, len_questions): 368 try: 369 name = unpack_domain_name() 370 type, class_ = Question.HEADER.unpack_from(buffer, offset) 371 offset += Question.HEADER.size 372 msg.questions.append(Question(name=name, type=type, class_=class_)) 373 except struct.error as e: 374 raise struct.error(f"question #{i}: {str(e)}") 375 376 def unpack_rrs( 377 section: list[ResourceRecord], section_name: str, count: int 378 ) -> None: 379 nonlocal buffer, offset 380 for i in range(0, count): 381 try: 382 name = unpack_domain_name() 383 type, class_, ttl, len_data = ResourceRecord.HEADER.unpack_from( 384 buffer, offset 385 ) 386 offset += ResourceRecord.HEADER.size 387 end_data = offset + len_data 388 if len(buffer) < end_data: 389 raise struct.error( 390 f"unpack requires a data buffer of {len_data} bytes" 391 ) 392 data = buffer[offset:end_data] 393 394 if domain_names.record_data_can_have_compression(type): 395 data = domain_names.decompress_from_record_data( 396 buffer, offset, end_data, cached_names 397 ) 398 399 section.append(ResourceRecord(name, type, class_, ttl, data)) 400 offset += len_data 401 except struct.error as e: 402 raise struct.error(f"{section_name} #{i}: {str(e)}") 403 404 unpack_rrs(msg.answers, "answer", len_answers) 405 unpack_rrs(msg.authorities, "authority", len_authorities) 406 unpack_rrs(msg.additionals, "additional", len_additionals) 407 return (offset, msg) 408 409 @property 410 def packed(self) -> bytes: 411 """Converts the message into network bytes.""" 412 if self.id < 0 or self.id > 65535: 413 raise ValueError(f"DNS message's id {self.id} is out of bounds.") 414 flags = 0 415 if not self.query: 416 flags |= 1 << 15 417 if self.op_code < 0 or self.op_code > 0b1111: 418 raise ValueError(f"DNS message's op_code {self.op_code} is out of bounds.") 419 flags |= self.op_code << 11 420 if self.authoritative_answer: 421 flags |= 1 << 10 422 if self.truncation: 423 flags |= 1 << 9 424 if self.recursion_desired: 425 flags |= 1 << 8 426 if self.recursion_available: 427 flags |= 1 << 7 428 if self.reserved < 0 or self.reserved > 0b111: 429 raise ValueError( 430 f"DNS message's reserved value of {self.reserved} is out of bounds." 431 ) 432 flags |= self.reserved << 4 433 if self.response_code < 0 or self.response_code > 0b1111: 434 raise ValueError( 435 f"DNS message's response_code {self.response_code} is out of bounds." 436 ) 437 flags |= self.response_code 438 data = bytearray() 439 data.extend( 440 Message.HEADER.pack( 441 self.id, 442 flags, 443 len(self.questions), 444 len(self.answers), 445 len(self.authorities), 446 len(self.additionals), 447 ) 448 ) 449 # TODO implement compression 450 for question in self.questions: 451 data.extend(domain_names.pack(question.name)) 452 data.extend(Question.HEADER.pack(question.type, question.class_)) 453 for rr in (*self.answers, *self.authorities, *self.additionals): 454 data.extend(domain_names.pack(rr.name)) 455 data.extend( 456 ResourceRecord.HEADER.pack(rr.type, rr.class_, rr.ttl, len(rr.data)) 457 ) 458 data.extend(rr.data) 459 return bytes(data) 460 461 def to_json(self) -> dict: 462 """ 463 Converts the message into json for mitmweb. 464 Sync with web/src/flow.ts. 465 """ 466 return { 467 "id": self.id, 468 "query": self.query, 469 "op_code": op_codes.to_str(self.op_code), 470 "authoritative_answer": self.authoritative_answer, 471 "truncation": self.truncation, 472 "recursion_desired": self.recursion_desired, 473 "recursion_available": self.recursion_available, 474 "response_code": response_codes.to_str(self.response_code), 475 "status_code": response_codes.http_equiv_status_code(self.response_code), 476 "questions": [question.to_json() for question in self.questions], 477 "answers": [rr.to_json() for rr in self.answers], 478 "authorities": [rr.to_json() for rr in self.authorities], 479 "additionals": [rr.to_json() for rr in self.additionals], 480 "size": self.size, 481 "timestamp": self.timestamp, 482 } 483 484 def copy(self) -> Message: 485 # we keep the copy semantics but change the ID generation 486 state = self.get_state() 487 state["id"] = random.randint(0, 65535) 488 return Message.from_state(state) 489 490 491class DNSFlow(flow.Flow): 492 """A DNSFlow is a collection of DNS messages representing a single DNS query.""" 493 494 request: Message 495 """The DNS request.""" 496 response: Message | None = None 497 """The DNS response.""" 498 499 def get_state(self) -> serializable.State: 500 return { 501 **super().get_state(), 502 "request": self.request.get_state(), 503 "response": self.response.get_state() if self.response else None, 504 } 505 506 def set_state(self, state: serializable.State) -> None: 507 self.request = Message.from_state(state.pop("request")) 508 self.response = Message.from_state(r) if (r := state.pop("response")) else None 509 super().set_state(state) 510 511 def __repr__(self) -> str: 512 return f"<DNSFlow\r\n request={repr(self.request)}\r\n response={repr(self.response)}\r\n>"
492class DNSFlow(flow.Flow): 493 """A DNSFlow is a collection of DNS messages representing a single DNS query.""" 494 495 request: Message 496 """The DNS request.""" 497 response: Message | None = None 498 """The DNS response.""" 499 500 def get_state(self) -> serializable.State: 501 return { 502 **super().get_state(), 503 "request": self.request.get_state(), 504 "response": self.response.get_state() if self.response else None, 505 } 506 507 def set_state(self, state: serializable.State) -> None: 508 self.request = Message.from_state(state.pop("request")) 509 self.response = Message.from_state(r) if (r := state.pop("response")) else None 510 super().set_state(state) 511 512 def __repr__(self) -> str: 513 return f"<DNSFlow\r\n request={repr(self.request)}\r\n response={repr(self.response)}\r\n>"
A DNSFlow is a collection of DNS messages representing a single DNS query.