/ Rust / lib / intcode.rs
intcode.rs
  1  use std::collections::VecDeque;
  2  
  3  use rustc_hash::FxHashMap;
  4  use thiserror::Error;
  5  
  6  pub type Int = i64;
  7  
  8  #[derive(Debug, Clone)]
  9  pub struct IntcodeVm {
 10      memory: FxHashMap<Int, Int>,
 11      ip: Int,
 12      base: Int,
 13      input: VecDeque<Int>,
 14      output: VecDeque<Int>,
 15  }
 16  
 17  impl IntcodeVm {
 18      /// Create a new vm and load the given program.
 19      pub fn new(program: impl IntoIterator<Item = Int>) -> Self {
 20          Self::with_input(program, [])
 21      }
 22  
 23      /// Create a new vm and load the given program and input.
 24      pub fn with_input(
 25          program: impl IntoIterator<Item = Int>,
 26          input: impl Into<VecDeque<Int>>,
 27      ) -> Self {
 28          Self {
 29              memory: program
 30                  .into_iter()
 31                  .enumerate()
 32                  .map(|(i, x)| (i as _, x))
 33                  .collect(),
 34              ip: 0,
 35              base: 0,
 36              input: input.into(),
 37              output: Default::default(),
 38          }
 39      }
 40  
 41      /// Push additional input to the end of the queue.
 42      pub fn push_input(&mut self, input: impl IntoIterator<Item = Int>) {
 43          self.input.extend(input);
 44      }
 45  
 46      /// Pop values from the output queue.
 47      pub fn pop_output(&mut self) -> Option<Int> {
 48          self.output.pop_front()
 49      }
 50  
 51      /// Return the next output value.
 52      ///
 53      /// Advances execution only if the output queue is empty.
 54      pub fn next_output(&mut self) -> Result<Option<Int>> {
 55          loop {
 56              if let Some(out) = self.pop_output() {
 57                  return Ok(Some(out));
 58              }
 59              if !self.step()? {
 60                  return Ok(None);
 61              }
 62          }
 63      }
 64  
 65      /// Return and clear the output queue.
 66      pub fn take_output(&mut self) -> VecDeque<Int> {
 67          std::mem::take(&mut self.output)
 68      }
 69  
 70      /// Read the value at the given address from memory.
 71      pub fn read(&self, addr: Int) -> Int {
 72          self.memory.get(&addr).copied().unwrap_or(0)
 73      }
 74  
 75      /// Write the given value to the given address in memory.
 76      pub fn write(&mut self, addr: Int, value: Int) -> Int {
 77          self.memory.insert(addr, value).unwrap_or(0)
 78      }
 79  
 80      /// Run the program until it halts or an error occurs.
 81      pub fn run(&mut self) -> Result<()> {
 82          while self.step()? {}
 83          Ok(())
 84      }
 85  
 86      /// Run a single step of the program (i.e. one operation).
 87      ///
 88      /// Returns `true` if execution can continue and `false` if the program has
 89      /// halted.
 90      pub fn step(&mut self) -> Result<bool> {
 91          match self.parse_instruction()? {
 92              Instruction::Add(in1, in2, out) => {
 93                  self.write_arg(out, self.read_arg(in1) + self.read_arg(in2))?;
 94                  self.ip += 4;
 95              }
 96              Instruction::Mul(in1, in2, out) => {
 97                  self.write_arg(out, self.read_arg(in1) * self.read_arg(in2))?;
 98                  self.ip += 4;
 99              }
100              Instruction::In(arg) => {
101                  let value = self.input.pop_front().ok_or(Error::NeedsInput)?;
102                  self.write_arg(arg, value)?;
103                  self.ip += 2;
104              }
105              Instruction::Out(arg) => {
106                  let value = self.read_arg(arg);
107                  self.output.push_back(value);
108                  self.ip += 2;
109              }
110              Instruction::Jeq(arg, addr) => {
111                  if self.read_arg(arg) != 0 {
112                      self.ip = self.read_arg(addr);
113                  } else {
114                      self.ip += 3;
115                  }
116              }
117              Instruction::Jne(arg, addr) => {
118                  if self.read_arg(arg) == 0 {
119                      self.ip = self.read_arg(addr);
120                  } else {
121                      self.ip += 3;
122                  }
123              }
124              Instruction::Lt(in1, in2, out) => {
125                  self.write_arg(out, (self.read_arg(in1) < self.read_arg(in2)) as _)?;
126                  self.ip += 4;
127              }
128              Instruction::Eq(in1, in2, out) => {
129                  self.write_arg(out, (self.read_arg(in1) == self.read_arg(in2)) as _)?;
130                  self.ip += 4;
131              }
132              Instruction::Base(offset) => {
133                  self.base += self.read_arg(offset);
134                  self.ip += 2;
135              }
136              Instruction::Halt => return Ok(false),
137          }
138  
139          Ok(true)
140      }
141  
142      /// Parse the operation at the current instruction pointer.
143      fn parse_instruction(&self) -> Result<Instruction> {
144          let code = self.read(self.ip) % 100;
145          let mut arg = 0;
146          let mut next_arg = || {
147              arg += 1;
148              self.parse_argument(arg - 1)
149          };
150          Ok(match code {
151              1 => Instruction::Add(next_arg()?, next_arg()?, next_arg()?),
152              2 => Instruction::Mul(next_arg()?, next_arg()?, next_arg()?),
153              3 => Instruction::In(next_arg()?),
154              4 => Instruction::Out(next_arg()?),
155              5 => Instruction::Jeq(next_arg()?, next_arg()?),
156              6 => Instruction::Jne(next_arg()?, next_arg()?),
157              7 => Instruction::Lt(next_arg()?, next_arg()?, next_arg()?),
158              8 => Instruction::Eq(next_arg()?, next_arg()?, next_arg()?),
159              9 => Instruction::Base(next_arg()?),
160              99 => Instruction::Halt,
161              code => return Err(Error::InvalidOpcode { ip: self.ip, code }),
162          })
163      }
164  
165      /// Parse the nth argument of the current instruction.
166      fn parse_argument(&self, n: u32) -> Result<Argument> {
167          let ip = self.ip;
168          let arg = self.read(ip + 1 + n as Int);
169          let mode = self.read(ip) / 10i64.pow(2 + n) % 10;
170          match mode {
171              0 => Ok(Argument::Addr(arg)),
172              1 => Ok(Argument::Immediate(arg)),
173              2 => Ok(Argument::Relative(arg)),
174              _ => Err(Error::InvalidArgMode { ip, n, mode }),
175          }
176      }
177  
178      /// Return the value referenced by the given argument.
179      fn read_arg(&self, arg: Argument) -> Int {
180          match arg {
181              Argument::Addr(addr) => self.read(addr),
182              Argument::Immediate(value) => value,
183              Argument::Relative(offset) => self.read(self.base + offset),
184          }
185      }
186  
187      /// Write the given value to the memory position referenced by the given
188      /// argument.
189      fn write_arg(&mut self, arg: Argument, value: Int) -> Result<Int> {
190          Ok(match arg {
191              Argument::Addr(addr) => self.write(addr, value),
192              Argument::Immediate(_) => return Err(Error::WriteImmediate { ip: self.ip }),
193              Argument::Relative(offset) => self.write(self.base + offset, value),
194          })
195      }
196  }
197  
198  impl Iterator for IntcodeVm {
199      type Item = Result<Int>;
200  
201      fn next(&mut self) -> Option<Self::Item> {
202          self.next_output().transpose()
203      }
204  }
205  
206  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
207  enum Instruction {
208      Add(Argument, Argument, Argument), // 1
209      Mul(Argument, Argument, Argument), // 2
210      In(Argument),                      // 3
211      Out(Argument),                     // 4
212      Jeq(Argument, Argument),           // 5
213      Jne(Argument, Argument),           // 6
214      Lt(Argument, Argument, Argument),  // 7
215      Eq(Argument, Argument, Argument),  // 8
216      Base(Argument),                    // 9
217      Halt,                              // 99
218  }
219  
220  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
221  enum Argument {
222      Addr(Int),
223      Immediate(Int),
224      Relative(Int),
225  }
226  
227  #[derive(Debug, Error)]
228  pub enum Error {
229      #[error("Invalid opcode {code} at {ip}")]
230      InvalidOpcode { ip: Int, code: Int },
231      #[error("Cannot write to arg in immediate mode at {ip}")]
232      WriteImmediate { ip: Int },
233      #[error("Invalid arg mode {mode} at {ip} (arg {n})")]
234      InvalidArgMode { ip: Int, n: u32, mode: Int },
235      #[error("Executing cannot continue until more input is provided")]
236      NeedsInput,
237  }
238  
239  pub type Result<T, E = Error> = core::result::Result<T, E>;