/ src / lightspeed / protocol.gleam
protocol.gleam
  1  //// Versioned protocol frame model.
  2  
  3  import gleam/int
  4  import gleam/list
  5  import gleam/string
  6  
  7  pub const protocol_name = "lightspeed"
  8  
  9  pub const protocol_version = 1
 10  
 11  /// Protocol frame exchanged between client and server.
 12  pub type Frame {
 13    Hello(protocol: String, version: Int)
 14    Event(ref: String, name: String, payload: String)
 15    Diff(ref: String, html: String)
 16    Ack(ref: String)
 17    Failure(ref: String, reason: String)
 18  }
 19  
 20  /// Protocol decode errors.
 21  pub type DecodeError {
 22    EmptyFrame
 23    UnknownFrameTag(String)
 24    BadFieldCount(tag: String, expected: Int, actual: Int)
 25    InvalidVersion(String)
 26    UnsupportedVersion(Int)
 27    UnsupportedProtocol(String)
 28    InvalidEscapeSequence
 29  }
 30  
 31  /// Construct a protocol hello frame.
 32  pub fn hello() -> Frame {
 33    Hello(protocol: protocol_name, version: protocol_version)
 34  }
 35  
 36  /// Return the frame reference when one exists.
 37  pub fn ref(frame: Frame) -> String {
 38    case frame {
 39      Hello(_, _) -> ""
 40      Event(ref, _, _) -> ref
 41      Diff(ref, _) -> ref
 42      Ack(ref) -> ref
 43      Failure(ref, _) -> ref
 44    }
 45  }
 46  
 47  /// True when the frame is part of the current protocol.
 48  pub fn is_current_hello(frame: Frame) -> Bool {
 49    case frame {
 50      Hello(protocol, version) ->
 51        protocol == protocol_name && version == protocol_version
 52      _ -> False
 53    }
 54  }
 55  
 56  /// Encode a frame into a transport-safe textual format.
 57  pub fn encode(frame: Frame) -> String {
 58    case frame {
 59      Hello(protocol, version) ->
 60        join_fields(["hello", protocol, int.to_string(version)])
 61      Event(ref, name, payload) -> join_fields(["event", ref, name, payload])
 62      Diff(ref, html) -> join_fields(["diff", ref, html])
 63      Ack(ref) -> join_fields(["ack", ref])
 64      Failure(ref, reason) -> join_fields(["failure", ref, reason])
 65    }
 66  }
 67  
 68  /// Decode a textual frame payload.
 69  pub fn decode(payload: String) -> Result(Frame, DecodeError) {
 70    case payload {
 71      "" -> Error(EmptyFrame)
 72      _ ->
 73        case split_fields(payload) {
 74          Error(error) -> Error(error)
 75          Ok(fields) -> decode_fields(fields)
 76        }
 77    }
 78  }
 79  
 80  /// Convert decode errors to stable strings for logs and adapter errors.
 81  pub fn decode_error_to_string(error: DecodeError) -> String {
 82    case error {
 83      EmptyFrame -> "empty_frame"
 84      UnknownFrameTag(tag) -> "unknown_frame_tag:" <> tag
 85      BadFieldCount(tag, expected, actual) ->
 86        "bad_field_count:"
 87        <> tag
 88        <> ":"
 89        <> int.to_string(expected)
 90        <> ":"
 91        <> int.to_string(actual)
 92      InvalidVersion(version) -> "invalid_version:" <> version
 93      UnsupportedVersion(version) ->
 94        "unsupported_version:" <> int.to_string(version)
 95      UnsupportedProtocol(protocol) -> "unsupported_protocol:" <> protocol
 96      InvalidEscapeSequence -> "invalid_escape_sequence"
 97    }
 98  }
 99  
