/ src / eval.rs
eval.rs
  1  use std::{
  2      borrow::Cow,
  3      collections::HashMap,
  4      fmt,
  5      io::{self, stdin},
  6      rc::Rc,
  7  };
  8  
  9  use crate::ast::Ast;
 10  
 11  // #[derive(Debug)]
 12  // pub struct InternalCtx {
 13  //     pub value: *const u8,
 14  //     pub clone: fn(*const u8) -> *const u8,
 15  //     pub drop: fn(*const u8) -> (),
 16  // }
 17  
 18  // impl Clone for InternalCtx {
 19  //     fn clone(&self) -> Self {
 20  //         Self {
 21  //             value: (self.clone)(self.value),
 22  //             clone: self.clone,
 23  //             drop: self.drop,
 24  //         }
 25  //     }
 26  // }
 27  
 28  // impl Drop for InternalCtx {
 29  //     fn drop(&mut self) {
 30  //         (self.drop)(self.value);
 31  //     }
 32  // }
 33  
 34  // impl Default for InternalCtx {
 35  //     fn default() -> Self {
 36  //         Self {
 37  //             value: null_mut(),
 38  //             clone: |_| null_mut(),
 39  //             drop: |_| {},
 40  //         }
 41  //     }
 42  // }
 43  
 44  pub type InternalHandler = Rc<dyn Fn(&[Rc<Val>]) -> Result<Rc<Val>, EvalErr>>;
 45  
 46  #[derive(Clone)]
 47  pub enum Val {
 48      Unit,
 49      Str(String),
 50      Int(i64),
 51      IO(IOVal),
 52      Lambda(Vec<Rc<Val>>, Rc<Term>),
 53      Internal(&'static str, InternalHandler),
 54  }
 55  
 56  impl fmt::Debug for Val {
 57      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 58          match self {
 59              Self::Unit => write!(f, "Unit"),
 60              Self::Str(arg0) => f.debug_tuple("Str").field(arg0).finish(),
 61              Self::Int(arg0) => f.debug_tuple("Int").field(arg0).finish(),
 62              Self::IO(arg0) => f.debug_tuple("IO").field(arg0).finish(),
 63              Self::Lambda(arg0, arg1) => f.debug_tuple("Lambda").field(arg0).field(arg1).finish(),
 64              Self::Internal(name, _) => f.debug_tuple("Internal").field(name).finish(),
 65          }
 66      }
 67  }
 68  
 69  #[derive(Debug, Clone)]
 70  pub enum IOVal {
 71      Print(Rc<Val>),
 72      PrintLn(Rc<Val>),
 73      Input,
 74      Bind(Rc<Val>, Rc<Val>),
 75  }
 76  
 77  pub trait IOImpl {
 78      fn print(&self, s: &str) -> Result<(), EvalErr>;
 79      fn println(&self, s: &str) -> Result<(), EvalErr>;
 80      fn input(&self) -> Result<String, EvalErr>;
 81  }
 82  
 83  pub struct DefaultIOImpl;
 84  
 85  impl IOImpl for DefaultIOImpl {
 86      fn print(&self, s: &str) -> Result<(), EvalErr> {
 87          print!("{s}");
 88          Ok(())
 89      }
 90  
 91      fn println(&self, s: &str) -> Result<(), EvalErr> {
 92          println!("{s}");
 93          Ok(())
 94      }
 95  
 96      fn input(&self) -> Result<String, EvalErr> {
 97          let mut res = String::new();
 98          stdin().read_line(&mut res)?;
 99          while res.ends_with("\n") {
100              res.pop();
101          }
102          Ok(res)
103      }
104  }
105  
106  impl IOVal {
107      pub fn exec(&self, io_impl: &impl IOImpl) -> Result<Rc<Val>, EvalErr> {
108          match self {
109              IOVal::Print(val) => match &**val {
110                  Val::Str(s) => {
111                      io_impl.print(s)?;
112                      Ok(Rc::new(Val::Unit))
113                  }
114                  _ => Err(EvalErr::WrongType(
115                      "print must be called with string or int",
116                  )),
117              },
118              IOVal::PrintLn(val) => match &**val {
119                  Val::Str(s) => {
120                      io_impl.println(s)?;
121                      Ok(Rc::new(Val::Unit))
122                  }
123                  _ => Err(EvalErr::WrongType(
124                      "println must be called with string or int",
125                  )),
126              },
127              IOVal::Input => Ok(Rc::new(Val::Str(io_impl.input()?))),
128              IOVal::Bind(lhs, rhs) => {
129                  let lhs = match &**lhs {
130                      Val::IO(ioval) => ioval.exec(io_impl),
131                      _ => Err(EvalErr::WrongType("first argument of bind must be IO")),
132                  }?;
133                  let rhs = match &**rhs {
134                      Val::Lambda(fctx, term) => {
135                          let mut fctx = fctx.clone();
136                          fctx.push(lhs);
137                          let rhs = term.eval_with_ctx(&mut fctx);
138                          fctx.pop();
139                          rhs
140                      }
141                      _ => Err(EvalErr::WrongType(
142                          "second argument of bind must be a function",
143                      )),
144                  }?;
145                  match &*rhs {
146                      Val::IO(ioval) => ioval.exec(io_impl),
147                      _ => Err(EvalErr::WrongType(
148                          "second argument of bind must be return IO",
149                      )),
150                  }
151              }
152          }
153      }
154  }
155  
156  pub struct Globals {
157      names: Vec<String>,
158      vals: Vec<Rc<Val>>,
159  }
160  
161  impl Globals {
162      pub fn empty() -> Self {
163          Self {
164              names: Vec::new(),
165              vals: Vec::new(),
166          }
167      }
168  
169      pub fn new_with_io() -> Self {
170          let mut globals = Self::empty();
171  
172          globals.register(
173              "println",
174              Rc::new(Val::Internal(
175                  "println",
176                  Rc::new(|ctx| {
177                      let arg = ctx.last().ok_or(EvalErr::InvalidState)?;
178                      Ok(Rc::new(Val::IO(IOVal::PrintLn(arg.clone()))))
179                  }),
180              )),
181          );
182          globals.register(
183              "print",
184              Rc::new(Val::Internal(
185                  "print",
186                  Rc::new(|ctx| {
187                      let arg = ctx.last().ok_or(EvalErr::InvalidState)?;
188                      Ok(Rc::new(Val::IO(IOVal::Print(arg.clone()))))
189                  }),
190              )),
191          );
192          globals.register("input", Rc::new(Val::IO(IOVal::Input)));
193  
194          globals.register(
195              "bind",
196              Rc::new(Val::Internal(
197                  "bind",
198                  Rc::new(|lctx| {
199                      let lhs = lctx.last().ok_or(EvalErr::InvalidState)?.clone();
200                      Ok(Rc::new(Val::Internal(
201                          "pbind",
202                          Rc::new(move |rctx| {
203                              let rhs = rctx.last().ok_or(EvalErr::InvalidState)?;
204                              Ok(Rc::new(Val::IO(IOVal::Bind(lhs.clone(), rhs.clone()))))
205                          }),
206                      )))
207                  }),
208              )),
209          );
210  
211          globals
212      }
213  
214      pub fn register(&mut self, name: impl Into<String>, val: Rc<Val>) {
215          self.names.push(name.into());
216          self.vals.push(val);
217      }
218  }
219  
220  #[derive(Debug, Clone)]
221  pub enum Term {
222      Value(Rc<Val>),
223      Variable(u16),
224      Lambda(Rc<Term>),
225      Application(Rc<Term>, Rc<Term>),
226  }
227  
228  #[derive(Debug)]
229  pub enum EvalErr {
230      UnknownVariable(String),
231      WrongType(&'static str),
232      InvalidState,
233      IOError(io::Error),
234  }
235  
236  impl From<io::Error> for EvalErr {
237      fn from(err: io::Error) -> Self {
238          Self::IOError(err)
239      }
240  }
241  
242  impl Term {
243      fn marshal_ast(ast: Ast, vars: &mut Vec<String>) -> Result<Self, EvalErr> {
244          match ast {
245              Ast::UnitLit => Ok(Self::Value(Rc::new(Val::Unit))),
246              Ast::StrLit(s) => Ok(Self::Value(Rc::new(Val::Str(s)))),
247              Ast::IntLit(i) => Ok(Self::Value(Rc::new(Val::Int(i)))),
248              Ast::Ident(x) => {
249                  for (i, y) in vars.iter().rev().enumerate() {
250                      if &x == y {
251                          return Ok(Self::Variable(i as u16));
252                      }
253                  }
254                  Err(EvalErr::UnknownVariable(x))
255              }
256              Ast::Lambda(var, expr) => {
257                  vars.push(var);
258                  let res = Self::marshal_ast(*expr, vars)?;
259                  vars.pop();
260                  Ok(Self::Lambda(Rc::new(res)))
261              }
262              Ast::Application(fun, arg) => {
263                  let fun = Self::marshal_ast(*fun, vars)?;
264                  let arg = Self::marshal_ast(*arg, vars)?;
265                  Ok(Self::Application(Rc::new(fun), Rc::new(arg)))
266              }
267          }
268      }
269  
270      pub fn from_ast_with_globals(expr: Ast, globals: &Globals) -> Result<Self, EvalErr> {
271          Self::marshal_ast(expr, &mut globals.names.clone())
272      }
273  
274      // pub fn eval(self) -> Result<Self, EvalErr> {
275      //     match self {
276      //         Expr::Value(val) => Ok(Expr::Value(val)),
277      //         Expr::IO(ioterm) => Ok(Expr::IO(match ioterm {
278      //             IOVal::Print(val) => IOVal::Print(val),
279      //             IOVal::Input => IOVal::Input,
280      //             IOVal::Bind(lhs, rhs) => IOVal::Bind(Box::new(lhs.eval()?), Box::new(rhs.eval()?)),
281      //         })),
282      //         Expr::Variable(i) => Ok(Expr::Variable(i)),
283      //         Expr::Lambda(terms, term) => Ok(Expr::Lambda(terms, Box::new(term.eval()?))),
284      //         Expr::Application(fun, arg) => Ok(Expr::Application(
285      //             Box::new(fun.eval()?),
286      //             Box::new(arg.eval()?),
287      //         )),
288      //     }
289      // }
290  
291      pub fn eval_with_ctx(&self, ctx: &mut Vec<Rc<Val>>) -> Result<Rc<Val>, EvalErr> {
292          match self {
293              Term::Value(val) => Ok(val.clone()),
294              Term::Variable(i) => ctx
295                  .get(ctx.len() - 1 - (*i as usize))
296                  .cloned()
297                  .ok_or(EvalErr::InvalidState),
298              Term::Lambda(body) => Ok(Rc::new(Val::Lambda(ctx.to_vec(), body.clone()))),
299              Term::Application(fun, arg) => {
300                  let fun = fun.eval_with_ctx(ctx)?;
301                  let arg = arg.eval_with_ctx(ctx)?;
302                  match Rc::unwrap_or_clone(fun) {
303                      Val::Lambda(mut fctx, body) => {
304                          fctx.push(arg);
305                          let res = body.eval_with_ctx(&mut fctx);
306                          fctx.pop();
307                          res
308                      }
309                      Val::Internal(_, f) => {
310                          ctx.push(arg);
311                          let res = f(ctx);
312                          ctx.pop();
313                          res
314                      }
315                      _ => Err(EvalErr::WrongType("trying to call a non-function")),
316                  }
317              }
318          }
319      }
320  
321      pub fn eval(&self, globals: &Globals) -> Result<Rc<Val>, EvalErr> {
322          self.eval_with_ctx(&mut globals.vals.clone())
323      }
324  }
325  
326  // impl TryFrom<Expr> for Term {
327  //     type Error = EvalErr;
328  
329  //     fn try_from(expr: Expr) -> Result<Self, Self::Error> {
330  //         Self::marshal_expr(expr, &mut Vec::new())
331  //     }
332  // }
333  
334  type InternalCtx<'e> = Option<Box<Value<'e>>>;
335  
336  type InternalFn<'e> = fn(InternalCtx<'e>, Value<'e>) -> Result<Value<'e>, EvalError<'e>>;
337  
338  #[derive(Debug, Clone)]
339  pub enum Value<'e> {
340      Unit,
341      Str(Cow<'e, str>),
342      Int(i64),
343      IO(Box<Value<'e>>),
344      Function(Rc<Context<'e>>, String, &'e Ast),
345      Internal(InternalCtx<'e>, InternalFn<'e>),
346  }
347  
348  // impl<'e> Value<'e> {
349  //     pub fn internal_of<T: Clone>(
350  //         ctx: Rc<T>,
351  //         f: fn(&T, Value<'e>) -> Result<Value<'e>, EvalError<'e>>,
352  //     ) -> Self {
353  //         let ctx = InternalCtx {
354  //             value: Rc::into_raw(ctx) as *const u8,
355  //             clone: |ctx| unsafe {
356  //                 Rc::increment_strong_count(ctx as *const T);
357  //                 ctx
358  //             },
359  //             drop: |ctx| unsafe {
360  //                 Rc::decrement_strong_count(ctx);
361  //             },
362  //         };
363  //         Self::Internal(ctx, |ctx, v| f(unsafe { &*(ctx.value as *const T) }, v))
364  //     }
365  // }
366  
367  // macro_rules! value_internal_of {
368  //     ($ctx:expr, $f:expr) => {{
369  //         let ctx = $crate::eval::InternalCtx {
370  //             value: Rc::into_raw($ctx) as *const u8,
371  //             clone: |ctx| unsafe {
372  //                 Rc::increment_strong_count(ctx as *const _);
373  //                 ctx
374  //             },
375  //             drop: |ctx| unsafe {
376  //                 Rc::decrement_strong_count(ctx);
377  //             },
378  //         };
379  //         $crate::eval::Value::Internal(ctx, |ctx, v| $f(unsafe { &*(ctx.value as *const _) }, v))
380  //     }};
381  // }
382  // pub(crate) use value_internal_of;
383  
384  #[derive(Debug, Clone)]
385  pub struct Context<'e> {
386      parent: Option<Rc<Context<'e>>>,
387      vars: HashMap<String, Value<'e>>,
388  }
389  
390  impl<'e> Context<'e> {
391      pub fn register(&mut self, name: impl Into<String>, f: InternalFn<'e>) -> Result<(), ()> {
392          if self
393              .vars
394              .insert(name.into(), Value::Internal(Default::default(), f))
395              .is_none()
396          {
397              Ok(())
398          } else {
399              Err(())
400          }
401      }
402  
403      pub fn with_value(self: Rc<Self>, v: String, val: Value<'e>) -> Rc<Context<'e>> {
404          Rc::new(Context {
405              parent: Some(self),
406              vars: [(v, val)].into(),
407          })
408      }
409  
410      pub fn get(&self, v: &str) -> Option<Value<'e>> {
411          self.vars
412              .get(v)
413              .cloned()
414              .or_else(|| self.parent.as_ref().and_then(|p| p.get(v)))
415      }
416  }
417  
418  impl<'e> Default for Context<'e> {
419      fn default() -> Self {
420          Self {
421              parent: None,
422              vars: Default::default(),
423          }
424      }
425  }
426  
427  #[derive(Debug)]
428  pub enum EvalError<'e> {
429      UnknownVariable(String),
430      CallingNonFunction(Value<'e>),
431      InvalidState,
432      InvalidArgument,
433      IOError(io::Error),
434  }
435  
436  impl<'e> From<io::Error> for EvalError<'e> {
437      fn from(err: io::Error) -> Self {
438          Self::IOError(err)
439      }
440  }
441  
442  pub fn eval<'e>(e: &'e Ast, ctx: Rc<Context<'e>>) -> Result<Value<'e>, EvalError<'e>> {
443      match e {
444          Ast::UnitLit => Ok(Value::Unit),
445          Ast::StrLit(s) => Ok(Value::Str(Cow::Borrowed(s))),
446          Ast::IntLit(i) => Ok(Value::Int(*i)),
447          Ast::Ident(v) => ctx
448              .get(v)
449              .ok_or_else(|| EvalError::UnknownVariable(v.clone())),
450          Ast::Lambda(var, expr) => Ok(Value::Function(ctx.clone(), var.clone(), expr)),
451          Ast::Application(fun, arg) => {
452              let fun = eval(fun, ctx.clone())?;
453              let arg = eval(arg, ctx.clone())?;
454              match fun {
455                  Value::Function(ctx, var, body) => eval(body, ctx.with_value(var, arg)),
456                  Value::Internal(ctx, f) => f(ctx, arg),
457                  _ => Err(EvalError::CallingNonFunction(fun)),
458              }
459          }
460      }
461  }