/ asyncio / stream.py
stream.py
  1  # SPDX-FileCopyrightText: 2019-2020 Damien P. George
  2  #
  3  # SPDX-License-Identifier: MIT
  4  #
  5  # MicroPython uasyncio module
  6  # MIT license; Copyright (c) 2019-2020 Damien P. George
  7  #
  8  # This code comes from MicroPython, and has not been run through black or pylint there.
  9  # Altering these files significantly would make merging difficult, so we will not use
 10  # pylint or black.
 11  # pylint: skip-file
 12  # fmt: off
 13  """
 14  Streams
 15  =======
 16  """
 17  
 18  from . import core
 19  
 20  
 21  class Stream:
 22      """This represents a TCP stream connection. To minimise code this class
 23      implements both a reader and a writer, and both ``StreamReader`` and
 24      ``StreamWriter`` alias to this class.
 25      """
 26  
 27      def __init__(self, s, e={}):
 28          self.s = s
 29          self.e = e
 30          self.out_buf = b""
 31  
 32      def get_extra_info(self, v):
 33          """Get extra information about the stream, given by *v*. The valid
 34          values for *v* are: ``peername``.
 35          """
 36  
 37          return self.e[v]
 38  
 39      async def __aenter__(self):
 40          return self
 41  
 42      async def __aexit__(self, exc_type, exc, tb):
 43          await self.close()
 44  
 45      def close(self):
 46          pass
 47  
 48      async def wait_closed(self):
 49          """Wait for the stream to close.
 50  
 51          This is a coroutine.
 52          """
 53  
 54          # TODO yield?
 55          self.s.close()
 56  
 57      async def read(self, n):
 58          """Read up to *n* bytes and return them.
 59  
 60          This is a coroutine.
 61          """
 62  
 63          core._io_queue.queue_read(self.s)
 64          await core.sleep(0)
 65          return self.s.read(n)
 66  
 67      async def readinto(self, buf):
 68          """Read up to n bytes into *buf* with n being equal to the length of *buf*
 69  
 70          Return the number of bytes read into *buf*
 71  
 72          This is a coroutine, and a MicroPython extension.
 73          """
 74  
 75          core._io_queue.queue_read(self.s)
 76          await core.sleep(0)
 77          return self.s.readinto(buf)
 78  
 79      async def readexactly(self, n):
 80          """Read exactly *n* bytes and return them as a bytes object.
 81  
 82          Raises an ``EOFError`` exception if the stream ends before reading
 83          *n* bytes.
 84  
 85          This is a coroutine.
 86          """
 87  
 88          r = b""
 89          while n:
 90              core._io_queue.queue_read(self.s)
 91              await core.sleep(0)
 92              r2 = self.s.read(n)
 93              if r2 is not None:
 94                  if not len(r2):
 95                      raise EOFError
 96                  r += r2
 97                  n -= len(r2)
 98          return r
 99  
100      async def readline(self):
101          """Read a line and return it.
102  
103          This is a coroutine.
104          """
105  
106          l = b""
107          while True:
108              core._io_queue.queue_read(self.s)
109              await core.sleep(0)
110              l2 = self.s.readline()  # may do multiple reads but won't block
111              l += l2
112              if not l2 or l[-1] == 10:  # \n (check l in case l2 is str)
113                  return l
114  
115      def write(self, buf):
116          """Accumulated *buf* to the output buffer. The data is only flushed when
117          `Stream.drain` is called. It is recommended to call `Stream.drain`
118          immediately after calling this function.
119          """
120  
121          self.out_buf += buf
122  
123      async def drain(self):
124          """Drain (write) all buffered output data out to the stream.
125  
126          This is a coroutine.
127          """
128  
129          mv = memoryview(self.out_buf)
130          off = 0
131          while off < len(mv):
132              yield core._io_queue.queue_write(self.s)
133              ret = self.s.write(mv[off:])
134              if ret is not None:
135                  off += ret
136          self.out_buf = b""
137  
138  
139  # Stream can be used for both reading and writing to save code size
140  StreamReader = Stream
141  StreamWriter = Stream
142  
143  
144  # Create a TCP stream connection to a remote host
145  async def open_connection(host, port):
146      """Open a TCP connection to the given *host* and *port*. The *host* address will
147      be resolved using `socket.getaddrinfo`, which is currently a blocking call.
148  
149      Returns a pair of streams: a reader and a writer stream. Will raise a socket-specific
150      ``OSError`` if the host could not be resolved or if the connection could not be made.
151  
152      This is a coroutine.
153      """
154  
155      from uerrno import EINPROGRESS
156      import usocket as socket
157  
158      ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[
159          0
160      ]  # TODO this is blocking!
161      s = socket.socket(ai[0], ai[1], ai[2])
162      s.setblocking(False)
163      ss = Stream(s)
164      try:
165          s.connect(ai[-1])
166      except OSError as er:
167          if er.errno != EINPROGRESS:
168              raise er
169      core._io_queue.queue_write(s)
170      await core.sleep(0)
171      return ss, ss
172  
173  
174  # Class representing a TCP stream server, can be closed and used in "async with"
175  class Server:
176      """This represents the server class returned from `start_server`.  It can be used in
177      an ``async with`` statement to close the server upon exit.
178      """
179  
180      async def __aenter__(self):
181          return self
182  
183      async def __aexit__(self, exc_type, exc, tb):
184          self.close()
185          await self.wait_closed()
186  
187      def close(self):
188          """Close the server."""
189  
190          self.task.cancel()
191  
192      async def wait_closed(self):
193          """Wait for the server to close.
194  
195          This is a coroutine.
196          """
197  
198          await self.task
199  
200      async def _serve(self, s, cb):
201          # Accept incoming connections
202          while True:
203              try:
204                  yield core._io_queue.queue_read(s)
205              except core.CancelledError:
206                  # Shutdown server
207                  s.close()
208                  return
209              try:
210                  s2, addr = s.accept()
211              except:
212                  # Ignore a failed accept
213                  continue
214              s2.setblocking(False)
215              s2s = Stream(s2, {"peername": addr})
216              core.create_task(cb(s2s, s2s))
217  
218  
219  # Helper function to start a TCP stream server, running as a new task
220  # TODO could use an accept-callback on socket read activity instead of creating a task
221  async def start_server(cb, host, port, backlog=5):
222      """Start a TCP server on the given *host* and *port*. The *cb* callback will be
223      called with incoming, accepted connections, and be passed 2 arguments: reader
224      writer streams for the connection.
225  
226      Returns a `Server` object.
227  
228      This is a coroutine.
229      """
230  
231      import usocket as socket
232  
233      # Create and bind server socket.
234      host = socket.getaddrinfo(host, port)[0]  # TODO this is blocking!
235      s = socket.socket()
236      s.setblocking(False)
237      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
238      s.bind(host[-1])
239      s.listen(backlog)
240  
241      # Create and return server object and task.
242      srv = Server()
243      srv.task = core.create_task(srv._serve(s, cb))
244      return srv
245  
246  
247  ################################################################################
248  # Legacy uasyncio compatibility
249  
250  
251  async def stream_awrite(self, buf, off=0, sz=-1):
252      if off != 0 or sz != -1:
253          buf = memoryview(buf)
254          if sz == -1:
255              sz = len(buf)
256          buf = buf[off : off + sz]
257      self.write(buf)
258      await self.drain()
259  
260  
261  Stream.aclose = Stream.wait_closed
262  Stream.awrite = stream_awrite
263  Stream.awritestr = stream_awrite  # TODO explicitly convert to bytes?