/ src / typechecker.rs
typechecker.rs
  1  use crate::compiler_types::{Map, Str};
  2  use crate::ir::*;
  3  
  4  #[derive(Clone, Debug)]
  5  pub enum ErrorKind {
  6      /// We performed an integer operation on a non-integer.
  7      NotInt(Register),
  8      /// We dereferenced a register of a non-pointer type.
  9      NotPointer(Register),
 10      /// We called a register of a non-function type.
 11      NotFunction(Register),
 12      /// We accessed a field of a non-struct type.
 13      NotStruct(Register),
 14      /// We accessed a non-existent field of a struct.
 15      NoFieldNamed(Register, Str),
 16      /// The register has one type but we expected another.
 17      Expected(Register, Ty),
 18  }
 19  
 20  type Error = (Str, ErrorKind);
 21  type Result<T = ()> = std::result::Result<T, Error>;
 22  
 23  type Tys = Map<Register, Ty>;
 24  // NOTE: This can be changed to take 2 lifetime parameters.
 25  type FunctionTys<'a> = &'a Map<Str, (Vec<Ty>, Vec<Ty>)>;
 26  
 27  #[derive(Debug)]
 28  struct TypeChecker<'a> {
 29      function_tys: FunctionTys<'a>,
 30      return_tys: &'a [Ty],
 31      tys: &'a Tys,
 32      name: &'a str,
 33  }
 34  
 35  impl<'a> TypeChecker<'a> {
 36      /*
 37      fn expect(&self, r: Register, ty: &'a Ty) -> Result {
 38          Self::expect_ty(self.tys.get_mut(&r).unwrap(), ty)
 39              .ok_or_else(|| self.err(ErrorKind::Expected(r, ty.clone())))
 40      }
 41      fn expect_ty(dst: &mut Ty, ty: &Ty) -> Option<()> {
 42          if dst == ty {
 43              Some(())
 44          } else {
 45              None
 46          }
 47      }
 48      */
 49      fn t(&self, r: Register) -> &'a Ty {
 50          self.tys.get(&r).expect("register with no type")
 51      }
 52      fn err(&self, r: Register, ty: Ty) -> Result {
 53          Err((self.name.into(), ErrorKind::Expected(r, ty)))
 54      }
 55      fn expect(&self, r: Register, ty: &'a Ty) -> Result {
 56          if self.t(r) == ty {
 57              Ok(())
 58          } else {
 59              self.err(r, ty.clone())
 60          }
 61      }
 62      fn int(&self, r: Register) -> Result<IntKind> {
 63          match self.t(r) {
 64              &Ty::Int(k) => Ok(k),
 65              Ty::Pointer(_) | Ty::Function(..) | Ty::Struct(_) => {
 66                  Err((self.name.into(), ErrorKind::NotInt(r)))
 67              }
 68          }
 69      }
 70      fn pointer(&self, r: Register) -> Result<&'a Ty> {
 71          match self.t(r) {
 72              Ty::Pointer(inner) => Ok(inner.as_ref()),
 73              Ty::Int(_) | Ty::Function(..) | Ty::Struct(_) => {
 74                  Err((self.name.into(), ErrorKind::NotPointer(r)))
 75              }
 76          }
 77      }
 78      fn infer_storekind(&self, sk: &StoreKind) -> Result<Ty> {
 79          use StoreKind as Sk;
 80          let ty = match *sk {
 81              Sk::Int(_, kind) | Sk::IntCast(_, kind) => Ty::Int(kind),
 82              Sk::Copy(r) => self.t(r).clone(),
 83              Sk::BinOp(op, lhs, rhs) => {
 84                  match op {
 85                      BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::CmpLe => (),
 86                  }
 87                  let lhs_int = self.int(lhs)?;
 88                  let rhs_int = self.int(rhs)?;
 89                  if lhs_int != rhs_int {
 90                      self.err(rhs, Ty::Int(lhs_int))?;
 91                  }
 92                  Ty::Int(lhs_int)
 93              }
 94              Sk::PtrOffset(lhs, rhs) => {
 95                  self.pointer(lhs)?;
 96                  self.expect(rhs, &Ty::Int(IntKind::Usize))?;
 97                  self.t(lhs).clone()
 98              }
 99              Sk::FieldOffset(r, ref field) => {
100                  let Ty::Struct(fields) = self.pointer(r)? else {
101                      return Err((self.name.into(), ErrorKind::NotStruct(r)));
102                  };
103                  fields
104                      .iter()
105                      .find_map(|(name, ty)| {
106                          (name == field).then(|| Ty::Pointer(Box::new(ty.clone())))
107                      })
108                      .ok_or_else(|| (self.name.into(), ErrorKind::NoFieldNamed(r, field.clone())))?
109              }
110              Sk::UnaryOp(UnaryOp::Neg, rhs) => {
111                  let kind = self.int(rhs)?;
112                  Ty::Int(kind)
113              }
114              Sk::StackAlloc(ref inner) => Ty::Pointer(Box::new(inner.clone())),
115              Sk::Read(src) => self.pointer(src)?.clone(),
116              Sk::Phi(ref rs) => {
117                  let mut rs = rs.values().copied();
118                  let ty = self.t(rs.next().expect("empty phi"));
119                  for r in rs {
120                      self.expect(r, ty)?;
121                  }
122                  ty.clone()
123              }
124              Sk::Function(ref name) => {
125                  let (params, returns) = self.function_tys.get(name.as_ref()).expect("function get");
126                  Ty::Function(params.clone(), returns.clone())
127              }
128          };
129          Ok(ty)
130      }
131      fn visit_inst(&self, inst: &Inst) -> Result {
132          match inst {
133              Inst::Store(r, sk) => {
134                  let expected = self.t(*r);
135                  let got = self.infer_storekind(sk)?;
136                  if *expected == got {
137                      Ok(())
138                  } else {
139                      Err((self.name.into(), ErrorKind::Expected(*r, got)))
140                  }
141              }
142              &Inst::Write(dst, src) => {
143                  let inner = self.pointer(dst)?;
144                  self.expect(src, inner)
145              }
146              Inst::Nop => Ok(()),
147              Inst::Call {
148                  callee,
149                  args,
150                  returns,
151              } => {
152                  match self.t(*callee) {
153                      Ty::Function(..) => {}
154                      Ty::Int(_) | Ty::Pointer(_) | Ty::Struct(_) => {
155                          return Err((self.name.into(), ErrorKind::NotFunction(*callee)))
156                      }
157                  }
158                  let arg_tys = args.iter().map(|&r| self.t(r).clone()).collect();
159                  let return_tys = returns.iter().map(|&r| self.t(r).clone()).collect();
160                  let fn_ty = Ty::Function(arg_tys, return_tys);
161                  self.expect(*callee, &fn_ty)
162              }
163          }
164      }
165      // fn visit_jump_loc(&self, loc: &JumpLocation) -> Result {
166      //     match loc {
167      //         JumpLocation::Block(_) => Ok(()),
168      //         JumpLocation::Return(regs) => {
169      //             if regs.len() != self.return_tys.len() {
170      //                 // The IR lowering phase will always produce functions with 0 or 1 returns, and it checks that all paths return the appropriate number of values. This code path will only run when typechecking transformed IR, namely after lowering IR types to machine-friendly types.
171      //                 todo!("proper error diagnostic for wrong number of returns");
172      //             }
173      //             regs.iter()
174      //                 .zip(self.return_tys)
175      //                 .try_for_each(|(&r, ty)| self.expect(r, ty))
176      //         }
177      //     }
178      // }
179      fn visit_block(&self, block: &Block) -> Result {
180          for inst in &block.insts {
181              self.visit_inst(inst)?;
182          }
183          match &block.exit {
184              Exit::Jump(_) => Ok(()),
185              Exit::CondJump(cond, _, _) => match cond {
186                  &Condition::NonZero(r) => self.int(r).map(|_| ()),
187              },
188              Exit::Return(regs) => {
189                  if regs.len() != self.return_tys.len() {
190                      // The IR lowering phase will always produce functions with 0 or 1 returns, and it checks that all paths return the appropriate number of values. This code path will only run when typechecking transformed IR, namely after lowering IR types to machine-friendly types.
191                      todo!("proper error diagnostic for wrong number of returns");
192                  }
193                  regs.iter()
194                      .zip(self.return_tys)
195                      .try_for_each(|(&r, ty)| self.expect(r, ty))
196              }
197          }
198      }
199      fn visit_function(f: &'a Function, name: &'a str, function_tys: FunctionTys<'a>) -> Result {
200          let return_tys = &function_tys.get(name).unwrap().1;
201          let this = Self {
202              function_tys,
203              return_tys,
204              tys: &f.tys,
205              name,
206          };
207          for i in f.cfg.dom_iter() {
208              let block = f.blocks.get(&i).unwrap();
209              this.visit_block(block)?;
210          }
211          Ok(())
212      }
213  }
214  
215  pub fn typecheck(program: &Program) -> Result {
216      for (fn_name, f) in &program.functions {
217          /*
218          println!("typechecking {fn_name}");
219          for (r, ty) in &f.tys {
220              println!("  {r} {ty}");
221          }
222          */
223          TypeChecker::visit_function(f, fn_name, &program.function_tys)?;
224      }
225      Ok(())
226  }