Home | History | Annotate | Download | only in cloud
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
     17 #define TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
     18 
     19 #include <memory>
     20 #include "tensorflow/core/platform/cloud/auth_provider.h"
     21 #include "tensorflow/core/platform/cloud/oauth_client.h"
     22 #include "tensorflow/core/platform/mutex.h"
     23 #include "tensorflow/core/platform/thread_annotations.h"
     24 
     25 namespace tensorflow {
     26 
     27 /// Implementation based on Google Application Default Credentials.
     28 class GoogleAuthProvider : public AuthProvider {
     29  public:
     30   GoogleAuthProvider();
     31   explicit GoogleAuthProvider(
     32       std::unique_ptr<OAuthClient> oauth_client,
     33       std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env,
     34       int64 initial_retry_delay_usec);
     35   virtual ~GoogleAuthProvider() {}
     36 
     37   /// \brief Returns the short-term authentication bearer token.
     38   ///
     39   /// Safe for concurrent use by multiple threads.
     40   Status GetToken(string* token) override;
     41 
     42  private:
     43   /// \brief Gets the bearer token from files.
     44   ///
     45   /// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the
     46   /// standard gcloud tool's location.
     47   Status GetTokenFromFiles() EXCLUSIVE_LOCKS_REQUIRED(mu_);
     48 
     49   /// Gets the bearer token from Google Compute Engine environment.
     50   Status GetTokenFromGce() EXCLUSIVE_LOCKS_REQUIRED(mu_);
     51 
     52   /// Gets the bearer token from the systen env variable, for testing purposes.
     53   Status GetTokenForTesting() EXCLUSIVE_LOCKS_REQUIRED(mu_);
     54 
     55   std::unique_ptr<OAuthClient> oauth_client_;
     56   std::unique_ptr<HttpRequest::Factory> http_request_factory_;
     57   Env* env_;
     58   mutex mu_;
     59   string current_token_ GUARDED_BY(mu_);
     60   uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
     61   // The initial delay for exponential backoffs when retrying failed calls.
     62   const int64 initial_retry_delay_usec_;
     63   TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider);
     64 };
     65 
     66 }  // namespace tensorflow
     67 
     68 #endif  // TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
     69