dns.rs
  1  //! Houses the DNS-specifc code, including the structs that we pack the bytes
  2  //! into and suitable traits and implementations to convert to and from bytes
  3  //! and structs
  4  //!
  5  //! ### Disclaimer
  6  //! This is a very barebones DNS client implementation. It hard-codes a lot of
  7  //! values and is intended only for demonstration purposes on how even custom
  8  //! protocols over TCP can be tunnelled through Tor. It is not meant for any
  9  //! real production usage.
 10  use anyhow::Result;
 11  use std::fmt::Display;
 12  use thiserror::Error;
 13  use tracing::{debug, error};
 14  
 15  #[derive(Error, Debug)]
 16  #[error("Failed to parse bytes into struct!")]
 17  /// Generic error we return if we fail to parse bytes into the struct
 18  struct FromBytesError;
 19  
 20  #[derive(Error, Debug)]
 21  #[error("Invalid domain name passed")]
 22  /// Error we return if a bad domain name is passed
 23  pub struct DomainError;
 24  
 25  /// Hardcoded DNS server, stored as (&str, u16) detailing host and port
 26  pub const DNS_SERVER: (&str, u16) = ("1.1.1.1", 53);
 27  
 28  /// Default value for QTYPE field
 29  const QTYPE: u16 = 0x0001;
 30  /// Default value for QCLASS field
 31  const QCLASS: u16 = 0x0001;
 32  
 33  /// Used to convert struct to raw bytes to be sent over the network
 34  ///
 35  /// Example:
 36  /// ```
 37  /// // We have some struct S that implements this trait
 38  /// let s = S::new();
 39  /// // This prints the raw bytes as debug output
 40  /// dbg!("{}", s.as_bytes());
 41  /// ```
 42  pub trait AsBytes {
 43      /// Return a `Vec<u8>` of the same information stored in struct
 44      ///
 45      /// This is ideal to convert typed values into raw bytes to be sent
 46      /// over the network.
 47      fn as_bytes(&self) -> Vec<u8>;
 48  }
 49  
 50  /// Used to convert raw bytes representation into a Rust struct
 51  ///
 52  /// Example:
 53  /// ```
 54  /// let mut buf: Vec<u8> = Vec::new();
 55  /// // Read the response from a stream
 56  /// stream.read_to_end(&mut buf).await.unwrap();
 57  /// // Interpret the response into a struct S
 58  /// let resp = S::from_bytes(&buf);
 59  /// ```
 60  ///
 61  /// In the above code, `resp` is `Option<Box<S>>` type, so you will have to
 62  /// deal with the `None` value appropriately. This helps denote invalid
 63  /// situations, ie, parse failures
 64  ///
 65  /// You will have to interpret each byte and convert it into each field
 66  /// of your struct yourself when implementing this trait.
 67  pub trait FromBytes {
 68      /// Convert two u8's into a u16
 69      ///
 70      /// It is just a thin wrapper over [u16::from_be_bytes()]
 71      fn u8_to_u16(upper: u8, lower: u8) -> u16 {
 72          let bytes = [upper, lower];
 73          u16::from_be_bytes(bytes)
 74      }
 75      /// Convert four u8's contained in a slice into a u32
 76      ///
 77      /// It is just a thin wrapper over [u32::from_be_bytes()] but also deals
 78      /// with converting &\[u8\] (u8 slice) into [u8; 4] (a fixed size array of u8)
 79      fn u8_to_u32(bytes_slice: &[u8]) -> Result<u32> {
 80          let bytes: [u8; 4] = bytes_slice.try_into()?;
 81          Ok(u32::from_be_bytes(bytes))
 82      }
 83      /// Try converting given bytes into the struct
 84      ///
 85      /// Returns an `Option<Box>` of the struct which implements
 86      /// this trait to help denote parsing failures
 87      fn from_bytes(bytes: &[u8]) -> Result<Box<Self>>;
 88  }
 89  
 90  /// Report length of the struct as in byte stream
 91  ///
 92  /// Note that this doesn't mean length of struct
 93  ///
 94  /// It is simply used to denote how long the struct is if it were
 95  /// sent over the wire
 96  trait Len {
 97      /// Report length of the struct as in byte stream
 98      fn len(&self) -> usize;
 99  }