100  fn decode_fields(fields: List(String)) -> Result(Frame, DecodeError) {
101    case fields {
102      [] -> Error(EmptyFrame)
103      [tag, ..] ->
104        case tag {
105          "hello" -> decode_hello(fields)
106          "event" -> decode_event(fields)
107          "diff" -> decode_diff(fields)
108          "ack" -> decode_ack(fields)
109          "failure" -> decode_failure(fields)
110          _ -> Error(UnknownFrameTag(tag))
111        }
112    }
113  }
114  
115  fn decode_hello(fields: List(String)) -> Result(Frame, DecodeError) {
116    case fields {
117      ["hello", protocol, version_text] ->
118        case int.parse(version_text) {
119          Error(_) -> Error(InvalidVersion(version_text))
120          Ok(version) ->
121            case protocol == protocol_name {
122              False -> Error(UnsupportedProtocol(protocol))
123              True ->
124                case version == protocol_version {
125                  True -> Ok(Hello(protocol: protocol, version: version))
126                  False -> Error(UnsupportedVersion(version))
127                }
128            }
129        }
130      _ -> bad_field_count("hello", 3, fields)
131    }
132  }
133  
134  fn decode_event(fields: List(String)) -> Result(Frame, DecodeError) {
135    case fields {
136      ["event", ref, name, payload] ->
137        Ok(Event(ref: ref, name: name, payload: payload))
138      _ -> bad_field_count("event", 4, fields)
139    }
140  }
141  
142  fn decode_diff(fields: List(String)) -> Result(Frame, DecodeError) {
143    case fields {
144      ["diff", ref, html] -> Ok(Diff(ref: ref, html: html))
145      _ -> bad_field_count("diff", 3, fields)
146    }
147  }
148  
149  fn decode_ack(fields: List(String)) -> Result(Frame, DecodeError) {
150    case fields {
151      ["ack", ref] -> Ok(Ack(ref: ref))
152      _ -> bad_field_count("ack", 2, fields)
153    }
154  }
155  
156  fn decode_failure(fields: List(String)) -> Result(Frame, DecodeError) {
157    case fields {
158      ["failure", ref, reason] -> Ok(Failure(ref: ref, reason: reason))
159      _ -> bad_field_count("failure", 3, fields)
160    }
161  }
162  
163  fn bad_field_count(
164    tag: String,
165    expected: Int,
166    fields: List(String),
167  ) -> Result(Frame, DecodeError) {
168    Error(BadFieldCount(tag: tag, expected: expected, actual: list.length(fields)))
169  }
170  
171  fn join_fields(fields: List(String)) -> String {
172    case fields {
173      [] -> ""
174      [field, ..rest] -> join_fields_loop(rest, escape_field(field))
175    }
176  }
177  
178  fn join_fields_loop(fields: List(String), acc: String) -> String {
179    case fields {
180      [] -> acc
181      [field, ..rest] -> join_fields_loop(rest, acc <> "|" <> escape_field(field))
182    }
183  }
184  
185  fn escape_field(value: String) -> String {
186    escape_chars(string.to_graphemes(value), "")
187  }
188  
189  fn escape_chars(chars: List(String), acc: String) -> String {
190    case chars {
191      [] -> acc
192      [char, ..rest] ->
193        case char {
194          "\\" -> escape_chars(rest, acc <> "\\\\")
195          "|" -> escape_chars(rest, acc <> "\\|")
196          _ -> escape_chars(rest, acc <> char)
197        }
198    }
199  }
200  
201  fn split_fields(payload: String) -> Result(List(String), DecodeError) {
202    split_chars(string.to_graphemes(payload), "", [], False)
203  }
204  
205  fn split_chars(
206    chars: List(String),
207    current: String,
208    fields_rev: List(String),
209    escaped: Bool,
210  ) -> Result(List(String), DecodeError) {
211    case chars {
212      [] ->
213        case escaped {
214          True -> Error(InvalidEscapeSequence)
215          False -> Ok(list.reverse([current, ..fields_rev]))
216        }
217  
218      [char, ..rest] ->
219        case escaped {
220          True -> split_chars(rest, current <> char, fields_rev, False)
221  
222          False ->
223            case char {
224              "\\" -> split_chars(rest, current, fields_rev, True)
225              "|" -> split_chars(rest, "", [current, ..fields_rev], False)
226              _ -> split_chars(rest, current <> char, fields_rev, False)
227            }
228        }
229    }
230  }