/ adafruit_requests.py
adafruit_requests.py
  1  # The MIT License (MIT)
  2  #
  3  # Copyright (c) 2019 ladyada for Adafruit Industries
  4  # Copyright (c) 2020 Scott Shawcroft for Adafruit Industries
  5  #
  6  # Permission is hereby granted, free of charge, to any person obtaining a copy
  7  # of this software and associated documentation files (the "Software"), to deal
  8  # in the Software without restriction, including without limitation the rights
  9  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 10  # copies of the Software, and to permit persons to whom the Software is
 11  # furnished to do so, subject to the following conditions:
 12  #
 13  # The above copyright notice and this permission notice shall be included in
 14  # all copies or substantial portions of the Software.
 15  #
 16  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 17  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 18  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 19  # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 20  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 21  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 22  # THE SOFTWARE.
 23  """
 24  `adafruit_requests`
 25  ================================================================================
 26  
 27  A requests-like library for web interfacing
 28  
 29  
 30  * Author(s): ladyada, Paul Sokolovsky, Scott Shawcroft
 31  
 32  Implementation Notes
 33  --------------------
 34  
 35  Adapted from https://github.com/micropython/micropython-lib/tree/master/urequests
 36  
 37  micropython-lib consists of multiple modules from different sources and
 38  authors. Each module comes under its own licensing terms. Short name of
 39  a license can be found in a file within a module directory (usually
 40  metadata.txt or setup.py). Complete text of each license used is provided
 41  at https://github.com/micropython/micropython-lib/blob/master/LICENSE
 42  
 43  author='Paul Sokolovsky'
 44  license='MIT'
 45  
 46  **Software and Dependencies:**
 47  
 48  * Adafruit CircuitPython firmware for the supported boards:
 49    https://github.com/adafruit/circuitpython/releases
 50  
 51  """
 52  
 53  __version__ = "0.0.0-auto.0"
 54  __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_Requests.git"
 55  
 56  import errno
 57  
 58  # CircuitPython 6.0 does not have the bytearray.split method.
 59  # This function emulates buf.split(needle)[0], which is the functionality
 60  # required.
 61  def _buffer_split0(buf, needle):
 62      index = buf.find(needle)
 63      if index == -1:
 64          return buf
 65      return buf[:index]
 66  
 67  
 68  class _RawResponse:
 69      def __init__(self, response):
 70          self._response = response
 71  
 72      def read(self, size=-1):
 73          """Read as much as available or up to size and return it in a byte string.
 74  
 75          Do NOT use this unless you really need to. Reusing memory with `readinto` is much better.
 76          """
 77          if size == -1:
 78              return self._response.content
 79          return self._response.socket.recv(size)
 80  
 81      def readinto(self, buf):
 82          """Read as much as available into buf or until it is full. Returns the number of bytes read
 83          into buf."""
 84          return self._response._readinto(buf)  # pylint: disable=protected-access
 85  
 86  
 87  class _SendFailed(Exception):
 88      """Custom exception to abort sending a request."""
 89  
 90  
 91  class Response:
 92      """The response from a request, contains all the headers/content"""
 93  
 94      # pylint: disable=too-many-instance-attributes
 95  
 96      encoding = None
 97  
 98      def __init__(self, sock, session=None):
 99          self.socket = sock
