/ src / codegen_fox32.rs
codegen_fox32.rs
  1  use crate::compiler_types::{Map, Set, Str};
  2  use crate::ir::*;
  3  use crate::ir_liveness::{self, FunctionLiveness};
  4  
  5  #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
  6  enum Size {
  7      /// 8 bits
  8      Byte,
  9      /// 32 bits
 10      Word,
 11  }
 12  
 13  impl Size {
 14      fn of_ty(ty: &Ty) -> Self {
 15          Self::of_ty_or(ty)
 16              .unwrap_or_else(|_| unreachable!("struct type {ty} encountered during codegen"))
 17      }
 18      fn of_ty_or(ty: &Ty) -> Result<Self, u32> {
 19          match ty {
 20              Ty::Int(IntKind::U8) => Ok(Self::Byte),
 21              Ty::Int(IntKind::Usize) | Ty::Pointer(_) | Ty::Function(..) => Ok(Self::Word),
 22              Ty::Struct(fields) => Err(fields.iter().map(|(_, ty)| Self::of_in_bytes(ty)).sum()),
 23          }
 24      }
 25      fn of_inner(ty: &Ty) -> Self {
 26          match ty {
 27              Ty::Pointer(inner) => Self::of_ty(inner),
 28              _ => unreachable!("accessing inner type of non-pointer type `{ty}`"),
 29          }
 30      }
 31      const fn in_bytes(self) -> u32 {
 32          match self {
 33              Self::Byte => 1,
 34              Self::Word => 4,
 35          }
 36      }
 37      fn of_in_bytes(ty: &Ty) -> u32 {
 38          Self::of_ty_or(ty).map_or_else(|x| x, Self::in_bytes)
 39      }
 40  }
 41  
 42  impl std::fmt::Display for Size {
 43      fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 44          let dot_size = match self {
 45              Self::Byte => ".8",
 46              Self::Word => "", // implicitly ".32" to the assembler
 47          };
 48          write!(f, "{dot_size}")
 49      }
 50  }
 51  
 52  const NUM_REGISTERS: usize = 31;
 53  const TEMP_REG: StoreLoc = StoreLoc::Register(31);
 54  
 55  #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
 56  enum StoreLoc {
 57      Register(u8),
 58      Constant(Str),
 59      // Stack(u32),
 60  }
 61  
 62  impl StoreLoc {
 63      pub fn foo(&self) -> Str {
 64          match self {
 65              Self::Register(i) => format!("r{i}").into(),
 66              Self::Constant(c) => c.clone(),
 67          }
 68      }
 69      // stricter, must be syntactically dereferenceable
 70      pub fn bar(&self) -> Str {
 71          match self {
 72              Self::Register(i) => format!("r{i}").into(),
 73              Self::Constant(c) => c.clone(),
 74          }
 75      }
 76  }
 77  
 78  #[derive(Clone, Debug, Default)]
 79  struct LivenessGraph {
 80      regs: Map<Register, Set<Register>>,
 81  }
 82  
 83  impl LivenessGraph {
 84      pub fn new() -> Self {
 85          Self::default()
 86      }
 87      pub fn from_function_liveness(live: &FunctionLiveness) -> Self {
 88          let mut this = Self::new();
 89          live.blocks
 90              .values()
 91              .flat_map(|block_live| std::iter::once(&block_live.start).chain(&block_live.insts))
 92              .for_each(|set| this.insert_set(set));
 93          this
 94      }
 95      pub fn minmax(a: Register, b: Register) -> (Register, Register) {
 96          let min = a.min(b);
 97          let max = a.max(b);
 98          assert!(min != max, "inserted same register {a} {b}");
 99          (min, max)
100      }
101      pub fn get(&self, a: Register, b: Register) -> bool {
102          let (min, max) = Self::minmax(a, b);
103          self.regs.get(&min).map_or(false, |s| s.contains(&max))
104      }
105      pub fn insert(&mut self, a: Register, b: Register) {
106          let (min, max) = Self::minmax(a, b);
107          self.regs.entry(min).or_default().insert(max);
108      }
109      pub fn insert_set(&mut self, set: &Set<Register>) {
110          for (i, &min) in set.iter().enumerate() {
111              for &max in set.iter().skip(i + 1) {
112                  self.insert(min, max);
113              }
114          }
115      }
116  }
117  
118  #[derive(Clone, Debug)]
119  struct RegAllocInfo {
120      pub regs: Map<Register, StoreLoc>,
121      pub local_locs: Map<Register, u32>,
122      pub stack_size: u32,
123      pub liveness: FunctionLiveness,
124  }
125  
126  fn reg_alloc(f: &Function) -> RegAllocInfo {
127      let liveness = ir_liveness::calculate_liveness(f);
128      let live_graph = LivenessGraph::from_function_liveness(&liveness);
129      let mut stack_size = 0;
130      let local_locs: Map<Register, u32> = f
131          .blocks
132          .values()
133          .flat_map(|b| &b.insts)
134          .filter_map(|inst| match inst {
135              Inst::Store(reg, StoreKind::StackAlloc(_ty)) => {
136                  let ret = Some((*reg, stack_size));
137                  stack_size += 4;
138                  ret
139              }
140              _ => None,
141          })
142          .collect();
143      let mut regs: Map<Register, StoreLoc> = Map::new();
144      let mut open: Set<_> = f.tys.keys().copied().collect();
145      for block in f.blocks.values() {
146          for inst in &block.insts {
147              let Inst::Store(r, sk) = inst else {
148                  continue;
149              };
150              let constant_str = match sk {
151                  // cast from i128 to u32 because fox32asm doesn't support negative int literals
152                  #[allow(clippy::cast_sign_loss)]
153                  &StoreKind::Int(i, kind) => match kind {
154                      IntKind::Usize => (i as u32).to_string().into(),
155                      IntKind::U8 => (i as u8).to_string().into(),
156                  },
157                  StoreKind::Function(name) => name.clone(),
158                  _ => continue,
159              };
160              open.remove(r);
161              regs.insert(*r, StoreLoc::Constant(constant_str));
162          }
163      }
164      let mut reg_counter = 0;
165      while let Some(reg) = open.pop_first() {
166          let store_loc = if reg_counter < NUM_REGISTERS {
167              let x = StoreLoc::Register(reg_counter as u8);
168              reg_counter += 1;
169              x
170          } else {
171              todo!("stack spilling");
172              // let x = StoreLoc::Stack(stack_size);
173              // stack_size += 4;
174              // x
175          };
176          regs.insert(reg, store_loc.clone());
177          let mut shared = Set::new();
178          shared.insert(reg);
179          open.retain(|&fellow_reg| {
180              if shared.iter().any(|&r| live_graph.get(r, fellow_reg)) {
181                  return true;
182              }
183              shared.insert(fellow_reg);
184              regs.insert(fellow_reg, store_loc.clone());
185              false
186          });
187      }
188      RegAllocInfo {
189          regs,
190          local_locs,
191          stack_size,
192          liveness,
193      }
194  }
195  
196  macro_rules! write_label {
197      ($dst:expr, $($arg:tt)*) => {{
198          use ::std::fmt::Write;
199          let w: &mut String = &mut $dst;
200          write!(w, $($arg)*).unwrap();
201          w.push_str(":\n");
202      }}
203  }
204  
205  macro_rules! write_inst {
206      ($dst:expr, $($arg:tt)*) => {{
207          use ::std::fmt::Write;
208          let w: &mut String = &mut $dst;
209          w.push_str("    ");
210          write!(w, $($arg)*).unwrap();
211          w.push('\n');
212      }}
213  }
214  
215  #[allow(unused_macros)]
216  macro_rules! write_comment {
217      ($dst:expr, $($arg:tt)*) => {{
218          use ::std::fmt::Write;
219          let w: &mut String = &mut $dst;
220          w.push_str("; ");
221          write!(w, $($arg)*).unwrap();
222          w.push('\n');
223      }}
224  }
225  
226  pub fn gen_program(ir: &Program) -> String {
227      let mut code = String::new();
228      for (i, (name, f)) in ir.functions.iter().enumerate() {
229          if i != 0 {
230              code.push('\n');
231          }
232          let fn_output = gen_function(f, name);
233          code.push_str(&fn_output);
234      }
235      code
236  }
237  
238  pub fn gen_function(f: &Function, function_name: &str) -> String {
239      let mut code = String::new();
240      let RegAllocInfo {
241          regs,
242          local_locs,
243          stack_size,
244          liveness,
245      } = reg_alloc(f);
246      {
247          let locs: Set<_> = regs.values().collect();
248          for loc in locs {
249              eprint!("{}:", loc.foo());
250              for (r, r_loc) in &regs {
251                  if loc == r_loc {
252                      eprint!(" {r}");
253                  }
254              }
255              eprintln!();
256          }
257      }
258      write_label!(code, "{function_name}");
259      if !f.parameters.is_empty() {
260          write_inst!(code, "pop rfp");
261          for arg in &f.parameters {
262              match regs.get(arg).unwrap() {
263                  StoreLoc::Register(i) => write_inst!(code, "pop r{i}"),
264                  StoreLoc::Constant(_) => unreachable!(),
265                  // e @ StoreLoc::Stack(_) => todo!("function argument got assigned {e:?}"),
266              }
267          }
268          write_inst!(code, "push rfp");
269      }
270      if stack_size != 0 {
271          write_inst!(code, "sub rsp, {}", stack_size);
272      }
273      let mut indices: Set<BlockId> = f.blocks.keys().copied().collect();
274      let mut i = BlockId::ENTRY;
275      loop {
276          use StoreKind as Sk;
277          assert!(indices.remove(&i));
278          let block = f.blocks.get(&i).unwrap();
279          write_label!(code, "{function_name}_{}", i.0);
280          for (inst_i, inst) in block.insts.iter().enumerate() {
281              match inst {
282                  Inst::Store(r, sk) => {
283                      let size = Size::of_ty(&f.tys[r]);
284                      let reg = regs.get(r).unwrap();
285                      if matches!(reg, StoreLoc::Constant(_)) {
286                          continue;
287                      }
288                      match sk {
289                          Sk::StackAlloc(_) => {
290                              write_inst!(code, "mov {}, rsp", reg.foo());
291                              let stack_offset = *local_locs.get(r).unwrap();
292                              if stack_offset != 0 {
293                                  write_inst!(code, "add {}, {}", reg.foo(), stack_offset);
294                              }
295                          }
296                          Sk::Copy(src) => {
297                              let src_reg = regs.get(src).unwrap();
298                              if reg != src_reg {
299                                  write_inst!(code, "movz{size} {}, {}", reg.foo(), src_reg.foo());
300                              }
301                          }
302                          Sk::Read(src) => {
303                              let inner_size = Size::of_inner(&f.tys[src]);
304                              let src_reg = regs.get(src).unwrap();
305                              write_inst!(
306                                  code,
307                                  "movz{inner_size} {}, [{}]",
308                                  reg.foo(),
309                                  src_reg.bar()
310                              );
311                          }
312                          Sk::UnaryOp(op, inner) => {
313                              let inner_reg = regs.get(inner).unwrap();
314                              let op_mnemonic = match op {
315                                  UnaryOp::Neg => "neg",
316                              };
317                              if reg != inner_reg {
318                                  write_inst!(code, "mov {}, {}", reg.foo(), inner_reg.foo());
319                              }
320                              write_inst!(code, "{}{size} {}", op_mnemonic, reg.foo());
321                          }
322                          Sk::BinOp(op, lhs, rhs) => {
323                              let lhs_reg = regs.get(lhs).unwrap();
324                              let rhs_reg = regs.get(rhs).unwrap();
325                              let arithmetic = |mnemonic| {
326                                  Box::new(move |code: &mut String| {
327                                      if reg != lhs_reg {
328                                          write_inst!(*code, "mov {}, {}", reg.foo(), lhs_reg.foo());
329                                      }
330                                      write_inst!(
331                                          *code,
332                                          "{mnemonic}{size} {}, {}",
333                                          reg.foo(),
334                                          rhs_reg.foo(),
335                                      );
336                                  }) as Box<dyn Fn(&mut String)>
337                              };
338                              let comparison = |condition| {
339                                  Box::new(move |code: &mut String| {
340                                      write_inst!(
341                                          *code,
342                                          "cmp{size} {}, {}",
343                                          lhs_reg.foo(),
344                                          rhs_reg.foo()
345                                      );
346                                      // NOTE: This `mov` comes after the comparison because `reg` might be the same as `lhs_reg` or `rhs_reg` and we don't want to overwrite the value before the comparison.
347                                      write_inst!(*code, "mov {}, 0", reg.foo());
348                                      write_inst!(*code, "{condition} mov {}, 1", reg.foo());
349                                  })
350                              };
351                              let compile = match op {
352                                  BinOp::Add => arithmetic("add"),
353                                  BinOp::Sub => arithmetic("sub"),
354                                  BinOp::Mul => arithmetic("mul"),
355                                  BinOp::CmpLe => comparison("iflteq"),
356                              };
357                              compile(&mut code);
358                          }
359                          &Sk::IntCast(inner, _kind) => {
360                              let inner_reg = &regs[&inner];
361                              // this relies on how we store smaller-than-word types in registers
362                              write_inst!(code, "mov {}, {}", reg.foo(), inner_reg.foo());
363                          }
364                          Sk::PtrOffset(lhs, rhs) => {
365                              let stride = Size::of_inner(&f.tys[lhs]).in_bytes();
366                              let lhs_reg = regs.get(lhs).unwrap();
367                              let rhs_reg = regs.get(rhs).unwrap();
368                              if reg != lhs_reg {
369                                  write_inst!(code, "mov {}, {}", reg.foo(), lhs_reg.foo());
370                              }
371                              write_inst!(code, "mov {}, {stride}", TEMP_REG.foo());
372                              write_inst!(code, "mul {}, {}", TEMP_REG.foo(), rhs_reg.foo());
373                              write_inst!(code, "add {}, {}", reg.foo(), TEMP_REG.foo());
374                          }
375                          Sk::FieldOffset(ptr, accessed_field) => {
376                              let Ty::Pointer(value) = &f.tys[ptr] else {
377                                  unreachable!("field offset");
378                              };
379                              let Ty::Struct(fields) = value.as_ref() else {
380                                  unreachable!("field offset");
381                              };
382                              let mut offset: u32 = 0;
383                              for (field_name, field_ty) in fields {
384                                  if field_name == accessed_field {
385                                      break;
386                                  }
387                                  offset += Size::of_in_bytes(field_ty);
388                              }
389                              let ptr_reg = &regs[ptr];
390                              if reg != ptr_reg {
391                                  write_inst!(code, "mov {}, {}", reg.foo(), ptr_reg.foo());
392                              }
393                              write_inst!(code, "add {}, {offset}", reg.foo());
394                          }
395                          Sk::Int(..) | Sk::Function(_) => unreachable!(
396                              "register store should have been optimized as a constant literal"
397                          ),
398                          Sk::Phi(_) => (),
399                          // _ => write_comment!(code, "TODO: {inst:?}"),
400                      }
401                  }
402                  Inst::Write(dst, src) => {
403                      let inner_size = Size::of_inner(&f.tys[dst]);
404                      let dst_reg = regs.get(dst).unwrap();
405                      let src_reg = regs.get(src).unwrap();
406                      write_inst!(
407                          code,
408                          "movz{inner_size} [{}], {}",
409                          dst_reg.bar(),
410                          src_reg.foo()
411                      );
412                  }
413                  Inst::Call {
414                      callee,
415                      returns,
416                      args,
417                  } => {
418                      let saved: Set<_> = {
419                          let mut saved: Set<_> = liveness.blocks[&i].insts[inst_i]
420                              .iter()
421                              .map(|r| regs.get(r).unwrap())
422                              .collect();
423                          for r in returns {
424                              saved.remove(regs.get(r).unwrap());
425                          }
426                          saved
427                      };
428                      for &r in saved.iter().rev() {
429                          match r {
430                              StoreLoc::Register(_) => write_inst!(code, "push {}", r.foo()),
431                              StoreLoc::Constant(_) => {}
432                          }
433                      }
434                      for r in args.iter().rev() {
435                          let reg = regs.get(r).unwrap();
436                          write_inst!(code, "push {}", reg.foo());
437                      }
438                      let callee_reg = regs.get(callee).unwrap();
439                      write_inst!(code, "call {}", callee_reg.foo());
440                      for r in returns {
441                          let reg = regs.get(r).unwrap();
442                          write_inst!(code, "pop {}", reg.foo());
443                      }
444                      for r in saved {
445                          match r {
446                              StoreLoc::Register(_) => write_inst!(code, "pop {}", r.foo()),
447                              StoreLoc::Constant(_) => {}
448                          }
449                      }
450                  }
451                  Inst::Nop => write_inst!(code, "nop"),
452              }
453          }
454          let merge_phis = |code: &mut String, jump_index, prefix| {
455              let jump_block = f.blocks.get(&jump_index).unwrap();
456              for inst in &jump_block.insts {
457                  let Inst::Store(dst, Sk::Phi(srcs)) = inst else {
458                      continue;
459                  };
460                  let src = srcs.get(&i).unwrap();
461                  let dst_reg = regs.get(dst).unwrap().foo();
462                  let src_reg = regs.get(src).unwrap().foo();
463                  write_inst!(*code, "{prefix}mov {dst_reg}, {src_reg}");
464              }
465          };
466          let next_i = match &block.exit {
467              Exit::Jump(loc) => {
468                  merge_phis(&mut code, *loc, "");
469                  if indices.contains(loc) {
470                      Some(loc)
471                  } else {
472                      write_inst!(code, "jmp {function_name}_{}", loc.0);
473                      None
474                  }
475              }
476              Exit::CondJump(cond, branch_true, branch_false) => {
477                  match cond {
478                      Condition::NonZero(r) => {
479                          let reg = regs.get(r).unwrap();
480                          write_inst!(code, "cmp {}, 0", reg.foo());
481                      }
482                  }
483                  let next_true = {
484                      merge_phis(&mut code, *branch_true, "ifnz ");
485                      if indices.contains(branch_true) {
486                          Some(branch_true)
487                      } else {
488                          write_inst!(code, "ifnz jmp {function_name}_{}", branch_true.0);
489                          None
490                      }
491                  };
492                  let next_false = {
493                      merge_phis(&mut code, *branch_false, "ifz ");
494                      if next_true.is_none() && indices.contains(branch_false) {
495                          Some(branch_false)
496                      } else {
497                          write_inst!(code, "ifz jmp {function_name}_{}", branch_false.0);
498                          None
499                      }
500                  };
501                  next_true.or(next_false)
502              }
503              Exit::Return(returns) => {
504                  if stack_size != 0 {
505                      write_inst!(code, "add rsp, {stack_size}");
506                  }
507                  if returns.is_empty() {
508                      write_inst!(code, "ret");
509                  } else {
510                      write_inst!(code, "pop rfp");
511                      for r in returns {
512                          let r_reg = regs.get(r).unwrap().foo();
513                          write_inst!(code, "push {r_reg}");
514                      }
515                      write_inst!(code, "jmp rfp");
516                  }
517                  None
518              }
519          };
520          // obviously bad 2 lines of code
521          if next_i.is_some() && indices.contains(next_i.unwrap()) {
522              i = *next_i.unwrap();
523          } else {
524              match indices.iter().next() {
525                  Some(&next_i) => i = next_i,
526                  None => break,
527              }
528          }
529      }
530      code
531  }