/ acme / acme / standalone_test.py
standalone_test.py
  1  """Tests for acme.standalone."""
  2  import os
  3  import shutil
  4  import socket
  5  import threading
  6  import tempfile
  7  import time
  8  import unittest
  9  
 10  from six.moves import http_client  # pylint: disable=import-error
 11  from six.moves import socketserver  # pylint: disable=import-error
 12  
 13  import requests
 14  
 15  from acme import challenges
 16  from acme import crypto_util
 17  from acme import errors
 18  from acme import jose
 19  from acme import test_util
 20  
 21  
 22  class TLSServerTest(unittest.TestCase):
 23      """Tests for acme.standalone.TLSServer."""
 24  
 25      def test_bind(self):  # pylint: disable=no-self-use
 26          from acme.standalone import TLSServer
 27          server = TLSServer(
 28              ('', 0), socketserver.BaseRequestHandler, bind_and_activate=True)
 29          server.server_close()  # pylint: disable=no-member
 30  
 31  
 32  class ACMEServerMixinTest(unittest.TestCase):
 33      """Tests for acme.standalone.ACMEServerMixin."""
 34  
 35      def setUp(self):
 36          from acme.standalone import ACMEServerMixin
 37  
 38          class _MockHandler(socketserver.BaseRequestHandler):
 39              # pylint: disable=missing-docstring,no-member,no-init
 40  
 41              def handle(self):
 42                  self.request.sendall(b"DONE")
 43  
 44          class _MockServer(socketserver.TCPServer, ACMEServerMixin):
 45              def __init__(self, *args, **kwargs):
 46                  socketserver.TCPServer.__init__(self, *args, **kwargs)
 47                  ACMEServerMixin.__init__(self)
 48  
 49          self.server = _MockServer(("", 0), _MockHandler)
 50  
 51      def _busy_wait(self):  # pragma: no cover
 52          # This function is used to avoid race conditions in tests, but
 53          # not all of the functionality is always used, hence "no
 54          # cover"
 55          while True:
 56              sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 57              try:
 58                  # pylint: disable=no-member
 59                  sock.connect(self.server.socket.getsockname())
 60              except socket.error:
 61                  pass
 62              else:
 63                  sock.recv(4)  # wait until handle_request is actually called
 64                  break
 65              finally:
 66                  sock.close()
 67              time.sleep(1)
 68  
 69      def test_serve_shutdown(self):
 70          thread = threading.Thread(target=self.server.serve_forever2)
 71          thread.start()
 72          self._busy_wait()
 73          self.server.shutdown2()
 74  
 75      def test_shutdown2_not_running(self):
 76          self.server.shutdown2()
 77          self.server.shutdown2()
 78  
 79  
 80  class DVSNIServerTest(unittest.TestCase):
 81      """Test for acme.standalone.DVSNIServer."""
 82  
 83      def setUp(self):
 84          self.certs = {
 85              b'localhost': (test_util.load_pyopenssl_private_key('rsa512_key.pem'),
 86                             # pylint: disable=protected-access
 87                             test_util.load_cert('cert.pem')._wrapped),
 88          }
 89          from acme.standalone import DVSNIServer
 90          self.server = DVSNIServer(("", 0), certs=self.certs)
 91          # pylint: disable=no-member
 92          self.thread = threading.Thread(target=self.server.handle_request)
 93          self.thread.start()
 94  
 95      def tearDown(self):
 96          self.server.shutdown2()
 97          self.thread.join()
 98  
 99      def test_init(self):
100          # pylint: disable=protected-access
101          self.assertFalse(self.server._stopped)
102  
103      def test_dvsni(self):
104          host, port = self.server.socket.getsockname()[:2]
105          cert = crypto_util.probe_sni(b'localhost', host=host, port=port)
106          self.assertEqual(jose.ComparableX509(cert),
107                           jose.ComparableX509(self.certs[b'localhost'][1]))
108  
109  
110  class SimpleHTTPServerTest(unittest.TestCase):
111      """Tests for acme.standalone.SimpleHTTPServer."""
112  
113      def setUp(self):
114          self.account_key = jose.JWK.load(
115              test_util.load_vector('rsa1024_key.pem'))
116          self.resources = set()
117  
118          from acme.standalone import SimpleHTTPServer
119          self.server = SimpleHTTPServer(('', 0), resources=self.resources)
120  
121          # pylint: disable=no-member
122          self.port = self.server.socket.getsockname()[1]
123          self.thread = threading.Thread(target=self.server.handle_request)
124          self.thread.start()
125  
126      def tearDown(self):
127          self.server.shutdown2()
128          self.thread.join()
129  
130      def test_index(self):
131          response = requests.get(
132              'http://localhost:{0}'.format(self.port), verify=False)
133          self.assertEqual(
134              response.text, 'ACME client standalone challenge solver')
135          self.assertTrue(response.ok)
136  
137      def test_404(self):
138          response = requests.get(
139              'http://localhost:{0}/foo'.format(self.port), verify=False)
140          self.assertEqual(response.status_code, http_client.NOT_FOUND)
141  
142      def _test_simple_http(self, add):
143          chall = challenges.SimpleHTTP(token=(b'x' * 16))
144          response = challenges.SimpleHTTPResponse(tls=False)
145  
146          from acme.standalone import SimpleHTTPRequestHandler
147          resource = SimpleHTTPRequestHandler.SimpleHTTPResource(
148              chall=chall, response=response, validation=response.gen_validation(
149                  chall, self.account_key))
150          if add:
151              self.resources.add(resource)
152          return resource.response.simple_verify(
153              resource.chall, 'localhost', self.account_key.public_key(),
154              port=self.port)
155  
156      def test_simple_http_found(self):
157          self.assertTrue(self._test_simple_http(add=True))
158  
159      def test_simple_http_not_found(self):
160          self.assertFalse(self._test_simple_http(add=False))
161  
162  
163  class TestSimpleDVSNIServer(unittest.TestCase):
164      """Tests for acme.standalone.simple_dvsni_server."""
165  
166      def setUp(self):
167          # mirror ../examples/standalone
168          self.test_cwd = tempfile.mkdtemp()
169          localhost_dir = os.path.join(self.test_cwd, 'localhost')
170          os.makedirs(localhost_dir)
171          shutil.copy(test_util.vector_path('cert.pem'), localhost_dir)
172          shutil.copy(test_util.vector_path('rsa512_key.pem'),
173                      os.path.join(localhost_dir, 'key.pem'))
174  
175          from acme.standalone import simple_dvsni_server
176          self.port = 1234
177          self.thread = threading.Thread(target=simple_dvsni_server, kwargs={
178              'cli_args': ('xxx', '--port', str(self.port)),
179              'forever': False,
180          })
181          self.old_cwd = os.getcwd()
182          os.chdir(self.test_cwd)
183          self.thread.start()
184  
185      def tearDown(self):
186          os.chdir(self.old_cwd)
187          self.thread.join()
188          shutil.rmtree(self.test_cwd)
189  
190      def test_it(self):
191          max_attempts = 5
192          while max_attempts:
193              max_attempts -= 1
194              try:
195                  cert = crypto_util.probe_sni(b'localhost', b'0.0.0.0', self.port)
196              except errors.Error:
197                  self.assertTrue(max_attempts > 0, "Timeout!")
198                  time.sleep(1)  # wait until thread starts
199              else:
200                  self.assertEqual(jose.ComparableX509(cert),
201                                   test_util.load_cert('cert.pem'))
202                  break
203  
204  
205  if __name__ == "__main__":
206      unittest.main()  # pragma: no cover