/ src / ir_desugar.rs
ir_desugar.rs
  1  //! An IR stage deconstructing struct values into its individual fields.
  2  
  3  // TODO: Desugar function types with struct arguments/returns.
  4  
  5  use crate::compiler_types::{Map, Set, Str};
  6  use crate::ir::*;
  7  
  8  type StructFields = Map<Register, (Ty, Vec<Str>)>;
  9  
 10  fn make_struct_fields(fields: &[(Str, Ty)], next_register: &mut u128) -> StructFields {
 11      fn visit(
 12          this: &mut StructFields,
 13          prefix: &[Str],
 14          fields: &[(Str, Ty)],
 15          next_register: &mut u128,
 16      ) {
 17          for (field, ty) in fields {
 18              let mut path = prefix.to_vec();
 19              path.push(field.clone());
 20              match ty {
 21                  // Types that don't need desugaring
 22                  Ty::Int(_) | Ty::Pointer(_) | Ty::Function { .. } => {
 23                      let r = Register(*next_register);
 24                      *next_register += 1;
 25                      this.insert(r, (ty.clone(), path));
 26                  }
 27                  Ty::Struct(child_fields) => {
 28                      visit(this, &path, child_fields, next_register);
 29                  }
 30              }
 31          }
 32      }
 33      let mut this = StructFields::new();
 34      visit(&mut this, &[], fields, next_register);
 35      this
 36  }
 37  
 38  pub fn desugar_program(program: &mut Program) {
 39      let Program {
 40          functions,
 41          function_tys,
 42      } = program;
 43      for f in functions.values_mut() {
 44          desugar_function(f);
 45      }
 46      for (params, returns) in function_tys.values_mut() {
 47          desugar_ty_vec(params);
 48          desugar_ty_vec(returns);
 49      }
 50  }
 51  
 52  pub fn desugar_function(f: &mut Function) {
 53      let struct_regs: Map<Register, StructFields> = f
 54          .tys
 55          .iter()
 56          .filter_map(|(&r, ty)| match ty {
 57              Ty::Struct(fields) => Some((r, make_struct_fields(fields, &mut f.next_register))),
 58              Ty::Int(_) | Ty::Pointer(_) | Ty::Function { .. } => None,
 59          })
 60          .collect();
 61      for (r, fields) in &struct_regs {
 62          eprintln!("{r}:");
 63          for (r, (ty, accesses)) in fields {
 64              eprintln!("    {r}: {accesses:?} {ty}");
 65          }
 66      }
 67      let Function {
 68          parameters,
 69          blocks,
 70          tys,
 71          spans,
 72          cfg: _, // does not contain registers
 73          next_register,
 74      } = f;
 75  
 76      desugar_vec(&struct_regs, parameters);
 77      for block in blocks.values_mut() {
 78          desugar_block(&struct_regs, block, tys, next_register);
 79      }
 80      for (r, fields) in &struct_regs {
 81          // NOTE: `desugar_block` relies on getting the type of `r`
 82          tys.remove(r);
 83          let span = spans.remove(r).unwrap();
 84          for (&field_r, (field_ty, _)) in fields {
 85              tys.insert(field_r, field_ty.clone());
 86              spans.insert(field_r, span.clone());
 87          }
 88      }
 89      for ty in tys.values_mut() {
 90          desugar_ty(ty);
 91      }
 92  }
 93  
 94  fn desugar_block(
 95      struct_regs: &Map<Register, StructFields>,
 96      block: &mut Block,
 97      tys: &mut Map<Register, Ty>,
 98      next_register: &mut u128,
 99  ) {
100      let Block {
101          insts,
102          exit,
103          defined_regs,
104          used_regs,
105      } = block;
106      for (r, fields) in struct_regs {
107          let desugar_set = |set: &mut Set<Register>| {
108              set.remove(r);
109              set.extend(fields.keys());
110          };
111          desugar_set(defined_regs);
112          desugar_set(used_regs);
113      }
114      // sanity check function: it would be a type error for this register to be a struct
115      let do_not_visit = |r: Register| {
116          assert!(
117              !struct_regs.contains_key(&r),
118              "found struct register {r} in condition"
119          );
120      };
121      match exit {
122          Exit::Jump(_) => {}
123          Exit::CondJump(cond, _, _) => match cond {
124              Condition::NonZero(r) => do_not_visit(*r),
125          },
126          Exit::Return(regs) => desugar_vec(struct_regs, regs),
127      }
128      let mut i = 0;
129      while let Some(inst) = insts.get_mut(i) {
130          use StoreKind as Sk;
131          i += 1;
132          match inst {
133              Inst::Nop => {}
134              Inst::Call {
135                  callee,
136                  returns,
137                  args,
138              } => {
139                  do_not_visit(*callee);
140                  desugar_vec(struct_regs, returns);
141                  desugar_vec(struct_regs, args);
142              }
143              &mut Inst::Write(dst, src) => {
144                  let Some(fields) = struct_regs.get(&src) else {
145                      continue;
146                  };
147                  i -= 1;
148                  insts.remove(i);
149                  for (&r, (_, accesses)) in fields {
150                      let mut ptr = dst;
151                      let mut ty = tys[&src].clone();
152                      for access in accesses {
153                          ty = {
154                              let Ty::Struct(fields) = ty else {
155                                  unreachable!();
156                              };
157                              fields
158                                  .into_iter()
159                                  .find_map(|(name, ty)| (&name == access).then_some(ty))
160                                  .unwrap()
161                          };
162                          let new_ptr = Register(*next_register);
163                          *next_register += 1;
164                          tys.insert(new_ptr, Ty::Pointer(Box::new(ty.clone())));
165                          insts.insert(
166                              i,
167                              Inst::Store(new_ptr, Sk::FieldOffset(ptr, access.clone())),
168                          );
169                          i += 1;
170                          ptr = new_ptr;
171                      }
172                      insts.insert(i, Inst::Write(ptr, r));
173                      i += 1;
174                  }
175              }
176              &mut Inst::Store(r, ref mut sk) => {
177                  if let Sk::StackAlloc(ty) = sk {
178                      eprintln!("before: {ty}");
179                      desugar_ty(ty);
180                      eprintln!("after:  {ty}");
181                  }
182                  let Some(fields) = struct_regs.get(&r) else {
183                      continue;
184                  };
185                  match sk {
186                      // we could `do_not_visit` all of these, but that would be annoying
187                      Sk::Int(..)
188                      | Sk::IntCast(..)
189                      | Sk::PtrOffset(..)
190                      | Sk::FieldOffset(..)
191                      | Sk::StackAlloc(_)
192                      | Sk::Function(_)
193                      | Sk::UnaryOp(UnaryOp::Neg, _)
194                      | Sk::BinOp(BinOp::Add | BinOp::Mul | BinOp::Sub | BinOp::CmpLe, _, _) => {
195                          unreachable!("illegal op on struct during destructuring: {inst:?}")
196                      }
197                      Sk::Copy(copied) => {
198                          let copied_fields = &struct_regs[copied];
199                          i -= 1;
200                          insts.remove(i);
201                          let rs_to = fields.iter().map(|(&r, _)| r);
202                          let rs_from = copied_fields.iter().map(|(&r, _)| r);
203                          for (r_to, r_from) in rs_to.zip(rs_from) {
204                              insts.insert(i, Inst::Store(r_to, Sk::Copy(r_from)));
205                              i += 1;
206                          }
207                      }
208                      Sk::Phi(_) => {
209                          i -= 1;
210                          // we need this to avoid double borrowing
211                          let Inst::Store(_, Sk::Phi(preds)) = insts.remove(i) else {
212                              unreachable!()
213                          };
214                          for (&r, (_, name)) in fields {
215                              let field_preds = preds
216                                  .iter()
217                                  .map(|(&k, v)| {
218                                      (
219                                          k,
220                                          struct_regs[v]
221                                              .iter()
222                                              .find_map(|(&r2, (_, name2))| {
223                                                  (name == name2).then_some(r2)
224                                              })
225                                              .unwrap(),
226                                      )
227                                  })
228                                  .collect();
229                              insts.insert(i, Inst::Store(r, Sk::Phi(field_preds)));
230                              i += 1;
231                          }
232                      }
233                      &mut Sk::Read(src) => {
234                          i -= 1;
235                          insts.remove(i);
236                          for (&r2, (_, accesses)) in fields {
237                              let mut ptr = src;
238                              let mut ty = tys[&r].clone();
239                              for access in accesses {
240                                  ty = {
241                                      let Ty::Struct(fields) = ty else {
242                                          unreachable!();
243                                      };
244                                      fields
245                                          .into_iter()
246                                          .find_map(|(name, ty)| (&name == access).then_some(ty))
247                                          .unwrap()
248                                  };
249                                  let new_ptr = Register(*next_register);
250                                  *next_register += 1;
251                                  tys.insert(new_ptr, Ty::Pointer(Box::new(ty.clone())));
252                                  insts.insert(
253                                      i,
254                                      Inst::Store(new_ptr, Sk::FieldOffset(ptr, access.clone())),
255                                  );
256                                  i += 1;
257                                  ptr = new_ptr;
258                              }
259                              insts.insert(i, Inst::Store(r2, Sk::Read(ptr)));
260                              i += 1;
261                          }
262                      }
263                  }
264              }
265          }
266      }
267  }
268  
269  fn desugar_ty(ty: &mut Ty) {
270      match ty {
271          Ty::Int(_) => {}
272          Ty::Pointer(ty) => desugar_ty(ty),
273          Ty::Function(params, returns) => {
274              desugar_ty_vec(params);
275              desugar_ty_vec(returns);
276          }
277          Ty::Struct(fields) => {
278              for (_, ty) in fields {
279                  desugar_ty(ty);
280              }
281          }
282      }
283  }
284  
285  fn desugar_ty_vec(tys: &mut Vec<Ty>) {
286      let mut i = 0;
287      while let Some(ty) = tys.get_mut(i) {
288          i += 1;
289          match ty {
290              Ty::Int(_) => {}
291              Ty::Pointer(ty) => desugar_ty(ty),
292              Ty::Function(params, returns) => {
293                  desugar_ty_vec(params);
294                  desugar_ty_vec(returns);
295              }
296              Ty::Struct(_) => {
297                  i -= 1;
298                  let old_i = i;
299                  let Ty::Struct(fields) = tys.remove(i) else {
300                      unreachable!();
301                  };
302                  for (_, ty) in fields {
303                      tys.insert(i, ty);
304                      i += 1;
305                  }
306                  i = old_i; // yeah this sucks
307              }
308          }
309      }
310  }
311  
312  fn desugar_vec(struct_regs: &Map<Register, StructFields>, regs: &mut Vec<Register>) {
313      // we could probably write some unsafe code here if this becomes a bottleneck
314      let mut i = 0;
315      while i < regs.len() {
316          if let Some(fields) = struct_regs.get(&regs[i]) {
317              regs.remove(i);
318              for &reg in fields.keys() {
319                  regs.insert(i, reg);
320                  i += 1;
321              }
322          } else {
323              i += 1;
324          }
325      }
326  }