/ src / news.nim
news.nim
  1  import std/[
  2    base64, deques, httpcore, nativesockets, net, oids, random, sha1, streams,
  3    strformat, strtabs, strutils, uri]
  4  
  5  when not declaredInScope(newsUseChronos):
  6    # Currently chronos is second class citizen. To use this library in chronos-based
  7    # projects, include this file as follows:
  8    # const newsUseChronos = true
  9    # include news
 10    const newsUseChronos = false
 11  
 12  type
 13    WebSocketError* = object of CatchableError
 14    WebSocketClosedError* = object of WebSocketError
 15  
 16  when newsUseChronos:
 17    import chronos, chronos/streams/[asyncstream, tlsstream]
 18  
 19    type Transport = object
 20      transp: StreamTransport
 21      reader: AsyncStreamReader
 22      writer: AsyncStreamWriter
 23  
 24    proc send(s: Transport, data: string) {.async.} =
 25      # echo "sending: ", data.len
 26      if s.writer == nil:
 27        raise newException(WebSocketClosedError, "WebSocket is closed")
 28      await s.writer.write(data)
 29  
 30    proc recv(s: Transport, len: int): Future[string] {.async.} =
 31      var res = newString(len)
 32      if len != 0:
 33        # echo "receiving: ", len
 34        if s.reader == nil:
 35          raise newException(WebSocketClosedError, "WebSocket is closed")
 36        await s.reader.readExactly(addr res[0], len)
 37      return res
 38  
 39    proc isClosed(transp: Transport): bool {.inline.} =
 40      (transp.reader == nil and transp.writer == nil) or
 41      (transp.reader.closed or transp.writer.closed)
 42  
 43    proc close(transp: var Transport) =
 44      if transp.reader != nil:
 45        transp.reader.close()
 46        transp.reader = nil
 47      if transp.writer != nil:
 48        transp.writer.close()
 49        transp.writer = nil
 50      transp.transp.close()
 51      transp.transp = nil
 52  
 53    proc closeWait(transp: var Transport): Future[void] =
 54      if transp.reader != nil:
 55        transp.reader.close()
 56        transp.reader = nil
 57      if transp.writer != nil:
 58        transp.writer.close()
 59        transp.writer = nil
 60      let t = transp.transp
 61      transp.transp = nil
 62      t.closeWait()
 63  
 64  else:
 65    import std/[asyncdispatch, asynchttpserver, asyncnet]
 66    type Transport = AsyncSocket
 67  
 68  const CRLF = "\c\l"
 69  const GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
 70  
 71  type
 72    Opcode* = enum
 73      ## 4 bits. Defines the interpretation of the "Payload data".
 74      Cont = 0x0 ## denotes a continuation frame
 75      Text = 0x1 ## denotes a text frame
 76      Binary = 0x2 ## denotes a binary frame
 77      # 3-7 are reserved for further non-control frames
 78      Close = 0x8 ## denotes a connection close
 79      Ping = 0x9 ## denotes a ping
 80      Pong = 0xa ## denotes a pong
 81      # B-F are reserved for further control frames
 82  
 83    #[
 84     0                   1                   2                   3
 85     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 86    +-+-+-+-+-------+-+-------------+-------------------------------+
 87    |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
 88    |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
 89    |N|V|V|V|       |S|             |   (if payload len==126/127)   |
 90    | |1|2|3|       |K|             |                               |
 91    +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
 92    |     Extended payload length continued, if payload len == 127  |
 93    + - - - - - - - - - - - - - - - +-------------------------------+
 94    |                               |Masking-key, if MASK set to 1  |
 95    +-------------------------------+-------------------------------+
 96    | Masking-key (continued)       |          Payload Data         |
 97    +-------------------------------- - - - - - - - - - - - - - - - +
 98    :                     Payload Data continued ...                :
 99    + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
