1 """Tests for oauth2l.""" 2 3 import json 4 import os 5 import sys 6 7 import mock 8 import oauth2client.client 9 import six 10 from six.moves import http_client 11 import unittest2 12 13 import apitools.base.py as apitools_base 14 15 _OAUTH2L_MAIN_RUN = False 16 17 if six.PY2: 18 import gflags as flags 19 from google.apputils import appcommands 20 from apitools.scripts import oauth2l 21 FLAGS = flags.FLAGS 22 23 24 class _FakeResponse(object): 25 26 def __init__(self, status_code, scopes=None): 27 self.status_code = status_code 28 if self.status_code == http_client.OK: 29 self.content = json.dumps({'scope': ' '.join(scopes or [])}) 30 else: 31 self.content = 'Error' 32 self.info = str(http_client.responses[self.status_code]) 33 self.request_url = 'some-url' 34 35 36 def _GetCommandOutput(t, command_name, command_argv): 37 global _OAUTH2L_MAIN_RUN # pylint: disable=global-statement 38 if not _OAUTH2L_MAIN_RUN: 39 oauth2l.main(None) 40 _OAUTH2L_MAIN_RUN = True 41 command = appcommands.GetCommandByName(command_name) 42 if command is None: 43 t.fail('Unknown command: %s' % command_name) 44 orig_stdout = sys.stdout 45 new_stdout = six.StringIO() 46 try: 47 sys.stdout = new_stdout 48 command.CommandRun([command_name] + command_argv) 49 finally: 50 sys.stdout = orig_stdout 51 FLAGS.Reset() 52 new_stdout.seek(0) 53 return new_stdout.getvalue().rstrip() 54 55 56 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3') 57 class TestTest(unittest2.TestCase): 58 59 def testOutput(self): 60 self.assertRaises(AssertionError, 61 _GetCommandOutput, self, 'foo', []) 62 63 64 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3') 65 class Oauth2lFormattingTest(unittest2.TestCase): 66 67 def setUp(self): 68 # Set up an access token to use 69 self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz' 70 self.user_agent = 'oauth2l/1.0' 71 self.credentials = oauth2client.client.AccessTokenCredentials( 72 self.access_token, self.user_agent) 73 74 def _Args(self, credentials_format): 75 return ['--credentials_format=' + credentials_format, 'userinfo.email'] 76 77 def testFormatBare(self): 78 with mock.patch.object(oauth2l, 'FetchCredentials', 79 return_value=self.credentials, 80 autospec=True) as mock_credentials: 81 output = _GetCommandOutput(self, 'fetch', self._Args('bare')) 82 self.assertEqual(self.access_token, output) 83 self.assertEqual(1, mock_credentials.call_count) 84 85 def testFormatHeader(self): 86 with mock.patch.object(oauth2l, 'FetchCredentials', 87 return_value=self.credentials, 88 autospec=True) as mock_credentials: 89 output = _GetCommandOutput(self, 'fetch', self._Args('header')) 90 header = 'Authorization: Bearer %s' % self.access_token 91 self.assertEqual(header, output) 92 self.assertEqual(1, mock_credentials.call_count) 93 94 def testHeaderCommand(self): 95 with mock.patch.object(oauth2l, 'FetchCredentials', 96 return_value=self.credentials, 97 autospec=True) as mock_credentials: 98 output = _GetCommandOutput(self, 'header', ['userinfo.email']) 99 header = 'Authorization: Bearer %s' % self.access_token 100 self.assertEqual(header, output) 101 self.assertEqual(1, mock_credentials.call_count) 102 103 def testFormatJson(self): 104 with mock.patch.object(oauth2l, 'FetchCredentials', 105 return_value=self.credentials, 106 autospec=True) as mock_credentials: 107 output = _GetCommandOutput(self, 'fetch', self._Args('json')) 108 output_lines = [l.strip() for l in output.splitlines()] 109 expected_lines = [ 110 '"_class": "AccessTokenCredentials",', 111 '"access_token": "%s",' % self.access_token, 112 ] 113 for line in expected_lines: 114 self.assertIn(line, output_lines) 115 self.assertEqual(1, mock_credentials.call_count) 116 117 def testFormatJsonCompact(self): 118 with mock.patch.object(oauth2l, 'FetchCredentials', 119 return_value=self.credentials, 120 autospec=True) as mock_credentials: 121 output = _GetCommandOutput(self, 'fetch', 122 self._Args('json_compact')) 123 expected_clauses = [ 124 '"_class":"AccessTokenCredentials",', 125 '"access_token":"%s",' % self.access_token, 126 ] 127 for clause in expected_clauses: 128 self.assertIn(clause, output) 129 self.assertEqual(1, len(output.splitlines())) 130 self.assertEqual(1, mock_credentials.call_count) 131 132 def testFormatPretty(self): 133 with mock.patch.object(oauth2l, 'FetchCredentials', 134 return_value=self.credentials, 135 autospec=True) as mock_credentials: 136 output = _GetCommandOutput(self, 'fetch', self._Args('pretty')) 137 expecteds = ['oauth2client.client.AccessTokenCredentials', 138 self.access_token] 139 for expected in expecteds: 140 self.assertIn(expected, output) 141 self.assertEqual(1, mock_credentials.call_count) 142 143 def testFakeFormat(self): 144 self.assertRaises(ValueError, 145 oauth2l._Format, 'xml', self.credentials) 146 147 148 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3') 149 class TestFetch(unittest2.TestCase): 150 151 def setUp(self): 152 # Set up an access token to use 153 self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz' 154 self.user_agent = 'oauth2l/1.0' 155 self.credentials = oauth2client.client.AccessTokenCredentials( 156 self.access_token, self.user_agent) 157 158 def testNoScopes(self): 159 output = _GetCommandOutput(self, 'fetch', []) 160 self.assertEqual( 161 'Exception raised in fetch operation: No scopes provided', 162 output) 163 164 def testScopes(self): 165 expected_scopes = [ 166 'https://www.googleapis.com/auth/userinfo.email', 167 'https://www.googleapis.com/auth/cloud-platform', 168 ] 169 with mock.patch.object(apitools_base, 'GetCredentials', 170 return_value=self.credentials, 171 autospec=True) as mock_fetch: 172 with mock.patch.object(oauth2l, '_GetTokenScopes', 173 return_value=expected_scopes, 174 autospec=True) as mock_get_scopes: 175 output = _GetCommandOutput( 176 self, 'fetch', ['userinfo.email', 'cloud-platform']) 177 self.assertIn(self.access_token, output) 178 self.assertEqual(1, mock_fetch.call_count) 179 args, _ = mock_fetch.call_args 180 self.assertEqual(expected_scopes, args[-1]) 181 self.assertEqual(1, mock_get_scopes.call_count) 182 self.assertEqual((self.access_token,), 183 mock_get_scopes.call_args[0]) 184 185 def testCredentialsRefreshed(self): 186 with mock.patch.object(apitools_base, 'GetCredentials', 187 return_value=self.credentials, 188 autospec=True) as mock_fetch: 189 with mock.patch.object(oauth2l, '_ValidateToken', 190 return_value=False, 191 autospec=True) as mock_validate: 192 with mock.patch.object(self.credentials, 'refresh', 193 return_value=None, 194 autospec=True) as mock_refresh: 195 output = _GetCommandOutput(self, 'fetch', 196 ['userinfo.email']) 197 self.assertIn(self.access_token, output) 198 self.assertEqual(1, mock_fetch.call_count) 199 self.assertEqual(1, mock_validate.call_count) 200 self.assertEqual(1, mock_refresh.call_count) 201 202 def testDefaultClientInfo(self): 203 with mock.patch.object(apitools_base, 'GetCredentials', 204 return_value=self.credentials, 205 autospec=True) as mock_fetch: 206 with mock.patch.object(oauth2l, '_ValidateToken', 207 return_value=True, 208 autospec=True) as mock_validate: 209 output = _GetCommandOutput(self, 'fetch', ['userinfo.email']) 210 self.assertIn(self.access_token, output) 211 self.assertEqual(1, mock_fetch.call_count) 212 _, kwargs = mock_fetch.call_args 213 self.assertEqual( 214 '1042881264118.apps.googleusercontent.com', 215 kwargs['client_id']) 216 self.assertEqual(1, mock_validate.call_count) 217 218 def testMissingClientSecrets(self): 219 try: 220 FLAGS.client_secrets = '/non/existent/file' 221 self.assertRaises( 222 ValueError, 223 oauth2l.GetClientInfoFromFlags) 224 finally: 225 FLAGS.Reset() 226 227 def testWrongClientSecretsFormat(self): 228 client_secrets_path = os.path.join( 229 os.path.dirname(__file__), 230 'testdata/noninstalled_client_secrets.json') 231 try: 232 FLAGS.client_secrets = client_secrets_path 233 self.assertRaises( 234 ValueError, 235 oauth2l.GetClientInfoFromFlags) 236 finally: 237 FLAGS.Reset() 238 239 def testCustomClientInfo(self): 240 client_secrets_path = os.path.join( 241 os.path.dirname(__file__), 'testdata/fake_client_secrets.json') 242 with mock.patch.object(apitools_base, 'GetCredentials', 243 return_value=self.credentials, 244 autospec=True) as mock_fetch: 245 with mock.patch.object(oauth2l, '_ValidateToken', 246 return_value=True, 247 autospec=True) as mock_validate: 248 fetch_args = [ 249 '--client_secrets=' + client_secrets_path, 250 'userinfo.email'] 251 output = _GetCommandOutput(self, 'fetch', fetch_args) 252 self.assertIn(self.access_token, output) 253 self.assertEqual(1, mock_fetch.call_count) 254 _, kwargs = mock_fetch.call_args 255 self.assertEqual('144169.apps.googleusercontent.com', 256 kwargs['client_id']) 257 self.assertEqual('awesomesecret', 258 kwargs['client_secret']) 259 self.assertEqual(1, mock_validate.call_count) 260 261 262 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3') 263 class TestOtherCommands(unittest2.TestCase): 264 265 def setUp(self): 266 # Set up an access token to use 267 self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz' 268 self.user_agent = 'oauth2l/1.0' 269 self.credentials = oauth2client.client.AccessTokenCredentials( 270 self.access_token, self.user_agent) 271 272 def testEmail(self): 273 user_info = {'email': 'foo (at] example.com'} 274 with mock.patch.object(apitools_base, 'GetUserinfo', 275 return_value=user_info, 276 autospec=True) as mock_get_userinfo: 277 output = _GetCommandOutput(self, 'email', [self.access_token]) 278 self.assertEqual(user_info['email'], output) 279 self.assertEqual(1, mock_get_userinfo.call_count) 280 self.assertEqual(self.access_token, 281 mock_get_userinfo.call_args[0][0].access_token) 282 283 def testNoEmail(self): 284 with mock.patch.object(apitools_base, 'GetUserinfo', 285 return_value={}, 286 autospec=True) as mock_get_userinfo: 287 output = _GetCommandOutput(self, 'email', [self.access_token]) 288 self.assertEqual('', output) 289 self.assertEqual(1, mock_get_userinfo.call_count) 290 291 def testUserinfo(self): 292 user_info = {'email': 'foo (at] example.com'} 293 with mock.patch.object(apitools_base, 'GetUserinfo', 294 return_value=user_info, 295 autospec=True) as mock_get_userinfo: 296 output = _GetCommandOutput(self, 'userinfo', [self.access_token]) 297 self.assertEqual(json.dumps(user_info, indent=4), output) 298 self.assertEqual(1, mock_get_userinfo.call_count) 299 self.assertEqual(self.access_token, 300 mock_get_userinfo.call_args[0][0].access_token) 301 302 def testUserinfoCompact(self): 303 user_info = {'email': 'foo (at] example.com'} 304 with mock.patch.object(apitools_base, 'GetUserinfo', 305 return_value=user_info, 306 autospec=True) as mock_get_userinfo: 307 output = _GetCommandOutput( 308 self, 'userinfo', ['--format=json_compact', self.access_token]) 309 self.assertEqual(json.dumps(user_info, separators=(',', ':')), 310 output) 311 self.assertEqual(1, mock_get_userinfo.call_count) 312 self.assertEqual(self.access_token, 313 mock_get_userinfo.call_args[0][0].access_token) 314 315 def testScopes(self): 316 scopes = [u'https://www.googleapis.com/auth/userinfo.email', 317 u'https://www.googleapis.com/auth/cloud-platform'] 318 response = _FakeResponse(http_client.OK, scopes=scopes) 319 with mock.patch.object(apitools_base, 'MakeRequest', 320 return_value=response, 321 autospec=True) as mock_make_request: 322 output = _GetCommandOutput(self, 'scopes', [self.access_token]) 323 self.assertEqual(sorted(scopes), output.splitlines()) 324 self.assertEqual(1, mock_make_request.call_count) 325 326 def testValidate(self): 327 scopes = [u'https://www.googleapis.com/auth/userinfo.email', 328 u'https://www.googleapis.com/auth/cloud-platform'] 329 response = _FakeResponse(http_client.OK, scopes=scopes) 330 with mock.patch.object(apitools_base, 'MakeRequest', 331 return_value=response, 332 autospec=True) as mock_make_request: 333 output = _GetCommandOutput(self, 'validate', [self.access_token]) 334 self.assertEqual('', output) 335 self.assertEqual(1, mock_make_request.call_count) 336 337 def testBadResponseCode(self): 338 response = _FakeResponse(http_client.BAD_REQUEST) 339 with mock.patch.object(apitools_base, 'MakeRequest', 340 return_value=response, 341 autospec=True) as mock_make_request: 342 output = _GetCommandOutput(self, 'scopes', [self.access_token]) 343 self.assertEqual('', output) 344 self.assertEqual(1, mock_make_request.call_count) 345 346 def testUnexpectedResponseCode(self): 347 response = _FakeResponse(http_client.INTERNAL_SERVER_ERROR) 348 with mock.patch.object(apitools_base, 'MakeRequest', 349 return_value=response, 350 autospec=True) as mock_make_request: 351 output = _GetCommandOutput(self, 'scopes', [self.access_token]) 352 self.assertIn(str(http_client.responses[response.status_code]), 353 output) 354 self.assertIn('Exception raised in scopes operation: HttpError', 355 output) 356 self.assertEqual(1, mock_make_request.call_count) 357