codegen_fox32.rs
1 use crate::compiler_types::{Map, Set, Str}; 2 use crate::ir::*; 3 use crate::ir_liveness::{self, FunctionLiveness}; 4 5 #[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] 6 enum Size { 7 /// 8 bits 8 Byte, 9 /// 32 bits 10 Word, 11 } 12 13 impl Size { 14 fn of_ty(ty: &Ty) -> Self { 15 Self::of_ty_or(ty) 16 .unwrap_or_else(|_| unreachable!("struct type {ty} encountered during codegen")) 17 } 18 fn of_ty_or(ty: &Ty) -> Result<Self, u32> { 19 match ty { 20 Ty::Int(IntKind::U8) => Ok(Self::Byte), 21 Ty::Int(IntKind::Usize) | Ty::Pointer(_) | Ty::Function(..) => Ok(Self::Word), 22 Ty::Struct(fields) => Err(fields.iter().map(|(_, ty)| Self::of_in_bytes(ty)).sum()), 23 } 24 } 25 fn of_inner(ty: &Ty) -> Self { 26 match ty { 27 Ty::Pointer(inner) => Self::of_ty(inner), 28 _ => unreachable!("accessing inner type of non-pointer type `{ty}`"), 29 } 30 } 31 const fn in_bytes(self) -> u32 { 32 match self { 33 Self::Byte => 1, 34 Self::Word => 4, 35 } 36 } 37 fn of_in_bytes(ty: &Ty) -> u32 { 38 Self::of_ty_or(ty).map_or_else(|x| x, Self::in_bytes) 39 } 40 } 41 42 impl std::fmt::Display for Size { 43 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 44 let dot_size = match self { 45 Self::Byte => ".8", 46 Self::Word => "", // implicitly ".32" to the assembler 47 }; 48 write!(f, "{dot_size}") 49 } 50 } 51 52 const NUM_REGISTERS: usize = 31; 53 const TEMP_REG: StoreLoc = StoreLoc::Register(31); 54 55 #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] 56 enum StoreLoc { 57 Register(u8), 58 Constant(Str), 59 // Stack(u32), 60 } 61 62 impl StoreLoc { 63 pub fn foo(&self) -> Str { 64 match self { 65 Self::Register(i) => format!("r{i}").into(), 66 Self::Constant(c) => c.clone(), 67 } 68 } 69 // stricter, must be syntactically dereferenceable 70 pub fn bar(&self) -> Str { 71 match self { 72 Self::Register(i) => format!("r{i}").into(), 73 Self::Constant(c) => c.clone(), 74 } 75 } 76 } 77 78 #[derive(Clone, Debug, Default)] 79 struct LivenessGraph { 80 regs: Map<Register, Set<Register>>, 81 } 82 83 impl LivenessGraph { 84 pub fn new() -> Self { 85 Self::default() 86 } 87 pub fn from_function_liveness(live: &FunctionLiveness) -> Self { 88 let mut this = Self::new(); 89 live.blocks 90 .values() 91 .flat_map(|block_live| std::iter::once(&block_live.start).chain(&block_live.insts)) 92 .for_each(|set| this.insert_set(set)); 93 this 94 } 95 pub fn minmax(a: Register, b: Register) -> (Register, Register) { 96 let min = a.min(b); 97 let max = a.max(b); 98 assert!(min != max, "inserted same register {a} {b}"); 99 (min, max) 100 } 101 pub fn get(&self, a: Register, b: Register) -> bool { 102 let (min, max) = Self::minmax(a, b); 103 self.regs.get(&min).map_or(false, |s| s.contains(&max)) 104 } 105 pub fn insert(&mut self, a: Register, b: Register) { 106 let (min, max) = Self::minmax(a, b); 107 self.regs.entry(min).or_default().insert(max); 108 } 109 pub fn insert_set(&mut self, set: &Set<Register>) { 110 for (i, &min) in set.iter().enumerate() { 111 for &max in set.iter().skip(i + 1) { 112 self.insert(min, max); 113 } 114 } 115 } 116 } 117 118 #[derive(Clone, Debug)] 119 struct RegAllocInfo { 120 pub regs: Map<Register, StoreLoc>, 121 pub local_locs: Map<Register, u32>, 122 pub stack_size: u32, 123 pub liveness: FunctionLiveness, 124 } 125 126 fn reg_alloc(f: &Function) -> RegAllocInfo { 127 let liveness = ir_liveness::calculate_liveness(f); 128 let live_graph = LivenessGraph::from_function_liveness(&liveness); 129 let mut stack_size = 0; 130 let local_locs: Map<Register, u32> = f 131 .blocks 132 .values() 133 .flat_map(|b| &b.insts) 134 .filter_map(|inst| match inst { 135 Inst::Store(reg, StoreKind::StackAlloc(_ty)) => { 136 let ret = Some((*reg, stack_size)); 137 stack_size += 4; 138 ret 139 } 140 _ => None, 141 }) 142 .collect(); 143 let mut regs: Map<Register, StoreLoc> = Map::new(); 144 let mut open: Set<_> = f.tys.keys().copied().collect(); 145 for block in f.blocks.values() { 146 for inst in &block.insts { 147 let Inst::Store(r, sk) = inst else { 148 continue; 149 }; 150 let constant_str = match sk { 151 // cast from i128 to u32 because fox32asm doesn't support negative int literals 152 #[allow(clippy::cast_sign_loss)] 153 &StoreKind::Int(i, kind) => match kind { 154 IntKind::Usize => (i as u32).to_string().into(), 155 IntKind::U8 => (i as u8).to_string().into(), 156 }, 157 StoreKind::Function(name) => name.clone(), 158 _ => continue, 159 }; 160 open.remove(r); 161 regs.insert(*r, StoreLoc::Constant(constant_str)); 162 } 163 } 164 let mut reg_counter = 0; 165 while let Some(reg) = open.pop_first() { 166 let store_loc = if reg_counter < NUM_REGISTERS { 167 let x = StoreLoc::Register(reg_counter as u8); 168 reg_counter += 1; 169 x 170 } else { 171 todo!("stack spilling"); 172 // let x = StoreLoc::Stack(stack_size); 173 // stack_size += 4; 174 // x 175 }; 176 regs.insert(reg, store_loc.clone()); 177 let mut shared = Set::new(); 178 shared.insert(reg); 179 open.retain(|&fellow_reg| { 180 if shared.iter().any(|&r| live_graph.get(r, fellow_reg)) { 181 return true; 182 } 183 shared.insert(fellow_reg); 184 regs.insert(fellow_reg, store_loc.clone()); 185 false 186 }); 187 } 188 RegAllocInfo { 189 regs, 190 local_locs, 191 stack_size, 192 liveness, 193 } 194 } 195 196 macro_rules! write_label { 197 ($dst:expr, $($arg:tt)*) => {{ 198 use ::std::fmt::Write; 199 let w: &mut String = &mut $dst; 200 write!(w, $($arg)*).unwrap(); 201 w.push_str(":\n"); 202 }} 203 } 204 205 macro_rules! write_inst { 206 ($dst:expr, $($arg:tt)*) => {{ 207 use ::std::fmt::Write; 208 let w: &mut String = &mut $dst; 209 w.push_str(" "); 210 write!(w, $($arg)*).unwrap(); 211 w.push('\n'); 212 }} 213 } 214 215 #[allow(unused_macros)] 216 macro_rules! write_comment { 217 ($dst:expr, $($arg:tt)*) => {{ 218 use ::std::fmt::Write; 219 let w: &mut String = &mut $dst; 220 w.push_str("; "); 221 write!(w, $($arg)*).unwrap(); 222 w.push('\n'); 223 }} 224 } 225 226 pub fn gen_program(ir: &Program) -> String { 227 let mut code = String::new(); 228 for (i, (name, f)) in ir.functions.iter().enumerate() { 229 if i != 0 { 230 code.push('\n'); 231 } 232 let fn_output = gen_function(f, name); 233 code.push_str(&fn_output); 234 } 235 code 236 } 237 238 pub fn gen_function(f: &Function, function_name: &str) -> String { 239 let mut code = String::new(); 240 let RegAllocInfo { 241 regs, 242 local_locs, 243 stack_size, 244 liveness, 245 } = reg_alloc(f); 246 { 247 let locs: Set<_> = regs.values().collect(); 248 for loc in locs { 249 eprint!("{}:", loc.foo()); 250 for (r, r_loc) in ®s { 251 if loc == r_loc { 252 eprint!(" {r}"); 253 } 254 } 255 eprintln!(); 256 } 257 } 258 write_label!(code, "{function_name}"); 259 if !f.parameters.is_empty() { 260 write_inst!(code, "pop rfp"); 261 for arg in &f.parameters { 262 match regs.get(arg).unwrap() { 263 StoreLoc::Register(i) => write_inst!(code, "pop r{i}"), 264 StoreLoc::Constant(_) => unreachable!(), 265 // e @ StoreLoc::Stack(_) => todo!("function argument got assigned {e:?}"), 266 } 267 } 268 write_inst!(code, "push rfp"); 269 } 270 if stack_size != 0 { 271 write_inst!(code, "sub rsp, {}", stack_size); 272 } 273 let mut indices: Set<BlockId> = f.blocks.keys().copied().collect(); 274 let mut i = BlockId::ENTRY; 275 loop { 276 use StoreKind as Sk; 277 assert!(indices.remove(&i)); 278 let block = f.blocks.get(&i).unwrap(); 279 write_label!(code, "{function_name}_{}", i.0); 280 for (inst_i, inst) in block.insts.iter().enumerate() { 281 match inst { 282 Inst::Store(r, sk) => { 283 let size = Size::of_ty(&f.tys[r]); 284 let reg = regs.get(r).unwrap(); 285 if matches!(reg, StoreLoc::Constant(_)) { 286 continue; 287 } 288 match sk { 289 Sk::StackAlloc(_) => { 290 write_inst!(code, "mov {}, rsp", reg.foo()); 291 let stack_offset = *local_locs.get(r).unwrap(); 292 if stack_offset != 0 { 293 write_inst!(code, "add {}, {}", reg.foo(), stack_offset); 294 } 295 } 296 Sk::Copy(src) => { 297 let src_reg = regs.get(src).unwrap(); 298 if reg != src_reg { 299 write_inst!(code, "movz{size} {}, {}", reg.foo(), src_reg.foo()); 300 } 301 } 302 Sk::Read(src) => { 303 let inner_size = Size::of_inner(&f.tys[src]); 304 let src_reg = regs.get(src).unwrap(); 305 write_inst!( 306 code, 307 "movz{inner_size} {}, [{}]", 308 reg.foo(), 309 src_reg.bar() 310 ); 311 } 312 Sk::UnaryOp(op, inner) => { 313 let inner_reg = regs.get(inner).unwrap(); 314 let op_mnemonic = match op { 315 UnaryOp::Neg => "neg", 316 }; 317 if reg != inner_reg { 318 write_inst!(code, "mov {}, {}", reg.foo(), inner_reg.foo()); 319 } 320 write_inst!(code, "{}{size} {}", op_mnemonic, reg.foo()); 321 } 322 Sk::BinOp(op, lhs, rhs) => { 323 let lhs_reg = regs.get(lhs).unwrap(); 324 let rhs_reg = regs.get(rhs).unwrap(); 325 let arithmetic = |mnemonic| { 326 Box::new(move |code: &mut String| { 327 if reg != lhs_reg { 328 write_inst!(*code, "mov {}, {}", reg.foo(), lhs_reg.foo()); 329 } 330 write_inst!( 331 *code, 332 "{mnemonic}{size} {}, {}", 333 reg.foo(), 334 rhs_reg.foo(), 335 ); 336 }) as Box<dyn Fn(&mut String)> 337 }; 338 let comparison = |condition| { 339 Box::new(move |code: &mut String| { 340 write_inst!( 341 *code, 342 "cmp{size} {}, {}", 343 lhs_reg.foo(), 344 rhs_reg.foo() 345 ); 346 // NOTE: This `mov` comes after the comparison because `reg` might be the same as `lhs_reg` or `rhs_reg` and we don't want to overwrite the value before the comparison. 347 write_inst!(*code, "mov {}, 0", reg.foo()); 348 write_inst!(*code, "{condition} mov {}, 1", reg.foo()); 349 }) 350 }; 351 let compile = match op { 352 BinOp::Add => arithmetic("add"), 353 BinOp::Sub => arithmetic("sub"), 354 BinOp::Mul => arithmetic("mul"), 355 BinOp::CmpLe => comparison("iflteq"), 356 }; 357 compile(&mut code); 358 } 359 &Sk::IntCast(inner, _kind) => { 360 let inner_reg = ®s[&inner]; 361 // this relies on how we store smaller-than-word types in registers 362 write_inst!(code, "mov {}, {}", reg.foo(), inner_reg.foo()); 363 } 364 Sk::PtrOffset(lhs, rhs) => { 365 let stride = Size::of_inner(&f.tys[lhs]).in_bytes(); 366 let lhs_reg = regs.get(lhs).unwrap(); 367 let rhs_reg = regs.get(rhs).unwrap(); 368 if reg != lhs_reg { 369 write_inst!(code, "mov {}, {}", reg.foo(), lhs_reg.foo()); 370 } 371 write_inst!(code, "mov {}, {stride}", TEMP_REG.foo()); 372 write_inst!(code, "mul {}, {}", TEMP_REG.foo(), rhs_reg.foo()); 373 write_inst!(code, "add {}, {}", reg.foo(), TEMP_REG.foo()); 374 } 375 Sk::FieldOffset(ptr, accessed_field) => { 376 let Ty::Pointer(value) = &f.tys[ptr] else { 377 unreachable!("field offset"); 378 }; 379 let Ty::Struct(fields) = value.as_ref() else { 380 unreachable!("field offset"); 381 }; 382 let mut offset: u32 = 0; 383 for (field_name, field_ty) in fields { 384 if field_name == accessed_field { 385 break; 386 } 387 offset += Size::of_in_bytes(field_ty); 388 } 389 let ptr_reg = ®s[ptr]; 390 if reg != ptr_reg { 391 write_inst!(code, "mov {}, {}", reg.foo(), ptr_reg.foo()); 392 } 393 write_inst!(code, "add {}, {offset}", reg.foo()); 394 } 395 Sk::Int(..) | Sk::Function(_) => unreachable!( 396 "register store should have been optimized as a constant literal" 397 ), 398 Sk::Phi(_) => (), 399 // _ => write_comment!(code, "TODO: {inst:?}"), 400 } 401 } 402 Inst::Write(dst, src) => { 403 let inner_size = Size::of_inner(&f.tys[dst]); 404 let dst_reg = regs.get(dst).unwrap(); 405 let src_reg = regs.get(src).unwrap(); 406 write_inst!( 407 code, 408 "movz{inner_size} [{}], {}", 409 dst_reg.bar(), 410 src_reg.foo() 411 ); 412 } 413 Inst::Call { 414 callee, 415 returns, 416 args, 417 } => { 418 let saved: Set<_> = { 419 let mut saved: Set<_> = liveness.blocks[&i].insts[inst_i] 420 .iter() 421 .map(|r| regs.get(r).unwrap()) 422 .collect(); 423 for r in returns { 424 saved.remove(regs.get(r).unwrap()); 425 } 426 saved 427 }; 428 for &r in saved.iter().rev() { 429 match r { 430 StoreLoc::Register(_) => write_inst!(code, "push {}", r.foo()), 431 StoreLoc::Constant(_) => {} 432 } 433 } 434 for r in args.iter().rev() { 435 let reg = regs.get(r).unwrap(); 436 write_inst!(code, "push {}", reg.foo()); 437 } 438 let callee_reg = regs.get(callee).unwrap(); 439 write_inst!(code, "call {}", callee_reg.foo()); 440 for r in returns { 441 let reg = regs.get(r).unwrap(); 442 write_inst!(code, "pop {}", reg.foo()); 443 } 444 for r in saved { 445 match r { 446 StoreLoc::Register(_) => write_inst!(code, "pop {}", r.foo()), 447 StoreLoc::Constant(_) => {} 448 } 449 } 450 } 451 Inst::Nop => write_inst!(code, "nop"), 452 } 453 } 454 let merge_phis = |code: &mut String, jump_index, prefix| { 455 let jump_block = f.blocks.get(&jump_index).unwrap(); 456 for inst in &jump_block.insts { 457 let Inst::Store(dst, Sk::Phi(srcs)) = inst else { 458 continue; 459 }; 460 let src = srcs.get(&i).unwrap(); 461 let dst_reg = regs.get(dst).unwrap().foo(); 462 let src_reg = regs.get(src).unwrap().foo(); 463 write_inst!(*code, "{prefix}mov {dst_reg}, {src_reg}"); 464 } 465 }; 466 let next_i = match &block.exit { 467 Exit::Jump(loc) => { 468 merge_phis(&mut code, *loc, ""); 469 if indices.contains(loc) { 470 Some(loc) 471 } else { 472 write_inst!(code, "jmp {function_name}_{}", loc.0); 473 None 474 } 475 } 476 Exit::CondJump(cond, branch_true, branch_false) => { 477 match cond { 478 Condition::NonZero(r) => { 479 let reg = regs.get(r).unwrap(); 480 write_inst!(code, "cmp {}, 0", reg.foo()); 481 } 482 } 483 let next_true = { 484 merge_phis(&mut code, *branch_true, "ifnz "); 485 if indices.contains(branch_true) { 486 Some(branch_true) 487 } else { 488 write_inst!(code, "ifnz jmp {function_name}_{}", branch_true.0); 489 None 490 } 491 }; 492 let next_false = { 493 merge_phis(&mut code, *branch_false, "ifz "); 494 if next_true.is_none() && indices.contains(branch_false) { 495 Some(branch_false) 496 } else { 497 write_inst!(code, "ifz jmp {function_name}_{}", branch_false.0); 498 None 499 } 500 }; 501 next_true.or(next_false) 502 } 503 Exit::Return(returns) => { 504 if stack_size != 0 { 505 write_inst!(code, "add rsp, {stack_size}"); 506 } 507 if returns.is_empty() { 508 write_inst!(code, "ret"); 509 } else { 510 write_inst!(code, "pop rfp"); 511 for r in returns { 512 let r_reg = regs.get(r).unwrap().foo(); 513 write_inst!(code, "push {r_reg}"); 514 } 515 write_inst!(code, "jmp rfp"); 516 } 517 None 518 } 519 }; 520 // obviously bad 2 lines of code 521 if next_i.is_some() && indices.contains(next_i.unwrap()) { 522 i = *next_i.unwrap(); 523 } else { 524 match indices.iter().next() { 525 Some(&next_i) => i = next_i, 526 None => break, 527 } 528 } 529 } 530 code 531 }