dvsni_test.py
  1  """Test for letsencrypt_apache.dvsni."""
  2  import unittest
  3  import shutil
  4  
  5  import mock
  6  
  7  from letsencrypt.plugins import common_test
  8  
  9  from letsencrypt_apache import obj
 10  from letsencrypt_apache.tests import util
 11  
 12  
 13  class DvsniPerformTest(util.ApacheTest):
 14      """Test the ApacheDVSNI challenge."""
 15  
 16      auth_key = common_test.DvsniTest.auth_key
 17      achalls = common_test.DvsniTest.achalls
 18  
 19      def setUp(self):  # pylint: disable=arguments-differ
 20          super(DvsniPerformTest, self).setUp()
 21  
 22          config = util.get_apache_configurator(
 23              self.config_path, self.config_dir, self.work_dir)
 24          config.config.dvsni_port = 443
 25  
 26          from letsencrypt_apache import dvsni
 27          self.sni = dvsni.ApacheDvsni(config)
 28  
 29      def tearDown(self):
 30          shutil.rmtree(self.temp_dir)
 31          shutil.rmtree(self.config_dir)
 32          shutil.rmtree(self.work_dir)
 33  
 34      def test_perform0(self):
 35          resp = self.sni.perform()
 36          self.assertEqual(len(resp), 0)
 37  
 38      @mock.patch("letsencrypt.le_util.exe_exists")
 39      @mock.patch("letsencrypt.le_util.run_script")
 40      def test_perform1(self, _, mock_exists):
 41          mock_register = mock.Mock()
 42          self.sni.configurator.reverter.register_undo_command = mock_register
 43  
 44          mock_exists.return_value = True
 45          self.sni.configurator.parser.update_runtime_variables = mock.Mock()
 46  
 47          achall = self.achalls[0]
 48          self.sni.add_chall(achall)
 49          response = self.achalls[0].gen_response(self.auth_key)
 50          mock_setup_cert = mock.MagicMock(return_value=response)
 51          # pylint: disable=protected-access
 52          self.sni._setup_challenge_cert = mock_setup_cert
 53  
 54          responses = self.sni.perform()
 55  
 56          # Make sure that register_undo_command was called into temp directory.
 57          self.assertEqual(True, mock_register.call_args[0][0])
 58  
 59          mock_setup_cert.assert_called_once_with(achall)
 60  
 61          # Check to make sure challenge config path is included in apache config.
 62          self.assertEqual(
 63              len(self.sni.configurator.parser.find_dir(
 64                  "Include", self.sni.challenge_conf)), 1)
 65          self.assertEqual(len(responses), 1)
 66          self.assertEqual(responses[0], response)
 67  
 68      def test_perform2(self):
 69          # Avoid load module
 70          self.sni.configurator.parser.modules.add("ssl_module")
 71  
 72          acme_responses = []
 73          for achall in self.achalls:
 74              self.sni.add_chall(achall)
 75              acme_responses.append(achall.gen_response(self.auth_key))
 76  
 77          mock_setup_cert = mock.MagicMock(side_effect=acme_responses)
 78          # pylint: disable=protected-access
 79          self.sni._setup_challenge_cert = mock_setup_cert
 80  
 81          sni_responses = self.sni.perform()
 82  
 83          self.assertEqual(mock_setup_cert.call_count, 2)
 84  
 85          # Make sure calls made to mocked function were correct
 86          self.assertEqual(
 87              mock_setup_cert.call_args_list[0], mock.call(self.achalls[0]))
 88          self.assertEqual(
 89              mock_setup_cert.call_args_list[1], mock.call(self.achalls[1]))
 90  
 91          self.assertEqual(
 92              len(self.sni.configurator.parser.find_dir(
 93                  "Include", self.sni.challenge_conf)),
 94              1)
 95          self.assertEqual(len(sni_responses), 2)
 96          for i in xrange(2):
 97              self.assertEqual(sni_responses[i], acme_responses[i])
 98  
 99      def test_mod_config(self):
100          z_domains = []
101          for achall in self.achalls:
102              self.sni.add_chall(achall)
103              z_domain = achall.gen_response(self.auth_key).z_domain
104              z_domains.append(set([z_domain]))
105  
106          self.sni._mod_config()  # pylint: disable=protected-access
107          self.sni.configurator.save()
108  
109          self.sni.configurator.parser.find_dir(
110              "Include", self.sni.challenge_conf)
111          vh_match = self.sni.configurator.aug.match(
112              "/files" + self.sni.challenge_conf + "//VirtualHost")
113  
114          vhs = []
115          for match in vh_match:
116              # pylint: disable=protected-access
117              vhs.append(self.sni.configurator._create_vhost(match))
118          self.assertEqual(len(vhs), 2)
119          for vhost in vhs:
120              self.assertEqual(vhost.addrs, set([obj.Addr.fromstring("*:443")]))
121              names = vhost.get_names()
122              self.assertTrue(names in z_domains)
123  
124      def test_get_dvsni_addrs_default(self):
125          self.sni.configurator.choose_vhost = mock.Mock(
126              return_value=obj.VirtualHost(
127                  "path", "aug_path", set([obj.Addr.fromstring("_default_:443")]),
128                  False, False)
129          )
130  
131          self.assertEqual(
132              set([obj.Addr.fromstring("*:443")]),
133              self.sni.get_dvsni_addrs(self.achalls[0]))
134  
135  
136  if __name__ == "__main__":
137      unittest.main()  # pragma: no cover