/ tests / Framework.py
Framework.py
  1  ############################ Copyrights and license ############################
  2  #                                                                              #
  3  # Copyright 2012 Vincent Jacques <vincent@vincent-jacques.net>                 #
  4  # Copyright 2012 Zearin <zearin@gonk.net>                                      #
  5  # Copyright 2013 AKFish <akfish@gmail.com>                                     #
  6  # Copyright 2013 Vincent Jacques <vincent@vincent-jacques.net>                 #
  7  # Copyright 2014 Vincent Jacques <vincent@vincent-jacques.net>                 #
  8  # Copyright 2015 Uriel Corfa <uriel@corfa.fr>                                  #
  9  # Copyright 2016 Peter Buckley <dx-pbuckley@users.noreply.github.com>          #
 10  # Copyright 2017 Chris McBride <thehighlander@users.noreply.github.com>        #
 11  # Copyright 2017 Hugo <hugovk@users.noreply.github.com>                        #
 12  # Copyright 2017 Simon <spam@esemi.ru>                                         #
 13  # Copyright 2018 Jacopo Notarstefano <jacopo.notarstefano@gmail.com>           #
 14  # Copyright 2018 Laurent Mazuel <lmazuel@microsoft.com>                        #
 15  # Copyright 2018 Mike Miller <github@mikeage.net>                              #
 16  # Copyright 2018 Wan Liuyang <tsfdye@gmail.com>                                #
 17  # Copyright 2018 sfdye <tsfdye@gmail.com>                                      #
 18  #                                                                              #
 19  # This file is part of PyGithub.                                               #
 20  # http://pygithub.readthedocs.io/                                              #
 21  #                                                                              #
 22  # PyGithub is free software: you can redistribute it and/or modify it under    #
 23  # the terms of the GNU Lesser General Public License as published by the Free  #
 24  # Software Foundation, either version 3 of the License, or (at your option)    #
 25  # any later version.                                                           #
 26  #                                                                              #
 27  # PyGithub is distributed in the hope that it will be useful, but WITHOUT ANY  #
 28  # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS    #
 29  # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more #
 30  # details.                                                                     #
 31  #                                                                              #
 32  # You should have received a copy of the GNU Lesser General Public License     #
 33  # along with PyGithub. If not, see <http://www.gnu.org/licenses/>.             #
 34  #                                                                              #
 35  ################################################################################
 36  
 37  import contextlib
 38  import io
 39  import json
 40  import os
 41  import traceback
 42  import unittest
 43  import warnings
 44  from typing import Optional
 45  
 46  import httpretty  # type: ignore
 47  from requests.structures import CaseInsensitiveDict
 48  from urllib3.util import Url  # type: ignore
 49  
 50  import github
 51  from github import Consts
 52  
 53  APP_PRIVATE_KEY = """
 54  -----BEGIN RSA PRIVATE KEY-----
 55  MIICXAIBAAKBgQC+5ePolLv6VcWLp2f17g6r6vHl+eoLuodOOfUl8JK+MVmvXbPa
 56  xDy0SS0pQhwTOMtB0VdSt++elklDCadeokhEoGDQp411o+kiOhzLxfakp/kewf4U
 57  HJnu4M/A2nHmxXVe2lzYnZvZHX5BM4SJo5PGdr0Ue2JtSXoAtYr6qE9maQIDAQAB
 58  AoGAFhOJ7sy8jG+837Clcihso+8QuHLVYTPaD+7d7dxLbBlS8NfaQ9Nr3cGUqm/N
 59  xV9NCjiGa7d/y4w/vrPwGh6UUsA+CvndwDgBd0S3WgIdWvAvHM8wKgNh/GBLLzhT
 60  Bg9BouRUzcT1MjAnkGkWqqCAgN7WrCSUMLt57TNleNWfX90CQQDjvVKTT3pOiavD
 61  3YcLxwkyeGd0VMvKiS4nV0XXJ97cGXs2GpOGXldstDTnF5AnB6PbukdFLHpsx4sW
 62  Hft3LRWnAkEA1pY15ke08wX6DZVXy7zuQ2izTrWSGySn7B41pn55dlKpttjHeutA
 63  3BEQKTFvMhBCphr8qST7Wf1SR9FgO0tFbwJAEhHji2yy96hUyKW7IWQZhrem/cP8
 64  p4Va9CQolnnDZRNgg1p4eiDiLu3dhLiJ547joXuWTBbLX/Y1Qvv+B+a74QJBAMCW
 65  O3WbMZlS6eK6//rIa4ZwN00SxDg8I8FUM45jwBsjgVGrKQz2ilV3sutlhIiH82kk
 66  m1Iq8LMJGYl/LkDJA10CQBV1C+Xu3ukknr7C4A/4lDCa6Xb27cr1HanY7i89A+Ab
 67  eatdM6f/XVqWp8uPT9RggUV9TjppJobYGT2WrWJMkYw=
 68  -----END RSA PRIVATE KEY-----
 69  """
 70  
 71  
 72  def readLine(file_):
 73      line = file_.readline()
 74      if isinstance(line, bytes):
 75          line = line.decode("utf-8")
 76      return line.strip()
 77  
 78  
 79  class FakeHttpResponse:
 80      def __init__(self, status, headers, output):
 81          self.status = status
 82          self.__headers = headers
 83          self.__output = output
 84  
 85      def getheaders(self):
 86          return self.__headers
 87  
 88      def read(self):
 89          return self.__output
 90  
 91  
 92  def fixAuthorizationHeader(headers):
 93      if "Authorization" in headers:
 94          if headers["Authorization"].endswith("ZmFrZV9sb2dpbjpmYWtlX3Bhc3N3b3Jk"):
 95              # This special case is here to test the real Authorization header
 96              # sent by PyGithub. It would have avoided issue https://github.com/jacquev6/PyGithub/issues/153
 97              # because we would have seen that Python 3 was not generating the same
 98              # header as Python 2
 99              pass
