Home | History | Annotate | Download | only in scripts
      1 """Tests for oauth2l."""
      2 
      3 import json
      4 import os
      5 import sys
      6 
      7 import mock
      8 import oauth2client.client
      9 import six
     10 from six.moves import http_client
     11 import unittest2
     12 
     13 import apitools.base.py as apitools_base
     14 
     15 _OAUTH2L_MAIN_RUN = False
     16 
     17 if six.PY2:
     18     import gflags as flags
     19     from google.apputils import appcommands
     20     from apitools.scripts import oauth2l
     21     FLAGS = flags.FLAGS
     22 
     23 
     24 class _FakeResponse(object):
     25 
     26     def __init__(self, status_code, scopes=None):
     27         self.status_code = status_code
     28         if self.status_code == http_client.OK:
     29             self.content = json.dumps({'scope': ' '.join(scopes or [])})
     30         else:
     31             self.content = 'Error'
     32             self.info = str(http_client.responses[self.status_code])
     33             self.request_url = 'some-url'
     34 
     35 
     36 def _GetCommandOutput(t, command_name, command_argv):
     37     global _OAUTH2L_MAIN_RUN  # pylint: disable=global-statement
     38     if not _OAUTH2L_MAIN_RUN:
     39         oauth2l.main(None)
     40         _OAUTH2L_MAIN_RUN = True
     41     command = appcommands.GetCommandByName(command_name)
     42     if command is None:
     43         t.fail('Unknown command: %s' % command_name)
     44     orig_stdout = sys.stdout
     45     new_stdout = six.StringIO()
     46     try:
     47         sys.stdout = new_stdout
     48         command.CommandRun([command_name] + command_argv)
     49     finally:
     50         sys.stdout = orig_stdout
     51         FLAGS.Reset()
     52     new_stdout.seek(0)
     53     return new_stdout.getvalue().rstrip()
     54 
     55 
     56 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3')
     57 class TestTest(unittest2.TestCase):
     58 
     59     def testOutput(self):
     60         self.assertRaises(AssertionError,
     61                           _GetCommandOutput, self, 'foo', [])
     62 
     63 
     64 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3')
     65 class Oauth2lFormattingTest(unittest2.TestCase):
     66 
     67     def setUp(self):
     68         # Set up an access token to use
     69         self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz'
     70         self.user_agent = 'oauth2l/1.0'
     71         self.credentials = oauth2client.client.AccessTokenCredentials(
     72             self.access_token, self.user_agent)
     73 
     74     def _Args(self, credentials_format):
     75         return ['--credentials_format=' + credentials_format, 'userinfo.email']
     76 
     77     def testFormatBare(self):
     78         with mock.patch.object(oauth2l, 'FetchCredentials',
     79                                return_value=self.credentials,
     80                                autospec=True) as mock_credentials:
     81             output = _GetCommandOutput(self, 'fetch', self._Args('bare'))
     82             self.assertEqual(self.access_token, output)
     83             self.assertEqual(1, mock_credentials.call_count)
     84 
     85     def testFormatHeader(self):
     86         with mock.patch.object(oauth2l, 'FetchCredentials',
     87                                return_value=self.credentials,
     88                                autospec=True) as mock_credentials:
     89             output = _GetCommandOutput(self, 'fetch', self._Args('header'))
     90             header = 'Authorization: Bearer %s' % self.access_token
     91             self.assertEqual(header, output)
     92             self.assertEqual(1, mock_credentials.call_count)
     93 
     94     def testHeaderCommand(self):
     95         with mock.patch.object(oauth2l, 'FetchCredentials',
     96                                return_value=self.credentials,
     97                                autospec=True) as mock_credentials:
     98             output = _GetCommandOutput(self, 'header', ['userinfo.email'])
     99             header = 'Authorization: Bearer %s' % self.access_token
    100             self.assertEqual(header, output)
    101             self.assertEqual(1, mock_credentials.call_count)
    102 
    103     def testFormatJson(self):
    104         with mock.patch.object(oauth2l, 'FetchCredentials',
    105                                return_value=self.credentials,
    106                                autospec=True) as mock_credentials:
    107             output = _GetCommandOutput(self, 'fetch', self._Args('json'))
    108             output_lines = [l.strip() for l in output.splitlines()]
    109             expected_lines = [
    110                 '"_class": "AccessTokenCredentials",',
    111                 '"access_token": "%s",' % self.access_token,
    112             ]
    113             for line in expected_lines:
    114                 self.assertIn(line, output_lines)
    115             self.assertEqual(1, mock_credentials.call_count)
    116 
    117     def testFormatJsonCompact(self):
    118         with mock.patch.object(oauth2l, 'FetchCredentials',
    119                                return_value=self.credentials,
    120                                autospec=True) as mock_credentials:
    121             output = _GetCommandOutput(self, 'fetch',
    122                                        self._Args('json_compact'))
    123             expected_clauses = [
    124                 '"_class":"AccessTokenCredentials",',
    125                 '"access_token":"%s",' % self.access_token,
    126             ]
    127             for clause in expected_clauses:
    128                 self.assertIn(clause, output)
    129             self.assertEqual(1, len(output.splitlines()))
    130             self.assertEqual(1, mock_credentials.call_count)
    131 
    132     def testFormatPretty(self):
    133         with mock.patch.object(oauth2l, 'FetchCredentials',
    134                                return_value=self.credentials,
    135                                autospec=True) as mock_credentials:
    136             output = _GetCommandOutput(self, 'fetch', self._Args('pretty'))
    137             expecteds = ['oauth2client.client.AccessTokenCredentials',
    138                          self.access_token]
    139             for expected in expecteds:
    140                 self.assertIn(expected, output)
    141             self.assertEqual(1, mock_credentials.call_count)
    142 
    143     def testFakeFormat(self):
    144         self.assertRaises(ValueError,
    145                           oauth2l._Format, 'xml', self.credentials)
    146 
    147 
    148 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3')
    149 class TestFetch(unittest2.TestCase):
    150 
    151     def setUp(self):
    152         # Set up an access token to use
    153         self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz'
    154         self.user_agent = 'oauth2l/1.0'
    155         self.credentials = oauth2client.client.AccessTokenCredentials(
    156             self.access_token, self.user_agent)
    157 
    158     def testNoScopes(self):
    159         output = _GetCommandOutput(self, 'fetch', [])
    160         self.assertEqual(
    161             'Exception raised in fetch operation: No scopes provided',
    162             output)
    163 
    164     def testScopes(self):
    165         expected_scopes = [
    166             'https://www.googleapis.com/auth/userinfo.email',
    167             'https://www.googleapis.com/auth/cloud-platform',
    168         ]
    169         with mock.patch.object(apitools_base, 'GetCredentials',
    170                                return_value=self.credentials,
    171                                autospec=True) as mock_fetch:
    172             with mock.patch.object(oauth2l, '_GetTokenScopes',
    173                                    return_value=expected_scopes,
    174                                    autospec=True) as mock_get_scopes:
    175                 output = _GetCommandOutput(
    176                     self, 'fetch', ['userinfo.email', 'cloud-platform'])
    177                 self.assertIn(self.access_token, output)
    178                 self.assertEqual(1, mock_fetch.call_count)
    179                 args, _ = mock_fetch.call_args
    180                 self.assertEqual(expected_scopes, args[-1])
    181                 self.assertEqual(1, mock_get_scopes.call_count)
    182                 self.assertEqual((self.access_token,),
    183                                  mock_get_scopes.call_args[0])
    184 
    185     def testCredentialsRefreshed(self):
    186         with mock.patch.object(apitools_base, 'GetCredentials',
    187                                return_value=self.credentials,
    188                                autospec=True) as mock_fetch:
    189             with mock.patch.object(oauth2l, '_ValidateToken',
    190                                    return_value=False,
    191                                    autospec=True) as mock_validate:
    192                 with mock.patch.object(self.credentials, 'refresh',
    193                                        return_value=None,
    194                                        autospec=True) as mock_refresh:
    195                     output = _GetCommandOutput(self, 'fetch',
    196                                                ['userinfo.email'])
    197                     self.assertIn(self.access_token, output)
    198                     self.assertEqual(1, mock_fetch.call_count)
    199                     self.assertEqual(1, mock_validate.call_count)
    200                     self.assertEqual(1, mock_refresh.call_count)
    201 
    202     def testDefaultClientInfo(self):
    203         with mock.patch.object(apitools_base, 'GetCredentials',
    204                                return_value=self.credentials,
    205                                autospec=True) as mock_fetch:
    206             with mock.patch.object(oauth2l, '_ValidateToken',
    207                                    return_value=True,
    208                                    autospec=True) as mock_validate:
    209                 output = _GetCommandOutput(self, 'fetch', ['userinfo.email'])
    210                 self.assertIn(self.access_token, output)
    211                 self.assertEqual(1, mock_fetch.call_count)
    212                 _, kwargs = mock_fetch.call_args
    213                 self.assertEqual(
    214                     '1042881264118.apps.googleusercontent.com',
    215                     kwargs['client_id'])
    216                 self.assertEqual(1, mock_validate.call_count)
    217 
    218     def testMissingClientSecrets(self):
    219         try:
    220             FLAGS.client_secrets = '/non/existent/file'
    221             self.assertRaises(
    222                 ValueError,
    223                 oauth2l.GetClientInfoFromFlags)
    224         finally:
    225             FLAGS.Reset()
    226 
    227     def testWrongClientSecretsFormat(self):
    228         client_secrets_path = os.path.join(
    229             os.path.dirname(__file__),
    230             'testdata/noninstalled_client_secrets.json')
    231         try:
    232             FLAGS.client_secrets = client_secrets_path
    233             self.assertRaises(
    234                 ValueError,
    235                 oauth2l.GetClientInfoFromFlags)
    236         finally:
    237             FLAGS.Reset()
    238 
    239     def testCustomClientInfo(self):
    240         client_secrets_path = os.path.join(
    241             os.path.dirname(__file__), 'testdata/fake_client_secrets.json')
    242         with mock.patch.object(apitools_base, 'GetCredentials',
    243                                return_value=self.credentials,
    244                                autospec=True) as mock_fetch:
    245             with mock.patch.object(oauth2l, '_ValidateToken',
    246                                    return_value=True,
    247                                    autospec=True) as mock_validate:
    248                 fetch_args = [
    249                     '--client_secrets=' + client_secrets_path,
    250                     'userinfo.email']
    251                 output = _GetCommandOutput(self, 'fetch', fetch_args)
    252                 self.assertIn(self.access_token, output)
    253                 self.assertEqual(1, mock_fetch.call_count)
    254                 _, kwargs = mock_fetch.call_args
    255                 self.assertEqual('144169.apps.googleusercontent.com',
    256                                  kwargs['client_id'])
    257                 self.assertEqual('awesomesecret',
    258                                  kwargs['client_secret'])
    259                 self.assertEqual(1, mock_validate.call_count)
    260 
    261 
    262 @unittest2.skipIf(six.PY3, 'oauth2l unsupported in python3')
    263 class TestOtherCommands(unittest2.TestCase):
    264 
    265     def setUp(self):
    266         # Set up an access token to use
    267         self.access_token = 'ya29.abdefghijklmnopqrstuvwxyz'
    268         self.user_agent = 'oauth2l/1.0'
    269         self.credentials = oauth2client.client.AccessTokenCredentials(
    270             self.access_token, self.user_agent)
    271 
    272     def testEmail(self):
    273         user_info = {'email': 'foo (at] example.com'}
    274         with mock.patch.object(apitools_base, 'GetUserinfo',
    275                                return_value=user_info,
    276                                autospec=True) as mock_get_userinfo:
    277             output = _GetCommandOutput(self, 'email', [self.access_token])
    278             self.assertEqual(user_info['email'], output)
    279             self.assertEqual(1, mock_get_userinfo.call_count)
    280             self.assertEqual(self.access_token,
    281                              mock_get_userinfo.call_args[0][0].access_token)
    282 
    283     def testNoEmail(self):
    284         with mock.patch.object(apitools_base, 'GetUserinfo',
    285                                return_value={},
    286                                autospec=True) as mock_get_userinfo:
    287             output = _GetCommandOutput(self, 'email', [self.access_token])
    288             self.assertEqual('', output)
    289             self.assertEqual(1, mock_get_userinfo.call_count)
    290 
    291     def testUserinfo(self):
    292         user_info = {'email': 'foo (at] example.com'}
    293         with mock.patch.object(apitools_base, 'GetUserinfo',
    294                                return_value=user_info,
    295                                autospec=True) as mock_get_userinfo:
    296             output = _GetCommandOutput(self, 'userinfo', [self.access_token])
    297             self.assertEqual(json.dumps(user_info, indent=4), output)
    298             self.assertEqual(1, mock_get_userinfo.call_count)
    299             self.assertEqual(self.access_token,
    300                              mock_get_userinfo.call_args[0][0].access_token)
    301 
    302     def testUserinfoCompact(self):
    303         user_info = {'email': 'foo (at] example.com'}
    304         with mock.patch.object(apitools_base, 'GetUserinfo',
    305                                return_value=user_info,
    306                                autospec=True) as mock_get_userinfo:
    307             output = _GetCommandOutput(
    308                 self, 'userinfo', ['--format=json_compact', self.access_token])
    309             self.assertEqual(json.dumps(user_info, separators=(',', ':')),
    310                              output)
    311             self.assertEqual(1, mock_get_userinfo.call_count)
    312             self.assertEqual(self.access_token,
    313                              mock_get_userinfo.call_args[0][0].access_token)
    314 
    315     def testScopes(self):
    316         scopes = [u'https://www.googleapis.com/auth/userinfo.email',
    317                   u'https://www.googleapis.com/auth/cloud-platform']
    318         response = _FakeResponse(http_client.OK, scopes=scopes)
    319         with mock.patch.object(apitools_base, 'MakeRequest',
    320                                return_value=response,
    321                                autospec=True) as mock_make_request:
    322             output = _GetCommandOutput(self, 'scopes', [self.access_token])
    323             self.assertEqual(sorted(scopes), output.splitlines())
    324             self.assertEqual(1, mock_make_request.call_count)
    325 
    326     def testValidate(self):
    327         scopes = [u'https://www.googleapis.com/auth/userinfo.email',
    328                   u'https://www.googleapis.com/auth/cloud-platform']
    329         response = _FakeResponse(http_client.OK, scopes=scopes)
    330         with mock.patch.object(apitools_base, 'MakeRequest',
    331                                return_value=response,
    332                                autospec=True) as mock_make_request:
    333             output = _GetCommandOutput(self, 'validate', [self.access_token])
    334             self.assertEqual('', output)
    335             self.assertEqual(1, mock_make_request.call_count)
    336 
    337     def testBadResponseCode(self):
    338         response = _FakeResponse(http_client.BAD_REQUEST)
    339         with mock.patch.object(apitools_base, 'MakeRequest',
    340                                return_value=response,
    341                                autospec=True) as mock_make_request:
    342             output = _GetCommandOutput(self, 'scopes', [self.access_token])
    343             self.assertEqual('', output)
    344             self.assertEqual(1, mock_make_request.call_count)
    345 
    346     def testUnexpectedResponseCode(self):
    347         response = _FakeResponse(http_client.INTERNAL_SERVER_ERROR)
    348         with mock.patch.object(apitools_base, 'MakeRequest',
    349                                return_value=response,
    350                                autospec=True) as mock_make_request:
    351             output = _GetCommandOutput(self, 'scopes', [self.access_token])
    352             self.assertIn(str(http_client.responses[response.status_code]),
    353                           output)
    354             self.assertIn('Exception raised in scopes operation: HttpError',
    355                           output)
    356             self.assertEqual(1, mock_make_request.call_count)
    357