stream.py
1 # SPDX-FileCopyrightText: 2019-2020 Damien P. George 2 # 3 # SPDX-License-Identifier: MIT 4 # 5 # MicroPython uasyncio module 6 # MIT license; Copyright (c) 2019-2020 Damien P. George 7 # 8 # This code comes from MicroPython, and has not been run through black or pylint there. 9 # Altering these files significantly would make merging difficult, so we will not use 10 # pylint or black. 11 # pylint: skip-file 12 # fmt: off 13 """ 14 Streams 15 ======= 16 """ 17 18 from . import core 19 20 21 class Stream: 22 """This represents a TCP stream connection. To minimise code this class 23 implements both a reader and a writer, and both ``StreamReader`` and 24 ``StreamWriter`` alias to this class. 25 """ 26 27 def __init__(self, s, e={}): 28 self.s = s 29 self.e = e 30 self.out_buf = b"" 31 32 def get_extra_info(self, v): 33 """Get extra information about the stream, given by *v*. The valid 34 values for *v* are: ``peername``. 35 """ 36 37 return self.e[v] 38 39 async def __aenter__(self): 40 return self 41 42 async def __aexit__(self, exc_type, exc, tb): 43 await self.close() 44 45 def close(self): 46 pass 47 48 async def wait_closed(self): 49 """Wait for the stream to close. 50 51 This is a coroutine. 52 """ 53 54 # TODO yield? 55 self.s.close() 56 57 async def read(self, n): 58 """Read up to *n* bytes and return them. 59 60 This is a coroutine. 61 """ 62 63 core._io_queue.queue_read(self.s) 64 await core.sleep(0) 65 return self.s.read(n) 66 67 async def readinto(self, buf): 68 """Read up to n bytes into *buf* with n being equal to the length of *buf* 69 70 Return the number of bytes read into *buf* 71 72 This is a coroutine, and a MicroPython extension. 73 """ 74 75 core._io_queue.queue_read(self.s) 76 await core.sleep(0) 77 return self.s.readinto(buf) 78 79 async def readexactly(self, n): 80 """Read exactly *n* bytes and return them as a bytes object. 81 82 Raises an ``EOFError`` exception if the stream ends before reading 83 *n* bytes. 84 85 This is a coroutine. 86 """ 87 88 r = b"" 89 while n: 90 core._io_queue.queue_read(self.s) 91 await core.sleep(0) 92 r2 = self.s.read(n) 93 if r2 is not None: 94 if not len(r2): 95 raise EOFError 96 r += r2 97 n -= len(r2) 98 return r 99 100 async def readline(self): 101 """Read a line and return it. 102 103 This is a coroutine. 104 """ 105 106 l = b"" 107 while True: 108 core._io_queue.queue_read(self.s) 109 await core.sleep(0) 110 l2 = self.s.readline() # may do multiple reads but won't block 111 l += l2 112 if not l2 or l[-1] == 10: # \n (check l in case l2 is str) 113 return l 114 115 def write(self, buf): 116 """Accumulated *buf* to the output buffer. The data is only flushed when 117 `Stream.drain` is called. It is recommended to call `Stream.drain` 118 immediately after calling this function. 119 """ 120 121 self.out_buf += buf 122 123 async def drain(self): 124 """Drain (write) all buffered output data out to the stream. 125 126 This is a coroutine. 127 """ 128 129 mv = memoryview(self.out_buf) 130 off = 0 131 while off < len(mv): 132 yield core._io_queue.queue_write(self.s) 133 ret = self.s.write(mv[off:]) 134 if ret is not None: 135 off += ret 136 self.out_buf = b"" 137 138 139 # Stream can be used for both reading and writing to save code size 140 StreamReader = Stream 141 StreamWriter = Stream 142 143 144 # Create a TCP stream connection to a remote host 145 async def open_connection(host, port): 146 """Open a TCP connection to the given *host* and *port*. The *host* address will 147 be resolved using `socket.getaddrinfo`, which is currently a blocking call. 148 149 Returns a pair of streams: a reader and a writer stream. Will raise a socket-specific 150 ``OSError`` if the host could not be resolved or if the connection could not be made. 151 152 This is a coroutine. 153 """ 154 155 from uerrno import EINPROGRESS 156 import usocket as socket 157 158 ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[ 159 0 160 ] # TODO this is blocking! 161 s = socket.socket(ai[0], ai[1], ai[2]) 162 s.setblocking(False) 163 ss = Stream(s) 164 try: 165 s.connect(ai[-1]) 166 except OSError as er: 167 if er.errno != EINPROGRESS: 168 raise er 169 core._io_queue.queue_write(s) 170 await core.sleep(0) 171 return ss, ss 172 173 174 # Class representing a TCP stream server, can be closed and used in "async with" 175 class Server: 176 """This represents the server class returned from `start_server`. It can be used in 177 an ``async with`` statement to close the server upon exit. 178 """ 179 180 async def __aenter__(self): 181 return self 182 183 async def __aexit__(self, exc_type, exc, tb): 184 self.close() 185 await self.wait_closed() 186 187 def close(self): 188 """Close the server.""" 189 190 self.task.cancel() 191 192 async def wait_closed(self): 193 """Wait for the server to close. 194 195 This is a coroutine. 196 """ 197 198 await self.task 199 200 async def _serve(self, s, cb): 201 # Accept incoming connections 202 while True: 203 try: 204 yield core._io_queue.queue_read(s) 205 except core.CancelledError: 206 # Shutdown server 207 s.close() 208 return 209 try: 210 s2, addr = s.accept() 211 except: 212 # Ignore a failed accept 213 continue 214 s2.setblocking(False) 215 s2s = Stream(s2, {"peername": addr}) 216 core.create_task(cb(s2s, s2s)) 217 218 219 # Helper function to start a TCP stream server, running as a new task 220 # TODO could use an accept-callback on socket read activity instead of creating a task 221 async def start_server(cb, host, port, backlog=5): 222 """Start a TCP server on the given *host* and *port*. The *cb* callback will be 223 called with incoming, accepted connections, and be passed 2 arguments: reader 224 writer streams for the connection. 225 226 Returns a `Server` object. 227 228 This is a coroutine. 229 """ 230 231 import usocket as socket 232 233 # Create and bind server socket. 234 host = socket.getaddrinfo(host, port)[0] # TODO this is blocking! 235 s = socket.socket() 236 s.setblocking(False) 237 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 238 s.bind(host[-1]) 239 s.listen(backlog) 240 241 # Create and return server object and task. 242 srv = Server() 243 srv.task = core.create_task(srv._serve(s, cb)) 244 return srv 245 246 247 ################################################################################ 248 # Legacy uasyncio compatibility 249 250 251 async def stream_awrite(self, buf, off=0, sz=-1): 252 if off != 0 or sz != -1: 253 buf = memoryview(buf) 254 if sz == -1: 255 sz = len(buf) 256 buf = buf[off : off + sz] 257 self.write(buf) 258 await self.drain() 259 260 261 Stream.aclose = Stream.wait_closed 262 Stream.awrite = stream_awrite 263 Stream.awritestr = stream_awrite # TODO explicitly convert to bytes?