ir_desugar.rs
1 //! An IR stage deconstructing struct values into its individual fields. 2 3 // TODO: Desugar function types with struct arguments/returns. 4 5 use crate::compiler_types::{Map, Set, Str}; 6 use crate::ir::*; 7 8 type StructFields = Map<Register, (Ty, Vec<Str>)>; 9 10 fn make_struct_fields(fields: &[(Str, Ty)], next_register: &mut u128) -> StructFields { 11 fn visit( 12 this: &mut StructFields, 13 prefix: &[Str], 14 fields: &[(Str, Ty)], 15 next_register: &mut u128, 16 ) { 17 for (field, ty) in fields { 18 let mut path = prefix.to_vec(); 19 path.push(field.clone()); 20 match ty { 21 // Types that don't need desugaring 22 Ty::Int(_) | Ty::Pointer(_) | Ty::Function { .. } => { 23 let r = Register(*next_register); 24 *next_register += 1; 25 this.insert(r, (ty.clone(), path)); 26 } 27 Ty::Struct(child_fields) => { 28 visit(this, &path, child_fields, next_register); 29 } 30 } 31 } 32 } 33 let mut this = StructFields::new(); 34 visit(&mut this, &[], fields, next_register); 35 this 36 } 37 38 pub fn desugar_program(program: &mut Program) { 39 let Program { 40 functions, 41 function_tys, 42 } = program; 43 for f in functions.values_mut() { 44 desugar_function(f); 45 } 46 for (params, returns) in function_tys.values_mut() { 47 desugar_ty_vec(params); 48 desugar_ty_vec(returns); 49 } 50 } 51 52 pub fn desugar_function(f: &mut Function) { 53 let struct_regs: Map<Register, StructFields> = f 54 .tys 55 .iter() 56 .filter_map(|(&r, ty)| match ty { 57 Ty::Struct(fields) => Some((r, make_struct_fields(fields, &mut f.next_register))), 58 Ty::Int(_) | Ty::Pointer(_) | Ty::Function { .. } => None, 59 }) 60 .collect(); 61 for (r, fields) in &struct_regs { 62 eprintln!("{r}:"); 63 for (r, (ty, accesses)) in fields { 64 eprintln!(" {r}: {accesses:?} {ty}"); 65 } 66 } 67 let Function { 68 parameters, 69 blocks, 70 tys, 71 spans, 72 cfg: _, // does not contain registers 73 next_register, 74 } = f; 75 76 desugar_vec(&struct_regs, parameters); 77 for block in blocks.values_mut() { 78 desugar_block(&struct_regs, block, tys, next_register); 79 } 80 for (r, fields) in &struct_regs { 81 // NOTE: `desugar_block` relies on getting the type of `r` 82 tys.remove(r); 83 let span = spans.remove(r).unwrap(); 84 for (&field_r, (field_ty, _)) in fields { 85 tys.insert(field_r, field_ty.clone()); 86 spans.insert(field_r, span.clone()); 87 } 88 } 89 for ty in tys.values_mut() { 90 desugar_ty(ty); 91 } 92 } 93 94 fn desugar_block( 95 struct_regs: &Map<Register, StructFields>, 96 block: &mut Block, 97 tys: &mut Map<Register, Ty>, 98 next_register: &mut u128, 99 ) { 100 let Block { 101 insts, 102 exit, 103 defined_regs, 104 used_regs, 105 } = block; 106 for (r, fields) in struct_regs { 107 let desugar_set = |set: &mut Set<Register>| { 108 set.remove(r); 109 set.extend(fields.keys()); 110 }; 111 desugar_set(defined_regs); 112 desugar_set(used_regs); 113 } 114 // sanity check function: it would be a type error for this register to be a struct 115 let do_not_visit = |r: Register| { 116 assert!( 117 !struct_regs.contains_key(&r), 118 "found struct register {r} in condition" 119 ); 120 }; 121 match exit { 122 Exit::Jump(_) => {} 123 Exit::CondJump(cond, _, _) => match cond { 124 Condition::NonZero(r) => do_not_visit(*r), 125 }, 126 Exit::Return(regs) => desugar_vec(struct_regs, regs), 127 } 128 let mut i = 0; 129 while let Some(inst) = insts.get_mut(i) { 130 use StoreKind as Sk; 131 i += 1; 132 match inst { 133 Inst::Nop => {} 134 Inst::Call { 135 callee, 136 returns, 137 args, 138 } => { 139 do_not_visit(*callee); 140 desugar_vec(struct_regs, returns); 141 desugar_vec(struct_regs, args); 142 } 143 &mut Inst::Write(dst, src) => { 144 let Some(fields) = struct_regs.get(&src) else { 145 continue; 146 }; 147 i -= 1; 148 insts.remove(i); 149 for (&r, (_, accesses)) in fields { 150 let mut ptr = dst; 151 let mut ty = tys[&src].clone(); 152 for access in accesses { 153 ty = { 154 let Ty::Struct(fields) = ty else { 155 unreachable!(); 156 }; 157 fields 158 .into_iter() 159 .find_map(|(name, ty)| (&name == access).then_some(ty)) 160 .unwrap() 161 }; 162 let new_ptr = Register(*next_register); 163 *next_register += 1; 164 tys.insert(new_ptr, Ty::Pointer(Box::new(ty.clone()))); 165 insts.insert( 166 i, 167 Inst::Store(new_ptr, Sk::FieldOffset(ptr, access.clone())), 168 ); 169 i += 1; 170 ptr = new_ptr; 171 } 172 insts.insert(i, Inst::Write(ptr, r)); 173 i += 1; 174 } 175 } 176 &mut Inst::Store(r, ref mut sk) => { 177 if let Sk::StackAlloc(ty) = sk { 178 eprintln!("before: {ty}"); 179 desugar_ty(ty); 180 eprintln!("after: {ty}"); 181 } 182 let Some(fields) = struct_regs.get(&r) else { 183 continue; 184 }; 185 match sk { 186 // we could `do_not_visit` all of these, but that would be annoying 187 Sk::Int(..) 188 | Sk::IntCast(..) 189 | Sk::PtrOffset(..) 190 | Sk::FieldOffset(..) 191 | Sk::StackAlloc(_) 192 | Sk::Function(_) 193 | Sk::UnaryOp(UnaryOp::Neg, _) 194 | Sk::BinOp(BinOp::Add | BinOp::Mul | BinOp::Sub | BinOp::CmpLe, _, _) => { 195 unreachable!("illegal op on struct during destructuring: {inst:?}") 196 } 197 Sk::Copy(copied) => { 198 let copied_fields = &struct_regs[copied]; 199 i -= 1; 200 insts.remove(i); 201 let rs_to = fields.iter().map(|(&r, _)| r); 202 let rs_from = copied_fields.iter().map(|(&r, _)| r); 203 for (r_to, r_from) in rs_to.zip(rs_from) { 204 insts.insert(i, Inst::Store(r_to, Sk::Copy(r_from))); 205 i += 1; 206 } 207 } 208 Sk::Phi(_) => { 209 i -= 1; 210 // we need this to avoid double borrowing 211 let Inst::Store(_, Sk::Phi(preds)) = insts.remove(i) else { 212 unreachable!() 213 }; 214 for (&r, (_, name)) in fields { 215 let field_preds = preds 216 .iter() 217 .map(|(&k, v)| { 218 ( 219 k, 220 struct_regs[v] 221 .iter() 222 .find_map(|(&r2, (_, name2))| { 223 (name == name2).then_some(r2) 224 }) 225 .unwrap(), 226 ) 227 }) 228 .collect(); 229 insts.insert(i, Inst::Store(r, Sk::Phi(field_preds))); 230 i += 1; 231 } 232 } 233 &mut Sk::Read(src) => { 234 i -= 1; 235 insts.remove(i); 236 for (&r2, (_, accesses)) in fields { 237 let mut ptr = src; 238 let mut ty = tys[&r].clone(); 239 for access in accesses { 240 ty = { 241 let Ty::Struct(fields) = ty else { 242 unreachable!(); 243 }; 244 fields 245 .into_iter() 246 .find_map(|(name, ty)| (&name == access).then_some(ty)) 247 .unwrap() 248 }; 249 let new_ptr = Register(*next_register); 250 *next_register += 1; 251 tys.insert(new_ptr, Ty::Pointer(Box::new(ty.clone()))); 252 insts.insert( 253 i, 254 Inst::Store(new_ptr, Sk::FieldOffset(ptr, access.clone())), 255 ); 256 i += 1; 257 ptr = new_ptr; 258 } 259 insts.insert(i, Inst::Store(r2, Sk::Read(ptr))); 260 i += 1; 261 } 262 } 263 } 264 } 265 } 266 } 267 } 268 269 fn desugar_ty(ty: &mut Ty) { 270 match ty { 271 Ty::Int(_) => {} 272 Ty::Pointer(ty) => desugar_ty(ty), 273 Ty::Function(params, returns) => { 274 desugar_ty_vec(params); 275 desugar_ty_vec(returns); 276 } 277 Ty::Struct(fields) => { 278 for (_, ty) in fields { 279 desugar_ty(ty); 280 } 281 } 282 } 283 } 284 285 fn desugar_ty_vec(tys: &mut Vec<Ty>) { 286 let mut i = 0; 287 while let Some(ty) = tys.get_mut(i) { 288 i += 1; 289 match ty { 290 Ty::Int(_) => {} 291 Ty::Pointer(ty) => desugar_ty(ty), 292 Ty::Function(params, returns) => { 293 desugar_ty_vec(params); 294 desugar_ty_vec(returns); 295 } 296 Ty::Struct(_) => { 297 i -= 1; 298 let old_i = i; 299 let Ty::Struct(fields) = tys.remove(i) else { 300 unreachable!(); 301 }; 302 for (_, ty) in fields { 303 tys.insert(i, ty); 304 i += 1; 305 } 306 i = old_i; // yeah this sucks 307 } 308 } 309 } 310 } 311 312 fn desugar_vec(struct_regs: &Map<Register, StructFields>, regs: &mut Vec<Register>) { 313 // we could probably write some unsafe code here if this becomes a bottleneck 314 let mut i = 0; 315 while i < regs.len() { 316 if let Some(fields) = struct_regs.get(®s[i]) { 317 regs.remove(i); 318 for ® in fields.keys() { 319 regs.insert(i, reg); 320 i += 1; 321 } 322 } else { 323 i += 1; 324 } 325 } 326 }