Home | History | Annotate | Download | only in cloud
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
     17 #include <stdlib.h>
     18 #include "tensorflow/core/lib/core/status_test_util.h"
     19 #include "tensorflow/core/lib/io/path.h"
     20 #include "tensorflow/core/platform/cloud/http_request_fake.h"
     21 #include "tensorflow/core/platform/test.h"
     22 
     23 namespace tensorflow {
     24 
     25 namespace {
     26 
     27 constexpr char kTestData[] = "core/platform/cloud/testdata/";
     28 
     29 class FakeEnv : public EnvWrapper {
     30  public:
     31   FakeEnv() : EnvWrapper(Env::Default()) {}
     32 
     33   uint64 NowSeconds() override { return now; }
     34   uint64 now = 10000;
     35 };
     36 
     37 class FakeOAuthClient : public OAuthClient {
     38  public:
     39   Status GetTokenFromServiceAccountJson(
     40       Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
     41       string* token, uint64* expiration_timestamp_sec) override {
     42     provided_credentials_json = json;
     43     *token = return_token;
     44     *expiration_timestamp_sec = return_expiration_timestamp;
     45     return Status::OK();
     46   }
     47 
     48   /// Retrieves a bearer token using a refresh token.
     49   Status GetTokenFromRefreshTokenJson(
     50       Json::Value json, StringPiece oauth_server_uri, string* token,
     51       uint64* expiration_timestamp_sec) override {
     52     provided_credentials_json = json;
     53     *token = return_token;
     54     *expiration_timestamp_sec = return_expiration_timestamp;
     55     return Status::OK();
     56   }
     57 
     58   string return_token;
     59   uint64 return_expiration_timestamp;
     60   Json::Value provided_credentials_json;
     61 };
     62 
     63 }  // namespace
     64 
     65 class GoogleAuthProviderTest : public ::testing::Test {
     66  protected:
     67   void SetUp() override { ClearEnvVars(); }
     68 
     69   void TearDown() override { ClearEnvVars(); }
     70 
     71   void ClearEnvVars() {
     72     unsetenv("GOOGLE_APPLICATION_CREDENTIALS");
     73     unsetenv("CLOUDSDK_CONFIG");
     74     unsetenv("GOOGLE_AUTH_TOKEN_FOR_TESTING");
     75   }
     76 };
     77 
     78 TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
     79   setenv("GOOGLE_APPLICATION_CREDENTIALS",
     80          io::JoinPath(
     81              io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(),
     82              "service_account_credentials.json")
     83              .c_str(),
     84          1);
     85   setenv("CLOUDSDK_CONFIG",
     86          io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(),
     87          1);  // Will not be used.
     88 
     89   auto oauth_client = new FakeOAuthClient;
     90   std::vector<HttpRequest*> requests;
     91 
     92   FakeEnv env;
     93   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
     94                               std::unique_ptr<HttpRequest::Factory>(
     95                                   new FakeHttpRequestFactory(&requests)),
     96                               &env, 0);
     97   oauth_client->return_token = "fake-token";
     98   oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
     99 
    100   string token;
    101   TF_EXPECT_OK(provider.GetToken(&token));
    102   EXPECT_EQ("fake-token", token);
    103   EXPECT_EQ("fake_key_id",
    104             oauth_client->provided_credentials_json.get("private_key_id", "")
    105                 .asString());
    106 
    107   // Check that the token is re-used if not expired.
    108   oauth_client->return_token = "new-fake-token";
    109   env.now += 3000;
    110   TF_EXPECT_OK(provider.GetToken(&token));
    111   EXPECT_EQ("fake-token", token);
    112 
    113   // Check that the token is re-generated when almost expired.
    114   env.now += 598;  // 2 seconds before expiration
    115   TF_EXPECT_OK(provider.GetToken(&token));
    116   EXPECT_EQ("new-fake-token", token);
    117 }
    118 
    119 TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
    120   setenv("CLOUDSDK_CONFIG",
    121          io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 1);
    122 
    123   auto oauth_client = new FakeOAuthClient;
    124   std::vector<HttpRequest*> requests;
    125 
    126   FakeEnv env;
    127   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
    128                               std::unique_ptr<HttpRequest::Factory>(
    129                                   new FakeHttpRequestFactory(&requests)),
    130                               &env, 0);
    131   oauth_client->return_token = "fake-token";
    132   oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
    133 
    134   string token;
    135   TF_EXPECT_OK(provider.GetToken(&token));
    136   EXPECT_EQ("fake-token", token);
    137   EXPECT_EQ("fake-refresh-token",
    138             oauth_client->provided_credentials_json.get("refresh_token", "")
    139                 .asString());
    140 }
    141 
    142 TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
    143   auto oauth_client = new FakeOAuthClient;
    144   std::vector<HttpRequest*> requests(
    145       {new FakeHttpRequest(
    146            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
    147            "/default/token\n"
    148            "Header Metadata-Flavor: Google\n",
    149            R"(
    150           {
    151             "access_token":"fake-gce-token",
    152             "expires_in": 3920,
    153             "token_type":"Bearer"
    154           })"),
    155        // The first token refresh request fails and will be retried.
    156        new FakeHttpRequest(
    157            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
    158            "/default/token\n"
    159            "Header Metadata-Flavor: Google\n",
    160            "", errors::Unavailable("503"), 503),
    161        new FakeHttpRequest(
    162            "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
    163            "/default/token\n"
    164            "Header Metadata-Flavor: Google\n",
    165            R"(
    166               {
    167                 "access_token":"new-fake-gce-token",
    168                 "expires_in": 3920,
    169                 "token_type":"Bearer"
    170               })")});
    171 
    172   FakeEnv env;
    173   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
    174                               std::unique_ptr<HttpRequest::Factory>(
    175                                   new FakeHttpRequestFactory(&requests)),
    176                               &env, 0);
    177 
    178   string token;
    179   TF_EXPECT_OK(provider.GetToken(&token));
    180   EXPECT_EQ("fake-gce-token", token);
    181 
    182   // Check that the token is re-used if not expired.
    183   env.now += 3700;
    184   TF_EXPECT_OK(provider.GetToken(&token));
    185   EXPECT_EQ("fake-gce-token", token);
    186 
    187   // Check that the token is re-generated when almost expired.
    188   env.now += 598;  // 2 seconds before expiration
    189   TF_EXPECT_OK(provider.GetToken(&token));
    190   EXPECT_EQ("new-fake-gce-token", token);
    191 }
    192 
    193 TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
    194   setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "tokenForTesting", 1);
    195 
    196   auto oauth_client = new FakeOAuthClient;
    197   std::vector<HttpRequest*> empty_requests;
    198   FakeEnv env;
    199   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
    200                               std::unique_ptr<HttpRequest::Factory>(
    201                                   new FakeHttpRequestFactory(&empty_requests)),
    202                               &env, 0);
    203 
    204   string token;
    205   TF_EXPECT_OK(provider.GetToken(&token));
    206   EXPECT_EQ("tokenForTesting", token);
    207 }
    208 
    209 TEST_F(GoogleAuthProviderTest, NothingAvailable) {
    210   auto oauth_client = new FakeOAuthClient;
    211 
    212   std::vector<HttpRequest*> requests({new FakeHttpRequest(
    213       "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
    214       "/default/token\n"
    215       "Header Metadata-Flavor: Google\n",
    216       "", errors::NotFound("404"), 404)});
    217 
    218   FakeEnv env;
    219   GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
    220                               std::unique_ptr<HttpRequest::Factory>(
    221                                   new FakeHttpRequestFactory(&requests)),
    222                               &env, 0);
    223 
    224   string token;
    225   TF_EXPECT_OK(provider.GetToken(&token));
    226   EXPECT_EQ("", token);
    227 }
    228 
    229 }  // namespace tensorflow
    230