dvsni_test.py
1 """Test for letsencrypt_apache.dvsni.""" 2 import unittest 3 import shutil 4 5 import mock 6 7 from letsencrypt.plugins import common_test 8 9 from letsencrypt_apache import obj 10 from letsencrypt_apache.tests import util 11 12 13 class DvsniPerformTest(util.ApacheTest): 14 """Test the ApacheDVSNI challenge.""" 15 16 auth_key = common_test.DvsniTest.auth_key 17 achalls = common_test.DvsniTest.achalls 18 19 def setUp(self): # pylint: disable=arguments-differ 20 super(DvsniPerformTest, self).setUp() 21 22 config = util.get_apache_configurator( 23 self.config_path, self.config_dir, self.work_dir) 24 config.config.dvsni_port = 443 25 26 from letsencrypt_apache import dvsni 27 self.sni = dvsni.ApacheDvsni(config) 28 29 def tearDown(self): 30 shutil.rmtree(self.temp_dir) 31 shutil.rmtree(self.config_dir) 32 shutil.rmtree(self.work_dir) 33 34 def test_perform0(self): 35 resp = self.sni.perform() 36 self.assertEqual(len(resp), 0) 37 38 @mock.patch("letsencrypt.le_util.exe_exists") 39 @mock.patch("letsencrypt.le_util.run_script") 40 def test_perform1(self, _, mock_exists): 41 mock_register = mock.Mock() 42 self.sni.configurator.reverter.register_undo_command = mock_register 43 44 mock_exists.return_value = True 45 self.sni.configurator.parser.update_runtime_variables = mock.Mock() 46 47 achall = self.achalls[0] 48 self.sni.add_chall(achall) 49 response = self.achalls[0].gen_response(self.auth_key) 50 mock_setup_cert = mock.MagicMock(return_value=response) 51 # pylint: disable=protected-access 52 self.sni._setup_challenge_cert = mock_setup_cert 53 54 responses = self.sni.perform() 55 56 # Make sure that register_undo_command was called into temp directory. 57 self.assertEqual(True, mock_register.call_args[0][0]) 58 59 mock_setup_cert.assert_called_once_with(achall) 60 61 # Check to make sure challenge config path is included in apache config. 62 self.assertEqual( 63 len(self.sni.configurator.parser.find_dir( 64 "Include", self.sni.challenge_conf)), 1) 65 self.assertEqual(len(responses), 1) 66 self.assertEqual(responses[0], response) 67 68 def test_perform2(self): 69 # Avoid load module 70 self.sni.configurator.parser.modules.add("ssl_module") 71 72 acme_responses = [] 73 for achall in self.achalls: 74 self.sni.add_chall(achall) 75 acme_responses.append(achall.gen_response(self.auth_key)) 76 77 mock_setup_cert = mock.MagicMock(side_effect=acme_responses) 78 # pylint: disable=protected-access 79 self.sni._setup_challenge_cert = mock_setup_cert 80 81 sni_responses = self.sni.perform() 82 83 self.assertEqual(mock_setup_cert.call_count, 2) 84 85 # Make sure calls made to mocked function were correct 86 self.assertEqual( 87 mock_setup_cert.call_args_list[0], mock.call(self.achalls[0])) 88 self.assertEqual( 89 mock_setup_cert.call_args_list[1], mock.call(self.achalls[1])) 90 91 self.assertEqual( 92 len(self.sni.configurator.parser.find_dir( 93 "Include", self.sni.challenge_conf)), 94 1) 95 self.assertEqual(len(sni_responses), 2) 96 for i in xrange(2): 97 self.assertEqual(sni_responses[i], acme_responses[i]) 98 99 def test_mod_config(self): 100 z_domains = [] 101 for achall in self.achalls: 102 self.sni.add_chall(achall) 103 z_domain = achall.gen_response(self.auth_key).z_domain 104 z_domains.append(set([z_domain])) 105 106 self.sni._mod_config() # pylint: disable=protected-access 107 self.sni.configurator.save() 108 109 self.sni.configurator.parser.find_dir( 110 "Include", self.sni.challenge_conf) 111 vh_match = self.sni.configurator.aug.match( 112 "/files" + self.sni.challenge_conf + "//VirtualHost") 113 114 vhs = [] 115 for match in vh_match: 116 # pylint: disable=protected-access 117 vhs.append(self.sni.configurator._create_vhost(match)) 118 self.assertEqual(len(vhs), 2) 119 for vhost in vhs: 120 self.assertEqual(vhost.addrs, set([obj.Addr.fromstring("*:443")])) 121 names = vhost.get_names() 122 self.assertTrue(names in z_domains) 123 124 def test_get_dvsni_addrs_default(self): 125 self.sni.configurator.choose_vhost = mock.Mock( 126 return_value=obj.VirtualHost( 127 "path", "aug_path", set([obj.Addr.fromstring("_default_:443")]), 128 False, False) 129 ) 130 131 self.assertEqual( 132 set([obj.Addr.fromstring("*:443")]), 133 self.sni.get_dvsni_addrs(self.achalls[0])) 134 135 136 if __name__ == "__main__": 137 unittest.main() # pragma: no cover