100          elif headers["Authorization"].startswith("token "):
101              headers["Authorization"] = "token private_token_removed"
102          elif headers["Authorization"].startswith("Basic "):
103              headers["Authorization"] = "Basic login_and_password_removed"
104          elif headers["Authorization"].startswith("Bearer "):
105              headers["Authorization"] = "Bearer jwt_removed"
106  
107  
108  class RecordingConnection:
109      def __init__(self, file, protocol, host, port, *args, **kwds):
110          # write operations make the assumption that the file is not in binary mode
111          assert isinstance(file, io.TextIOBase)
112          self.__file = file
113          self.__protocol = protocol
114          self.__host = host
115          self.__port = port
116          self.__cnx = self._realConnection(host, port, *args, **kwds)
117  
118      def request(self, verb, url, input, headers):
119          self.__cnx.request(verb, url, input, headers)
120          # fixAuthorizationHeader changes the parameter directly to remove Authorization token.
121          # however, this is the real dictionary that *will be sent* by "requests",
122          # since we are writing here *before* doing the actual request.
123          # So we must avoid changing the real "headers" or this create this:
124          # https://github.com/PyGithub/PyGithub/pull/664#issuecomment-389964369
125          # https://github.com/PyGithub/PyGithub/issues/822
126          # Since it's dict[str, str], a simple copy is enough.
127          anonymous_headers = headers.copy()
128          fixAuthorizationHeader(anonymous_headers)
129          self.__writeLine(self.__protocol)
130          self.__writeLine(verb)
131          self.__writeLine(self.__host)
132          self.__writeLine(self.__port)
133          self.__writeLine(url)
134          self.__writeLine(anonymous_headers)
135          self.__writeLine(str(input).replace("\n", "").replace("\r", ""))
136  
137      def getresponse(self):
138          res = self.__cnx.getresponse()
139  
140          status = res.status
141          headers = res.getheaders()
142          output = res.read()
143  
144          self.__writeLine(status)
145          self.__writeLine(list(headers))
146          self.__writeLine(output)
147  
148          return FakeHttpResponse(status, headers, output)
149  
150      def close(self):
151          self.__writeLine("")
152          return self.__cnx.close()
153  
154      def __writeLine(self, line):
155          self.__file.write(str(line) + "\n")
156  
157  
158  class RecordingHttpConnection(RecordingConnection):
159      _realConnection = github.Requester.HTTPRequestsConnectionClass
160  
161      def __init__(self, file, *args, **kwds):
162          super().__init__(file, "http", *args, **kwds)
163  
164  
165  class RecordingHttpsConnection(RecordingConnection):
166      _realConnection = github.Requester.HTTPSRequestsConnectionClass
167  
168      def __init__(self, file, *args, **kwds):
169          super().__init__(file, "https", *args, **kwds)
170  
171  
172  class ReplayingConnection:
173      def __init__(self, file, protocol, host, port, *args, **kwds):
174          self.__file = file
175          self.__protocol = protocol
176          self.__host = host
177          self.__port = port
178          self.response_headers = CaseInsensitiveDict()
179  
180          self.__cnx = self._realConnection(host, port, *args, **kwds)
181  
182      def request(self, verb, url, input, headers):
183          full_url = Url(scheme=self.__protocol, host=self.__host, port=self.__port, path=url)
184  
185          httpretty.register_uri(verb, full_url.url, body=self.__request_callback)
186  
187          self.__cnx.request(verb, url, input, headers)
188  
189      def __readNextRequest(self, verb, url, input, headers):
190          fixAuthorizationHeader(headers)
191          assert self.__protocol == readLine(self.__file)
192          assert verb == readLine(self.__file)
193          assert self.__host == readLine(self.__file)
194          assert str(self.__port) == readLine(self.__file)
195          assert self.__splitUrl(url) == self.__splitUrl(readLine(self.__file))
196          assert headers == eval(readLine(self.__file))
197          expectedInput = readLine(self.__file)
198          if isinstance(input, str):
199              trInput = input.replace("\n", "").replace("\r", "")
200              if input.startswith("{"):
201                  assert expectedInput.startswith("{"), expectedInput
202                  assert json.loads(trInput) == json.loads(expectedInput)
203              else:
204                  assert trInput == expectedInput
205          else:
206              # for non-string input (e.g. upload asset), let it pass.
207              pass
208  
209      def __splitUrl(self, url):
210          splitedUrl = url.split("?")
211          if len(splitedUrl) == 1:
212              return splitedUrl
213          assert len(splitedUrl) == 2
214          base, qs = splitedUrl
215          return (base, sorted(qs.split("&")))
216  
217      def __request_callback(self, request, uri, response_headers):
218          self.__readNextRequest(self.__cnx.verb, self.__cnx.url, self.__cnx.input, self.__cnx.headers)
219  
220          status = int(readLine(self.__file))
221          self.response_headers = CaseInsensitiveDict(eval(readLine(self.__file)))
222          output = bytearray(readLine(self.__file), "utf-8")
223          readLine(self.__file)
224  
225          # make a copy of the headers and remove the ones that interfere with the response handling
226          adding_headers = CaseInsensitiveDict(self.response_headers)
227          adding_headers.pop("content-length", None)
228          adding_headers.pop("transfer-encoding", None)
229          adding_headers.pop("content-encoding", None)
230  
231          response_headers.update(adding_headers)
232          return [status, response_headers, output]
233  
234      def getresponse(self):
235          # call original connection, this will go all the way down to the python socket and will be intercepted by httpretty
236          response = self.__cnx.getresponse()
237  
238          # restore original headers to the response
239          response.headers = self.response_headers
240  
241          return response
242  
243      def close(self):
244          self.__cnx.close()
245  
246  
247  class ReplayingHttpConnection(ReplayingConnection):
248      _realConnection = github.Requester.HTTPRequestsConnectionClass
249  
250      def __init__(self, file, *args, **kwds):
251          super().__init__(file, "http", *args, **kwds)
252  
253  
254  class ReplayingHttpsConnection(ReplayingConnection):
255      _realConnection = github.Requester.HTTPSRequestsConnectionClass
256  
257      def __init__(self, file, *args, **kwds):
258          super().__init__(file, "https", *args, **kwds)
259  
260  
261  class BasicTestCase(unittest.TestCase):
262      recordMode = False
263      tokenAuthMode = False
264      jwtAuthMode = False
265      per_page = Consts.DEFAULT_PER_PAGE
266      retry = None
267      pool_size = None
268      seconds_between_requests: Optional[float] = None
269      seconds_between_writes: Optional[float] = None
270      replayDataFolder = os.path.join(os.path.dirname(__file__), "ReplayData")
271  
272      def setUp(self):
273          super().setUp()
274          self.__fileName = ""
275          self.__file = None
276          if (
277              self.recordMode
278          ):  # pragma no cover (Branch useful only when recording new tests, not used during automated tests)
279              github.Requester.Requester.injectConnectionClasses(
280                  lambda ignored, *args, **kwds: RecordingHttpConnection(self.__openFile("w"), *args, **kwds),
281                  lambda ignored, *args, **kwds: RecordingHttpsConnection(self.__openFile("w"), *args, **kwds),
282              )
283              import GithubCredentials  # type: ignore
284  
285              self.login = (
286                  github.Auth.Login(GithubCredentials.login, GithubCredentials.password)
287                  if GithubCredentials.login and GithubCredentials.password
288                  else None
289              )
290              self.oauth_token = (
291                  github.Auth.Token(GithubCredentials.oauth_token) if GithubCredentials.oauth_token else None
292              )
293              self.jwt = github.Auth.AppAuthToken(GithubCredentials.jwt) if GithubCredentials.jwt else None
294              self.app_auth = (
295                  github.Auth.AppAuth(GithubCredentials.app_id, GithubCredentials.app_private_key)
296                  if GithubCredentials.app_id and GithubCredentials.app_private_key
297                  else None
298              )
299          else:
300              github.Requester.Requester.injectConnectionClasses(
301                  lambda ignored, *args, **kwds: ReplayingHttpConnection(self.__openFile("r"), *args, **kwds),
302                  lambda ignored, *args, **kwds: ReplayingHttpsConnection(self.__openFile("r"), *args, **kwds),
303              )
304              self.login = github.Auth.Login("login", "password")
305              self.oauth_token = github.Auth.Token("oauth_token")
306              self.jwt = github.Auth.AppAuthToken("jwt")
307              self.app_auth = github.Auth.AppAuth(123456, APP_PRIVATE_KEY)
308  
309              httpretty.enable(allow_net_connect=False)
310  
311      @property
312      def thisTestFailed(self) -> bool:
313          if hasattr(self._outcome, "errors"):  # type: ignore
314              # Python 3.4 - 3.10
315              result = self.defaultTestResult()
316              self._feedErrorsToResult(result, self._outcome.errors)  # type: ignore
317              ok = all(test != self for test, text in result.errors + result.failures)
318              return not ok
319          else:
320              # Python 3.11+
321              return self._outcome.result._excinfo is not None and self._outcome.result._excinfo  # type: ignore
322  
323      def tearDown(self):
324          super().tearDown()
325          httpretty.disable()
326          httpretty.reset()
327  
328          self.__closeReplayFileIfNeeded(silent=self.thisTestFailed)
329          github.Requester.Requester.resetConnectionClasses()
330  
331      def assertWarning(self, warning, expected):
332          self.assertWarnings(warning, expected)
333  
334      def assertWarnings(self, warning, *expecteds):
335          actual = [(type(message), type(message.message), message.message.args) for message in warning.warnings]
336          expected = [(warnings.WarningMessage, DeprecationWarning, (expected,)) for expected in expecteds]
337          self.assertSequenceEqual(actual, expected)
338  
339      @contextlib.contextmanager
340      def ignoreWarning(self, category=Warning, module=""):
341          with warnings.catch_warnings():
342              warnings.filterwarnings("ignore", category=category, module=module)
343              yield
344  
345      def __openFile(self, mode):
346          for _, _, functionName, _ in traceback.extract_stack():
347              if functionName.startswith("test") or functionName == "setUp" or functionName == "tearDown":
348                  if functionName != "test":  # because in class Hook(Framework.TestCase), method testTest calls Hook.test
349                      fileName = os.path.join(
350                          self.replayDataFolder,
351                          f"{self.__class__.__name__}.{functionName}.txt",
352                      )
353          if fileName != self.__fileName:
354              self.__closeReplayFileIfNeeded()
355              self.__fileName = fileName
356              self.__file = open(self.__fileName, mode, encoding="utf-8")
357          return self.__file
358  
359      def __closeReplayFileIfNeeded(self, silent=False):
360          if self.__file is not None:
361              if (
362                  not self.recordMode and not silent
363              ):  # pragma no branch (Branch useful only when recording new tests, not used during automated tests)
364                  self.assertEqual(readLine(self.__file), "", self.__fileName)
365              self.__file.close()
366  
367      def assertListKeyEqual(self, elements, key, expectedKeys):
368          realKeys = [key(element) for element in elements]
369          self.assertEqual(realKeys, expectedKeys)
370  
371      def assertListKeyBegin(self, elements, key, expectedKeys):
372          realKeys = [key(element) for element in elements[: len(expectedKeys)]]
373          self.assertEqual(realKeys, expectedKeys)
374  
375  
376  class TestCase(BasicTestCase):
377      def doCheckFrame(self, obj, frame):
378          if obj._headers == {} and frame is None:
379              return
380          if obj._headers is None and frame == {}:
381              return
382          self.assertEqual(obj._headers, frame[2])
383  
384      def getFrameChecker(self):
385          return lambda requester, obj, frame: self.doCheckFrame(obj, frame)
386  
387      def setUp(self):
388          super().setUp()
389  
390          # Set up frame debugging
391          github.GithubObject.GithubObject.setCheckAfterInitFlag(True)
392          github.Requester.Requester.setDebugFlag(True)
393          github.Requester.Requester.setOnCheckMe(self.getFrameChecker())
394  
395          self.g = self.get_github(self.retry, self.pool_size)
396  
397      def get_github(self, retry, pool_size):
398          if self.tokenAuthMode:
399              return github.Github(
400                  auth=self.oauth_token,
401                  per_page=self.per_page,
402                  retry=retry,
403                  pool_size=pool_size,
404                  seconds_between_requests=self.seconds_between_requests,
405                  seconds_between_writes=self.seconds_between_writes,
406              )
407          elif self.jwtAuthMode:
408              return github.Github(
409                  auth=self.jwt,
410                  per_page=self.per_page,
411                  retry=retry,
412                  pool_size=pool_size,
413                  seconds_between_requests=self.seconds_between_requests,
414                  seconds_between_writes=self.seconds_between_writes,
415              )
416          else:
417              return github.Github(
418                  auth=self.login,
419                  per_page=self.per_page,
420                  retry=retry,
421                  pool_size=pool_size,
422                  seconds_between_requests=self.seconds_between_requests,
423                  seconds_between_writes=self.seconds_between_writes,
424              )
425  
426  
427  def activateRecordMode():  # pragma no cover (Function useful only when recording new tests, not used during automated tests)
428      BasicTestCase.recordMode = True
429  
430  
431  def activateTokenAuthMode():  # pragma no cover (Function useful only when recording new tests, not used during automated tests)
432      BasicTestCase.tokenAuthMode = True
433  
434  
435  def activateJWTAuthMode():  # pragma no cover (Function useful only when recording new tests, not used during automated tests)
436      BasicTestCase.jwtAuthMode = True
437  
438  
439  def enableRetry(retry):
440      BasicTestCase.retry = retry
441  
442  
443  def setPoolSize(pool_size):
444      BasicTestCase.pool_size = pool_size