typechecker.rs
1 use crate::compiler_types::{Map, Str}; 2 use crate::ir::*; 3 4 #[derive(Clone, Debug)] 5 pub enum ErrorKind { 6 /// We performed an integer operation on a non-integer. 7 NotInt(Register), 8 /// We dereferenced a register of a non-pointer type. 9 NotPointer(Register), 10 /// We called a register of a non-function type. 11 NotFunction(Register), 12 /// We accessed a field of a non-struct type. 13 NotStruct(Register), 14 /// We accessed a non-existent field of a struct. 15 NoFieldNamed(Register, Str), 16 /// The register has one type but we expected another. 17 Expected(Register, Ty), 18 } 19 20 type Error = (Str, ErrorKind); 21 type Result<T = ()> = std::result::Result<T, Error>; 22 23 type Tys = Map<Register, Ty>; 24 // NOTE: This can be changed to take 2 lifetime parameters. 25 type FunctionTys<'a> = &'a Map<Str, (Vec<Ty>, Vec<Ty>)>; 26 27 #[derive(Debug)] 28 struct TypeChecker<'a> { 29 function_tys: FunctionTys<'a>, 30 return_tys: &'a [Ty], 31 tys: &'a Tys, 32 name: &'a str, 33 } 34 35 impl<'a> TypeChecker<'a> { 36 /* 37 fn expect(&self, r: Register, ty: &'a Ty) -> Result { 38 Self::expect_ty(self.tys.get_mut(&r).unwrap(), ty) 39 .ok_or_else(|| self.err(ErrorKind::Expected(r, ty.clone()))) 40 } 41 fn expect_ty(dst: &mut Ty, ty: &Ty) -> Option<()> { 42 if dst == ty { 43 Some(()) 44 } else { 45 None 46 } 47 } 48 */ 49 fn t(&self, r: Register) -> &'a Ty { 50 self.tys.get(&r).expect("register with no type") 51 } 52 fn err(&self, r: Register, ty: Ty) -> Result { 53 Err((self.name.into(), ErrorKind::Expected(r, ty))) 54 } 55 fn expect(&self, r: Register, ty: &'a Ty) -> Result { 56 if self.t(r) == ty { 57 Ok(()) 58 } else { 59 self.err(r, ty.clone()) 60 } 61 } 62 fn int(&self, r: Register) -> Result<IntKind> { 63 match self.t(r) { 64 &Ty::Int(k) => Ok(k), 65 Ty::Pointer(_) | Ty::Function(..) | Ty::Struct(_) => { 66 Err((self.name.into(), ErrorKind::NotInt(r))) 67 } 68 } 69 } 70 fn pointer(&self, r: Register) -> Result<&'a Ty> { 71 match self.t(r) { 72 Ty::Pointer(inner) => Ok(inner.as_ref()), 73 Ty::Int(_) | Ty::Function(..) | Ty::Struct(_) => { 74 Err((self.name.into(), ErrorKind::NotPointer(r))) 75 } 76 } 77 } 78 fn infer_storekind(&self, sk: &StoreKind) -> Result<Ty> { 79 use StoreKind as Sk; 80 let ty = match *sk { 81 Sk::Int(_, kind) | Sk::IntCast(_, kind) => Ty::Int(kind), 82 Sk::Copy(r) => self.t(r).clone(), 83 Sk::BinOp(op, lhs, rhs) => { 84 match op { 85 BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::CmpLe => (), 86 } 87 let lhs_int = self.int(lhs)?; 88 let rhs_int = self.int(rhs)?; 89 if lhs_int != rhs_int { 90 self.err(rhs, Ty::Int(lhs_int))?; 91 } 92 Ty::Int(lhs_int) 93 } 94 Sk::PtrOffset(lhs, rhs) => { 95 self.pointer(lhs)?; 96 self.expect(rhs, &Ty::Int(IntKind::Usize))?; 97 self.t(lhs).clone() 98 } 99 Sk::FieldOffset(r, ref field) => { 100 let Ty::Struct(fields) = self.pointer(r)? else { 101 return Err((self.name.into(), ErrorKind::NotStruct(r))); 102 }; 103 fields 104 .iter() 105 .find_map(|(name, ty)| { 106 (name == field).then(|| Ty::Pointer(Box::new(ty.clone()))) 107 }) 108 .ok_or_else(|| (self.name.into(), ErrorKind::NoFieldNamed(r, field.clone())))? 109 } 110 Sk::UnaryOp(UnaryOp::Neg, rhs) => { 111 let kind = self.int(rhs)?; 112 Ty::Int(kind) 113 } 114 Sk::StackAlloc(ref inner) => Ty::Pointer(Box::new(inner.clone())), 115 Sk::Read(src) => self.pointer(src)?.clone(), 116 Sk::Phi(ref rs) => { 117 let mut rs = rs.values().copied(); 118 let ty = self.t(rs.next().expect("empty phi")); 119 for r in rs { 120 self.expect(r, ty)?; 121 } 122 ty.clone() 123 } 124 Sk::Function(ref name) => { 125 let (params, returns) = self.function_tys.get(name.as_ref()).expect("function get"); 126 Ty::Function(params.clone(), returns.clone()) 127 } 128 }; 129 Ok(ty) 130 } 131 fn visit_inst(&self, inst: &Inst) -> Result { 132 match inst { 133 Inst::Store(r, sk) => { 134 let expected = self.t(*r); 135 let got = self.infer_storekind(sk)?; 136 if *expected == got { 137 Ok(()) 138 } else { 139 Err((self.name.into(), ErrorKind::Expected(*r, got))) 140 } 141 } 142 &Inst::Write(dst, src) => { 143 let inner = self.pointer(dst)?; 144 self.expect(src, inner) 145 } 146 Inst::Nop => Ok(()), 147 Inst::Call { 148 callee, 149 args, 150 returns, 151 } => { 152 match self.t(*callee) { 153 Ty::Function(..) => {} 154 Ty::Int(_) | Ty::Pointer(_) | Ty::Struct(_) => { 155 return Err((self.name.into(), ErrorKind::NotFunction(*callee))) 156 } 157 } 158 let arg_tys = args.iter().map(|&r| self.t(r).clone()).collect(); 159 let return_tys = returns.iter().map(|&r| self.t(r).clone()).collect(); 160 let fn_ty = Ty::Function(arg_tys, return_tys); 161 self.expect(*callee, &fn_ty) 162 } 163 } 164 } 165 // fn visit_jump_loc(&self, loc: &JumpLocation) -> Result { 166 // match loc { 167 // JumpLocation::Block(_) => Ok(()), 168 // JumpLocation::Return(regs) => { 169 // if regs.len() != self.return_tys.len() { 170 // // The IR lowering phase will always produce functions with 0 or 1 returns, and it checks that all paths return the appropriate number of values. This code path will only run when typechecking transformed IR, namely after lowering IR types to machine-friendly types. 171 // todo!("proper error diagnostic for wrong number of returns"); 172 // } 173 // regs.iter() 174 // .zip(self.return_tys) 175 // .try_for_each(|(&r, ty)| self.expect(r, ty)) 176 // } 177 // } 178 // } 179 fn visit_block(&self, block: &Block) -> Result { 180 for inst in &block.insts { 181 self.visit_inst(inst)?; 182 } 183 match &block.exit { 184 Exit::Jump(_) => Ok(()), 185 Exit::CondJump(cond, _, _) => match cond { 186 &Condition::NonZero(r) => self.int(r).map(|_| ()), 187 }, 188 Exit::Return(regs) => { 189 if regs.len() != self.return_tys.len() { 190 // The IR lowering phase will always produce functions with 0 or 1 returns, and it checks that all paths return the appropriate number of values. This code path will only run when typechecking transformed IR, namely after lowering IR types to machine-friendly types. 191 todo!("proper error diagnostic for wrong number of returns"); 192 } 193 regs.iter() 194 .zip(self.return_tys) 195 .try_for_each(|(&r, ty)| self.expect(r, ty)) 196 } 197 } 198 } 199 fn visit_function(f: &'a Function, name: &'a str, function_tys: FunctionTys<'a>) -> Result { 200 let return_tys = &function_tys.get(name).unwrap().1; 201 let this = Self { 202 function_tys, 203 return_tys, 204 tys: &f.tys, 205 name, 206 }; 207 for i in f.cfg.dom_iter() { 208 let block = f.blocks.get(&i).unwrap(); 209 this.visit_block(block)?; 210 } 211 Ok(()) 212 } 213 } 214 215 pub fn typecheck(program: &Program) -> Result { 216 for (fn_name, f) in &program.functions { 217 /* 218 println!("typechecking {fn_name}"); 219 for (r, ty) in &f.tys { 220 println!(" {r} {ty}"); 221 } 222 */ 223 TypeChecker::visit_function(f, fn_name, &program.function_tys)?; 224 } 225 Ok(()) 226 }