/ src / ir_builder.rs
ir_builder.rs
  1  use crate::ast;
  2  use crate::compiler_types::{Map, Name, Span, Spanned, Str};
  3  use crate::ir::*;
  4  
  5  /// This trait defines a helper method for transforming a `T` into an `Option<T>` with a postfix syntax.
  6  trait ToSome {
  7      fn some(self) -> Option<Self>
  8      where
  9          Self: Sized;
 10  
 11      fn some_if(self, condition: bool) -> Option<Self>
 12      where
 13          Self: Sized;
 14  }
 15  
 16  impl<T> ToSome for T {
 17      fn some(self) -> Option<Self> {
 18          Some(self)
 19      }
 20  
 21      fn some_if(self, condition: bool) -> Option<Self> {
 22          condition.then_some(self)
 23      }
 24  }
 25  
 26  pub enum ErrorKind {
 27      NotFound(&'static str, Str),
 28      NameConflict(&'static str, Option<Span>),
 29      DoesNotYield(Span),
 30      CantAssignToConstant,
 31      UnknownIntLiteralSuffix,
 32      CantCastToTy(Ty),
 33      #[allow(dead_code)]
 34      Todo(&'static str),
 35  }
 36  
 37  type Error = Spanned<ErrorKind>;
 38  type Result<T> = std::result::Result<T, Error>;
 39  
 40  #[derive(Clone, Debug)]
 41  struct IrBuilder<'a> {
 42      parameters: Vec<Register>,
 43      current_block: Vec<Inst>,
 44      current_block_id: BlockId,
 45      next_block_id: usize,
 46      blocks: Map<BlockId, Block>,
 47      tys: Map<Register, Ty>,
 48      spans: Map<Register, Span>,
 49      scopes: Vec<Map<Str, Register>>,
 50      next_reg_id: u128,
 51      function_tys: &'a Map<Str, (Vec<Ty>, Option<Ty>)>,
 52      defined_tys: &'a DefinedTys,
 53  }
 54  
 55  #[derive(Clone, Copy, Debug, Eq, PartialEq)]
 56  enum MaybeVar {
 57      // A register containing the value we're accessing.
 58      Constant(Register),
 59      // A register containing a pointer to the value we're accessing.
 60      Variable(Register),
 61  }
 62  
 63  #[derive(Clone, Debug)]
 64  struct DefinedTys {
 65      tys: Map<Str, (Ty, Option<Span>)>,
 66  }
 67  
 68  impl DefinedTys {
 69      fn build_ty(&self, ty: &ast::Ty) -> Result<Ty> {
 70          use ast::TyKind as T;
 71          match &ty.kind {
 72              T::Name(name) => match self.tys.get(&name.kind) {
 73                  Some((ty, _)) => Ok(ty.clone()),
 74                  None => Err(Error {
 75                      kind: ErrorKind::NotFound("type", name.kind.clone()),
 76                      span: ty.span.clone(),
 77                  }),
 78              },
 79              T::Pointer(inner) => self.build_ty(inner).map(Box::new).map(Ty::Pointer),
 80              T::Function(params, returns) => {
 81                  let params = params
 82                      .iter()
 83                      .map(|t| self.build_ty(t))
 84                      .collect::<Result<_>>()?;
 85                  let returns = returns
 86                      .iter()
 87                      .map(|t| self.build_ty(t))
 88                      .collect::<Result<_>>()?;
 89                  Ok(Ty::Function(params, returns))
 90              }
 91          }
 92      }
 93  }
 94  
 95  impl<'a> IrBuilder<'a> {
 96      fn new(function_tys: &'a Map<Str, (Vec<Ty>, Option<Ty>)>, defined_tys: &'a DefinedTys) -> Self {
 97          Self {
 98              parameters: vec![],
 99              current_block: vec![],
100              current_block_id: BlockId::ENTRY,
101              next_block_id: BlockId::ENTRY.0 + 1,
102              blocks: Map::new(),
103              tys: Map::new(),
104              spans: Map::new(),
105              scopes: vec![Map::new()],
106              next_reg_id: 0,
107              function_tys,
108              defined_tys,
109          }
110      }
111  
112      fn build_function(
113          mut self,
114          ast::Function {
115              name: _,
116              parameters,
117              returns,
118              body,
119          }: &ast::Function,
120      ) -> Result<Function> {
121          for (p_name, p_ty) in parameters {
122              let ir_ty = self.build_ty(p_ty)?;
123              let reg = self.new_reg(ir_ty.clone(), p_name.span.clone());
124              self.parameters.push(reg);
125              // Currently, we assume all variables are stack allocated, so we copy the argument to a stack allocation.
126              let var_reg = self.new_var(p_name.clone(), ir_ty);
127              self.push_write(var_reg, reg);
128          }
129          let return_regs = if let Some(returns) = returns.as_ref() {
130              vec![self.build_block_unvoid(body, returns.span.clone())?]
131          } else {
132              self.build_block(body, false)?;
133              vec![]
134          };
135          self.switch_to_new_block(Exit::Return(return_regs), BlockId::DUMMY);
136          assert_eq!(self.scopes.len(), 1);
137          Ok(Function::new(
138              self.parameters,
139              self.blocks,
140              self.tys,
141              self.spans,
142              self.next_reg_id,
143          ))
144      }
145  
146      fn build_block_unvoid(&mut self, block: &ast::Block, outer: Span) -> Result<Register> {
147          let r = self.build_block(block, true)?;
148          r.ok_or_else(|| Error {
149              kind: ErrorKind::DoesNotYield(outer),
150              span: block.0.last().unwrap().span.clone(),
151          })
152      }
153  
154      fn build_block(
155          &mut self,
156          ast::Block(stmts): &ast::Block,
157          unvoid: bool,
158      ) -> Result<Option<Register>> {
159          self.enter_scope();
160          let mut last_stmt_return = None;
161          for (i, stmt) in stmts.iter().enumerate() {
162              let is_last = i == stmts.len() - 1;
163              last_stmt_return = self.build_stmt(stmt, unvoid && is_last)?;
164          }
165          self.exit_scope();
166          Ok(last_stmt_return)
167      }
168  
169      fn build_stmt(&mut self, stmt: &ast::Stmt, unvoid: bool) -> Result<Option<Register>> {
170          use ast::StmtKind as S;
171          let ast::Stmt { kind, span } = stmt;
172          let span = span.clone();
173          let reg = match kind {
174              S::Let(name, ty, body) => {
175                  let value_reg = self.build_expr_unvoid(body, span)?;
176                  let alloc_ty = match ty {
177                      Some(t) => self.build_ty(t)?,
178                      None => self.tys.get(&value_reg).unwrap().clone(),
179                  };
180                  let alloc_reg = self.new_var(name.clone(), alloc_ty);
181                  self.push_write(alloc_reg, value_reg);
182                  None
183              }
184              S::Expr(expr) => self.build_expr(expr, unvoid)?,
185          };
186          Ok(reg)
187      }
188  
189      fn build_expr_unvoid(&mut self, expr: &ast::Expr, outer: Span) -> Result<Register> {
190          let reg = self.build_expr(expr, true)?;
191          match reg {
192              Some(r) => Ok(r),
193              None => Err(Error {
194                  kind: ErrorKind::DoesNotYield(outer),
195                  span: expr.span.clone(),
196              }),
197          }
198      }
199  
200      fn build_expr(&mut self, expr: &ast::Expr, unvoid: bool) -> Result<Option<Register>> {
201          use ast::ExprKind as E;
202          use StoreKind as Sk;
203          let ast::Expr { kind, span } = expr;
204          let span = span.clone();
205          // let span2 = span.clone(); // maybe if i write enough of these, Rust 2024 will make it Copy
206          let reg = match kind {
207              E::Place(kind) => match self.build_place(kind, span.clone())? {
208                  MaybeVar::Variable(place_reg) => self
209                      .push_store(StoreKind::Read(place_reg), span)
210                      .some_if(unvoid),
211                  MaybeVar::Constant(value_reg) => Some(value_reg),
212              },
213              // NOTE: We're implicitly checking and evaluating the place expression first, but typechecking currently has to check the value expression first. Should we change our order here?
214              E::Assign(place, value) => {
215                  let MaybeVar::Variable(place_reg) =
216                      self.build_place(&place.kind, place.span.clone())?
217                  else {
218                      return Err(Error {
219                          kind: ErrorKind::CantAssignToConstant,
220                          span,
221                      });
222                  };
223                  let value_reg = self.build_expr_unvoid(value, span)?;
224                  self.push_write(place_reg, value_reg);
225                  None
226              }
227              E::Int(int, suffix) => {
228                  let int_ty = if let Some(suffix) = suffix {
229                      match suffix.kind.as_ref() {
230                          "usize" => IntKind::Usize,
231                          "u8" => IntKind::U8,
232                          _ => {
233                              return Err(Error {
234                                  kind: ErrorKind::UnknownIntLiteralSuffix,
235                                  span: suffix.span.clone(),
236                              });
237                          }
238                      }
239                  } else {
240                      IntKind::Usize
241                  };
242                  self.push_store(Sk::Int((*int).into(), int_ty), span).some()
243              }
244              E::UnaryOp(op, e) => {
245                  use ast::UnaryOpKind as A;
246                  use UnaryOp as B;
247                  match op.kind {
248                      A::Neg => {
249                          let reg = self.build_expr_unvoid(e, span.clone())?;
250                          self.push_store(Sk::UnaryOp(B::Neg, reg), span)
251                              .some_if(unvoid)
252                      }
253                      A::Ref => {
254                          let maybe_var = match &e.kind {
255                              E::Place(kind) => self.build_place(kind, span)?,
256                              _ => MaybeVar::Constant(self.build_expr_unvoid(e, span)?),
257                          };
258                          match maybe_var {
259                              MaybeVar::Variable(v) => Some(v),
260                              MaybeVar::Constant(c) => {
261                                  let r = self.push_store(
262                                      Sk::StackAlloc(self.tys.get(&c).unwrap().clone()),
263                                      e.span.clone(),
264                                  );
265                                  self.push_write(r, c);
266                                  Some(r)
267                              }
268                          }
269                      }
270                  }
271              }
272              E::BinOp(op, lhs, rhs) => {
273                  use ast::BinOpKind as A;
274                  use BinOp as B;
275                  let op_kind = match op.kind {
276                      A::Add => B::Add,
277                      A::Sub => B::Sub,
278                      A::Mul => B::Mul,
279                      A::CmpLe => B::CmpLe,
280                  };
281                  let lhs_reg = self.build_expr_unvoid(lhs, span.clone())?;
282                  let rhs_reg = self.build_expr_unvoid(rhs, span.clone())?;
283                  self.push_store(Sk::BinOp(op_kind, lhs_reg, rhs_reg), span)
284                      .some_if(unvoid)
285              }
286              E::As(value, ty) => {
287                  let value_reg = self.build_expr_unvoid(value, span.clone())?;
288                  let ir_ty = self.build_ty(ty)?;
289                  let kind = match ir_ty {
290                      Ty::Int(k) => k,
291                      t => {
292                          return Err(Error {
293                              kind: ErrorKind::CantCastToTy(t),
294                              // Should we annotate the span of the type or the entire `as` expression?
295                              span: ty.span.clone(),
296                          });
297                      }
298                  };
299                  self.push_store(Sk::IntCast(value_reg, kind), span).some()
300              }
301              // NOTE: When building Paren and Block, we forget their spans, which means subsequent error diagnostics will only ever point to the inner expression. Is this good or bad? We could change this by creating a Copy of the inner register, assigning the copy the outer span.
302              E::Paren(inner) => self.build_expr(inner.as_ref(), unvoid)?,
303              E::Block(b) => self.build_block(b, unvoid)?,
304              E::If(cond, then_body, else_body) => {
305                  let then_id = self.reserve_block_id();
306                  let end_id = self.reserve_block_id();
307                  let else_id = if else_body.is_some() {
308                      self.reserve_block_id()
309                  } else {
310                      end_id
311                  };
312  
313                  // evaluate condition, jump to either branch
314                  self.enter_scope();
315                  let cond_reg = self.build_expr_unvoid(cond, span.clone())?;
316                  self.switch_to_new_block(
317                      Exit::CondJump(Condition::NonZero(cond_reg), then_id, else_id),
318                      then_id,
319                  );
320  
321                  // evaluate true branch, jump to end
322                  let then_yield = self.build_block(then_body, unvoid)?;
323                  self.exit_scope();
324                  let then_id = self.current_block_id;
325                  self.switch_to_new_block(Exit::Jump(end_id), else_id);
326  
327                  // evaluate false branch, jump to end
328                  let else_yield = else_body
329                      .as_ref()
330                      .map(|e| {
331                          self.enter_scope();
332                          let else_yield = self.build_block(e, unvoid)?;
333                          self.exit_scope();
334                          let else_id = self.current_block_id;
335                          self.switch_to_new_block(Exit::Jump(end_id), end_id);
336                          Ok(else_yield.map(|e| (e, else_id)))
337                      })
338                      .transpose()?
339                      .flatten();
340  
341                  match (then_yield, else_yield) {
342                      (Some(a), Some((b, else_id))) => {
343                          let choices = [(then_id, a), (else_id, b)].into_iter().collect();
344                          self.push_store(StoreKind::Phi(choices), span).some()
345                      }
346                      _ => None,
347                  }
348              }
349              E::While(cond, body) => {
350                  // jump to condition evaluation
351                  let cond_id = self.reserve_block_id();
352                  let body_id = self.reserve_block_id();
353                  let end_id = self.reserve_block_id();
354                  self.switch_to_new_block(Exit::Jump(cond_id), cond_id);
355  
356                  // condition evaluation, jump to either inner body or end of expression
357                  self.enter_scope(); // with code like `while x is Some(y): ...`, `y` should be accessible from the body
358                  let cond_reg = self.build_expr_unvoid(cond, span)?;
359                  self.switch_to_new_block(
360                      Exit::CondJump(Condition::NonZero(cond_reg), body_id, end_id),
361                      body_id,
362                  );
363  
364                  // body evaluation, jump back to condition
365                  self.build_block(body, true)?;
366                  self.exit_scope();
367                  self.switch_to_new_block(Exit::Jump(cond_id), end_id);
368  
369                  // continue evaluation after while loop
370                  None
371              }
372              E::Call(callee, args) => {
373                  let callee = self.build_expr_unvoid(callee, span.clone())?;
374                  let returns: Vec<_> = match self.tys.get(&callee).unwrap() {
375                      Ty::Function(_, returns) => {
376                          assert!(matches!(returns.len(), 0 | 1));
377                          // somewhat silly clone because we need mutable access to `self` for `new_reg`.
378                          returns
379                              .clone()
380                              .into_iter()
381                              .map(|ty| self.new_reg(ty, span.clone()))
382                              .collect()
383                      }
384                      Ty::Int(_) | Ty::Pointer(_) | Ty::Struct(_) => {
385                          // dummy return
386                          vec![self.new_reg(Ty::Int(IntKind::Usize), span.clone())]
387                      }
388                  };
389                  let return_reg = returns.first().copied();
390                  let args = args
391                      .iter()
392                      .map(|arg| self.build_expr_unvoid(arg, span.clone()))
393                      .collect::<Result<_>>()?;
394                  self.push_inst(Inst::Call {
395                      callee,
396                      args,
397                      returns,
398                  });
399                  return_reg
400              }
401          };
402          // println!("unvoid {unvoid}\nexpr {expr:?}\nreg {reg:?}\n");
403          Ok(reg)
404      }
405      // This function returns a MaybeVar because not all syntactic place expressions are semantic place expressions. For example, we can't assign a value to a function. Different code paths we expect a place expression will have to properly handle these cases.
406      fn build_place(&mut self, kind: &ast::PlaceKind, span: Span) -> Result<MaybeVar> {
407          use ast::PlaceKind as Pk;
408          match kind {
409              Pk::Var(name) => self
410                  .get_var(name)
411                  .map(MaybeVar::Variable)
412                  .or_else(|| {
413                      self.function_tys
414                          .contains_key(&name.kind)
415                          .then(|| self.push_store(StoreKind::Function(name.kind.clone()), span))
416                          .map(MaybeVar::Constant)
417                  })
418                  .ok_or_else(|| Error {
419                      kind: ErrorKind::NotFound("variable", name.kind.clone()),
420                      span: name.span.clone(),
421                  }),
422              Pk::Deref(e, _) => self.build_expr_unvoid(e, span).map(MaybeVar::Variable),
423              Pk::Index(indexee, index, index_span) => {
424                  let indexee_reg = self.build_expr_unvoid(indexee, span.clone())?;
425                  let index_reg = self.build_expr_unvoid(index, index_span.clone())?;
426                  let indexed_reg =
427                      self.push_store(StoreKind::PtrOffset(indexee_reg, index_reg), span);
428                  Ok(MaybeVar::Variable(indexed_reg))
429              }
430              Pk::Field(struct_value, field) => {
431                  let struct_reg = self.build_expr_unvoid(struct_value, span.clone())?;
432                  let field_ptr_reg =
433                      self.push_store(StoreKind::FieldOffset(struct_reg, field.kind.clone()), span);
434                  Ok(MaybeVar::Variable(field_ptr_reg))
435              }
436          }
437      }
438  
439      fn enter_scope(&mut self) {
440          self.scopes.push(Map::new());
441      }
442  
443      fn exit_scope(&mut self) {
444          self.scopes.pop();
445      }
446  
447      fn build_ty(&self, ty: &ast::Ty) -> Result<Ty> {
448          self.defined_tys.build_ty(ty)
449      }
450  
451      fn new_var(&mut self, name: Name, ty: Ty) -> Register {
452          let reg = self.push_store(StoreKind::StackAlloc(ty), name.span);
453          self.scopes.last_mut().unwrap().insert(name.kind, reg);
454          reg
455      }
456  
457      fn get_var(&self, name: &Name) -> Option<Register> {
458          self.scopes
459              .iter()
460              .rev()
461              .find_map(|scope| scope.get(name.kind.as_ref()).copied())
462      }
463  
464      fn push_write(&mut self, dst: Register, src: Register) {
465          self.push_inst(Inst::Write(dst, src));
466      }
467  
468      fn push_store(&mut self, sk: StoreKind, span: Span) -> Register {
469          let ty = self.guess_ty(&sk);
470          let reg = self.new_reg(ty, span);
471          self.push_inst(Inst::Store(reg, sk));
472          reg
473      }
474      pub fn guess_ty(&self, sk: &StoreKind) -> Ty {
475          const DUMMY_TY: Ty = Ty::Int(IntKind::U8);
476          use StoreKind as Sk;
477          let t = |r| self.tys.get(r).unwrap().clone();
478          match sk {
479              &Sk::Int(_, kind) | &Sk::IntCast(_, kind) => Ty::Int(kind),
480              Sk::Phi(regs) => t(regs.first_key_value().expect("empty phi").1),
481              Sk::BinOp(_, lhs, _rhs) => t(lhs),
482              Sk::PtrOffset(ptr, _) => t(ptr),
483              Sk::FieldOffset(r, field) => {
484                  let Ty::Pointer(value) = t(r) else {
485                      println!("foo");
486                      return DUMMY_TY;
487                  };
488                  let Ty::Struct(fields) = value.as_ref() else {
489                      println!("foofoo");
490                      return DUMMY_TY;
491                  };
492                  fields
493                      .iter()
494                      .find_map(|(name, ty)| {
495                          (name == field).then(|| Ty::Pointer(Box::new(ty.clone())))
496                      })
497                      .unwrap_or(DUMMY_TY)
498              }
499              Sk::StackAlloc(ty) => Ty::Pointer(Box::new(ty.clone())),
500              Sk::Copy(r) | Sk::UnaryOp(UnaryOp::Neg, r) => t(r),
501              Sk::Read(r) => match self.tys.get(r).unwrap() {
502                  Ty::Pointer(inner) => inner.as_ref().clone(),
503                  Ty::Int(_) | Ty::Function(..) | Ty::Struct(_) => DUMMY_TY,
504              },
505              Sk::Function(name) => {
506                  let (args, returns) = self
507                      .function_tys
508                      .get(name)
509                      .expect("constructed function instruction to unknown function")
510                      .clone();
511                  Ty::Function(args, returns.into_iter().collect())
512              }
513          }
514      }
515  
516      fn push_inst(&mut self, inst: Inst) {
517          self.current_block.push(inst);
518      }
519  
520      fn new_reg(&mut self, ty: Ty, span: Span) -> Register {
521          let reg = Register(self.next_reg_id);
522          self.next_reg_id = self
523              .next_reg_id
524              .checked_add(1)
525              .expect("register allocation overflow");
526          self.tys.insert(reg, ty);
527          self.spans.insert(reg, span);
528          reg
529      }
530  
531      pub fn switch_to_new_block(&mut self, exit: Exit, id: BlockId) {
532          let insts = std::mem::take(&mut self.current_block);
533          let block = Block::new(insts, exit);
534          self.blocks.insert(self.current_block_id, block);
535          self.current_block_id = id;
536      }
537  
538      fn reserve_block_id(&mut self) -> BlockId {
539          let id = BlockId(self.next_block_id);
540          self.next_block_id = self
541              .next_block_id
542              .checked_add(1)
543              .expect("block allocation overflow");
544          id
545      }
546  }
547  
548  pub fn build(program: &ast::Program) -> Result<Program> {
549      use ast::DeclKind as D;
550      let mut function_tys = Map::new();
551      let mut defined_tys = DefinedTys {
552          tys: Map::from([
553              ("u8".into(), (Ty::Int(IntKind::U8), None)),
554              ("usize".into(), (Ty::Int(IntKind::Usize), None)),
555          ]),
556      };
557      // collect struct types
558      // currently we only accept forward declarations of structs (bad). fixing this in the general case requires a pervasive rewrite of the IR type system.
559      for ast::Decl { kind, span: _ } in &program.decls {
560          match kind {
561              D::Function(_) | D::ExternFunction(_) => {}
562              D::Struct(ast::Struct { name, fields }) => {
563                  let mut ir_fields: Vec<(Str, Ty)> = Vec::with_capacity(fields.len());
564                  let mut names = Map::new();
565                  for (field_name, field_ty) in fields {
566                      let ty = defined_tys.build_ty(field_ty)?;
567                      ir_fields.push((field_name.kind.clone(), ty));
568                      if let Some(previous_span) =
569                          names.insert(field_name.kind.clone(), field_name.span.clone())
570                      {
571                          return Err(Error {
572                              kind: ErrorKind::NameConflict("field", Some(previous_span)),
573                              span: field_name.span.clone(),
574                          });
575                      }
576                  }
577                  let ty = Ty::Struct(ir_fields);
578                  let name_span = name.span.clone();
579                  // check if another ty in the same scope has the same name
580                  let maybe_clash = defined_tys
581                      .tys
582                      .insert(name.kind.clone(), (ty, name_span.some()));
583                  if let Some((_, previous_span)) = maybe_clash {
584                      return Err(Error {
585                          kind: ErrorKind::NameConflict("type", previous_span),
586                          span: name.span.clone(),
587                      });
588                  }
589              }
590          }
591      }
592      for ast::Decl { kind, span: _ } in &program.decls {
593          match kind {
594              D::Function(ast::Function {
595                  name,
596                  parameters,
597                  returns,
598                  body: _,
599              }) => {
600                  function_tys.insert(
601                      name.kind.clone(),
602                      (
603                          parameters
604                              .iter()
605                              .map(|arg| defined_tys.build_ty(&arg.1))
606                              .collect::<Result<_>>()?,
607                          returns
608                              .as_ref()
609                              .map(|ret| defined_tys.build_ty(ret))
610                              .transpose()?,
611                      ),
612                  );
613              }
614              D::ExternFunction(ast::ExternFunction {
615                  name,
616                  parameters,
617                  returns,
618              }) => {
619                  function_tys.insert(
620                      name.kind.clone(),
621                      (
622                          parameters
623                              .iter()
624                              .map(|arg| defined_tys.build_ty(arg))
625                              .collect::<Result<_>>()?,
626                          returns
627                              .as_ref()
628                              .map(|ret| defined_tys.build_ty(ret))
629                              .transpose()?,
630                      ),
631                  );
632              }
633              D::Struct(_) => {}
634          }
635      }
636      let mut functions = Map::new();
637      for decl in &program.decls {
638          match &decl.kind {
639              D::Function(fn_decl) => {
640                  let builder = IrBuilder::new(&function_tys, &defined_tys);
641                  let function = builder.build_function(fn_decl)?;
642                  functions.insert(fn_decl.name.kind.clone(), function);
643              }
644              D::ExternFunction(_) | D::Struct(_) => {}
645          }
646      }
647      let function_tys = function_tys
648          .into_iter()
649          .map(|(name, (params, returns))| (name, (params, returns.into_iter().collect())))
650          .collect();
651      Ok(Program {
652          functions,
653          function_tys,
654      })
655  }