100  
101  /// DNS Header to be used by both Query and Response
102  ///
103  /// The default values chosen are from the perspective of the client
104  // TODO: For server we will have to interpret given values
105  struct Header {
106      /// Random 16 bit number used to identify the DNS request
107      identification: u16,
108      /// Set of fields packed together into one 16 bit number
109      ///
110      /// Refer to RFC 1035 for more info, but here's a small
111      /// layout of what is packed into this row:
112      ///
113      ///
114      ///   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
115      /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
116      /// |QR|   Opcode  |AA|TC|RD|RA|   Z    |   RCODE   |
117      /// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
118      ///
119      /// TODO: don't rely on cryptic packed bits
120      packed_second_row: u16, // set to 0x100
121      /// Number of questions we have
122      ///
123      /// Here, we set it to 1 since we only ask about one hostname in a query
124      qdcount: u16, // set to 1 since we have 1 question
125      /// Number of answers we have
126      ///
127      /// For a query it will be zero, for a response hopefully it is >= 1
128      ancount: u16, // set to 0 since client doesn't have answers
129      /// Refer to RFC 1035 section 4.1.1, NSCOUNT
130      nscount: u16, // set to 0
131      /// Refer to RFC 1035 section 4.1.1, ARCOUNT
132      arcount: u16, // set to 0
133  }
134  
135  // Ugly, repetitive code to convert all six 16-bit fields into Vec<u8>
136  impl AsBytes for Header {
137      fn as_bytes(&self) -> Vec<u8> {
138          let mut v: Vec<u8> = Vec::with_capacity(14);
139          // These 2 bytes store size of the rest of the payload (including header)
140          // Right now it denotes 51 byte size packet, excluding these 2 bytes
141          // We will change this when we know the size of Query
142          v.push(0x00);
143          v.push(0x33);
144          // Just break u16 into [u8, u8] array and copy into vector
145          v.extend_from_slice(&u16::to_be_bytes(self.identification));
146          v.extend_from_slice(&u16::to_be_bytes(self.packed_second_row));
147          v.extend_from_slice(&u16::to_be_bytes(self.qdcount));
148          v.extend_from_slice(&u16::to_be_bytes(self.ancount));
149          v.extend_from_slice(&u16::to_be_bytes(self.nscount));
150          v.extend_from_slice(&u16::to_be_bytes(self.arcount));
151          v
152      }
153  }
154  
155  impl Display for Header {
156      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157          writeln!(f, "ID: 0x{:x}", self.identification)?;
158          writeln!(f, "Flags: 0x{:x}", self.packed_second_row)?;
159          writeln!(f, "QDCOUNT: 0x{:x}", self.qdcount)?;
160          writeln!(f, "ANCOUNT: 0x{:x}", self.ancount)?;
161          writeln!(f, "NSCOUNT: 0x{:x}", self.nscount)?;
162          writeln!(f, "ARCOUNT: 0x{:x}", self.arcount)?;
163          Ok(())
164      }
165  }
166  
167  impl FromBytes for Header {
168      fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
169          debug!("Parsing the header");
170          let packed_second_row = Header::u8_to_u16(bytes[2], bytes[3]);
171          // 0x8180 denotes we have a response to a standard query,
172          // that isn't truncated, and has recursion requested to a server
173          // that can do recursion, with some bits reserved for future use
174          // and some that are not relevant for our purposes
175          if packed_second_row == 0x8180 {
176              debug!("Correct flags set in response");
177          } else {
178              error!(
179                  "Incorrect flags set in response, we got {}",
180                  packed_second_row
181              );
182              return Err(FromBytesError.into());
183          }
184          // These offsets were determined by looking at RFC 1035
185          Ok(Box::new(Header {
186              identification: Header::u8_to_u16(bytes[0], bytes[1]),
187              packed_second_row,
188              qdcount: Header::u8_to_u16(bytes[4], bytes[5]),
189              ancount: Header::u8_to_u16(bytes[6], bytes[7]),
190              nscount: Header::u8_to_u16(bytes[8], bytes[9]),
191              arcount: Header::u8_to_u16(bytes[10], bytes[11]),
192          }))
193      }
194  }
195  
196  /// The actual query we will send to a DNS server
197  ///
198  /// For now A records are fetched only
199  // TODO: add support for different records to be fetched
200  pub struct Query {
201      /// Header of the DNS packet, see [Header] for more info
202      header: Header,
203      /// The domain name, stored as a `Vec<u8>`
204      ///
205      /// When we call [Query::from_bytes()], `qname` is automatically
206      /// converted into string stored in a `Vec<u8>` instead of the raw
207      /// byte format used for `qname`
208      qname: Vec<u8>, // domain name
209      /// Denotes the type of record to get.
210      ///
211      /// Here we set to 1 to get an A record, ie, IPv4
212      qtype: u16, // set to 0x0001 for A records
213      /// Denotes the class of the record
214      ///
215      /// Here we set to 1 to get an Internet address
216      qclass: u16, // set to 1 for Internet addresses
217  }
218  
219  impl AsBytes for Query {
220      fn as_bytes(&self) -> Vec<u8> {
221          let mut v: Vec<u8> = Vec::new();
222          let header_bytes = self.header.as_bytes();
223          v.extend(header_bytes);
224          v.extend(&self.qname);
225          v.extend_from_slice(&u16::to_be_bytes(self.qtype));
226          v.extend_from_slice(&u16::to_be_bytes(self.qclass));
227          // Now that the packet is ready, we can calculate size and set that in
228          // first two octets
229          // Subtract 2 since these first 2 bits are never counted when reporting
230          // length like this
231          let len_bits = u16::to_be_bytes((v.len() - 2) as u16);
232          v[0] = len_bits[0];
233          v[1] = len_bits[1];
234          v
235      }
236  }
237  
238  impl Len for Query {
239      fn len(&self) -> usize {
240          // extra 1 is for compensating for how we
241          // use one byte more to store length of domain name
242          12 + 1 + self.qname.len() + 2 + 2
243      }
244  }
245  
246  impl FromBytes for Query {
247      // FIXME: the name struct isn't stored as it was sent over the wire
248      fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
249          let header = *Header::from_bytes(&bytes[..12])?;
250          if bytes.len() < 12 {
251              error!("Mismatch between expected number of bytes and given number of bytes!");
252              return Err(FromBytesError.into());
253          }
254          // Parse name
255          let mut name = String::new();
256          // 12 represents size of Header, which we have already parsed, or errored out of
257          let mut lastnamebyte = 12;
258          loop {
259              // bytes[lastnamebytes] denotes the prefix length, we read that many bytes into name
260              let start = lastnamebyte + 1;
261              let end = start + bytes[lastnamebyte] as usize;
262              name.extend(std::str::from_utf8(&bytes[start..end]));
263              lastnamebyte = end;
264              if lastnamebyte >= bytes.len() || bytes[lastnamebyte] == 0 {
265                  // End of domain name, proceed to parse further fields
266                  debug!("Reached end of name, moving on to parse other fields");
267                  lastnamebyte += 1;
268                  break;
269              }
270              name.push('.');
271          }
272          // These offsets were determined by looking at RFC 1035
273          Ok(Box::new(Self {
274              header,
275              qname: name.as_bytes().to_vec(),
276              qtype: Query::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
277              qclass: Query::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
278          }))
279      }
280  }
281  
282  /// A struct which represents one RR
283  struct ResourceRecord {
284      /// Denotes the record type
285      ///
286      /// It is similar to [Query::qtype]
287      rtype: u16, // same as in Query
288      /// Denotes the class of the record
289      ///
290      /// It is similar to [Query::qclass]
291      class: u16, // same as in Query
292      /// The TTL denotes the amount of time in seconds we can cache the result
293      ///
294      /// After the TTL expires, we have to make a fresh request since this
295      /// answer is not guaranteed to be correct
296      ttl: u32, // number of seconds to cache the result
297      /// Denotes the length of data
298      ///
299      /// For this implementation we only request IPv4 addresses, so its value
300      /// will be 4.
301      rdlength: u16, // Length of RDATA
302      /// The actual answer we need
303      ///
304      /// It is an IPv4 address for us in this case
305      rdata: [u8; 4], // IP address
306  }
307  
308  impl Len for ResourceRecord {
309      // return number of bytes it consumes
310      fn len(&self) -> usize {
311          let mut size = 0;
312          size += 2; // name, even though we don't store it here
313          size += 2; // rtype
314          size += 2; // class
315          size += 4; // ttl
316          size += 2; // rdlength
317          size += 4; // rdata
318          size
319      }
320  }
321  
322  impl FromBytes for ResourceRecord {
323      fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
324          let lastnamebyte = 1;
325          let mut rdata = [0u8; 4];
326          if bytes.len() < 15 {
327              return Err(FromBytesError.into());
328          }
329          // Copy over IP address into rdata
330          rdata.copy_from_slice(&bytes[lastnamebyte + 10..lastnamebyte + 14]);
331          // These offsets were determined by looking at RFC 1035
332          Ok(Box::new(Self {
333              rtype: ResourceRecord::u8_to_u16(bytes[lastnamebyte], bytes[lastnamebyte + 1]),
334              class: ResourceRecord::u8_to_u16(bytes[lastnamebyte + 2], bytes[lastnamebyte + 3]),
335              ttl: ResourceRecord::u8_to_u32(&bytes[lastnamebyte + 4..lastnamebyte + 8])?,
336              rdlength: Response::u8_to_u16(bytes[lastnamebyte + 8], bytes[lastnamebyte + 9]),
337              rdata,
338          }))
339      }
340  }
341  
342  impl Display for ResourceRecord {
343      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344          writeln!(f, "RR record type: 0x{:x}", self.rtype)?;
345          writeln!(f, "RR class: 0x{:x}", self.class)?;
346          writeln!(f, "TTL: {}", self.ttl)?;
347          writeln!(f, "RDLENGTH: 0x{:x}", self.rdlength)?;
348          writeln!(
349              f,
350              "IP address: {}.{}.{}.{}",
351              self.rdata[0], self.rdata[1], self.rdata[2], self.rdata[3]
352          )?;
353          Ok(())
354      }
355  }
356  
357  /// Stores the response in easy to interpret manner
358  ///
359  /// A Response is made up of the query given to the server and a bunch of
360  /// Resource Records (RR). Each RR will include the resource type, class, and
361  /// name. For the A records we're requesting, we will get an A record, of Internet class,
362  /// ie an IPv4 address
363  pub struct Response {
364      /// The Query part of the response we obtain from the server
365      query: Query,
366      /// A collection of resource records all parsed neatly and kept separately
367      /// for easy iteration
368      rr: Vec<ResourceRecord>,
369  }
370  
371  impl FromBytes for Response {
372      // Try to construct Response from raw byte data from network
373      // We will also try to check if a valid DNS response has been sent back to us
374      fn from_bytes(bytes: &[u8]) -> Result<Box<Self>> {
375          debug!("Parsing response into struct");
376          // Check message length
377          let l = bytes.len();
378          let messagelen = Response::u8_to_u16(bytes[0], bytes[1]);
379          if messagelen == (l - 2) as u16 {
380              debug!("Appear to have gotten good message from server");
381          } else {
382              error!(
383                  "Expected and observed message length don't match: {} and {} respectively",
384                  l - 2,
385                  messagelen
386              );
387          }
388          // Start index at 2 to skip over message length bytes
389          let mut index = 2;
390          let query = *Query::from_bytes(&bytes[index..])?;
391          index += query.len() + 2; // TODO: needs explanation why it works
392          let mut rrvec: Vec<ResourceRecord> = Vec::new();
393          while index < l {
394              match ResourceRecord::from_bytes(&bytes[index..]) {
395                  Ok(rr) => {
396                      index += rr.len();
397                      rrvec.push(*rr);
398                  }
399                  Err(_) => break,
400              }
401          }
402          Ok(Box::new(Response { query, rr: rrvec }))
403      }
404  }
405  
406  impl Display for Response {
407      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408          writeln!(f, "{}", self.query.header)?;
409          writeln!(
410              f,
411              "Name: {}",
412              String::from_utf8(self.query.qname.to_owned()).unwrap()
413          )?;
414          writeln!(f, "Res type: 0x{:x}", self.query.qtype)?;
415          writeln!(f, "Class: 0x{:x}", self.query.qclass)?;
416          for record in self.rr.iter() {
417              writeln!(f)?;
418              writeln!(f, "{}", record)?;
419          }
420          Ok(())
421      }
422  }
423  
424  /// Craft the actual query for a particular domain and returns a Query object
425  ///
426  /// The query is made for an A record of type Internet, ie, a normal IPv4 address
427  /// should be returned from the DNS server.
428  ///
429  /// Convert this Query into bytes to be sent over the network by calling [Query::as_bytes()]
430  pub fn build_query(domain: &str) -> Result<Query, DomainError> {
431      // TODO: generate identification randomly
432      let header = Header {
433          identification: 0x304e, // chosen by random dice roll, secure
434          packed_second_row: 0x0100,
435          qdcount: 0x0001,
436          ancount: 0x0000,
437          nscount: 0x0000,
438          arcount: 0x0000,
439      };
440      let mut qname: Vec<u8> = Vec::new();
441      let split_domain: Vec<&str> = domain.split('.').collect();
442      for part in split_domain {
443          if part.is_empty() {
444              return Err(DomainError);
445          }
446          let l = part.len() as u8;
447          if l != 0 {
448              qname.push(l);
449              qname.extend_from_slice(part.as_bytes());
450          }
451      }
452      qname.push(0x00); // Denote that hostname has ended by pushing 0x00
453      debug!("Crafted query successfully!");
454      Ok(Query {
455          header,
456          qname,
457          qtype: QTYPE,
458          qclass: QCLASS,
459      })
460  }