Edit on GitHub

mitmproxy.dns

  1from __future__ import annotations
  2
  3import base64
  4import itertools
  5import random
  6import struct
  7import time
  8from collections.abc import Iterable
  9from dataclasses import dataclass
 10from ipaddress import IPv4Address
 11from ipaddress import IPv6Address
 12from typing import ClassVar
 13
 14from mitmproxy import flow
 15from mitmproxy.coretypes import serializable
 16from mitmproxy.net.dns import classes
 17from mitmproxy.net.dns import domain_names
 18from mitmproxy.net.dns import https_records
 19from mitmproxy.net.dns import op_codes
 20from mitmproxy.net.dns import response_codes
 21from mitmproxy.net.dns import types
 22from mitmproxy.net.dns.https_records import HTTPSRecord
 23from mitmproxy.net.dns.https_records import SVCParamKeys
 24
 25# DNS parameters taken from https://www.iana.org/assignments/dns-parameters/dns-parameters.xml
 26
 27
 28@dataclass
 29class Question(serializable.SerializableDataclass):
 30    HEADER: ClassVar[struct.Struct] = struct.Struct("!HH")
 31
 32    name: str
 33    type: int
 34    class_: int
 35
 36    def __str__(self) -> str:
 37        return self.name
 38
 39    def to_json(self) -> dict:
 40        """
 41        Converts the question into json for mitmweb.
 42        Sync with web/src/flow.ts.
 43        """
 44        return {
 45            "name": self.name,
 46            "type": types.to_str(self.type),
 47            "class": classes.to_str(self.class_),
 48        }
 49
 50
 51@dataclass
 52class ResourceRecord(serializable.SerializableDataclass):
 53    DEFAULT_TTL: ClassVar[int] = 60
 54    HEADER: ClassVar[struct.Struct] = struct.Struct("!HHIH")
 55
 56    name: str
 57    type: int
 58    class_: int
 59    ttl: int
 60    data: bytes
 61
 62    def __str__(self) -> str:
 63        try:
 64            if self.type == types.A:
 65                return str(self.ipv4_address)
 66            if self.type == types.AAAA:
 67                return str(self.ipv6_address)
 68            if self.type in (types.NS, types.CNAME, types.PTR):
 69                return self.domain_name
 70            if self.type == types.TXT:
 71                return self.text
 72            if self.type == types.HTTPS:
 73                return str(https_records.unpack(self.data))
 74        except Exception:
 75            return f"0x{self.data.hex()} (invalid {types.to_str(self.type)} data)"
 76        return f"0x{self.data.hex()}"
 77
 78    @property
 79    def text(self) -> str:
 80        return self.data.decode("utf-8")
 81
 82    @text.setter
 83    def text(self, value: str) -> None:
 84        self.data = value.encode("utf-8")
 85
 86    @property
 87    def ipv4_address(self) -> IPv4Address:
 88        return IPv4Address(self.data)
 89
 90    @ipv4_address.setter
 91    def ipv4_address(self, ip: IPv4Address) -> None:
 92        self.data = ip.packed
 93
 94    @property
 95    def ipv6_address(self) -> IPv6Address:
 96        return IPv6Address(self.data)
 97
 98    @ipv6_address.setter
 99    def ipv6_address(self, ip: IPv6Address) -> None:
