1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/platform/cloud/google_auth_provider.h" 17 #include <stdlib.h> 18 #include "tensorflow/core/lib/core/status_test_util.h" 19 #include "tensorflow/core/lib/io/path.h" 20 #include "tensorflow/core/platform/cloud/http_request_fake.h" 21 #include "tensorflow/core/platform/test.h" 22 23 namespace tensorflow { 24 25 namespace { 26 27 constexpr char kTestData[] = "core/platform/cloud/testdata/"; 28 29 class FakeEnv : public EnvWrapper { 30 public: 31 FakeEnv() : EnvWrapper(Env::Default()) {} 32 33 uint64 NowSeconds() override { return now; } 34 uint64 now = 10000; 35 }; 36 37 class FakeOAuthClient : public OAuthClient { 38 public: 39 Status GetTokenFromServiceAccountJson( 40 Json::Value json, StringPiece oauth_server_uri, StringPiece scope, 41 string* token, uint64* expiration_timestamp_sec) override { 42 provided_credentials_json = json; 43 *token = return_token; 44 *expiration_timestamp_sec = return_expiration_timestamp; 45 return Status::OK(); 46 } 47 48 /// Retrieves a bearer token using a refresh token. 49 Status GetTokenFromRefreshTokenJson( 50 Json::Value json, StringPiece oauth_server_uri, string* token, 51 uint64* expiration_timestamp_sec) override { 52 provided_credentials_json = json; 53 *token = return_token; 54 *expiration_timestamp_sec = return_expiration_timestamp; 55 return Status::OK(); 56 } 57 58 string return_token; 59 uint64 return_expiration_timestamp; 60 Json::Value provided_credentials_json; 61 }; 62 63 } // namespace 64 65 class GoogleAuthProviderTest : public ::testing::Test { 66 protected: 67 void SetUp() override { ClearEnvVars(); } 68 69 void TearDown() override { ClearEnvVars(); } 70 71 void ClearEnvVars() { 72 unsetenv("GOOGLE_APPLICATION_CREDENTIALS"); 73 unsetenv("CLOUDSDK_CONFIG"); 74 unsetenv("GOOGLE_AUTH_TOKEN_FOR_TESTING"); 75 } 76 }; 77 78 TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) { 79 setenv("GOOGLE_APPLICATION_CREDENTIALS", 80 io::JoinPath( 81 io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 82 "service_account_credentials.json") 83 .c_str(), 84 1); 85 setenv("CLOUDSDK_CONFIG", 86 io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 87 1); // Will not be used. 88 89 auto oauth_client = new FakeOAuthClient; 90 std::vector<HttpRequest*> requests; 91 92 FakeEnv env; 93 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), 94 std::unique_ptr<HttpRequest::Factory>( 95 new FakeHttpRequestFactory(&requests)), 96 &env, 0); 97 oauth_client->return_token = "fake-token"; 98 oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600; 99 100 string token; 101 TF_EXPECT_OK(provider.GetToken(&token)); 102 EXPECT_EQ("fake-token", token); 103 EXPECT_EQ("fake_key_id", 104 oauth_client->provided_credentials_json.get("private_key_id", "") 105 .asString()); 106 107 // Check that the token is re-used if not expired. 108 oauth_client->return_token = "new-fake-token"; 109 env.now += 3000; 110 TF_EXPECT_OK(provider.GetToken(&token)); 111 EXPECT_EQ("fake-token", token); 112 113 // Check that the token is re-generated when almost expired. 114 env.now += 598; // 2 seconds before expiration 115 TF_EXPECT_OK(provider.GetToken(&token)); 116 EXPECT_EQ("new-fake-token", token); 117 } 118 119 TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) { 120 setenv("CLOUDSDK_CONFIG", 121 io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 1); 122 123 auto oauth_client = new FakeOAuthClient; 124 std::vector<HttpRequest*> requests; 125 126 FakeEnv env; 127 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), 128 std::unique_ptr<HttpRequest::Factory>( 129 new FakeHttpRequestFactory(&requests)), 130 &env, 0); 131 oauth_client->return_token = "fake-token"; 132 oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600; 133 134 string token; 135 TF_EXPECT_OK(provider.GetToken(&token)); 136 EXPECT_EQ("fake-token", token); 137 EXPECT_EQ("fake-refresh-token", 138 oauth_client->provided_credentials_json.get("refresh_token", "") 139 .asString()); 140 } 141 142 TEST_F(GoogleAuthProviderTest, RunningOnGCE) { 143 auto oauth_client = new FakeOAuthClient; 144 std::vector<HttpRequest*> requests( 145 {new FakeHttpRequest( 146 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" 147 "/default/token\n" 148 "Header Metadata-Flavor: Google\n", 149 R"( 150 { 151 "access_token":"fake-gce-token", 152 "expires_in": 3920, 153 "token_type":"Bearer" 154 })"), 155 // The first token refresh request fails and will be retried. 156 new FakeHttpRequest( 157 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" 158 "/default/token\n" 159 "Header Metadata-Flavor: Google\n", 160 "", errors::Unavailable("503"), 503), 161 new FakeHttpRequest( 162 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" 163 "/default/token\n" 164 "Header Metadata-Flavor: Google\n", 165 R"( 166 { 167 "access_token":"new-fake-gce-token", 168 "expires_in": 3920, 169 "token_type":"Bearer" 170 })")}); 171 172 FakeEnv env; 173 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), 174 std::unique_ptr<HttpRequest::Factory>( 175 new FakeHttpRequestFactory(&requests)), 176 &env, 0); 177 178 string token; 179 TF_EXPECT_OK(provider.GetToken(&token)); 180 EXPECT_EQ("fake-gce-token", token); 181 182 // Check that the token is re-used if not expired. 183 env.now += 3700; 184 TF_EXPECT_OK(provider.GetToken(&token)); 185 EXPECT_EQ("fake-gce-token", token); 186 187 // Check that the token is re-generated when almost expired. 188 env.now += 598; // 2 seconds before expiration 189 TF_EXPECT_OK(provider.GetToken(&token)); 190 EXPECT_EQ("new-fake-gce-token", token); 191 } 192 193 TEST_F(GoogleAuthProviderTest, OverrideForTesting) { 194 setenv("GOOGLE_AUTH_TOKEN_FOR_TESTING", "tokenForTesting", 1); 195 196 auto oauth_client = new FakeOAuthClient; 197 std::vector<HttpRequest*> empty_requests; 198 FakeEnv env; 199 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), 200 std::unique_ptr<HttpRequest::Factory>( 201 new FakeHttpRequestFactory(&empty_requests)), 202 &env, 0); 203 204 string token; 205 TF_EXPECT_OK(provider.GetToken(&token)); 206 EXPECT_EQ("tokenForTesting", token); 207 } 208 209 TEST_F(GoogleAuthProviderTest, NothingAvailable) { 210 auto oauth_client = new FakeOAuthClient; 211 212 std::vector<HttpRequest*> requests({new FakeHttpRequest( 213 "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" 214 "/default/token\n" 215 "Header Metadata-Flavor: Google\n", 216 "", errors::NotFound("404"), 404)}); 217 218 FakeEnv env; 219 GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), 220 std::unique_ptr<HttpRequest::Factory>( 221 new FakeHttpRequestFactory(&requests)), 222 &env, 0); 223 224 string token; 225 TF_EXPECT_OK(provider.GetToken(&token)); 226 EXPECT_EQ("", token); 227 } 228 229 } // namespace tensorflow 230