/ bin / app / src / expr / mod.rs
mod.rs
  1  /* This file is part of DarkFi (https://dark.fi)
  2   *
  3   * Copyright (C) 2020-2025 Dyne.org foundation
  4   *
  5   * This program is free software: you can redistribute it and/or modify
  6   * it under the terms of the GNU Affero General Public License as
  7   * published by the Free Software Foundation, either version 3 of the
  8   * License, or (at your option) any later version.
  9   *
 10   * This program is distributed in the hope that it will be useful,
 11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
 12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13   * GNU Affero General Public License for more details.
 14   *
 15   * You should have received a copy of the GNU Affero General Public License
 16   * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 17   */
 18  
 19  use crate::{
 20      error::{Error, Result},
 21      //prop::{Property, PropertySubType, PropertyType, PropertySExprValue},
 22  };
 23  use darkfi_serial::{
 24      async_trait, Decodable, Encodable, FutAsyncWriteExt, ReadExt, SerialDecodable, SerialEncodable,
 25  };
 26  use std::io::{Read, Write};
 27  
 28  mod compile;
 29  pub use compile::Compiler;
 30  
 31  pub fn const_f32(x: f32) -> SExprCode {
 32      vec![Op::ConstFloat32(x)]
 33  }
 34  pub fn load_var<S: Into<String>>(var: S) -> SExprCode {
 35      vec![Op::LoadVar(var.into())]
 36  }
 37  
 38  #[derive(Clone, Debug, PartialEq, SerialEncodable, SerialDecodable)]
 39  pub enum SExprVal {
 40      Null,
 41      Bool(bool),
 42      Uint32(u32),
 43      Float32(f32),
 44      Str(String),
 45  }
 46  
 47  impl SExprVal {
 48      #[allow(dead_code)]
 49      fn is_null(&self) -> bool {
 50          match self {
 51              Self::Null => true,
 52              _ => false,
 53          }
 54      }
 55  
 56      #[allow(dead_code)]
 57      fn is_bool(&self) -> bool {
 58          match self {
 59              Self::Bool(_) => true,
 60              _ => false,
 61          }
 62      }
 63  
 64      fn is_u32(&self) -> bool {
 65          match self {
 66              Self::Uint32(_) => true,
 67              _ => false,
 68          }
 69      }
 70  
 71      #[allow(dead_code)]
 72      fn is_f32(&self) -> bool {
 73          match self {
 74              Self::Float32(_) => true,
 75              _ => false,
 76          }
 77      }
 78  
 79      #[allow(dead_code)]
 80      fn is_str(&self) -> bool {
 81          match self {
 82              Self::Str(_) => true,
 83              _ => false,
 84          }
 85      }
 86  
 87      #[allow(dead_code)]
 88      fn as_bool(&self) -> Result<bool> {
 89          match self {
 90              Self::Bool(v) => Ok(*v),
 91              _ => Err(Error::PropertyWrongType),
 92          }
 93      }
 94  
 95      pub fn as_u32(&self) -> Result<u32> {
 96          match self {
 97              Self::Uint32(v) => Ok(*v),
 98              _ => Err(Error::PropertyWrongType),
 99          }
100      }
101  
102      pub fn as_f32(&self) -> Result<f32> {
103          match self {
104              Self::Float32(v) => Ok(*v),
105              _ => Err(Error::PropertyWrongType),
106          }
107      }
108  
109      #[allow(dead_code)]
110      fn as_str(&self) -> Result<String> {
111          match self {
112              Self::Str(v) => Ok(v.clone()),
113              _ => Err(Error::PropertyWrongType),
114          }
115      }
116  
117      pub fn coerce_f32(&self) -> Result<f32> {
118          match self {
119              Self::Uint32(v) => Ok(*v as f32),
120              Self::Float32(v) => Ok(*v),
121              _ => Err(Error::PropertyWrongType),
122          }
123      }
124  }
125  
126  #[derive(Debug, PartialEq)]
127  pub struct SExprMachine<'a> {
128      pub globals: Vec<(String, SExprVal)>,
129      pub stmts: &'a SExprCode,
130  }
131  
132  // Each item is a statement
133  pub type SExprCode = Vec<Op>;
134  
135  #[derive(Clone, Debug, PartialEq)]
136  pub enum Op {
137      Null,
138      Add((Box<Op>, Box<Op>)),
139      Sub((Box<Op>, Box<Op>)),
140      Mul((Box<Op>, Box<Op>)),
141      Div((Box<Op>, Box<Op>)),
142      ConstBool(bool),
143      ConstUint32(u32),
144      ConstFloat32(f32),
145      ConstStr(String),
146      LoadVar(String),
147      StoreVar((String, Box<Op>)),
148      Min((Box<Op>, Box<Op>)),
149      Max((Box<Op>, Box<Op>)),
150      IsEqual((Box<Op>, Box<Op>)),
151      LessThan((Box<Op>, Box<Op>)),
152      Float32ToUint32(Box<Op>),
153      IfElse((Box<Op>, SExprCode, SExprCode)),
154  }
155  
156  impl<'a> SExprMachine<'a> {
157      pub fn call(&mut self) -> Result<SExprVal> {
158          if self.stmts.is_empty() {
159              return Ok(SExprVal::Null)
160          }
161          for i in 0..(self.stmts.len() - 1) {
162              self.eval(&self.stmts[i])?;
163          }
164          self.eval(self.stmts.last().unwrap())
165      }
166  
167      fn eval(&mut self, op: &Op) -> Result<SExprVal> {
168          match op {
169              Op::Null => Ok(SExprVal::Null),
170              Op::Add((lhs, rhs)) => self.add(lhs, rhs),
171              Op::Sub((lhs, rhs)) => self.sub(lhs, rhs),
172              Op::Mul((lhs, rhs)) => self.mul(lhs, rhs),
173              Op::Div((lhs, rhs)) => self.div(lhs, rhs),
174              Op::ConstBool(val) => Ok(SExprVal::Bool(*val)),
175              Op::ConstUint32(val) => Ok(SExprVal::Uint32(*val)),
176              Op::ConstFloat32(val) => Ok(SExprVal::Float32(*val)),
177              Op::ConstStr(val) => Ok(SExprVal::Str(val.clone())),
178              Op::LoadVar(var) => self.load_var(var),
179              Op::StoreVar((var, val)) => self.store_var(var, val),
180              Op::Min((lhs, rhs)) => self.min(lhs, rhs),
181              Op::Max((lhs, rhs)) => self.max(lhs, rhs),
182              Op::IsEqual((lhs, rhs)) => self.is_equal(lhs, rhs),
183              Op::LessThan((lhs, rhs)) => self.less_than(lhs, rhs),
184              Op::Float32ToUint32(val) => self.float32_to_uint32(val),
185              Op::IfElse((cond, if_val, else_val)) => self.if_else(cond, if_val, else_val),
186          }
187      }
188  
189      fn add(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
190          let lhs = self.eval(lhs)?;
191          let rhs = self.eval(rhs)?;
192  
193          if lhs.is_u32() && rhs.is_u32() {
194              return Ok(SExprVal::Uint32(lhs.as_u32().unwrap() + rhs.as_u32().unwrap()))
195          }
196  
197          let lhs = lhs.coerce_f32()?;
198          let rhs = rhs.coerce_f32()?;
199  
200          Ok(SExprVal::Float32(lhs + rhs))
201      }
202      fn sub(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
203          let lhs = self.eval(lhs)?;
204          let rhs = self.eval(rhs)?;
205  
206          if lhs.is_u32() && rhs.is_u32() {
207              return Ok(SExprVal::Uint32(lhs.as_u32().unwrap() - rhs.as_u32().unwrap()))
208          }
209  
210          let lhs = lhs.coerce_f32()?;
211          let rhs = rhs.coerce_f32()?;
212  
213          Ok(SExprVal::Float32(lhs - rhs))
214      }
215      fn mul(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
216          let lhs = self.eval(lhs)?;
217          let rhs = self.eval(rhs)?;
218  
219          if lhs.is_u32() && rhs.is_u32() {
220              return Ok(SExprVal::Uint32(lhs.as_u32().unwrap() * rhs.as_u32().unwrap()))
221          }
222  
223          let lhs = lhs.coerce_f32()?;
224          let rhs = rhs.coerce_f32()?;
225  
226          Ok(SExprVal::Float32(lhs * rhs))
227      }
228      fn div(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
229          let lhs = self.eval(lhs)?;
230          let rhs = self.eval(rhs)?;
231  
232          // Always coerce
233  
234          let lhs = lhs.coerce_f32()?;
235          let rhs = rhs.coerce_f32()?;
236  
237          Ok(SExprVal::Float32(lhs / rhs))
238      }
239      fn load_var(&self, var: &str) -> Result<SExprVal> {
240          for (name, val) in &self.globals {
241              if name == var {
242                  return Ok(val.clone())
243              }
244          }
245          Err(Error::SExprGlobalNotFound)
246      }
247      fn store_var(&mut self, var: &str, val: &Op) -> Result<SExprVal> {
248          let val = self.eval(val)?;
249          self.globals.push((var.to_string(), val));
250          Ok(SExprVal::Null)
251      }
252      fn min(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
253          let lhs = self.eval(lhs)?;
254          let rhs = self.eval(rhs)?;
255  
256          if lhs.is_u32() && rhs.is_u32() {
257              let lhs = lhs.as_u32().unwrap();
258              let rhs = rhs.as_u32().unwrap();
259              let min = if lhs < rhs { lhs } else { rhs };
260              return Ok(SExprVal::Uint32(min))
261          }
262  
263          let lhs = lhs.coerce_f32()?;
264          let rhs = rhs.coerce_f32()?;
265          let min = if lhs < rhs { lhs } else { rhs };
266  
267          Ok(SExprVal::Float32(min))
268      }
269      fn max(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
270          let lhs = self.eval(lhs)?;
271          let rhs = self.eval(rhs)?;
272  
273          if lhs.is_u32() && rhs.is_u32() {
274              let lhs = lhs.as_u32().unwrap();
275              let rhs = rhs.as_u32().unwrap();
276              let max = if lhs > rhs { lhs } else { rhs };
277              return Ok(SExprVal::Uint32(max))
278          }
279  
280          let lhs = lhs.coerce_f32()?;
281          let rhs = rhs.coerce_f32()?;
282          let max = if lhs > rhs { lhs } else { rhs };
283  
284          Ok(SExprVal::Float32(max))
285      }
286      fn is_equal(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
287          let lhs = self.eval(lhs)?;
288          let rhs = self.eval(rhs)?;
289  
290          if lhs.is_u32() && rhs.is_u32() {
291              return Ok(SExprVal::Bool(lhs.as_u32().unwrap() == rhs.as_u32().unwrap()))
292          }
293  
294          let lhs = lhs.coerce_f32()?;
295          let rhs = rhs.coerce_f32()?;
296          let is_equal = (lhs - rhs).abs() < f32::EPSILON;
297  
298          Ok(SExprVal::Bool(is_equal))
299      }
300      fn less_than(&mut self, lhs: &Op, rhs: &Op) -> Result<SExprVal> {
301          let lhs = self.eval(lhs)?;
302          let rhs = self.eval(rhs)?;
303  
304          if lhs.is_u32() && rhs.is_u32() {
305              return Ok(SExprVal::Bool(lhs.as_u32().unwrap() < rhs.as_u32().unwrap()))
306          }
307  
308          let lhs = lhs.coerce_f32()?;
309          let rhs = rhs.coerce_f32()?;
310  
311          Ok(SExprVal::Bool(lhs < rhs))
312      }
313      fn float32_to_uint32(&mut self, val: &Op) -> Result<SExprVal> {
314          let val = self.eval(val)?;
315          if val.is_u32() {
316              return Ok(SExprVal::Uint32(val.as_u32()?))
317          }
318          Ok(SExprVal::Uint32(val.as_f32()? as u32))
319      }
320      fn if_else(&mut self, cond: &Op, if_val: &SExprCode, else_val: &SExprCode) -> Result<SExprVal> {
321          let cond = self.eval(cond)?;
322          let cond = cond.as_bool()?;
323  
324          if cond {
325              let mut machine = SExprMachine { globals: self.globals.clone(), stmts: if_val };
326              machine.call()
327          } else {
328              let mut machine = SExprMachine { globals: self.globals.clone(), stmts: else_val };
329              machine.call()
330          }
331      }
332  }
333  
334  impl Encodable for Op {
335      fn encode<S: Write>(&self, s: &mut S) -> std::result::Result<usize, std::io::Error> {
336          let mut len = 0;
337          match self {
338              Self::Null => {
339                  len += 0u8.encode(s)?;
340              }
341              Self::Add((lhs, rhs)) => {
342                  len += 1u8.encode(s)?;
343                  len += lhs.encode(s)?;
344                  len += rhs.encode(s)?;
345              }
346              Self::Sub((lhs, rhs)) => {
347                  len += 2u8.encode(s)?;
348                  len += lhs.encode(s)?;
349                  len += rhs.encode(s)?;
350              }
351              Self::Mul((lhs, rhs)) => {
352                  len += 3u8.encode(s)?;
353                  len += lhs.encode(s)?;
354                  len += rhs.encode(s)?;
355              }
356              Self::Div((lhs, rhs)) => {
357                  len += 4u8.encode(s)?;
358                  len += lhs.encode(s)?;
359                  len += rhs.encode(s)?;
360              }
361              Self::ConstBool(val) => {
362                  len += 5u8.encode(s)?;
363                  len += val.encode(s)?;
364              }
365              Self::ConstUint32(val) => {
366                  len += 6u8.encode(s)?;
367                  len += val.encode(s)?;
368              }
369              Self::ConstFloat32(val) => {
370                  len += 7u8.encode(s)?;
371                  len += val.encode(s)?;
372              }
373              Self::ConstStr(val) => {
374                  len += 8u8.encode(s)?;
375                  len += val.encode(s)?;
376              }
377              Self::LoadVar(var) => {
378                  len += 9u8.encode(s)?;
379                  len += var.encode(s)?;
380              }
381              Self::StoreVar((var, val)) => {
382                  len += 10u8.encode(s)?;
383                  len += var.encode(s)?;
384                  len += val.encode(s)?;
385              }
386              Self::Min((lhs, rhs)) => {
387                  len += 11u8.encode(s)?;
388                  len += lhs.encode(s)?;
389                  len += rhs.encode(s)?;
390              }
391              Self::Max((lhs, rhs)) => {
392                  len += 12u8.encode(s)?;
393                  len += lhs.encode(s)?;
394                  len += rhs.encode(s)?;
395              }
396              Self::IsEqual((lhs, rhs)) => {
397                  len += 13u8.encode(s)?;
398                  len += lhs.encode(s)?;
399                  len += rhs.encode(s)?;
400              }
401              Self::LessThan((lhs, rhs)) => {
402                  len += 14u8.encode(s)?;
403                  len += lhs.encode(s)?;
404                  len += rhs.encode(s)?;
405              }
406              Self::Float32ToUint32(val) => {
407                  len += 15u8.encode(s)?;
408                  len += val.encode(s)?;
409              }
410              Self::IfElse((cond, if_val, else_val)) => {
411                  len += 16u8.encode(s)?;
412                  len += cond.encode(s)?;
413                  len += if_val.encode(s)?;
414                  len += else_val.encode(s)?;
415              }
416          }
417          Ok(len)
418      }
419  }
420  
421  impl Decodable for Op {
422      fn decode<D: Read>(d: &mut D) -> std::result::Result<Self, std::io::Error> {
423          let op_type = d.read_u8()?;
424          let self_ = match op_type {
425              0 => Self::Null,
426              1 => Self::Add((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
427              2 => Self::Sub((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
428              3 => Self::Mul((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
429              4 => Self::Div((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
430              5 => Self::ConstBool(d.read_bool()?),
431              6 => Self::ConstUint32(d.read_u32()?),
432              7 => Self::ConstFloat32(d.read_f32()?),
433              8 => Self::ConstStr(String::decode(d)?),
434              9 => Self::LoadVar(String::decode(d)?),
435              10 => Self::StoreVar((String::decode(d)?, Box::new(Self::decode(d)?))),
436              11 => Self::Min((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
437              12 => Self::Max((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
438              13 => Self::IsEqual((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
439              14 => Self::LessThan((Box::new(Self::decode(d)?), Box::new(Self::decode(d)?))),
440              15 => Self::Float32ToUint32(Box::new(Self::decode(d)?)),
441              16 => Self::IfElse((
442                  Box::new(Self::decode(d)?),
443                  Decodable::decode(d)?,
444                  Decodable::decode(d)?,
445              )),
446              _ => return Err(std::io::Error::new(std::io::ErrorKind::Other, "Invalid Op type")),
447          };
448          Ok(self_)
449      }
450  }
451  
452  #[cfg(test)]
453  mod tests {
454      use super::*;
455      use darkfi_serial::{deserialize, serialize};
456  
457      #[test]
458      fn seval() {
459          let mut machine = SExprMachine {
460              globals: vec![
461                  ("sw".to_string(), SExprVal::Uint32(110u32)),
462                  ("sh".to_string(), SExprVal::Uint32(4u32)),
463              ],
464              stmts: &vec![Op::Add((
465                  Box::new(Op::ConstUint32(5)),
466                  Box::new(Op::Div((
467                      Box::new(Op::LoadVar("sw".to_string())),
468                      Box::new(Op::ConstUint32(2)),
469                  ))),
470              ))],
471          };
472          assert_eq!(machine.call().unwrap(), SExprVal::Float32(60.));
473      }
474  
475      #[test]
476      fn encdec_code() {
477          let code = Op::Add((
478              Box::new(Op::ConstUint32(5)),
479              Box::new(Op::Div((
480                  Box::new(Op::LoadVar("sw".to_string())),
481                  Box::new(Op::ConstUint32(2)),
482              ))),
483          ));
484  
485          let code_s = serialize(&code);
486          let code2 = deserialize::<Op>(&code_s).unwrap();
487          assert_eq!(code, code2);
488      }
489  
490      #[test]
491      fn if_store() {
492          let code = vec![
493              Op::StoreVar((
494                  "s".to_string(),
495                  Box::new(Op::IfElse((
496                      Box::new(Op::ConstBool(false)),
497                      vec![Op::ConstUint32(4)],
498                      vec![Op::ConstUint32(110)],
499                  ))),
500              )),
501              Op::LoadVar("s".to_string()),
502          ];
503          let mut machine = SExprMachine { globals: vec![], stmts: &code };
504          assert_eq!(machine.call().unwrap(), SExprVal::Uint32(110));
505      }
506  }