1 #!/usr/bin/python2.4 2 # 3 # Copyright 2014 Google Inc. All rights reserved. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 18 """Oauth2client tests 19 20 Unit tests for oauth2client. 21 """ 22 23 __author__ = 'jcgregorio (at] google.com (Joe Gregorio)' 24 25 import base64 26 import contextlib 27 import datetime 28 import json 29 import os 30 import sys 31 import time 32 import unittest 33 34 import mock 35 import six 36 from six.moves import urllib 37 38 from .http_mock import HttpMock 39 from .http_mock import HttpMockSequence 40 from oauth2client import GOOGLE_REVOKE_URI 41 from oauth2client import GOOGLE_TOKEN_URI 42 from oauth2client import client 43 from oauth2client.client import AccessTokenCredentials 44 from oauth2client.client import AccessTokenCredentialsError 45 from oauth2client.client import AccessTokenRefreshError 46 from oauth2client.client import ADC_HELP_MSG 47 from oauth2client.client import AssertionCredentials 48 from oauth2client.client import AUTHORIZED_USER 49 from oauth2client.client import Credentials 50 from oauth2client.client import DEFAULT_ENV_NAME 51 from oauth2client.client import ApplicationDefaultCredentialsError 52 from oauth2client.client import FlowExchangeError 53 from oauth2client.client import GoogleCredentials 54 from oauth2client.client import GOOGLE_APPLICATION_CREDENTIALS 55 from oauth2client.client import MemoryCache 56 from oauth2client.client import NonAsciiHeaderError 57 from oauth2client.client import OAuth2Credentials 58 from oauth2client.client import OAuth2WebServerFlow 59 from oauth2client.client import OOB_CALLBACK_URN 60 from oauth2client.client import REFRESH_STATUS_CODES 61 from oauth2client.client import SERVICE_ACCOUNT 62 from oauth2client.client import Storage 63 from oauth2client.client import TokenRevokeError 64 from oauth2client.client import VerifyJwtTokenError 65 from oauth2client.client import _extract_id_token 66 from oauth2client.client import _get_application_default_credential_from_file 67 from oauth2client.client import _get_environment 68 from oauth2client.client import _get_environment_variable_file 69 from oauth2client.client import _get_well_known_file 70 from oauth2client.client import _raise_exception_for_missing_fields 71 from oauth2client.client import _raise_exception_for_reading_json 72 from oauth2client.client import _update_query_params 73 from oauth2client.client import credentials_from_clientsecrets_and_code 74 from oauth2client.client import credentials_from_code 75 from oauth2client.client import flow_from_clientsecrets 76 from oauth2client.client import save_to_well_known_file 77 from oauth2client.clientsecrets import _loadfile 78 from oauth2client.service_account import _ServiceAccountCredentials 79 80 DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') 81 82 83 # TODO(craigcitro): This is duplicated from 84 # googleapiclient.test_discovery; consolidate these definitions. 85 def assertUrisEqual(testcase, expected, actual): 86 """Test that URIs are the same, up to reordering of query parameters.""" 87 expected = urllib.parse.urlparse(expected) 88 actual = urllib.parse.urlparse(actual) 89 testcase.assertEqual(expected.scheme, actual.scheme) 90 testcase.assertEqual(expected.netloc, actual.netloc) 91 testcase.assertEqual(expected.path, actual.path) 92 testcase.assertEqual(expected.params, actual.params) 93 testcase.assertEqual(expected.fragment, actual.fragment) 94 expected_query = urllib.parse.parse_qs(expected.query) 95 actual_query = urllib.parse.parse_qs(actual.query) 96 for name in expected_query.keys(): 97 testcase.assertEqual(expected_query[name], actual_query[name]) 98 for name in actual_query.keys(): 99 testcase.assertEqual(expected_query[name], actual_query[name]) 100 101 102 def datafile(filename): 103 return os.path.join(DATA_DIR, filename) 104 105 106 def load_and_cache(existing_file, fakename, cache_mock): 107 client_type, client_info = _loadfile(datafile(existing_file)) 108 cache_mock.cache[fakename] = {client_type: client_info} 109 110 111 class CacheMock(object): 112 def __init__(self): 113 self.cache = {} 114 115 def get(self, key, namespace=''): 116 # ignoring namespace for easier testing 117 return self.cache.get(key, None) 118 119 def set(self, key, value, namespace=''): 120 # ignoring namespace for easier testing 121 self.cache[key] = value 122 123 124 class CredentialsTests(unittest.TestCase): 125 126 def test_to_from_json(self): 127 credentials = Credentials() 128 json = credentials.to_json() 129 restored = Credentials.new_from_json(json) 130 131 132 class MockResponse(object): 133 """Mock the response of urllib2.urlopen() call.""" 134 135 def __init__(self, headers): 136 self._headers = headers 137 138 def info(self): 139 class Info: 140 def __init__(self, headers): 141 self.headers = headers 142 143 def get(self, key, default=None): 144 return self.headers.get(key, default) 145 146 return Info(self._headers) 147 148 149 @contextlib.contextmanager 150 def mock_module_import(module): 151 """Place a dummy objects in sys.modules to mock an import test.""" 152 parts = module.split('.') 153 entries = ['.'.join(parts[:i+1]) for i in range(len(parts))] 154 for entry in entries: 155 sys.modules[entry] = object() 156 157 try: 158 yield 159 160 finally: 161 for entry in entries: 162 del sys.modules[entry] 163 164 165 class GoogleCredentialsTests(unittest.TestCase): 166 167 def setUp(self): 168 self.env_server_software = os.environ.get('SERVER_SOFTWARE', None) 169 self.env_google_application_credentials = ( 170 os.environ.get(GOOGLE_APPLICATION_CREDENTIALS, None)) 171 self.env_appdata = os.environ.get('APPDATA', None) 172 self.os_name = os.name 173 from oauth2client import client 174 client.SETTINGS.env_name = None 175 176 def tearDown(self): 177 self.reset_env('SERVER_SOFTWARE', self.env_server_software) 178 self.reset_env(GOOGLE_APPLICATION_CREDENTIALS, 179 self.env_google_application_credentials) 180 self.reset_env('APPDATA', self.env_appdata) 181 os.name = self.os_name 182 183 def reset_env(self, env, value): 184 """Set the environment variable 'env' to 'value'.""" 185 if value is not None: 186 os.environ[env] = value 187 else: 188 os.environ.pop(env, '') 189 190 def validate_service_account_credentials(self, credentials): 191 self.assertTrue(isinstance(credentials, _ServiceAccountCredentials)) 192 self.assertEqual('123', credentials._service_account_id) 193 self.assertEqual('dummy (at] google.com', credentials._service_account_email) 194 self.assertEqual('ABCDEF', credentials._private_key_id) 195 self.assertEqual('', credentials._scopes) 196 197 def validate_google_credentials(self, credentials): 198 self.assertTrue(isinstance(credentials, GoogleCredentials)) 199 self.assertEqual(None, credentials.access_token) 200 self.assertEqual('123', credentials.client_id) 201 self.assertEqual('secret', credentials.client_secret) 202 self.assertEqual('alabalaportocala', credentials.refresh_token) 203 self.assertEqual(None, credentials.token_expiry) 204 self.assertEqual(GOOGLE_TOKEN_URI, credentials.token_uri) 205 self.assertEqual('Python client library', credentials.user_agent) 206 207 def get_a_google_credentials_object(self): 208 return GoogleCredentials(None, None, None, None, None, None, None, None) 209 210 def test_create_scoped_required(self): 211 self.assertFalse( 212 self.get_a_google_credentials_object().create_scoped_required()) 213 214 def test_create_scoped(self): 215 credentials = self.get_a_google_credentials_object() 216 self.assertEqual(credentials, credentials.create_scoped(None)) 217 self.assertEqual(credentials, 218 credentials.create_scoped(['dummy_scope'])) 219 220 def test_get_environment_gae_production(self): 221 with mock_module_import('google.appengine'): 222 os.environ['SERVER_SOFTWARE'] = 'Google App Engine/XYZ' 223 self.assertEqual('GAE_PRODUCTION', _get_environment()) 224 225 def test_get_environment_gae_local(self): 226 with mock_module_import('google.appengine'): 227 os.environ['SERVER_SOFTWARE'] = 'Development/XYZ' 228 self.assertEqual('GAE_LOCAL', _get_environment()) 229 230 def test_get_environment_gce_production(self): 231 os.environ['SERVER_SOFTWARE'] = '' 232 response = MockResponse({'Metadata-Flavor': 'Google'}) 233 with mock.patch.object(urllib.request, 'urlopen', 234 return_value=response, 235 autospec=True) as urlopen: 236 self.assertEqual('GCE_PRODUCTION', _get_environment()) 237 urlopen.assert_called_once_with( 238 'http://169.254.169.254/', timeout=1) 239 240 def test_get_environment_unknown(self): 241 os.environ['SERVER_SOFTWARE'] = '' 242 with mock.patch.object(urllib.request, 'urlopen', 243 return_value=MockResponse({}), 244 autospec=True) as urlopen: 245 self.assertEqual(DEFAULT_ENV_NAME, _get_environment()) 246 urlopen.assert_called_once_with( 247 'http://169.254.169.254/', timeout=1) 248 249 def test_get_environment_variable_file(self): 250 environment_variable_file = datafile( 251 os.path.join('gcloud', 'application_default_credentials.json')) 252 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file 253 self.assertEqual(environment_variable_file, 254 _get_environment_variable_file()) 255 256 def test_get_environment_variable_file_error(self): 257 nonexistent_file = datafile('nonexistent') 258 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = nonexistent_file 259 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 260 try: 261 _get_environment_variable_file() 262 self.fail(nonexistent_file + ' should not exist.') 263 except ApplicationDefaultCredentialsError as error: 264 self.assertEqual('File ' + nonexistent_file + 265 ' (pointed by ' + GOOGLE_APPLICATION_CREDENTIALS + 266 ' environment variable) does not exist!', 267 str(error)) 268 269 def test_get_well_known_file_on_windows(self): 270 ORIGINAL_ISDIR = os.path.isdir 271 try: 272 os.path.isdir = lambda path: True 273 well_known_file = datafile( 274 os.path.join(client._CLOUDSDK_CONFIG_DIRECTORY, 275 'application_default_credentials.json')) 276 os.name = 'nt' 277 os.environ['APPDATA'] = DATA_DIR 278 self.assertEqual(well_known_file, _get_well_known_file()) 279 finally: 280 os.path.isdir = ORIGINAL_ISDIR 281 282 def test_get_well_known_file_with_custom_config_dir(self): 283 ORIGINAL_ENVIRON = os.environ 284 ORIGINAL_ISDIR = os.path.isdir 285 CUSTOM_DIR = 'CUSTOM_DIR' 286 EXPECTED_FILE = os.path.join(CUSTOM_DIR, 287 'application_default_credentials.json') 288 try: 289 os.environ = {client._CLOUDSDK_CONFIG_ENV_VAR: CUSTOM_DIR} 290 os.path.isdir = lambda path: True 291 well_known_file = _get_well_known_file() 292 self.assertEqual(well_known_file, EXPECTED_FILE) 293 finally: 294 os.environ = ORIGINAL_ENVIRON 295 os.path.isdir = ORIGINAL_ISDIR 296 297 def test_get_application_default_credential_from_file_service_account(self): 298 credentials_file = datafile( 299 os.path.join('gcloud', 'application_default_credentials.json')) 300 credentials = _get_application_default_credential_from_file( 301 credentials_file) 302 self.validate_service_account_credentials(credentials) 303 304 def test_save_to_well_known_file_service_account(self): 305 credential_file = datafile( 306 os.path.join('gcloud', 'application_default_credentials.json')) 307 credentials = _get_application_default_credential_from_file( 308 credential_file) 309 temp_credential_file = datafile( 310 os.path.join('gcloud', 'temp_well_known_file_service_account.json')) 311 save_to_well_known_file(credentials, temp_credential_file) 312 with open(temp_credential_file) as f: 313 d = json.load(f) 314 self.assertEqual('service_account', d['type']) 315 self.assertEqual('123', d['client_id']) 316 self.assertEqual('dummy (at] google.com', d['client_email']) 317 self.assertEqual('ABCDEF', d['private_key_id']) 318 os.remove(temp_credential_file) 319 320 def test_save_well_known_file_with_non_existent_config_dir(self): 321 credential_file = datafile( 322 os.path.join('gcloud', 'application_default_credentials.json')) 323 credentials = _get_application_default_credential_from_file( 324 credential_file) 325 ORIGINAL_ISDIR = os.path.isdir 326 try: 327 os.path.isdir = lambda path: False 328 self.assertRaises(OSError, save_to_well_known_file, credentials) 329 finally: 330 os.path.isdir = ORIGINAL_ISDIR 331 332 def test_get_application_default_credential_from_file_authorized_user(self): 333 credentials_file = datafile( 334 os.path.join('gcloud', 335 'application_default_credentials_authorized_user.json')) 336 credentials = _get_application_default_credential_from_file( 337 credentials_file) 338 self.validate_google_credentials(credentials) 339 340 def test_save_to_well_known_file_authorized_user(self): 341 credentials_file = datafile( 342 os.path.join('gcloud', 343 'application_default_credentials_authorized_user.json')) 344 credentials = _get_application_default_credential_from_file( 345 credentials_file) 346 temp_credential_file = datafile( 347 os.path.join('gcloud', 'temp_well_known_file_authorized_user.json')) 348 save_to_well_known_file(credentials, temp_credential_file) 349 with open(temp_credential_file) as f: 350 d = json.load(f) 351 self.assertEqual('authorized_user', d['type']) 352 self.assertEqual('123', d['client_id']) 353 self.assertEqual('secret', d['client_secret']) 354 self.assertEqual('alabalaportocala', d['refresh_token']) 355 os.remove(temp_credential_file) 356 357 def test_get_application_default_credential_from_malformed_file_1(self): 358 credentials_file = datafile( 359 os.path.join('gcloud', 360 'application_default_credentials_malformed_1.json')) 361 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 362 try: 363 _get_application_default_credential_from_file(credentials_file) 364 self.fail('An exception was expected!') 365 except ApplicationDefaultCredentialsError as error: 366 self.assertEqual("'type' field should be defined " 367 "(and have one of the '" + AUTHORIZED_USER + 368 "' or '" + SERVICE_ACCOUNT + "' values)", 369 str(error)) 370 371 def test_get_application_default_credential_from_malformed_file_2(self): 372 credentials_file = datafile( 373 os.path.join('gcloud', 374 'application_default_credentials_malformed_2.json')) 375 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 376 try: 377 _get_application_default_credential_from_file(credentials_file) 378 self.fail('An exception was expected!') 379 except ApplicationDefaultCredentialsError as error: 380 self.assertEqual('The following field(s) must be defined: private_key_id', 381 str(error)) 382 383 def test_get_application_default_credential_from_malformed_file_3(self): 384 credentials_file = datafile( 385 os.path.join('gcloud', 386 'application_default_credentials_malformed_3.json')) 387 self.assertRaises(ValueError, _get_application_default_credential_from_file, 388 credentials_file) 389 390 def test_raise_exception_for_missing_fields(self): 391 missing_fields = ['first', 'second', 'third'] 392 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 393 try: 394 _raise_exception_for_missing_fields(missing_fields) 395 self.fail('An exception was expected!') 396 except ApplicationDefaultCredentialsError as error: 397 self.assertEqual('The following field(s) must be defined: ' + 398 ', '.join(missing_fields), 399 str(error)) 400 401 def test_raise_exception_for_reading_json(self): 402 credential_file = 'any_file' 403 extra_help = ' be good' 404 error = ApplicationDefaultCredentialsError('stuff happens') 405 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 406 try: 407 _raise_exception_for_reading_json(credential_file, extra_help, error) 408 self.fail('An exception was expected!') 409 except ApplicationDefaultCredentialsError as ex: 410 self.assertEqual('An error was encountered while reading ' 411 'json file: '+ credential_file + 412 extra_help + ': ' + str(error), 413 str(ex)) 414 415 def test_get_application_default_from_environment_variable_service_account( 416 self): 417 os.environ['SERVER_SOFTWARE'] = '' 418 environment_variable_file = datafile( 419 os.path.join('gcloud', 'application_default_credentials.json')) 420 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file 421 self.validate_service_account_credentials( 422 GoogleCredentials.get_application_default()) 423 424 def test_env_name(self): 425 from oauth2client import client 426 self.assertEqual(None, client.SETTINGS.env_name) 427 self.test_get_application_default_from_environment_variable_service_account() 428 self.assertEqual(DEFAULT_ENV_NAME, client.SETTINGS.env_name) 429 430 def test_get_application_default_from_environment_variable_authorized_user( 431 self): 432 os.environ['SERVER_SOFTWARE'] = '' 433 environment_variable_file = datafile( 434 os.path.join('gcloud', 435 'application_default_credentials_authorized_user.json')) 436 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file 437 self.validate_google_credentials( 438 GoogleCredentials.get_application_default()) 439 440 def test_get_application_default_from_environment_variable_malformed_file( 441 self): 442 os.environ['SERVER_SOFTWARE'] = '' 443 environment_variable_file = datafile( 444 os.path.join('gcloud', 445 'application_default_credentials_malformed_3.json')) 446 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = environment_variable_file 447 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 448 try: 449 GoogleCredentials.get_application_default() 450 self.fail('An exception was expected!') 451 except ApplicationDefaultCredentialsError as error: 452 self.assertTrue(str(error).startswith( 453 'An error was encountered while reading json file: ' + 454 environment_variable_file + ' (pointed to by ' + 455 GOOGLE_APPLICATION_CREDENTIALS + ' environment variable):')) 456 457 def test_get_application_default_environment_not_set_up(self): 458 # It is normal for this test to fail if run inside 459 # a Google Compute Engine VM or after 'gcloud auth login' command 460 # has been executed on a non Windows machine. 461 os.environ['SERVER_SOFTWARE'] = '' 462 os.environ[GOOGLE_APPLICATION_CREDENTIALS] = '' 463 os.environ['APPDATA'] = '' 464 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 465 VALID_CONFIG_DIR = client._CLOUDSDK_CONFIG_DIRECTORY 466 ORIGINAL_ISDIR = os.path.isdir 467 try: 468 os.path.isdir = lambda path: True 469 client._CLOUDSDK_CONFIG_DIRECTORY = 'BOGUS_CONFIG_DIR' 470 GoogleCredentials.get_application_default() 471 self.fail('An exception was expected!') 472 except ApplicationDefaultCredentialsError as error: 473 self.assertEqual(ADC_HELP_MSG, str(error)) 474 finally: 475 os.path.isdir = ORIGINAL_ISDIR 476 client._CLOUDSDK_CONFIG_DIRECTORY = VALID_CONFIG_DIR 477 478 def test_from_stream_service_account(self): 479 credentials_file = datafile( 480 os.path.join('gcloud', 'application_default_credentials.json')) 481 credentials = ( 482 self.get_a_google_credentials_object().from_stream(credentials_file)) 483 self.validate_service_account_credentials(credentials) 484 485 def test_from_stream_authorized_user(self): 486 credentials_file = datafile( 487 os.path.join('gcloud', 488 'application_default_credentials_authorized_user.json')) 489 credentials = ( 490 self.get_a_google_credentials_object().from_stream(credentials_file)) 491 self.validate_google_credentials(credentials) 492 493 def test_from_stream_malformed_file_1(self): 494 credentials_file = datafile( 495 os.path.join('gcloud', 496 'application_default_credentials_malformed_1.json')) 497 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 498 try: 499 self.get_a_google_credentials_object().from_stream(credentials_file) 500 self.fail('An exception was expected!') 501 except ApplicationDefaultCredentialsError as error: 502 self.assertEqual("An error was encountered while reading json file: " + 503 credentials_file + 504 " (provided as parameter to the from_stream() method): " 505 "'type' field should be defined (and have one of the '" + 506 AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + 507 "' values)", 508 str(error)) 509 510 def test_from_stream_malformed_file_2(self): 511 credentials_file = datafile( 512 os.path.join('gcloud', 513 'application_default_credentials_malformed_2.json')) 514 # we can't use self.assertRaisesRegexp() because it is only in Python 2.7+ 515 try: 516 self.get_a_google_credentials_object().from_stream(credentials_file) 517 self.fail('An exception was expected!') 518 except ApplicationDefaultCredentialsError as error: 519 self.assertEqual('An error was encountered while reading json file: ' + 520 credentials_file + 521 ' (provided as parameter to the from_stream() method): ' 522 'The following field(s) must be defined: ' 523 'private_key_id', 524 str(error)) 525 526 def test_from_stream_malformed_file_3(self): 527 credentials_file = datafile( 528 os.path.join('gcloud', 529 'application_default_credentials_malformed_3.json')) 530 self.assertRaises( 531 ApplicationDefaultCredentialsError, 532 self.get_a_google_credentials_object().from_stream, credentials_file) 533 534 535 class DummyDeleteStorage(Storage): 536 delete_called = False 537 538 def locked_delete(self): 539 self.delete_called = True 540 541 542 def _token_revoke_test_helper(testcase, status, revoke_raise, 543 valid_bool_value, token_attr): 544 current_store = getattr(testcase.credentials, 'store', None) 545 546 dummy_store = DummyDeleteStorage() 547 testcase.credentials.set_store(dummy_store) 548 549 actual_do_revoke = testcase.credentials._do_revoke 550 testcase.token_from_revoke = None 551 def do_revoke_stub(http_request, token): 552 testcase.token_from_revoke = token 553 return actual_do_revoke(http_request, token) 554 testcase.credentials._do_revoke = do_revoke_stub 555 556 http = HttpMock(headers={'status': status}) 557 if revoke_raise: 558 testcase.assertRaises(TokenRevokeError, testcase.credentials.revoke, http) 559 else: 560 testcase.credentials.revoke(http) 561 562 testcase.assertEqual(getattr(testcase.credentials, token_attr), 563 testcase.token_from_revoke) 564 testcase.assertEqual(valid_bool_value, testcase.credentials.invalid) 565 testcase.assertEqual(valid_bool_value, dummy_store.delete_called) 566 567 testcase.credentials.set_store(current_store) 568 569 570 class BasicCredentialsTests(unittest.TestCase): 571 572 def setUp(self): 573 access_token = 'foo' 574 client_id = 'some_client_id' 575 client_secret = 'cOuDdkfjxxnv+' 576 refresh_token = '1/0/a.df219fjls0' 577 token_expiry = datetime.datetime.utcnow() 578 user_agent = 'refresh_checker/1.0' 579 self.credentials = OAuth2Credentials( 580 access_token, client_id, client_secret, 581 refresh_token, token_expiry, GOOGLE_TOKEN_URI, 582 user_agent, revoke_uri=GOOGLE_REVOKE_URI) 583 584 def test_token_refresh_success(self): 585 for status_code in REFRESH_STATUS_CODES: 586 token_response = {'access_token': '1/3w', 'expires_in': 3600} 587 http = HttpMockSequence([ 588 ({'status': status_code}, b''), 589 ({'status': '200'}, json.dumps(token_response).encode('utf-8')), 590 ({'status': '200'}, 'echo_request_headers'), 591 ]) 592 http = self.credentials.authorize(http) 593 resp, content = http.request('http://example.com') 594 self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) 595 self.assertFalse(self.credentials.access_token_expired) 596 self.assertEqual(token_response, self.credentials.token_response) 597 598 def test_token_refresh_failure(self): 599 for status_code in REFRESH_STATUS_CODES: 600 http = HttpMockSequence([ 601 ({'status': status_code}, b''), 602 ({'status': '400'}, b'{"error":"access_denied"}'), 603 ]) 604 http = self.credentials.authorize(http) 605 try: 606 http.request('http://example.com') 607 self.fail('should raise AccessTokenRefreshError exception') 608 except AccessTokenRefreshError: 609 pass 610 self.assertTrue(self.credentials.access_token_expired) 611 self.assertEqual(None, self.credentials.token_response) 612 613 def test_token_revoke_success(self): 614 _token_revoke_test_helper( 615 self, '200', revoke_raise=False, 616 valid_bool_value=True, token_attr='refresh_token') 617 618 def test_token_revoke_failure(self): 619 _token_revoke_test_helper( 620 self, '400', revoke_raise=True, 621 valid_bool_value=False, token_attr='refresh_token') 622 623 def test_token_revoke_fallback(self): 624 original_credentials = self.credentials.to_json() 625 self.credentials.refresh_token = None 626 _token_revoke_test_helper( 627 self, '200', revoke_raise=False, 628 valid_bool_value=True, token_attr='access_token') 629 self.credentials = self.credentials.from_json(original_credentials) 630 631 def test_non_401_error_response(self): 632 http = HttpMockSequence([ 633 ({'status': '400'}, b''), 634 ]) 635 http = self.credentials.authorize(http) 636 resp, content = http.request('http://example.com') 637 self.assertEqual(400, resp.status) 638 self.assertEqual(None, self.credentials.token_response) 639 640 def test_to_from_json(self): 641 json = self.credentials.to_json() 642 instance = OAuth2Credentials.from_json(json) 643 self.assertEqual(OAuth2Credentials, type(instance)) 644 instance.token_expiry = None 645 self.credentials.token_expiry = None 646 647 self.assertEqual(instance.__dict__, self.credentials.__dict__) 648 649 def test_from_json_token_expiry(self): 650 data = json.loads(self.credentials.to_json()) 651 data['token_expiry'] = None 652 instance = OAuth2Credentials.from_json(json.dumps(data)) 653 self.assertTrue(isinstance(instance, OAuth2Credentials)) 654 655 def test_unicode_header_checks(self): 656 access_token = u'foo' 657 client_id = u'some_client_id' 658 client_secret = u'cOuDdkfjxxnv+' 659 refresh_token = u'1/0/a.df219fjls0' 660 token_expiry = str(datetime.datetime.utcnow()) 661 token_uri = str(GOOGLE_TOKEN_URI) 662 revoke_uri = str(GOOGLE_REVOKE_URI) 663 user_agent = u'refresh_checker/1.0' 664 credentials = OAuth2Credentials(access_token, client_id, client_secret, 665 refresh_token, token_expiry, token_uri, 666 user_agent, revoke_uri=revoke_uri) 667 668 # First, test that we correctly encode basic objects, making sure 669 # to include a bytes object. Note that oauth2client will normalize 670 # everything to bytes, no matter what python version we're in. 671 http = credentials.authorize(HttpMock(headers={'status': '200'})) 672 headers = {u'foo': 3, b'bar': True, 'baz': b'abc'} 673 cleaned_headers = {b'foo': b'3', b'bar': b'True', b'baz': b'abc'} 674 http.request(u'http://example.com', method=u'GET', headers=headers) 675 for k, v in cleaned_headers.items(): 676 self.assertTrue(k in http.headers) 677 self.assertEqual(v, http.headers[k]) 678 679 # Next, test that we do fail on unicode. 680 unicode_str = six.unichr(40960) + 'abcd' 681 self.assertRaises( 682 NonAsciiHeaderError, 683 http.request, 684 u'http://example.com', method=u'GET', headers={u'foo': unicode_str}) 685 686 def test_no_unicode_in_request_params(self): 687 access_token = u'foo' 688 client_id = u'some_client_id' 689 client_secret = u'cOuDdkfjxxnv+' 690 refresh_token = u'1/0/a.df219fjls0' 691 token_expiry = str(datetime.datetime.utcnow()) 692 token_uri = str(GOOGLE_TOKEN_URI) 693 revoke_uri = str(GOOGLE_REVOKE_URI) 694 user_agent = u'refresh_checker/1.0' 695 credentials = OAuth2Credentials(access_token, client_id, client_secret, 696 refresh_token, token_expiry, token_uri, 697 user_agent, revoke_uri=revoke_uri) 698 699 http = HttpMock(headers={'status': '200'}) 700 http = credentials.authorize(http) 701 http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'}) 702 for k, v in six.iteritems(http.headers): 703 self.assertEqual(six.binary_type, type(k)) 704 self.assertEqual(six.binary_type, type(v)) 705 706 # Test again with unicode strings that can't simply be converted to ASCII. 707 try: 708 http.request( 709 u'http://example.com', method=u'GET', headers={u'foo': u'\N{COMET}'}) 710 self.fail('Expected exception to be raised.') 711 except NonAsciiHeaderError: 712 pass 713 714 self.credentials.token_response = 'foobar' 715 instance = OAuth2Credentials.from_json(self.credentials.to_json()) 716 self.assertEqual('foobar', instance.token_response) 717 718 def test_get_access_token(self): 719 S = 2 # number of seconds in which the token expires 720 token_response_first = {'access_token': 'first_token', 'expires_in': S} 721 token_response_second = {'access_token': 'second_token', 'expires_in': S} 722 http = HttpMockSequence([ 723 ({'status': '200'}, json.dumps(token_response_first).encode('utf-8')), 724 ({'status': '200'}, json.dumps(token_response_second).encode('utf-8')), 725 ]) 726 727 token = self.credentials.get_access_token(http=http) 728 self.assertEqual('first_token', token.access_token) 729 self.assertEqual(S - 1, token.expires_in) 730 self.assertFalse(self.credentials.access_token_expired) 731 self.assertEqual(token_response_first, self.credentials.token_response) 732 733 token = self.credentials.get_access_token(http=http) 734 self.assertEqual('first_token', token.access_token) 735 self.assertEqual(S - 1, token.expires_in) 736 self.assertFalse(self.credentials.access_token_expired) 737 self.assertEqual(token_response_first, self.credentials.token_response) 738 739 time.sleep(S + 0.5) # some margin to avoid flakiness 740 self.assertTrue(self.credentials.access_token_expired) 741 742 token = self.credentials.get_access_token(http=http) 743 self.assertEqual('second_token', token.access_token) 744 self.assertEqual(S - 1, token.expires_in) 745 self.assertFalse(self.credentials.access_token_expired) 746 self.assertEqual(token_response_second, self.credentials.token_response) 747 748 749 class AccessTokenCredentialsTests(unittest.TestCase): 750 751 def setUp(self): 752 access_token = 'foo' 753 user_agent = 'refresh_checker/1.0' 754 self.credentials = AccessTokenCredentials(access_token, user_agent, 755 revoke_uri=GOOGLE_REVOKE_URI) 756 757 def test_token_refresh_success(self): 758 for status_code in REFRESH_STATUS_CODES: 759 http = HttpMockSequence([ 760 ({'status': status_code}, b''), 761 ]) 762 http = self.credentials.authorize(http) 763 try: 764 resp, content = http.request('http://example.com') 765 self.fail('should throw exception if token expires') 766 except AccessTokenCredentialsError: 767 pass 768 except Exception: 769 self.fail('should only throw AccessTokenCredentialsError') 770 771 def test_token_revoke_success(self): 772 _token_revoke_test_helper( 773 self, '200', revoke_raise=False, 774 valid_bool_value=True, token_attr='access_token') 775 776 def test_token_revoke_failure(self): 777 _token_revoke_test_helper( 778 self, '400', revoke_raise=True, 779 valid_bool_value=False, token_attr='access_token') 780 781 def test_non_401_error_response(self): 782 http = HttpMockSequence([ 783 ({'status': '400'}, b''), 784 ]) 785 http = self.credentials.authorize(http) 786 resp, content = http.request('http://example.com') 787 self.assertEqual(400, resp.status) 788 789 def test_auth_header_sent(self): 790 http = HttpMockSequence([ 791 ({'status': '200'}, 'echo_request_headers'), 792 ]) 793 http = self.credentials.authorize(http) 794 resp, content = http.request('http://example.com') 795 self.assertEqual(b'Bearer foo', content[b'Authorization']) 796 797 798 class TestAssertionCredentials(unittest.TestCase): 799 assertion_text = 'This is the assertion' 800 assertion_type = 'http://www.google.com/assertionType' 801 802 class AssertionCredentialsTestImpl(AssertionCredentials): 803 804 def _generate_assertion(self): 805 return TestAssertionCredentials.assertion_text 806 807 def setUp(self): 808 user_agent = 'fun/2.0' 809 self.credentials = self.AssertionCredentialsTestImpl(self.assertion_type, 810 user_agent=user_agent) 811 812 def test_assertion_body(self): 813 body = urllib.parse.parse_qs( 814 self.credentials._generate_refresh_request_body()) 815 self.assertEqual(self.assertion_text, body['assertion'][0]) 816 self.assertEqual('urn:ietf:params:oauth:grant-type:jwt-bearer', 817 body['grant_type'][0]) 818 819 def test_assertion_refresh(self): 820 http = HttpMockSequence([ 821 ({'status': '200'}, b'{"access_token":"1/3w"}'), 822 ({'status': '200'}, 'echo_request_headers'), 823 ]) 824 http = self.credentials.authorize(http) 825 resp, content = http.request('http://example.com') 826 self.assertEqual(b'Bearer 1/3w', content[b'Authorization']) 827 828 def test_token_revoke_success(self): 829 _token_revoke_test_helper( 830 self, '200', revoke_raise=False, 831 valid_bool_value=True, token_attr='access_token') 832 833 def test_token_revoke_failure(self): 834 _token_revoke_test_helper( 835 self, '400', revoke_raise=True, 836 valid_bool_value=False, token_attr='access_token') 837 838 839 class UpdateQueryParamsTest(unittest.TestCase): 840 def test_update_query_params_no_params(self): 841 uri = 'http://www.google.com' 842 updated = _update_query_params(uri, {'a': 'b'}) 843 self.assertEqual(updated, uri + '?a=b') 844 845 def test_update_query_params_existing_params(self): 846 uri = 'http://www.google.com?x=y' 847 updated = _update_query_params(uri, {'a': 'b', 'c': 'd&'}) 848 hardcoded_update = uri + '&a=b&c=d%26' 849 assertUrisEqual(self, updated, hardcoded_update) 850 851 852 class ExtractIdTokenTest(unittest.TestCase): 853 """Tests _extract_id_token().""" 854 855 def test_extract_success(self): 856 body = {'foo': 'bar'} 857 body_json = json.dumps(body).encode('ascii') 858 payload = base64.urlsafe_b64encode(body_json).strip(b'=') 859 jwt = b'stuff.' + payload + b'.signature' 860 861 extracted = _extract_id_token(jwt) 862 self.assertEqual(extracted, body) 863 864 def test_extract_failure(self): 865 body = {'foo': 'bar'} 866 body_json = json.dumps(body).encode('ascii') 867 payload = base64.urlsafe_b64encode(body_json).strip(b'=') 868 jwt = b'stuff.' + payload 869 870 self.assertRaises(VerifyJwtTokenError, _extract_id_token, jwt) 871 872 873 class OAuth2WebServerFlowTest(unittest.TestCase): 874 875 def setUp(self): 876 self.flow = OAuth2WebServerFlow( 877 client_id='client_id+1', 878 client_secret='secret+1', 879 scope='foo', 880 redirect_uri=OOB_CALLBACK_URN, 881 user_agent='unittest-sample/1.0', 882 revoke_uri='dummy_revoke_uri', 883 ) 884 885 def test_construct_authorize_url(self): 886 authorize_url = self.flow.step1_get_authorize_url() 887 888 parsed = urllib.parse.urlparse(authorize_url) 889 q = urllib.parse.parse_qs(parsed[4]) 890 self.assertEqual('client_id+1', q['client_id'][0]) 891 self.assertEqual('code', q['response_type'][0]) 892 self.assertEqual('foo', q['scope'][0]) 893 self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0]) 894 self.assertEqual('offline', q['access_type'][0]) 895 896 def test_override_flow_via_kwargs(self): 897 """Passing kwargs to override defaults.""" 898 flow = OAuth2WebServerFlow( 899 client_id='client_id+1', 900 client_secret='secret+1', 901 scope='foo', 902 redirect_uri=OOB_CALLBACK_URN, 903 user_agent='unittest-sample/1.0', 904 access_type='online', 905 response_type='token' 906 ) 907 authorize_url = flow.step1_get_authorize_url() 908 909 parsed = urllib.parse.urlparse(authorize_url) 910 q = urllib.parse.parse_qs(parsed[4]) 911 self.assertEqual('client_id+1', q['client_id'][0]) 912 self.assertEqual('token', q['response_type'][0]) 913 self.assertEqual('foo', q['scope'][0]) 914 self.assertEqual(OOB_CALLBACK_URN, q['redirect_uri'][0]) 915 self.assertEqual('online', q['access_type'][0]) 916 917 def test_exchange_failure(self): 918 http = HttpMockSequence([ 919 ({'status': '400'}, b'{"error":"invalid_request"}'), 920 ]) 921 922 try: 923 credentials = self.flow.step2_exchange('some random code', http=http) 924 self.fail('should raise exception if exchange doesn\'t get 200') 925 except FlowExchangeError: 926 pass 927 928 def test_urlencoded_exchange_failure(self): 929 http = HttpMockSequence([ 930 ({'status': '400'}, b'error=invalid_request'), 931 ]) 932 933 try: 934 credentials = self.flow.step2_exchange('some random code', http=http) 935 self.fail('should raise exception if exchange doesn\'t get 200') 936 except FlowExchangeError as e: 937 self.assertEqual('invalid_request', str(e)) 938 939 def test_exchange_failure_with_json_error(self): 940 # Some providers have 'error' attribute as a JSON object 941 # in place of regular string. 942 # This test makes sure no strange object-to-string coversion 943 # exceptions are being raised instead of FlowExchangeError. 944 http = HttpMockSequence([ 945 ({'status': '400'}, 946 b""" {"error": { 947 "type": "OAuthException", 948 "message": "Error validating verification code."} }"""), 949 ]) 950 951 try: 952 credentials = self.flow.step2_exchange('some random code', http=http) 953 self.fail('should raise exception if exchange doesn\'t get 200') 954 except FlowExchangeError as e: 955 pass 956 957 def test_exchange_success(self): 958 http = HttpMockSequence([ 959 ({'status': '200'}, 960 b"""{ "access_token":"SlAV32hkKG", 961 "expires_in":3600, 962 "refresh_token":"8xLOxBtZp8" }"""), 963 ]) 964 965 credentials = self.flow.step2_exchange('some random code', http=http) 966 self.assertEqual('SlAV32hkKG', credentials.access_token) 967 self.assertNotEqual(None, credentials.token_expiry) 968 self.assertEqual('8xLOxBtZp8', credentials.refresh_token) 969 self.assertEqual('dummy_revoke_uri', credentials.revoke_uri) 970 971 def test_exchange_dictlike(self): 972 class FakeDict(object): 973 def __init__(self, d): 974 self.d = d 975 976 def __getitem__(self, name): 977 return self.d[name] 978 979 def __contains__(self, name): 980 return name in self.d 981 982 code = 'some random code' 983 not_a_dict = FakeDict({'code': code}) 984 payload = (b'{' 985 b' "access_token":"SlAV32hkKG",' 986 b' "expires_in":3600,' 987 b' "refresh_token":"8xLOxBtZp8"' 988 b'}') 989 http = HttpMockSequence([({'status': '200'}, payload),]) 990 991 credentials = self.flow.step2_exchange(not_a_dict, http=http) 992 self.assertEqual('SlAV32hkKG', credentials.access_token) 993 self.assertNotEqual(None, credentials.token_expiry) 994 self.assertEqual('8xLOxBtZp8', credentials.refresh_token) 995 self.assertEqual('dummy_revoke_uri', credentials.revoke_uri) 996 request_code = urllib.parse.parse_qs(http.requests[0]['body'])['code'][0] 997 self.assertEqual(code, request_code) 998 999 def test_urlencoded_exchange_success(self): 1000 http = HttpMockSequence([ 1001 ({'status': '200'}, b'access_token=SlAV32hkKG&expires_in=3600'), 1002 ]) 1003 1004 credentials = self.flow.step2_exchange('some random code', http=http) 1005 self.assertEqual('SlAV32hkKG', credentials.access_token) 1006 self.assertNotEqual(None, credentials.token_expiry) 1007 1008 def test_urlencoded_expires_param(self): 1009 http = HttpMockSequence([ 1010 # Note the 'expires=3600' where you'd normally 1011 # have if named 'expires_in' 1012 ({'status': '200'}, b'access_token=SlAV32hkKG&expires=3600'), 1013 ]) 1014 1015 credentials = self.flow.step2_exchange('some random code', http=http) 1016 self.assertNotEqual(None, credentials.token_expiry) 1017 1018 def test_exchange_no_expires_in(self): 1019 http = HttpMockSequence([ 1020 ({'status': '200'}, b"""{ "access_token":"SlAV32hkKG", 1021 "refresh_token":"8xLOxBtZp8" }"""), 1022 ]) 1023 1024 credentials = self.flow.step2_exchange('some random code', http=http) 1025 self.assertEqual(None, credentials.token_expiry) 1026 1027 def test_urlencoded_exchange_no_expires_in(self): 1028 http = HttpMockSequence([ 1029 # This might be redundant but just to make sure 1030 # urlencoded access_token gets parsed correctly 1031 ({'status': '200'}, b'access_token=SlAV32hkKG'), 1032 ]) 1033 1034 credentials = self.flow.step2_exchange('some random code', http=http) 1035 self.assertEqual(None, credentials.token_expiry) 1036 1037 def test_exchange_fails_if_no_code(self): 1038 http = HttpMockSequence([ 1039 ({'status': '200'}, b"""{ "access_token":"SlAV32hkKG", 1040 "refresh_token":"8xLOxBtZp8" }"""), 1041 ]) 1042 1043 code = {'error': 'thou shall not pass'} 1044 try: 1045 credentials = self.flow.step2_exchange(code, http=http) 1046 self.fail('should raise exception if no code in dictionary.') 1047 except FlowExchangeError as e: 1048 self.assertTrue('shall not pass' in str(e)) 1049 1050 def test_exchange_id_token_fail(self): 1051 http = HttpMockSequence([ 1052 ({'status': '200'}, b"""{ "access_token":"SlAV32hkKG", 1053 "refresh_token":"8xLOxBtZp8", 1054 "id_token": "stuff.payload"}"""), 1055 ]) 1056 1057 self.assertRaises(VerifyJwtTokenError, self.flow.step2_exchange, 1058 'some random code', http=http) 1059 1060 def test_exchange_id_token(self): 1061 body = {'foo': 'bar'} 1062 body_json = json.dumps(body).encode('ascii') 1063 payload = base64.urlsafe_b64encode(body_json).strip(b'=') 1064 jwt = (base64.urlsafe_b64encode(b'stuff') + b'.' + payload + b'.' + 1065 base64.urlsafe_b64encode(b'signature')) 1066 1067 http = HttpMockSequence([ 1068 ({'status': '200'}, ("""{ "access_token":"SlAV32hkKG", 1069 "refresh_token":"8xLOxBtZp8", 1070 "id_token": "%s"}""" % jwt).encode('utf-8')), 1071 ]) 1072 1073 credentials = self.flow.step2_exchange('some random code', http=http) 1074 self.assertEqual(credentials.id_token, body) 1075 1076 1077 class FlowFromCachedClientsecrets(unittest.TestCase): 1078 1079 def test_flow_from_clientsecrets_cached(self): 1080 cache_mock = CacheMock() 1081 load_and_cache('client_secrets.json', 'some_secrets', cache_mock) 1082 1083 flow = flow_from_clientsecrets( 1084 'some_secrets', '', redirect_uri='oob', cache=cache_mock) 1085 self.assertEqual('foo_client_secret', flow.client_secret) 1086 1087 1088 class CredentialsFromCodeTests(unittest.TestCase): 1089 def setUp(self): 1090 self.client_id = 'client_id_abc' 1091 self.client_secret = 'secret_use_code' 1092 self.scope = 'foo' 1093 self.code = '12345abcde' 1094 self.redirect_uri = 'postmessage' 1095 1096 def test_exchange_code_for_token(self): 1097 token = 'asdfghjkl' 1098 payload = json.dumps({'access_token': token, 'expires_in': 3600}) 1099 http = HttpMockSequence([ 1100 ({'status': '200'}, payload.encode('utf-8')), 1101 ]) 1102 credentials = credentials_from_code(self.client_id, self.client_secret, 1103 self.scope, self.code, redirect_uri=self.redirect_uri, 1104 http=http) 1105 self.assertEqual(credentials.access_token, token) 1106 self.assertNotEqual(None, credentials.token_expiry) 1107 1108 def test_exchange_code_for_token_fail(self): 1109 http = HttpMockSequence([ 1110 ({'status': '400'}, b'{"error":"invalid_request"}'), 1111 ]) 1112 1113 try: 1114 credentials = credentials_from_code(self.client_id, self.client_secret, 1115 self.scope, self.code, redirect_uri=self.redirect_uri, 1116 http=http) 1117 self.fail('should raise exception if exchange doesn\'t get 200') 1118 except FlowExchangeError: 1119 pass 1120 1121 def test_exchange_code_and_file_for_token(self): 1122 http = HttpMockSequence([ 1123 ({'status': '200'}, 1124 b"""{ "access_token":"asdfghjkl", 1125 "expires_in":3600 }"""), 1126 ]) 1127 credentials = credentials_from_clientsecrets_and_code( 1128 datafile('client_secrets.json'), self.scope, 1129 self.code, http=http) 1130 self.assertEqual(credentials.access_token, 'asdfghjkl') 1131 self.assertNotEqual(None, credentials.token_expiry) 1132 1133 def test_exchange_code_and_cached_file_for_token(self): 1134 http = HttpMockSequence([ 1135 ({'status': '200'}, b'{ "access_token":"asdfghjkl"}'), 1136 ]) 1137 cache_mock = CacheMock() 1138 load_and_cache('client_secrets.json', 'some_secrets', cache_mock) 1139 1140 credentials = credentials_from_clientsecrets_and_code( 1141 'some_secrets', self.scope, 1142 self.code, http=http, cache=cache_mock) 1143 self.assertEqual(credentials.access_token, 'asdfghjkl') 1144 1145 def test_exchange_code_and_file_for_token_fail(self): 1146 http = HttpMockSequence([ 1147 ({'status': '400'}, b'{"error":"invalid_request"}'), 1148 ]) 1149 1150 try: 1151 credentials = credentials_from_clientsecrets_and_code( 1152 datafile('client_secrets.json'), self.scope, 1153 self.code, http=http) 1154 self.fail('should raise exception if exchange doesn\'t get 200') 1155 except FlowExchangeError: 1156 pass 1157 1158 1159 class MemoryCacheTests(unittest.TestCase): 1160 1161 def test_get_set_delete(self): 1162 m = MemoryCache() 1163 self.assertEqual(None, m.get('foo')) 1164 self.assertEqual(None, m.delete('foo')) 1165 m.set('foo', 'bar') 1166 self.assertEqual('bar', m.get('foo')) 1167 m.delete('foo') 1168 self.assertEqual(None, m.get('foo')) 1169 1170 1171 class Test__save_private_file(unittest.TestCase): 1172 1173 def _save_helper(self, filename): 1174 contents = [] 1175 contents_str = '[]' 1176 client._save_private_file(filename, contents) 1177 with open(filename, 'r') as f: 1178 stored_contents = f.read() 1179 self.assertEqual(stored_contents, contents_str) 1180 1181 stat_mode = os.stat(filename).st_mode 1182 # Octal 777, only last 3 positions matter for permissions mask. 1183 stat_mode &= 0o777 1184 self.assertEqual(stat_mode, 0o600) 1185 1186 def test_new(self): 1187 import tempfile 1188 filename = tempfile.mktemp() 1189 self.assertFalse(os.path.exists(filename)) 1190 self._save_helper(filename) 1191 1192 def test_existing(self): 1193 import tempfile 1194 filename = tempfile.mktemp() 1195 with open(filename, 'w') as f: 1196 f.write('a bunch of nonsense longer than []') 1197 self.assertTrue(os.path.exists(filename)) 1198 self._save_helper(filename) 1199 1200 1201 if __name__ == '__main__': 1202 unittest.main() 1203