100    |                     Payload Data continued ...                |
101    +---------------------------------------------------------------+
102    ]#
103    Frame* = tuple
104      fin: bool ## Indicates that this is the final fragment in a message.
105      rsv1: bool ## MUST be 0 unless negotiated that defines meanings
106      rsv2: bool
107      rsv3: bool
108      opcode: Opcode ## Defines the interpretation of the "Payload data".
109      mask: bool ## Defines whether the "Payload data" is masked.
110      data: string ## Payload data
111  
112    Packet* = object
113      case kind*: Opcode
114      of Text, Binary:
115        data*: string
116      else:
117        discard
118  
119    ReadyState* = enum
120      Connecting = 0 # The connection is not yet open.
121      Open = 1 # The connection is open and ready to communicate.
122      Closing = 2 # The connection is in the process of closing.
123      Closed = 3 # The connection is closed or couldn't be opened.
124  
125    WebSocket* = ref object
126      transp*: Transport
127      version*: int
128      key*: string
129      protocol*: string
130      readyState*: ReadyState
131      maskFrames*: bool
132      sendFut: Future[void]
133      sendQueue: Deque[tuple[text: string, opcode: Opcode, fut: Future[void]]]
134  
135  template `[]`(value: uint8, index: int): bool =
136    ## get bits from uint8, uint8[2] gets 2nd bit
137    (value and (1 shl (7 - index))) != 0
138  
139  
140  proc nibbleFromChar(c: char): int =
141    ## converts hex chars like `0` to 0 and `F` to 15
142    case c:
143      of '0'..'9': (ord(c) - ord('0'))
144      of 'a'..'f': (ord(c) - ord('a') + 10)
145      of 'A'..'F': (ord(c) - ord('A') + 10)
146      else: 255
147  
148  
149  proc nibbleToChar(value: int): char =
150    ## converts number like 0 to `0` and 15 to `fg`
151    case value:
152      of 0..9: char(value + ord('0'))
153      else: char(value + ord('a') - 10)
154  
155  
156  proc decodeBase16*(str: string): string =
157    ## base16 decode a string
158    result = newString(str.len div 2)
159    for i in 0 ..< result.len:
160      result[i] = chr(
161        (nibbleFromChar(str[2 * i]) shl 4) or
162        nibbleFromChar(str[2 * i + 1]))
163  
164  
165  proc encodeBase16*(str: string): string =
166    ## base16 encode a string
167    result = newString(str.len * 2)
168    for i, c in str:
169      result[i * 2] = nibbleToChar(ord(c) shr 4)
170      result[i * 2 + 1] = nibbleToChar(ord(c) and 0x0f)
171  
172  
173  proc genMaskKey*(): array[4, char] =
174    ## Generates a random key of 4 random chars
175    [char(rand(255)), char(rand(255)), char(rand(255)), char(rand(255))]
176  
177  when not defined(ssl):
178    type SSLContext = ref object
179  var defaultSslContext {.threadvar.}: SSLContext
180  
181  proc getDefaultSslContext(): SSLContext =
182    when defined(ssl):
183      if defaultSslContext.isNil:
184        defaultSslContext = newContext(protVersion = protTLSv1, verifyMode = CVerifyNone)
185        if defaultSslContext.isNil:
186          raise newException(WebSocketError, "Unable to initialize SSL context.")
187    result = defaultSslContext
188  
189  proc close*(ws: WebSocket) =
190    ## close the socket
191    ws.readyState = Closed
192    if not ws.transp.isClosed:
193      ws.transp.close()
194  
195  when newsUseChronos:
196    proc closeWait*(ws: WebSocket) {.async.} =
197      ## close the socket
198      ws.readyState = Closed
199      if not ws.transp.isClosed:
200        await ws.transp.closeWait()
201  
202  else:
203    proc newWebSocket*(req: Request): Future[WebSocket] {.async.} =
204      ## Creates a new socket from a request
205      var ws = WebSocket()
206  
207      try:
208        ws.version = parseInt(req.headers["sec-webSocket-version"])
209        ws.key = req.headers["sec-webSocket-key"].strip()
210        if req.headers.hasKey("sec-webSocket-protocol"):
211          ws.protocol = req.headers["sec-webSocket-protocol"].strip()
212  
213        let sh = secureHash(ws.key & GUID)
214        let acceptKey = base64.encode(decodeBase16($sh))
215  
216        var response = "HTTP/1.1 101 Web Socket Protocol Handshake" & CRLF
217        response.add("Sec-WebSocket-Accept: " & acceptKey & CRLF)
218        response.add("Connection: Upgrade" & CRLF)
219        response.add("Upgrade: websocket" & CRLF)
220        if ws.protocol.len > 0:
221          response.add("Sec-WebSocket-Protocol: " & ws.protocol & CRLF)
222        response.add CRLF
223  
224        ws.transp = req.client
225        # await ws.transp.connect(uri.hostname, port)
226        await ws.transp.send(response)
227        ws.readyState = Open
228      finally:
229        if ws.readyState != Open:
230          close(ws)
231  
232      return ws
233  
234  proc validateServerResponse(resp, secKey: string): string =
235    let respLines = resp.splitLines()
236    block statusCode:
237      const httpVersionStr = "HTTP/1.1 "
238      let httpVersionPos = respLines[0].find(httpVersionStr)
239      if httpVersionPos == -1:
240        return "HTTP version not specified"
241      let i = httpVersionPos + httpVersionStr.len
242      if respLines[0].len <= i + 2:
243        return "Request too short"
244      let v = respLines[0][i .. i + 2]
245      if v != "101":
246        return respLines[0][i ..< respLines[0].len]
247  
248    var validatedHeaders: array[3, bool]
249    for i in 1 ..< respLines.len:
250      let h = parseHeader(respLines[i])
251      if cmpIgnoreCase(h.key, "Upgrade") == 0:
252        if cmpIgnoreCase(h.value[0].toLowerAscii, "websocket") != 0:
253          return "Upgrade header is invalid"
254        validatedHeaders[0] = true
255  
256      elif cmpIgnoreCase(h.key, "Connection") == 0:
257        if cmpIgnoreCase(h.value[0], "upgrade") != 0:
258          return "Connection header is invalid"
259        validatedHeaders[1] = true
260  
261      elif cmpIgnoreCase(h.key, "Sec-WebSocket-Accept") == 0:
262        let sh = decodeBase16($secureHash(secKey & GUID))
263        if cmpIgnoreCase(h.value[0], base64.encode(sh)) != 0:
264          return "Secret key invalid"
265        validatedHeaders[2] = true
266  
267    if not validatedHeaders[0]: return "Missing Upgrade header"
268    if not validatedHeaders[1]: return "Missing Connection header"
269    if not validatedHeaders[2]: return "Missing Sec-WebSocket-Accept header"
270  
271  proc newWebSocket*(url: string, headers: StringTableRef = nil,
272                     sslContext: SSLContext = getDefaultSslContext()): Future[WebSocket] {.async.} =
273    ## Creates a client
274    var ws = WebSocket()
275  
276    try:
277      let uri = parseUri(url)
278      var port = Port(80)
279      case uri.scheme
280        of "wss":
281          port = Port(443)
282        of "ws":
283          discard
284        else:
285          raise newException(WebSocketError, &"Scheme {uri.scheme} not supported yet.")
286      if uri.port.len > 0:
287        port = Port(parseInt(uri.port))
288  
289      when newsUseChronos:
290        let tr = await connect(resolveTAddress(uri.hostname, port)[0])
291        ws.transp.transp = tr
292        ws.transp.reader = newAsyncStreamReader(tr)
293        ws.transp.writer = newAsyncStreamWriter(tr)
294  
295        if uri.scheme == "wss":
296          let s = newTLSClientAsyncStream(ws.transp.reader, ws.transp.writer, serverName = uri.hostname)
297          ws.transp.reader = s.reader
298          ws.transp.writer = s.writer
299  
300      else:
301        ws.transp = newAsyncSocket()
302        if uri.scheme == "wss":
303          when defined(ssl):
304            sslContext.wrapSocket(ws.transp)
305          else:
306            raise newException(WebSocketError, "SSL support is not available. Compile with -d:ssl to enable.")
307        await ws.transp.connect(uri.hostname, port)
308  
309      var urlPath = uri.path
310      if uri.query.len > 0:
311        urlPath.add("?" & uri.query)
312      if urlPath.len == 0:
313        urlPath = "/"
314      let
315        secKey = ($genOid())[^16..^1]
316        secKeyEncoded = encode(secKey)
317      let requestLine = &"GET {urlPath} HTTP/1.1"
318      let predefinedHeaders = [
319        &"Host: {uri.hostname}:{$port}",
320        "Connection: Upgrade",
321        "Upgrade: websocket",
322        "Sec-WebSocket-Version: 13",
323        &"Sec-WebSocket-Key: {secKeyEncoded}"
324      ]
325  
326      var customHeaders = ""
327      if not headers.isNil:
328        for k, v in headers:
329          customHeaders &= &"{k}: {v}{CRLF}"
330      var hello = requestLine & CRLF &
331                  customHeaders &
332                  predefinedHeaders.join(CRLF) &
333                  static(CRLF & CRLF)
334  
335      await ws.transp.send(hello)
336  
337      var output = ""
338      while not output.endsWith(static(CRLF & CRLF)):
339        output.add await ws.transp.recv(1)
340  
341      let error = validateServerResponse(output, secKeyEncoded)
342      if error.len > 0:
343        raise newException(WebSocketError, "WebSocket connection error: " & error)
344  
345      ws.readyState = Open
346      ws.maskFrames = true
347    finally:
348      if ws.readyState != Open:
349        close(ws)
350  
351    return ws
352  
353  proc encodeFrame*(f: Frame): string =
354    ## Encodes a frame into a string buffer
355    ## See https://tools.ietf.org/html/rfc6455#section-5.2
356  
357    var ret = newStringStream()
358  
359    var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags
360    if f.fin:
361      b0 = b0 or 128u8
362  
363    ret.write(b0)
364  
365    # Payload length can be 7 bits, 7+16 bits, or 7+64 bits
366  
367    var b1 = 0u8 # 1st byte: playload len start and mask bit
368  
369    if f.data.len <= 125:
370      b1 = f.data.len.uint8
371    elif f.data.len > 125 and f.data.len <= 0xffff:
372      b1 = 126u8
373    else:
374      b1 = 127u8
375  
376    if f.mask:
377      b1 = b1 or (1 shl 7)
378  
379    ret.write(uint8 b1)
380  
381    # Only need more bytes if data len is 7+16 bits, or 7+64 bits
382    if f.data.len > 125 and f.data.len <= 0xffff:
383      # data len is 7+16 bits
384      ret.write(htons(f.data.len.uint16))
385    elif f.data.len > 0xffff:
386      # data len is 7+64 bits
387      var len = f.data.len
388      ret.write char((len shr 56) and 255)
389      ret.write char((len shr 48) and 255)
390      ret.write char((len shr 40) and 255)
391      ret.write char((len shr 32) and 255)
392      ret.write char((len shr 24) and 255)
393      ret.write char((len shr 16) and 255)
394      ret.write char((len shr 8) and 255)
395      ret.write char(len and 255)
396  
397    var data = f.data
398  
399    if f.mask:
400      # if we need to maks it generate random mask key and mask the data
401      let maskKey = genMaskKey()
402      for i in 0..<data.len:
403        data[i] = (data[i].uint8 xor maskKey[i mod 4].uint8).char
404      # write mask key next
405      ret.write(maskKey)
406  
407    # write the data
408    ret.write(data)
409    ret.setPosition(0)
410    return ret.readAll()
411  
412  proc doSend(ws: WebSocket, text: string, opcode: Opcode): Future[void] {.async.} =
413    try:
414      ## write data to WebSocket
415      var frame = encodeFrame((
416        fin: true,
417        rsv1: false,
418        rsv2: false,
419        rsv3: false,
420        opcode: opcode,
421        mask: ws.maskFrames,
422        data: text
423      ))
424      const maxSize = 1024*1024
425      # send stuff in 1 megabyte chunks to prevent IOErrors
426      # with really large packets
427      var i = 0
428      while i < frame.len:
429        let data = frame[i ..< min(frame.len, i + maxSize)]
430        if ws.transp.isClosed:
431          raise newException(WebSocketClosedError, "Socket closed")
432        await ws.transp.send(data)
433        i += maxSize
434        await sleepAsync(1)
435    except CatchableError as e:
436      if ws.transp.isClosed:
437        ws.readyState = Closed
438        raise newException(WebSocketClosedError, "Socket closed")
439      else:
440        raise newException(WebSocketError,
441                           &"Could not send packet because of [{e.name}]: {e.msg}")
442  
443  proc continueSending(ws: WebSocket) =
444    if ws.sendQueue.len <= 0:
445      return
446  
447    let
448      task = ws.sendQueue.popFirst()
449      fut = task.fut
450      sendFut = ws.doSend(task.text, task.opcode)
451    ws.sendFut = sendFut
452  
453    proc doHandleSent() =
454      if ws.sendFut.failed:
455        fut.fail(ws.sendFut.error)
456      else:
457        fut.complete()
458      ws.sendFut = nil
459      ws.continueSending()
460  
461    when newsUseChronos:
462      proc handleSent(future: pointer) =
463        doHandleSent()
464    else:
465      proc handleSent() =
466        doHandleSent()
467  
468    ws.sendFut.addCallback(handleSent)
469  
470  proc send*(ws: WebSocket, text: string, opcode = Opcode.Text): Future[void] =
471    if ws.sendFut != nil:
472      let fut = newFuture[void]("send")
473      ws.sendQueue.addLast (text: text, opcode: opcode, fut: fut)
474      return fut
475  
476    ws.sendFut = ws.doSend(text, opcode)
477  
478    proc doHandleSent() =
479      ws.sendFut = nil
480      ws.continueSending()
481  
482    when newsUseChronos:
483      proc handleSent(future: pointer) =
484        doHandleSent()
485    else:
486      proc handleSent() =
487        doHandleSent()
488  
489    ws.sendFut.addCallback(handleSent)
490    ws.sendFut
491  
492  proc send*(ws: WebSocket, packet: Packet): Future[void] =
493    if packet.kind == Text or packet.kind == Binary:
494      return ws.send(packet.data, packet.kind)
495    else:
496      return ws.send("", packet.kind)
497  
498  proc recvFrame(ws: WebSocket): Future[Frame] {.async.} =
499    ## Gets a frame from the WebSocket
500    ## See https://tools.ietf.org/html/rfc6455#section-5.2
501  
502    if ws.transp.isClosed:
503      ws.readyState = Closed
504      return result
505  
506    # grab the header
507    let header = try:
508      await ws.transp.recv(2)
509    except CatchableError as err:
510      close ws
511      raise err
512  
513    if header.len != 2:
514      ws.readyState = Closed
515      close ws
516      raise newException(WebSocketClosedError, "socket closed")
517  
518    let b0 = header[0].uint8
519    let b1 = header[1].uint8
520  
521    # read the flags and fin from the header
522    result.fin  = b0[0]
523    result.rsv1 = b0[1]
524    result.rsv2 = b0[2]
525    result.rsv3 = b0[3]
526  
527    let opcodeVal = b0 and 0x0f
528    if opcodeVal > high(Opcode).uint8:
529      raise newException(WebSocketError, "Server did not respond with a valid WebSocket frame")
530    result.opcode = Opcode(opcodeVal)
531  
532    # if any of the rsv are set close the socket
533    if result.rsv1 or result.rsv2 or result.rsv3:
534      close ws
535      raise newException(WebSocketError, "WebSocket Potocol missmatch")
536  
537    # Payload length can be 7 bits, 7+16 bits, or 7+64 bits
538    var finalLen: uint = 0
539  
540    let headerLen = uint(b1 and 0x7f)
541    if headerLen == 0x7e:
542      # length must be 7+16 bits
543      var lenstr = try:
544        await ws.transp.recv(2)
545      except CatchableError as err:
546        close ws
547        raise err
548  
549      if lenstr.len != 2:
550        close ws
551        raise newException(WebSocketClosedError, "Socket closed")
552  
553      finalLen = cast[ptr uint16](lenstr[0].addr)[].htons
554  
555    elif headerLen == 0x7f:
556      # length must be 7+64 bits
557      var lenstr = try:
558        await ws.transp.recv(8)
559      except CatchableError as err:
560        close ws
561        raise err
562  
563      if lenstr.len != 8:
564        close ws
565        raise newException(WebSocketClosedError, "Socket closed")
566  
567      finalLen = cast[ptr uint32](lenstr[4].addr)[].htonl
568  
569    else:
570      # length must be 7 bits
571      finalLen = headerLen
572  
573    # do we need to apply mask?
574    result.mask = (b1 and 0x80) == 0x80
575    var maskKey = ""
576    if result.mask:
577      # read mask
578      maskKey = try:
579        await ws.transp.recv(4)
580      except CatchableError as err:
581        close ws
582        raise err
583  
584      if maskKey.len != 4:
585        close ws
586        raise newException(WebSocketClosedError, "Socket closed")
587  
588    # read the data
589    result.data = try:
590      await ws.transp.recv(int finalLen)
591    except CatchableError as err:
592      close ws
593      raise err
594  
595    if result.data.len != int finalLen:
596      close ws
597      raise newException(WebSocketClosedError, "Socket closed")
598  
599    if result.mask:
600      # apply mask if we need too
601      for i in 0 ..< result.data.len:
602        result.data[i] = (result.data[i].uint8 xor maskKey[i mod 4].uint8).char
603  
604  proc sendPing*(ws: WebSocket): Future[void] {.async.} =
605    await ws.send("", Opcode.Ping)
606  
607  proc sendPong(ws: WebSocket): Future[void] {.async.} =
608    await ws.send("", Opcode.Pong)
609  
610  proc sendClose(ws: WebSocket): Future[void] {.async.} =
611    await ws.send("", Opcode.Close)
612  
613  proc shutdown*(ws: WebSocket): Future[void] {.async.} =
614    ## close the socket
615    ws.readyState = Closing
616    await ws.sendClose
617  
618  proc receivePacket*(ws: WebSocket): Future[Packet] {.async.} =
619    try:
620      ## wait for a string packet to come
621      var frame = await ws.recvFrame()
622      result = Packet(kind: frame.opcode)
623      if frame.opcode == Text or frame.opcode == Binary:
624        result.data = frame.data
625        # If there are more parts read and wait for them
626        while frame.fin != true:
627          frame = await ws.recvFrame()
628          if frame.opcode != Cont:
629            close ws
630            raise newException(WebSocketError, "Socket did not get continue frame")
631          result.data.add frame.data
632        return
633  
634      if frame.opcode == Ping:
635        await ws.sendPong()
636  
637      elif frame.opcode == Pong:
638        return
639  
640      elif frame.opcode == Close:
641        if ws.readyState != Closing:
642          await ws.sendClose()
643        ws.readyState = Closed
644        if not ws.transp.isClosed:
645          ws.transp.close()
646  
647    except WebSocketError as e:
648      raise e
649    except CatchableError as e:
650      if ws.transp.isClosed:
651        ws.readyState = Closed
652        result = Packet(kind: Close)
653      else:
654        raise newException(WebSocketError,
655                           &"Could not receive packet because of [{e.name}]: {e.msg}")
656  
657  proc receiveString*(ws: WebSocket): Future[string] {.async.} =
658    var receivedString = false
659    while not (receivedString or ws.readyState == Closed):
660      let packet = await ws.receivePacket()
661      case packet.kind
662      of Text, Binary:
663        receivedString = true
664        result = packet.data
665      of Close:
666        result = ""
667      else:
668        discard