client_test.py
1 """Tests for acme.client.""" 2 import datetime 3 import json 4 import unittest 5 6 from six.moves import http_client # pylint: disable=import-error 7 8 import mock 9 import requests 10 11 from acme import challenges 12 from acme import errors 13 from acme import jose 14 from acme import jws as acme_jws 15 from acme import messages 16 from acme import messages_test 17 from acme import test_util 18 19 20 CERT_DER = test_util.load_vector('cert.der') 21 KEY = jose.JWKRSA.load(test_util.load_vector('rsa512_key.pem')) 22 KEY2 = jose.JWKRSA.load(test_util.load_vector('rsa256_key.pem')) 23 24 25 class ClientTest(unittest.TestCase): 26 """Tests for acme.client.Client.""" 27 # pylint: disable=too-many-instance-attributes,too-many-public-methods 28 29 def setUp(self): 30 self.response = mock.MagicMock( 31 ok=True, status_code=http_client.OK, headers={}, links={}) 32 self.net = mock.MagicMock() 33 self.net.post.return_value = self.response 34 self.net.get.return_value = self.response 35 36 self.directory = messages.Directory({ 37 messages.NewRegistration: 'https://www.letsencrypt-demo.org/acme/new-reg', 38 messages.Revocation: 'https://www.letsencrypt-demo.org/acme/revoke-cert', 39 }) 40 41 from acme.client import Client 42 self.client = Client( 43 directory=self.directory, key=KEY, alg=jose.RS256, net=self.net) 44 45 self.identifier = messages.Identifier( 46 typ=messages.IDENTIFIER_FQDN, value='example.com') 47 48 # Registration 49 self.contact = ('mailto:cert-admin@example.com', 'tel:+12025551212') 50 reg = messages.Registration( 51 contact=self.contact, key=KEY.public_key()) 52 self.new_reg = messages.NewRegistration(**dict(reg)) 53 self.regr = messages.RegistrationResource( 54 body=reg, uri='https://www.letsencrypt-demo.org/acme/reg/1', 55 new_authzr_uri='https://www.letsencrypt-demo.org/acme/new-reg', 56 terms_of_service='https://www.letsencrypt-demo.org/tos') 57 58 # Authorization 59 authzr_uri = 'https://www.letsencrypt-demo.org/acme/authz/1' 60 challb = messages.ChallengeBody( 61 uri=(authzr_uri + '/1'), status=messages.STATUS_VALID, 62 chall=challenges.DNS(token=jose.b64decode( 63 'evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA'))) 64 self.challr = messages.ChallengeResource( 65 body=challb, authzr_uri=authzr_uri) 66 self.authz = messages.Authorization( 67 identifier=messages.Identifier( 68 typ=messages.IDENTIFIER_FQDN, value='example.com'), 69 challenges=(challb,), combinations=None) 70 self.authzr = messages.AuthorizationResource( 71 body=self.authz, uri=authzr_uri, 72 new_cert_uri='https://www.letsencrypt-demo.org/acme/new-cert') 73 74 # Request issuance 75 self.certr = messages.CertificateResource( 76 body=messages_test.CERT, authzrs=(self.authzr,), 77 uri='https://www.letsencrypt-demo.org/acme/cert/1', 78 cert_chain_uri='https://www.letsencrypt-demo.org/ca') 79 80 def test_init_downloads_directory(self): 81 uri = 'http://www.letsencrypt-demo.org/directory' 82 from acme.client import Client 83 self.client = Client( 84 directory=uri, key=KEY, alg=jose.RS256, net=self.net) 85 self.net.get.assert_called_once_with(uri) 86 87 def test_register(self): 88 # "Instance of 'Field' has no to_json/update member" bug: 89 # pylint: disable=no-member 90 self.response.status_code = http_client.CREATED 91 self.response.json.return_value = self.regr.body.to_json() 92 self.response.headers['Location'] = self.regr.uri 93 self.response.links.update({ 94 'next': {'url': self.regr.new_authzr_uri}, 95 'terms-of-service': {'url': self.regr.terms_of_service}, 96 }) 97 98 self.assertEqual(self.regr, self.client.register(self.new_reg)) 99 # TODO: test POST call arguments 100 101 # TODO: split here and separate test 102 reg_wrong_key = self.regr.body.update(key=KEY2.public_key()) 103 self.response.json.return_value = reg_wrong_key.to_json() 104 self.assertRaises( 105 errors.UnexpectedUpdate, self.client.register, self.new_reg) 106 107 def test_register_missing_next(self): 108 self.response.status_code = http_client.CREATED 109 self.assertRaises( 110 errors.ClientError, self.client.register, self.new_reg) 111 112 def test_update_registration(self): 113 # "Instance of 'Field' has no to_json/update member" bug: 114 # pylint: disable=no-member 115 self.response.headers['Location'] = self.regr.uri 116 self.response.json.return_value = self.regr.body.to_json() 117 self.assertEqual(self.regr, self.client.update_registration(self.regr)) 118 # TODO: test POST call arguments 119 120 # TODO: split here and separate test 121 self.response.json.return_value = self.regr.body.update( 122 contact=()).to_json() 123 self.assertRaises( 124 errors.UnexpectedUpdate, self.client.update_registration, self.regr) 125 126 def test_query_registration(self): 127 self.response.json.return_value = self.regr.body.to_json() 128 self.assertEqual(self.regr, self.client.query_registration(self.regr)) 129 130 def test_agree_to_tos(self): 131 self.client.update_registration = mock.Mock() 132 self.client.agree_to_tos(self.regr) 133 regr = self.client.update_registration.call_args[0][0] 134 self.assertEqual(self.regr.terms_of_service, regr.body.agreement) 135 136 def test_request_challenges(self): 137 self.response.status_code = http_client.CREATED 138 self.response.headers['Location'] = self.authzr.uri 139 self.response.json.return_value = self.authz.to_json() 140 self.response.links = { 141 'next': {'url': self.authzr.new_cert_uri}, 142 } 143 144 self.client.request_challenges(self.identifier, self.authzr.uri) 145 # TODO: test POST call arguments 146 147 # TODO: split here and separate test 148 self.response.json.return_value = self.authz.update( 149 identifier=self.identifier.update(value='foo')).to_json() 150 self.assertRaises( 151 errors.UnexpectedUpdate, self.client.request_challenges, 152 self.identifier, self.authzr.uri) 153 154 def test_request_challenges_missing_next(self): 155 self.response.status_code = http_client.CREATED 156 self.assertRaises( 157 errors.ClientError, self.client.request_challenges, 158 self.identifier, self.regr) 159 160 def test_request_domain_challenges(self): 161 self.client.request_challenges = mock.MagicMock() 162 self.assertEqual( 163 self.client.request_challenges(self.identifier), 164 self.client.request_domain_challenges('example.com', self.regr)) 165 166 def test_answer_challenge(self): 167 self.response.links['up'] = {'url': self.challr.authzr_uri} 168 self.response.json.return_value = self.challr.body.to_json() 169 170 chall_response = challenges.DNSResponse(validation=None) 171 172 self.client.answer_challenge(self.challr.body, chall_response) 173 174 # TODO: split here and separate test 175 self.assertRaises(errors.UnexpectedUpdate, self.client.answer_challenge, 176 self.challr.body.update(uri='foo'), chall_response) 177 178 def test_answer_challenge_missing_next(self): 179 self.assertRaises( 180 errors.ClientError, self.client.answer_challenge, 181 self.challr.body, challenges.DNSResponse(validation=None)) 182 183 def test_retry_after_date(self): 184 self.response.headers['Retry-After'] = 'Fri, 31 Dec 1999 23:59:59 GMT' 185 self.assertEqual( 186 datetime.datetime(1999, 12, 31, 23, 59, 59), 187 self.client.retry_after(response=self.response, default=10)) 188 189 @mock.patch('acme.client.datetime') 190 def test_retry_after_invalid(self, dt_mock): 191 dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) 192 dt_mock.timedelta = datetime.timedelta 193 194 self.response.headers['Retry-After'] = 'foooo' 195 self.assertEqual( 196 datetime.datetime(2015, 3, 27, 0, 0, 10), 197 self.client.retry_after(response=self.response, default=10)) 198 199 @mock.patch('acme.client.datetime') 200 def test_retry_after_seconds(self, dt_mock): 201 dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) 202 dt_mock.timedelta = datetime.timedelta 203 204 self.response.headers['Retry-After'] = '50' 205 self.assertEqual( 206 datetime.datetime(2015, 3, 27, 0, 0, 50), 207 self.client.retry_after(response=self.response, default=10)) 208 209 @mock.patch('acme.client.datetime') 210 def test_retry_after_missing(self, dt_mock): 211 dt_mock.datetime.now.return_value = datetime.datetime(2015, 3, 27) 212 dt_mock.timedelta = datetime.timedelta 213 214 self.assertEqual( 215 datetime.datetime(2015, 3, 27, 0, 0, 10), 216 self.client.retry_after(response=self.response, default=10)) 217 218 def test_poll(self): 219 self.response.json.return_value = self.authzr.body.to_json() 220 self.assertEqual((self.authzr, self.response), 221 self.client.poll(self.authzr)) 222 223 # TODO: split here and separate test 224 self.response.json.return_value = self.authz.update( 225 identifier=self.identifier.update(value='foo')).to_json() 226 self.assertRaises( 227 errors.UnexpectedUpdate, self.client.poll, self.authzr) 228 229 def test_request_issuance(self): 230 self.response.content = CERT_DER 231 self.response.headers['Location'] = self.certr.uri 232 self.response.links['up'] = {'url': self.certr.cert_chain_uri} 233 self.assertEqual(self.certr, self.client.request_issuance( 234 messages_test.CSR, (self.authzr,))) 235 # TODO: check POST args 236 237 def test_request_issuance_missing_up(self): 238 self.response.content = CERT_DER 239 self.response.headers['Location'] = self.certr.uri 240 self.assertEqual( 241 self.certr.update(cert_chain_uri=None), 242 self.client.request_issuance(messages_test.CSR, (self.authzr,))) 243 244 def test_request_issuance_missing_location(self): 245 self.assertRaises( 246 errors.ClientError, self.client.request_issuance, 247 messages_test.CSR, (self.authzr,)) 248 249 @mock.patch('acme.client.datetime') 250 @mock.patch('acme.client.time') 251 def test_poll_and_request_issuance(self, time_mock, dt_mock): 252 # clock.dt | pylint: disable=no-member 253 clock = mock.MagicMock(dt=datetime.datetime(2015, 3, 27)) 254 255 def sleep(seconds): 256 """increment clock""" 257 clock.dt += datetime.timedelta(seconds=seconds) 258 time_mock.sleep.side_effect = sleep 259 260 def now(): 261 """return current clock value""" 262 return clock.dt 263 dt_mock.datetime.now.side_effect = now 264 dt_mock.timedelta = datetime.timedelta 265 266 def poll(authzr): # pylint: disable=missing-docstring 267 # record poll start time based on the current clock value 268 authzr.times.append(clock.dt) 269 270 # suppose it takes 2 seconds for server to produce the 271 # result, increment clock 272 clock.dt += datetime.timedelta(seconds=2) 273 274 if not authzr.retries: # no more retries 275 done = mock.MagicMock(uri=authzr.uri, times=authzr.times) 276 done.body.status = messages.STATUS_VALID 277 return done, [] 278 279 # response (2nd result tuple element) is reduced to only 280 # Retry-After header contents represented as integer 281 # seconds; authzr.retries is a list of Retry-After 282 # headers, head(retries) is peeled of as a current 283 # Retry-After header, and tail(retries) is persisted for 284 # later poll() calls 285 return (mock.MagicMock(retries=authzr.retries[1:], 286 uri=authzr.uri + '.', times=authzr.times), 287 authzr.retries[0]) 288 self.client.poll = mock.MagicMock(side_effect=poll) 289 290 mintime = 7 291 292 def retry_after(response, default): # pylint: disable=missing-docstring 293 # check that poll_and_request_issuance correctly passes mintime 294 self.assertEqual(default, mintime) 295 return clock.dt + datetime.timedelta(seconds=response) 296 self.client.retry_after = mock.MagicMock(side_effect=retry_after) 297 298 def request_issuance(csr, authzrs): # pylint: disable=missing-docstring 299 return csr, authzrs 300 self.client.request_issuance = mock.MagicMock( 301 side_effect=request_issuance) 302 303 csr = mock.MagicMock() 304 authzrs = ( 305 mock.MagicMock(uri='a', times=[], retries=(8, 20, 30)), 306 mock.MagicMock(uri='b', times=[], retries=(5,)), 307 ) 308 309 cert, updated_authzrs = self.client.poll_and_request_issuance( 310 csr, authzrs, mintime=mintime) 311 self.assertTrue(cert[0] is csr) 312 self.assertTrue(cert[1] is updated_authzrs) 313 self.assertEqual(updated_authzrs[0].uri, 'a...') 314 self.assertEqual(updated_authzrs[1].uri, 'b.') 315 self.assertEqual(updated_authzrs[0].times, [ 316 datetime.datetime(2015, 3, 27), 317 # a is scheduled for 10, but b is polling [9..11), so it 318 # will be picked up as soon as b is finished, without 319 # additional sleeping 320 datetime.datetime(2015, 3, 27, 0, 0, 11), 321 datetime.datetime(2015, 3, 27, 0, 0, 33), 322 datetime.datetime(2015, 3, 27, 0, 1, 5), 323 ]) 324 self.assertEqual(updated_authzrs[1].times, [ 325 datetime.datetime(2015, 3, 27, 0, 0, 2), 326 datetime.datetime(2015, 3, 27, 0, 0, 9), 327 ]) 328 self.assertEqual(clock.dt, datetime.datetime(2015, 3, 27, 0, 1, 7)) 329 330 def test_check_cert(self): 331 self.response.headers['Location'] = self.certr.uri 332 self.response.content = CERT_DER 333 self.assertEqual(self.certr.update(body=messages_test.CERT), 334 self.client.check_cert(self.certr)) 335 336 # TODO: split here and separate test 337 self.response.headers['Location'] = 'foo' 338 self.assertRaises( 339 errors.UnexpectedUpdate, self.client.check_cert, self.certr) 340 341 def test_check_cert_missing_location(self): 342 self.response.content = CERT_DER 343 self.assertRaises( 344 errors.ClientError, self.client.check_cert, self.certr) 345 346 def test_refresh(self): 347 self.client.check_cert = mock.MagicMock() 348 self.assertEqual( 349 self.client.check_cert(self.certr), self.client.refresh(self.certr)) 350 351 def test_fetch_chain_no_up_link(self): 352 self.assertEqual([], self.client.fetch_chain(self.certr.update( 353 cert_chain_uri=None))) 354 355 def test_fetch_chain_single(self): 356 # pylint: disable=protected-access 357 self.client._get_cert = mock.MagicMock() 358 self.client._get_cert.return_value = ( 359 mock.MagicMock(links={}), "certificate") 360 self.assertEqual([self.client._get_cert(self.certr.cert_chain_uri)[1]], 361 self.client.fetch_chain(self.certr)) 362 363 def test_fetch_chain_max(self): 364 # pylint: disable=protected-access 365 up_response = mock.MagicMock(links={'up': {'url': 'http://cert'}}) 366 noup_response = mock.MagicMock(links={}) 367 self.client._get_cert = mock.MagicMock() 368 self.client._get_cert.side_effect = [ 369 (up_response, "cert")] * 9 + [(noup_response, "last_cert")] 370 chain = self.client.fetch_chain(self.certr, max_length=10) 371 self.assertEqual(chain, ["cert"] * 9 + ["last_cert"]) 372 373 def test_fetch_chain_too_many(self): # recursive 374 # pylint: disable=protected-access 375 response = mock.MagicMock(links={'up': {'url': 'http://cert'}}) 376 self.client._get_cert = mock.MagicMock() 377 self.client._get_cert.return_value = (response, "certificate") 378 self.assertRaises(errors.Error, self.client.fetch_chain, self.certr) 379 380 def test_revoke(self): 381 self.client.revoke(self.certr.body) 382 self.net.post.assert_called_once_with( 383 self.directory[messages.Revocation], mock.ANY, content_type=None) 384 385 def test_revoke_bad_status_raises_error(self): 386 self.response.status_code = http_client.METHOD_NOT_ALLOWED 387 self.assertRaises(errors.ClientError, self.client.revoke, self.certr) 388 389 390 class ClientNetworkTest(unittest.TestCase): 391 """Tests for acme.client.ClientNetwork.""" 392 393 def setUp(self): 394 self.verify_ssl = mock.MagicMock() 395 self.wrap_in_jws = mock.MagicMock(return_value=mock.sentinel.wrapped) 396 397 from acme.client import ClientNetwork 398 self.net = ClientNetwork( 399 key=KEY, alg=jose.RS256, verify_ssl=self.verify_ssl) 400 401 self.response = mock.MagicMock(ok=True, status_code=http_client.OK) 402 self.response.headers = {} 403 self.response.links = {} 404 405 def test_init(self): 406 self.assertTrue(self.net.verify_ssl is self.verify_ssl) 407 408 def test_wrap_in_jws(self): 409 class MockJSONDeSerializable(jose.JSONDeSerializable): 410 # pylint: disable=missing-docstring 411 def __init__(self, value): 412 self.value = value 413 414 def to_partial_json(self): 415 return {'foo': self.value} 416 417 @classmethod 418 def from_json(cls, value): 419 pass # pragma: no cover 420 421 # pylint: disable=protected-access 422 jws_dump = self.net._wrap_in_jws( 423 MockJSONDeSerializable('foo'), nonce=b'Tg') 424 jws = acme_jws.JWS.json_loads(jws_dump) 425 self.assertEqual(json.loads(jws.payload.decode()), {'foo': 'foo'}) 426 self.assertEqual(jws.signature.combined.nonce, b'Tg') 427 428 def test_check_response_not_ok_jobj_no_error(self): 429 self.response.ok = False 430 self.response.json.return_value = {} 431 # pylint: disable=protected-access 432 self.assertRaises( 433 errors.ClientError, self.net._check_response, self.response) 434 435 def test_check_response_not_ok_jobj_error(self): 436 self.response.ok = False 437 self.response.json.return_value = messages.Error( 438 detail='foo', typ='serverInternal', title='some title').to_json() 439 # pylint: disable=protected-access 440 self.assertRaises( 441 messages.Error, self.net._check_response, self.response) 442 443 def test_check_response_not_ok_no_jobj(self): 444 self.response.ok = False 445 self.response.json.side_effect = ValueError 446 # pylint: disable=protected-access 447 self.assertRaises( 448 errors.ClientError, self.net._check_response, self.response) 449 450 def test_check_response_ok_no_jobj_ct_required(self): 451 self.response.json.side_effect = ValueError 452 for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: 453 self.response.headers['Content-Type'] = response_ct 454 # pylint: disable=protected-access 455 self.assertRaises( 456 errors.ClientError, self.net._check_response, self.response, 457 content_type=self.net.JSON_CONTENT_TYPE) 458 459 def test_check_response_ok_no_jobj_no_ct(self): 460 self.response.json.side_effect = ValueError 461 for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: 462 self.response.headers['Content-Type'] = response_ct 463 # pylint: disable=protected-access,no-value-for-parameter 464 self.assertEqual( 465 self.response, self.net._check_response(self.response)) 466 467 def test_check_response_jobj(self): 468 self.response.json.return_value = {} 469 for response_ct in [self.net.JSON_CONTENT_TYPE, 'foo']: 470 self.response.headers['Content-Type'] = response_ct 471 # pylint: disable=protected-access,no-value-for-parameter 472 self.assertEqual( 473 self.response, self.net._check_response(self.response)) 474 475 @mock.patch('acme.client.requests') 476 def test_send_request(self, mock_requests): 477 mock_requests.request.return_value = self.response 478 # pylint: disable=protected-access 479 self.assertEqual(self.response, self.net._send_request( 480 'HEAD', 'url', 'foo', bar='baz')) 481 mock_requests.request.assert_called_once_with( 482 'HEAD', 'url', 'foo', verify=mock.ANY, bar='baz') 483 484 @mock.patch('acme.client.requests') 485 def test_send_request_verify_ssl(self, mock_requests): 486 # pylint: disable=protected-access 487 for verify in True, False: 488 mock_requests.request.reset_mock() 489 mock_requests.request.return_value = self.response 490 self.net.verify_ssl = verify 491 # pylint: disable=protected-access 492 self.assertEqual( 493 self.response, self.net._send_request('GET', 'url')) 494 mock_requests.request.assert_called_once_with( 495 'GET', 'url', verify=verify) 496 497 @mock.patch('acme.client.requests') 498 def test_requests_error_passthrough(self, mock_requests): 499 mock_requests.exceptions = requests.exceptions 500 mock_requests.request.side_effect = requests.exceptions.RequestException 501 # pylint: disable=protected-access 502 self.assertRaises(requests.exceptions.RequestException, 503 self.net._send_request, 'GET', 'uri') 504 505 506 class ClientNetworkWithMockedResponseTest(unittest.TestCase): 507 """Tests for acme.client.ClientNetwork which mock out response.""" 508 # pylint: disable=too-many-instance-attributes 509 510 def setUp(self): 511 from acme.client import ClientNetwork 512 self.net = ClientNetwork(key=None, alg=None) 513 514 self.response = mock.MagicMock(ok=True, status_code=http_client.OK) 515 self.response.headers = {} 516 self.response.links = {} 517 self.checked_response = mock.MagicMock() 518 self.obj = mock.MagicMock() 519 self.wrapped_obj = mock.MagicMock() 520 self.content_type = mock.sentinel.content_type 521 522 self.all_nonces = [jose.b64encode(b'Nonce'), jose.b64encode(b'Nonce2')] 523 self.available_nonces = self.all_nonces[:] 524 525 def send_request(*args, **kwargs): 526 # pylint: disable=unused-argument,missing-docstring 527 if self.available_nonces: 528 self.response.headers = { 529 self.net.REPLAY_NONCE_HEADER: 530 self.available_nonces.pop().decode()} 531 else: 532 self.response.headers = {} 533 return self.response 534 535 # pylint: disable=protected-access 536 self.net._send_request = self.send_request = mock.MagicMock( 537 side_effect=send_request) 538 self.net._check_response = self.check_response 539 self.net._wrap_in_jws = mock.MagicMock(return_value=self.wrapped_obj) 540 541 def check_response(self, response, content_type): 542 # pylint: disable=missing-docstring 543 self.assertEqual(self.response, response) 544 self.assertEqual(self.content_type, content_type) 545 return self.checked_response 546 547 def test_head(self): 548 self.assertEqual(self.response, self.net.head('url', 'foo', bar='baz')) 549 self.send_request.assert_called_once_with( 550 'HEAD', 'url', 'foo', bar='baz') 551 552 def test_get(self): 553 self.assertEqual(self.checked_response, self.net.get( 554 'url', content_type=self.content_type, bar='baz')) 555 self.send_request.assert_called_once_with('GET', 'url', bar='baz') 556 557 def test_post(self): 558 # pylint: disable=protected-access 559 self.assertEqual(self.checked_response, self.net.post( 560 'uri', self.obj, content_type=self.content_type)) 561 self.net._wrap_in_jws.assert_called_once_with( 562 self.obj, jose.b64decode(self.all_nonces.pop())) 563 564 assert not self.available_nonces 565 self.assertRaises(errors.MissingNonce, self.net.post, 566 'uri', self.obj, content_type=self.content_type) 567 self.net._wrap_in_jws.assert_called_with( 568 self.obj, jose.b64decode(self.all_nonces.pop())) 569 570 def test_post_wrong_initial_nonce(self): # HEAD 571 self.available_nonces = [b'f', jose.b64encode(b'good')] 572 self.assertRaises(errors.BadNonce, self.net.post, 'uri', 573 self.obj, content_type=self.content_type) 574 575 def test_post_wrong_post_response_nonce(self): 576 self.available_nonces = [jose.b64encode(b'good'), b'f'] 577 self.assertRaises(errors.BadNonce, self.net.post, 'uri', 578 self.obj, content_type=self.content_type) 579 580 def test_head_get_post_error_passthrough(self): 581 self.send_request.side_effect = requests.exceptions.RequestException 582 for method in self.net.head, self.net.get: 583 self.assertRaises( 584 requests.exceptions.RequestException, method, 'GET', 'uri') 585 self.assertRaises(requests.exceptions.RequestException, 586 self.net.post, 'uri', obj=self.obj) 587 588 589 if __name__ == '__main__': 590 unittest.main() # pragma: no cover