100        self.data = ip.packed
101
102    @property
103    def domain_name(self) -> str:
104        return domain_names.unpack(self.data)
105
106    @domain_name.setter
107    def domain_name(self, name: str) -> None:
108        self.data = domain_names.pack(name)
109
110    @property
111    def https_alpn(self) -> tuple[bytes, ...] | None:
112        record = https_records.unpack(self.data)
113        alpn_bytes = record.params.get(SVCParamKeys.ALPN.value, None)
114        if alpn_bytes is not None:
115            i = 0
116            ret = []
117            while i < len(alpn_bytes):
118                token_len = alpn_bytes[i]
119                ret.append(alpn_bytes[i + 1 : i + 1 + token_len])
120                i += token_len + 1
121            return tuple(ret)
122        else:
123            return None
124
125    @https_alpn.setter
126    def https_alpn(self, alpn: Iterable[bytes] | None) -> None:
127        record = https_records.unpack(self.data)
128        if alpn is None:
129            record.params.pop(SVCParamKeys.ALPN.value, None)
130        else:
131            alpn_bytes = b"".join(bytes([len(a)]) + a for a in alpn)
132            record.params[SVCParamKeys.ALPN.value] = alpn_bytes
133        self.data = https_records.pack(record)
134
135    @property
136    def https_ech(self) -> str | None:
137        record = https_records.unpack(self.data)
138        ech_bytes = record.params.get(SVCParamKeys.ECH.value, None)
139        if ech_bytes is not None:
140            return base64.b64encode(ech_bytes).decode("utf-8")
141        else:
142            return None
143
144    @https_ech.setter
145    def https_ech(self, ech: str | None) -> None:
146        record = https_records.unpack(self.data)
147        if ech is None:
148            record.params.pop(SVCParamKeys.ECH.value, None)
149        else:
150            ech_bytes = base64.b64decode(ech.encode("utf-8"))
151            record.params[SVCParamKeys.ECH.value] = ech_bytes
152        self.data = https_records.pack(record)
153
154    def to_json(self) -> dict:
155        """
156        Converts the resource record into json for mitmweb.
157        Sync with web/src/flow.ts.
158        """
159        return {
160            "name": self.name,
161            "type": types.to_str(self.type),
162            "class": classes.to_str(self.class_),
163            "ttl": self.ttl,
164            "data": str(self),
165        }
166
167    @classmethod
168    def A(cls, name: str, ip: IPv4Address, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
169        """Create an IPv4 resource record."""
170        return cls(name, types.A, classes.IN, ttl, ip.packed)
171
172    @classmethod
173    def AAAA(
174        cls, name: str, ip: IPv6Address, *, ttl: int = DEFAULT_TTL
175    ) -> ResourceRecord:
176        """Create an IPv6 resource record."""
177        return cls(name, types.AAAA, classes.IN, ttl, ip.packed)
178
179    @classmethod
180    def CNAME(
181        cls, alias: str, canonical: str, *, ttl: int = DEFAULT_TTL
182    ) -> ResourceRecord:
183        """Create a canonical internet name resource record."""
184        return cls(alias, types.CNAME, classes.IN, ttl, domain_names.pack(canonical))
185
186    @classmethod
187    def PTR(cls, inaddr: str, ptr: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
188        """Create a canonical internet name resource record."""
189        return cls(inaddr, types.PTR, classes.IN, ttl, domain_names.pack(ptr))
190
191    @classmethod
192    def TXT(cls, name: str, text: str, *, ttl: int = DEFAULT_TTL) -> ResourceRecord:
193        """Create a textual resource record."""
194        return cls(name, types.TXT, classes.IN, ttl, text.encode("utf-8"))
195
196    @classmethod
197    def HTTPS(
198        cls, name: str, record: HTTPSRecord, ttl: int = DEFAULT_TTL
199    ) -> ResourceRecord:
200        """Create a HTTPS resource record"""
201        return cls(name, types.HTTPS, classes.IN, ttl, https_records.pack(record))
202
203
204# comments are taken from rfc1035
205@dataclass
206class Message(serializable.SerializableDataclass):
207    HEADER: ClassVar[struct.Struct] = struct.Struct("!HHHHHH")
208
209    timestamp: float
210    """The time at which the message was sent or received."""
211    id: int
212    """An identifier assigned by the program that generates any kind of query."""
213    query: bool
214    """A field that specifies whether this message is a query."""
215    op_code: int
216    """
217    A field that specifies kind of query in this message.
218    This value is set by the originator of a request and copied into the response.
219    """
220    authoritative_answer: bool
221    """
222    This field is valid in responses, and specifies that the responding name server
223    is an authority for the domain name in question section.
224    """
225    truncation: bool
226    """Specifies that this message was truncated due to length greater than that permitted on the transmission channel."""
227    recursion_desired: bool
228    """
229    This field may be set in a query and is copied into the response.
230    If set, it directs the name server to pursue the query recursively.
231    """
232    recursion_available: bool
233    """This field is set or cleared in a response, and denotes whether recursive query support is available in the name server."""
234    reserved: int
235    """Reserved for future use.  Must be zero in all queries and responses."""
236    response_code: int
237    """This field is set as part of responses."""
238    questions: list[Question]
239    """
240    The question section is used to carry the "question" in most queries, i.e.
241    the parameters that define what is being asked.
242    """
243    answers: list[ResourceRecord]
244    """First resource record section."""
245    authorities: list[ResourceRecord]
246    """Second resource record section."""
247    additionals: list[ResourceRecord]
248    """Third resource record section."""
249
250    def __str__(self) -> str:
251        return "\r\n".join(
252            map(
253                str,
254                itertools.chain(
255                    self.questions, self.answers, self.authorities, self.additionals
256                ),
257            )
258        )
259
260    @property
261    def content(self) -> bytes:
262        """Returns the user-friendly content of all parts as encoded bytes."""
263        return str(self).encode()
264
265    @property
266    def question(self) -> Question | None:
267        """DNS practically only supports a single question at the
268        same time, so this is a shorthand for this."""
269        if len(self.questions) == 1:
270            return self.questions[0]
271        return None
272
273    @property
274    def size(self) -> int:
275        """Returns the cumulative data size of all resource record sections."""
276        return sum(
277            len(x.data)
278            for x in itertools.chain.from_iterable(
279                [self.answers, self.authorities, self.additionals]
280            )
281        )
282
283    def fail(self, response_code: int) -> Message:
284        if response_code == response_codes.NOERROR:
285            raise ValueError("response_code must be an error code.")
286        return Message(
287            timestamp=time.time(),
288            id=self.id,
289            query=False,
290            op_code=self.op_code,
291            authoritative_answer=False,
292            truncation=False,
293            recursion_desired=self.recursion_desired,
294            recursion_available=False,
295            reserved=0,
296            response_code=response_code,
297            questions=self.questions,
298            answers=[],
299            authorities=[],
300            additionals=[],
301        )
302
303    def succeed(self, answers: list[ResourceRecord]) -> Message:
304        return Message(
305            timestamp=time.time(),
306            id=self.id,
307            query=False,
308            op_code=self.op_code,
309            authoritative_answer=False,
310            truncation=False,
311            recursion_desired=self.recursion_desired,
312            recursion_available=True,
313            reserved=0,
314            response_code=response_codes.NOERROR,
315            questions=self.questions,
316            answers=answers,
317            authorities=[],
318            additionals=[],
319        )
320
321    @classmethod
322    def unpack(cls, buffer: bytes) -> Message:
323        """Converts the entire given buffer into a DNS message."""
324        length, msg = cls.unpack_from(buffer, 0)
325        if length != len(buffer):
326            raise struct.error(f"unpack requires a buffer of {length} bytes")
327        return msg
328
329    @classmethod
330    def unpack_from(cls, buffer: bytes | bytearray, offset: int) -> tuple[int, Message]:
331        """Converts the buffer from a given offset into a DNS message and also returns its length."""
332        (
333            id,
334            flags,
335            len_questions,
336            len_answers,
337            len_authorities,
338            len_additionals,
339        ) = Message.HEADER.unpack_from(buffer, offset)
340        msg = Message(
341            timestamp=time.time(),
342            id=id,
343            query=(flags & (1 << 15)) == 0,
344            op_code=(flags >> 11) & 0b1111,
345            authoritative_answer=(flags & (1 << 10)) != 0,
346            truncation=(flags & (1 << 9)) != 0,
347            recursion_desired=(flags & (1 << 8)) != 0,
348            recursion_available=(flags & (1 << 7)) != 0,
349            reserved=(flags >> 4) & 0b111,
350            response_code=flags & 0b1111,
351            questions=[],
352            answers=[],
353            authorities=[],
354            additionals=[],
355        )
356        offset += Message.HEADER.size
357        cached_names = domain_names.cache()
358
359        def unpack_domain_name() -> str:
360            nonlocal buffer, offset, cached_names
361            name, length = domain_names.unpack_from_with_compression(
362                buffer, offset, cached_names
363            )
364            offset += length
365            return name
366
367        for i in range(0, len_questions):
368            try:
369                name = unpack_domain_name()
370                type, class_ = Question.HEADER.unpack_from(buffer, offset)
371                offset += Question.HEADER.size
372                msg.questions.append(Question(name=name, type=type, class_=class_))
373            except struct.error as e:
374                raise struct.error(f"question #{i}: {str(e)}")
375
376        def unpack_rrs(
377            section: list[ResourceRecord], section_name: str, count: int
378        ) -> None:
379            nonlocal buffer, offset
380            for i in range(0, count):
381                try:
382                    name = unpack_domain_name()
383                    type, class_, ttl, len_data = ResourceRecord.HEADER.unpack_from(
384                        buffer, offset
385                    )
386                    offset += ResourceRecord.HEADER.size
387                    end_data = offset + len_data
388                    if len(buffer) < end_data:
389                        raise struct.error(
390                            f"unpack requires a data buffer of {len_data} bytes"
391                        )
392                    data = buffer[offset:end_data]
393
394                    if domain_names.record_data_can_have_compression(type):
395                        data = domain_names.decompress_from_record_data(
396                            buffer, offset, end_data, cached_names
397                        )
398
399                    section.append(ResourceRecord(name, type, class_, ttl, data))
400                    offset += len_data
401                except struct.error as e:
402                    raise struct.error(f"{section_name} #{i}: {str(e)}")
403
404        unpack_rrs(msg.answers, "answer", len_answers)
405        unpack_rrs(msg.authorities, "authority", len_authorities)
406        unpack_rrs(msg.additionals, "additional", len_additionals)
407        return (offset, msg)
408
409    @property
410    def packed(self) -> bytes:
411        """Converts the message into network bytes."""
412        if self.id < 0 or self.id > 65535:
413            raise ValueError(f"DNS message's id {self.id} is out of bounds.")
414        flags = 0
415        if not self.query:
416            flags |= 1 << 15
417        if self.op_code < 0 or self.op_code > 0b1111:
418            raise ValueError(f"DNS message's op_code {self.op_code} is out of bounds.")
419        flags |= self.op_code << 11
420        if self.authoritative_answer:
421            flags |= 1 << 10
422        if self.truncation:
423            flags |= 1 << 9
424        if self.recursion_desired:
425            flags |= 1 << 8
426        if self.recursion_available:
427            flags |= 1 << 7
428        if self.reserved < 0 or self.reserved > 0b111:
429            raise ValueError(
430                f"DNS message's reserved value of {self.reserved} is out of bounds."
431            )
432        flags |= self.reserved << 4
433        if self.response_code < 0 or self.response_code > 0b1111:
434            raise ValueError(
435                f"DNS message's response_code {self.response_code} is out of bounds."
436            )
437        flags |= self.response_code
438        data = bytearray()
439        data.extend(
440            Message.HEADER.pack(
441                self.id,
442                flags,
443                len(self.questions),
444                len(self.answers),
445                len(self.authorities),
446                len(self.additionals),
447            )
448        )
449        # TODO implement compression
450        for question in self.questions:
451            data.extend(domain_names.pack(question.name))
452            data.extend(Question.HEADER.pack(question.type, question.class_))
453        for rr in (*self.answers, *self.authorities, *self.additionals):
454            data.extend(domain_names.pack(rr.name))
455            data.extend(
456                ResourceRecord.HEADER.pack(rr.type, rr.class_, rr.ttl, len(rr.data))
457            )
458            data.extend(rr.data)
459        return bytes(data)
460
461    def to_json(self) -> dict:
462        """
463        Converts the message into json for mitmweb.
464        Sync with web/src/flow.ts.
465        """
466        return {
467            "id": self.id,
468            "query": self.query,
469            "op_code": op_codes.to_str(self.op_code),
470            "authoritative_answer": self.authoritative_answer,
471            "truncation": self.truncation,
472            "recursion_desired": self.recursion_desired,
473            "recursion_available": self.recursion_available,
474            "response_code": response_codes.to_str(self.response_code),
475            "status_code": response_codes.http_equiv_status_code(self.response_code),
476            "questions": [question.to_json() for question in self.questions],
477            "answers": [rr.to_json() for rr in self.answers],
478            "authorities": [rr.to_json() for rr in self.authorities],
479            "additionals": [rr.to_json() for rr in self.additionals],
480            "size": self.size,
481            "timestamp": self.timestamp,
482        }
483
484    def copy(self) -> Message:
485        # we keep the copy semantics but change the ID generation
486        state = self.get_state()
487        state["id"] = random.randint(0, 65535)
488        return Message.from_state(state)
489
490
491class DNSFlow(flow.Flow):
492    """A DNSFlow is a collection of DNS messages representing a single DNS query."""
493
494    request: Message
495    """The DNS request."""
496    response: Message | None = None
497    """The DNS response."""
498
499    def get_state(self) -> serializable.State:
500        return {
501            **super().get_state(),
502            "request": self.request.get_state(),
503            "response": self.response.get_state() if self.response else None,
504        }
505
506    def set_state(self, state: serializable.State) -> None:
507        self.request = Message.from_state(state.pop("request"))
508        self.response = Message.from_state(r) if (r := state.pop("response")) else None
509        super().set_state(state)
510
511    def __repr__(self) -> str:
512        return f"<DNSFlow\r\n  request={repr(self.request)}\r\n  response={repr(self.response)}\r\n>"
class DNSFlow(mitmproxy.flow.Flow):
492class DNSFlow(flow.Flow):
493    """A DNSFlow is a collection of DNS messages representing a single DNS query."""
494
495    request: Message
496    """The DNS request."""
497    response: Message | None = None
498    """The DNS response."""
499
500    def get_state(self) -> serializable.State:
501        return {
502            **super().get_state(),
503            "request": self.request.get_state(),
504            "response": self.response.get_state() if self.response else None,
505        }
506
507    def set_state(self, state: serializable.State) -> None:
508        self.request = Message.from_state(state.pop("request"))
509        self.response = Message.from_state(r) if (r := state.pop("response")) else None
510        super().set_state(state)
511
512    def __repr__(self) -> str:
513        return f"<DNSFlow\r\n  request={repr(self.request)}\r\n  response={repr(self.response)}\r\n>"

A DNSFlow is a collection of DNS messages representing a single DNS query.

request: mitmproxy.dns.Message

The DNS request.

response: mitmproxy.dns.Message | None = None

The DNS response.

type: ClassVar[str] = 'dns'

The flow type, for example http, tcp, or dns.