100          self.encoding = "utf-8"
101          self._cached = None
102          self._headers = {}
103  
104          # _start_index and _receive_buffer are used when parsing headers.
105          # _receive_buffer will grow by 32 bytes everytime it is too small.
106          self._received_length = 0
107          self._receive_buffer = bytearray(32)
108          self._remaining = None
109          self._chunked = False
110  
111          self._backwards_compatible = not hasattr(sock, "recv_into")
112  
113          http = self._readto(b" ")
114          if not http:
115              if session:
116                  session._close_socket(self.socket)
117              else:
118                  self.socket.close()
119              raise RuntimeError("Unable to read HTTP response.")
120          self.status_code = int(bytes(self._readto(b" ")))
121          self.reason = self._readto(b"\r\n")
122          self._parse_headers()
123          self._raw = None
124          self._session = session
125  
126      def __enter__(self):
127          return self
128  
129      def __exit__(self, exc_type, exc_value, traceback):
130          self.close()
131  
132      def _recv_into(self, buf, size=0):
133          if self._backwards_compatible:
134              size = len(buf) if size == 0 else size
135              b = self.socket.recv(size)
136              read_size = len(b)
137              buf[:read_size] = b
138              return read_size
139          return self.socket.recv_into(buf, size)
140  
141      @staticmethod
142      def _find(buf, needle, start, end):
143          if hasattr(buf, "find"):
144              return buf.find(needle, start, end)
145          result = -1
146          i = start
147          while i < end:
148              j = 0
149              while j < len(needle) and i + j < end and buf[i + j] == needle[j]:
150                  j += 1
151              if j == len(needle):
152                  result = i
153                  break
154              i += 1
155  
156          return result
157  
158      def _readto(self, first, second=b""):
159          buf = self._receive_buffer
160          end = self._received_length
161          while True:
162              firsti = self._find(buf, first, 0, end)
163              secondi = -1
164              if second:
165                  secondi = self._find(buf, second, 0, end)
166  
167              i = -1
168              needle_len = 0
169              if firsti >= 0:
170                  i = firsti
171                  needle_len = len(first)
172              if secondi >= 0 and (firsti < 0 or secondi < firsti):
173                  i = secondi
174                  needle_len = len(second)
175              if i >= 0:
176                  result = buf[:i]
177                  new_start = i + needle_len
178  
179                  if i + needle_len <= end:
180                      new_end = end - new_start
181                      buf[:new_end] = buf[new_start:end]
182                      self._received_length = new_end
183                  return result
184  
185              # Not found so load more.
186  
187              # If our buffer is full, then make it bigger to load more.
188              if end == len(buf):
189                  new_size = len(buf) + 32
190                  new_buf = bytearray(new_size)
191                  new_buf[: len(buf)] = buf
192                  buf = new_buf
193                  self._receive_buffer = buf
194  
195              read = self._recv_into(memoryview(buf)[end:])
196              if read == 0:
197                  self._received_length = 0
198                  return buf[:end]
199              end += read
200  
201          return b""
202  
203      def _read_from_buffer(self, buf=None, nbytes=None):
204          if self._received_length == 0:
205              return 0
206          read = self._received_length
207          if nbytes < read:
208              read = nbytes
209          membuf = memoryview(self._receive_buffer)
210          if buf:
211              buf[:read] = membuf[:read]
212          if read < self._received_length:
213              new_end = self._received_length - read
214              self._receive_buffer[:new_end] = membuf[read : self._received_length]
215              self._received_length = new_end
216          else:
217              self._received_length = 0
218          return read
219  
220      def _readinto(self, buf):
221          if not self.socket:
222              raise RuntimeError(
223                  "Newer Response closed this one. Use Responses immediately."
224              )
225  
226          if not self._remaining:
227              # Consume the chunk header if need be.
228              if self._chunked:
229                  # Consume trailing \r\n for chunks 2+
230                  if self._remaining == 0:
231                      self._throw_away(2)
232                  chunk_header = _buffer_split0(self._readto(b"\r\n"), b";")
233                  http_chunk_size = int(bytes(chunk_header), 16)
234                  if http_chunk_size == 0:
235                      self._chunked = False
236                      self._parse_headers()
237                      return 0
238                  self._remaining = http_chunk_size
239              else:
240                  return 0
241  
242          nbytes = len(buf)
243          if nbytes > self._remaining:
244              nbytes = self._remaining
245  
246          read = self._read_from_buffer(buf, nbytes)
247          if read == 0:
248              read = self._recv_into(buf, nbytes)
249          self._remaining -= read
250  
251          return read
252  
253      def _throw_away(self, nbytes):
254          nbytes -= self._read_from_buffer(nbytes=nbytes)
255  
256          buf = self._receive_buffer
257          for _ in range(nbytes // len(buf)):
258              self._recv_into(buf)
259          remaining = nbytes % len(buf)
260          if remaining:
261              self._recv_into(buf, remaining)
262  
263      def close(self):
264          """Drain the remaining ESP socket buffers. We assume we already got what we wanted."""
265          if not self.socket:
266              return
267          # Make sure we've read all of our response.
268          if self._cached is None:
269              if self._remaining and self._remaining > 0:
270                  self._throw_away(self._remaining)
271              elif self._chunked:
272                  while True:
273                      chunk_header = self._readto(b"\r\n").split(b";", 1)[0]
274                      chunk_size = int(bytes(chunk_header), 16)
275                      if chunk_size == 0:
276                          break
277                      self._throw_away(chunk_size + 2)
278                  self._parse_headers()
279          if self._session:
280              self._session._free_socket(self.socket)  # pylint: disable=protected-access
281          else:
282              self.socket.close()
283          self.socket = None
284  
285      def _parse_headers(self):
286          """
287          Parses the header portion of an HTTP request/response from the socket.
288          Expects first line of HTTP request/response to have been read already.
289          """
290          while True:
291              title = self._readto(b": ", b"\r\n")
292              if not title:
293                  break
294  
295              content = self._readto(b"\r\n")
296              if title and content:
297                  title = str(title, "utf-8")
298                  content = str(content, "utf-8")
299                  # Check len first so we can skip the .lower allocation most of the time.
300                  if (
301                      len(title) == len("content-length")
302                      and title.lower() == "content-length"
303                  ):
304                      self._remaining = int(content)
305                  if (
306                      len(title) == len("transfer-encoding")
307                      and title.lower() == "transfer-encoding"
308                  ):
309                      self._chunked = content.lower() == "chunked"
310                  self._headers[title] = content
311  
312      @property
313      def headers(self):
314          """
315          The response headers. Does not include headers from the trailer until
316          the content has been read.
317          """
318          return self._headers
319  
320      @property
321      def content(self):
322          """The HTTP content direct from the socket, as bytes"""
323          if self._cached is not None:
324              if isinstance(self._cached, bytes):
325                  return self._cached
326              raise RuntimeError("Cannot access content after getting text or json")
327  
328          self._cached = b"".join(self.iter_content(chunk_size=32))
329          return self._cached
330  
331      @property
332      def text(self):
333          """The HTTP content, encoded into a string according to the HTTP
334          header encoding"""
335          if self._cached is not None:
336              if isinstance(self._cached, str):
337                  return self._cached
338              raise RuntimeError("Cannot access text after getting content or json")
339          self._cached = str(self.content, self.encoding)
340          return self._cached
341  
342      def json(self):
343          """The HTTP content, parsed into a json dictionary"""
344          # pylint: disable=import-outside-toplevel
345          import json
346  
347          # The cached JSON will be a list or dictionary.
348          if self._cached:
349              if isinstance(self._cached, (list, dict)):
350                  return self._cached
351              raise RuntimeError("Cannot access json after getting text or content")
352          if not self._raw:
353              self._raw = _RawResponse(self)
354  
355          try:
356              obj = json.load(self._raw)
357          except OSError:
358              # <5.3.1 doesn't piecemeal load json from any object with readinto so load the whole
359              # string.
360              obj = json.loads(self._raw.read())
361          if not self._cached:
362              self._cached = obj
363          self.close()
364          return obj
365  
366      def iter_content(self, chunk_size=1, decode_unicode=False):
367          """An iterator that will stream data by only reading 'chunk_size'
368          bytes and yielding them, when we can't buffer the whole datastream"""
369          if decode_unicode:
370              raise NotImplementedError("Unicode not supported")
371  
372          b = bytearray(chunk_size)
373          while True:
374              size = self._readinto(b)
375              if size == 0:
376                  break
377              if size < chunk_size:
378                  chunk = bytes(memoryview(b)[:size])
379              else:
380                  chunk = bytes(b)
381              yield chunk
382          self.close()
383  
384  
385  class Session:
386      """HTTP session that shares sockets and ssl context."""
387  
388      def __init__(self, socket_pool, ssl_context=None):
389          self._socket_pool = socket_pool
390          self._ssl_context = ssl_context
391          # Hang onto open sockets so that we can reuse them.
392          self._open_sockets = {}
393          self._socket_free = {}
394          self._last_response = None
395  
396      def _free_socket(self, socket):
397          if socket not in self._open_sockets.values():
398              raise RuntimeError("Socket not from session")
399          self._socket_free[socket] = True
400  
401      def _close_socket(self, sock):
402          sock.close()
403          del self._socket_free[sock]
404          key = None
405          for k in self._open_sockets:
406              if self._open_sockets[k] == sock:
407                  key = k
408                  break
409          if key:
410              del self._open_sockets[key]
411  
412      def _free_sockets(self):
413          free_sockets = []
414          for sock in self._socket_free:
415              if self._socket_free[sock]:
416                  free_sockets.append(sock)
417          for sock in free_sockets:
418              self._close_socket(sock)
419  
420      def _get_socket(self, host, port, proto, *, timeout=1):
421          key = (host, port, proto)
422          if key in self._open_sockets:
423              sock = self._open_sockets[key]
424              if self._socket_free[sock]:
425                  self._socket_free[sock] = False
426                  return sock
427          if proto == "https:" and not self._ssl_context:
428              raise RuntimeError(
429                  "ssl_context must be set before using adafruit_requests for https"
430              )
431          addr_info = self._socket_pool.getaddrinfo(
432              host, port, 0, self._socket_pool.SOCK_STREAM
433          )[0]
434          retry_count = 0
435          sock = None
436          while retry_count < 5 and sock is None:
437              if retry_count > 0:
438                  if any(self._socket_free.items()):
439                      self._free_sockets()
440                  else:
441                      raise RuntimeError("Sending request failed")
442              retry_count += 1
443  
444              try:
445                  sock = self._socket_pool.socket(
446                      addr_info[0], addr_info[1], addr_info[2]
447                  )
448              except OSError:
449                  continue
450  
451              connect_host = addr_info[-1][0]
452              if proto == "https:":
453                  sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
454                  connect_host = host
455              sock.settimeout(timeout)  # socket read timeout
456  
457              try:
458                  sock.connect((connect_host, port))
459              except MemoryError:
460                  sock.close()
461                  sock = None
462              except OSError:
463                  sock.close()
464                  sock = None
465  
466          if sock is None:
467              raise RuntimeError("Repeated socket failures")
468  
469          self._open_sockets[key] = sock
470          self._socket_free[sock] = False
471          return sock
472  
473      @staticmethod
474      def _send(socket, data):
475          total_sent = 0
476          while total_sent < len(data):
477              # ESP32SPI sockets raise a RuntimeError when unable to send.
478              try:
479                  sent = socket.send(data[total_sent:])
480              except RuntimeError:
481                  sent = 0
482              if sent is None:
483                  sent = len(data)
484              if sent == 0:
485                  raise _SendFailed()
486              total_sent += sent
487  
488      def _send_request(self, socket, host, method, path, headers, data, json):
489          # pylint: disable=too-many-arguments
490          self._send(socket, bytes(method, "utf-8"))
491          self._send(socket, b" /")
492          self._send(socket, bytes(path, "utf-8"))
493          self._send(socket, b" HTTP/1.1\r\n")
494          if "Host" not in headers:
495              self._send(socket, b"Host: ")
496              self._send(socket, bytes(host, "utf-8"))
497              self._send(socket, b"\r\n")
498          if "User-Agent" not in headers:
499              self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n")
500          # Iterate over keys to avoid tuple alloc
501          for k in headers:
502              self._send(socket, k.encode())
503              self._send(socket, b": ")
504              self._send(socket, headers[k].encode())
505              self._send(socket, b"\r\n")
506          if json is not None:
507              assert data is None
508              # pylint: disable=import-outside-toplevel
509              try:
510                  import json as json_module
511              except ImportError:
512                  import ujson as json_module
513              data = json_module.dumps(json)
514              self._send(socket, b"Content-Type: application/json\r\n")
515          if data:
516              if isinstance(data, dict):
517                  self._send(
518                      socket, b"Content-Type: application/x-www-form-urlencoded\r\n"
519                  )
520                  _post_data = ""
521                  for k in data:
522                      _post_data = "{}&{}={}".format(_post_data, k, data[k])
523                  data = _post_data[1:]
524              self._send(socket, b"Content-Length: %d\r\n" % len(data))
525          self._send(socket, b"\r\n")
526          if data:
527              if isinstance(data, bytearray):
528                  self._send(socket, bytes(data))
529              else:
530                  self._send(socket, bytes(data, "utf-8"))
531  
532      # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
533      def request(
534          self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
535      ):
536          """Perform an HTTP request to the given url which we will parse to determine
537          whether to use SSL ('https://') or not. We can also send some provided 'data'
538          or a json dictionary which we will stringify. 'headers' is optional HTTP headers
539          sent along. 'stream' will determine if we buffer everything, or whether to only
540          read only when requested
541          """
542          if not headers:
543              headers = {}
544  
545          try:
546              proto, dummy, host, path = url.split("/", 3)
547              # replace spaces in path
548              path = path.replace(" ", "%20")
549          except ValueError:
550              proto, dummy, host = url.split("/", 2)
551              path = ""
552          if proto == "http:":
553              port = 80
554          elif proto == "https:":
555              port = 443
556          else:
557              raise ValueError("Unsupported protocol: " + proto)
558  
559          if ":" in host:
560              host, port = host.split(":", 1)
561              port = int(port)
562  
563          if self._last_response:
564              self._last_response.close()
565              self._last_response = None
566  
567          # We may fail to send the request if the socket we got is closed already. So, try a second
568          # time in that case.
569          retry_count = 0
570          while retry_count < 2:
571              retry_count += 1
572              socket = self._get_socket(host, port, proto, timeout=timeout)
573              try:
574                  self._send_request(socket, host, method, path, headers, data, json)
575                  break
576              except _SendFailed:
577                  self._close_socket(socket)
578                  if retry_count > 1:
579                      raise
580  
581          resp = Response(socket, self)  # our response
582          if "location" in resp.headers and 300 <= resp.status_code <= 399:
583              raise NotImplementedError("Redirects not yet supported")
584  
585          self._last_response = resp
586          return resp
587  
588      def head(self, url, **kw):
589          """Send HTTP HEAD request"""
590          return self.request("HEAD", url, **kw)
591  
592      def get(self, url, **kw):
593          """Send HTTP GET request"""
594          return self.request("GET", url, **kw)
595  
596      def post(self, url, **kw):
597          """Send HTTP POST request"""
598          return self.request("POST", url, **kw)
599  
600      def put(self, url, **kw):
601          """Send HTTP PUT request"""
602          return self.request("PUT", url, **kw)
603  
604      def patch(self, url, **kw):
605          """Send HTTP PATCH request"""
606          return self.request("PATCH", url, **kw)
607  
608      def delete(self, url, **kw):
609          """Send HTTP DELETE request"""
610          return self.request("DELETE", url, **kw)
611  
612  
613  # Backwards compatible API:
614  
615  _default_session = None  # pylint: disable=invalid-name
616  
617  
618  class _FakeSSLSocket:
619      def __init__(self, socket, tls_mode):
620          self._socket = socket
621          self._mode = tls_mode
622          self.settimeout = socket.settimeout
623          self.send = socket.send
624          self.recv = socket.recv
625          self.close = socket.close
626  
627      def connect(self, address):
628          """connect wrapper to add non-standard mode parameter"""
629          try:
630              return self._socket.connect(address, self._mode)
631          except RuntimeError as error:
632              raise OSError(errno.ENOMEM) from error
633  
634  
635  class _FakeSSLContext:
636      def __init__(self, iface):
637          self._iface = iface
638  
639      def wrap_socket(self, socket, server_hostname=None):
640          """Return the same socket"""
641          # pylint: disable=unused-argument
642          return _FakeSSLSocket(socket, self._iface.TLS_MODE)
643  
644  
645  def set_socket(sock, iface=None):
646      """Legacy API for setting the socket and network interface. Use a `Session` instead."""
647      global _default_session  # pylint: disable=global-statement,invalid-name
648      _default_session = Session(sock, _FakeSSLContext(iface))
649      if iface:
650          sock.set_interface(iface)
651  
652  
653  def request(method, url, data=None, json=None, headers=None, stream=False, timeout=1):
654      """Send HTTP request"""
655      # pylint: disable=too-many-arguments
656      _default_session.request(
657          method,
658          url,
659          data=data,
660          json=json,
661          headers=headers,
662          stream=stream,
663          timeout=timeout,
664      )
665  
666  
667  def head(url, **kw):
668      """Send HTTP HEAD request"""
669      return _default_session.request("HEAD", url, **kw)
670  
671  
672  def get(url, **kw):
673      """Send HTTP GET request"""
674      return _default_session.request("GET", url, **kw)
675  
676  
677  def post(url, **kw):
678      """Send HTTP POST request"""
679      return _default_session.request("POST", url, **kw)
680  
681  
682  def put(url, **kw):
683      """Send HTTP PUT request"""
684      return _default_session.request("PUT", url, **kw)
685  
686  
687  def patch(url, **kw):
688      """Send HTTP PATCH request"""
689      return _default_session.request("PATCH", url, **kw)
690  
691  
692  def delete(url, **kw):
693      """Send HTTP DELETE request"""
694      return _default_session.request("DELETE", url, **kw)