Home | History | Annotate | Download | only in http
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 // See "SSPI Sample Application" at
      6 // http://msdn.microsoft.com/en-us/library/aa918273.aspx
      7 
      8 #include "net/http/http_auth_sspi_win.h"
      9 
     10 #include "base/base64.h"
     11 #include "base/logging.h"
     12 #include "base/string_util.h"
     13 #include "net/base/net_errors.h"
     14 #include "net/base/net_util.h"
     15 #include "net/http/http_auth.h"
     16 
     17 namespace net {
     18 
     19 HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme,
     20                            SEC_WCHAR* security_package)
     21     : scheme_(scheme),
     22       security_package_(security_package),
     23       max_token_length_(0) {
     24   SecInvalidateHandle(&cred_);
     25   SecInvalidateHandle(&ctxt_);
     26 }
     27 
     28 HttpAuthSSPI::~HttpAuthSSPI() {
     29   ResetSecurityContext();
     30   if (SecIsValidHandle(&cred_)) {
     31     FreeCredentialsHandle(&cred_);
     32     SecInvalidateHandle(&cred_);
     33   }
     34 }
     35 
     36 bool HttpAuthSSPI::NeedsIdentity() const {
     37   return decoded_server_auth_token_.empty();
     38 }
     39 
     40 bool HttpAuthSSPI::IsFinalRound() const {
     41   return !decoded_server_auth_token_.empty();
     42 }
     43 
     44 void HttpAuthSSPI::ResetSecurityContext() {
     45   if (SecIsValidHandle(&ctxt_)) {
     46     DeleteSecurityContext(&ctxt_);
     47     SecInvalidateHandle(&ctxt_);
     48   }
     49 }
     50 
     51 bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin,
     52                                   std::string::const_iterator challenge_end) {
     53   // Verify the challenge's auth-scheme.
     54   HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end);
     55   if (!challenge_tok.valid() ||
     56       !LowerCaseEqualsASCII(challenge_tok.scheme(),
     57                             StringToLowerASCII(scheme_).c_str()))
     58     return false;
     59   // Extract the auth-data.  We can't use challenge_tok.GetNext() because
     60   // auth-data is base64-encoded and may contain '=' padding at the end,
     61   // which would be mistaken for a name=value pair.
     62   challenge_begin += scheme_.length();  // Skip over scheme name.
     63   HttpUtil::TrimLWS(&challenge_begin, &challenge_end);
     64   std::string encoded_auth_token(challenge_begin, challenge_end);
     65   int encoded_length = encoded_auth_token.length();
     66   // Strip off any padding.
     67   // (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.)
     68   //
     69   // Our base64 decoder requires that the length be a multiple of 4.
     70   while (encoded_length > 0 && encoded_length % 4 != 0 &&
     71          encoded_auth_token[encoded_length - 1] == '=')
     72     encoded_length--;
     73   encoded_auth_token.erase(encoded_length);
     74 
     75   std::string decoded_auth_token;
     76   bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token);
     77   if (rv) {
     78     decoded_server_auth_token_ = decoded_auth_token;
     79   }
     80   return rv;
     81 }
     82 
     83 int HttpAuthSSPI::GenerateCredentials(const std::wstring& username,
     84                                       const std::wstring& password,
     85                                       const GURL& origin,
     86                                       const HttpRequestInfo* request,
     87                                       const ProxyInfo* proxy,
     88                                       std::string* out_credentials) {
     89   // |username| may be in the form "DOMAIN\user".  Parse it into the two
     90   // components.
     91   std::wstring domain;
     92   std::wstring user;
     93   SplitDomainAndUser(username, &domain, &user);
     94 
     95   // Initial challenge.
     96   if (!IsFinalRound()) {
     97     int rv = OnFirstRound(domain, user, password);
     98     if (rv != OK)
     99       return rv;
    100   }
    101 
    102   void* out_buf;
    103   int out_buf_len;
    104   int rv = GetNextSecurityToken(
    105       origin,
    106       static_cast<void *>(const_cast<char *>(
    107           decoded_server_auth_token_.c_str())),
    108       decoded_server_auth_token_.length(),
    109       &out_buf,
    110       &out_buf_len);
    111   if (rv != OK)
    112     return rv;
    113 
    114   // Base64 encode data in output buffer and prepend the scheme.
    115   std::string encode_input(static_cast<char*>(out_buf), out_buf_len);
    116   std::string encode_output;
    117   bool ok = base::Base64Encode(encode_input, &encode_output);
    118   // OK, we are done with |out_buf|
    119   free(out_buf);
    120   if (!ok)
    121     return rv;
    122   *out_credentials = scheme_ + " " + encode_output;
    123   return OK;
    124 }
    125 
    126 int HttpAuthSSPI::OnFirstRound(const std::wstring& domain,
    127                                const std::wstring& user,
    128                                const std::wstring& password) {
    129   int rv = DetermineMaxTokenLength(security_package_, &max_token_length_);
    130   if (rv != OK) {
    131     return rv;
    132   }
    133   rv = AcquireCredentials(security_package_, domain, user, password, &cred_);
    134   return rv;
    135 }
    136 
    137 int HttpAuthSSPI::GetNextSecurityToken(
    138     const GURL& origin,
    139     const void * in_token,
    140     int in_token_len,
    141     void** out_token,
    142     int* out_token_len) {
    143   SECURITY_STATUS status;
    144   TimeStamp expiry;
    145 
    146   DWORD ctxt_attr;
    147   CtxtHandle* ctxt_ptr;
    148   SecBufferDesc in_buffer_desc, out_buffer_desc;
    149   SecBufferDesc* in_buffer_desc_ptr;
    150   SecBuffer in_buffer, out_buffer;
    151 
    152   if (in_token_len > 0) {
    153     // Prepare input buffer.
    154     in_buffer_desc.ulVersion = SECBUFFER_VERSION;
    155     in_buffer_desc.cBuffers = 1;
    156     in_buffer_desc.pBuffers = &in_buffer;
    157     in_buffer.BufferType = SECBUFFER_TOKEN;
    158     in_buffer.cbBuffer = in_token_len;
    159     in_buffer.pvBuffer = const_cast<void*>(in_token);
    160     ctxt_ptr = &ctxt_;
    161     in_buffer_desc_ptr = &in_buffer_desc;
    162   } else {
    163     // If there is no input token, then we are starting a new authentication
    164     // sequence.  If we have already initialized our security context, then
    165     // we're incorrectly reusing the auth handler for a new sequence.
    166     if (SecIsValidHandle(&ctxt_)) {
    167       LOG(ERROR) << "Cannot restart authentication sequence";
    168       return ERR_UNEXPECTED;
    169     }
    170     ctxt_ptr = NULL;
    171     in_buffer_desc_ptr = NULL;
    172   }
    173 
    174   // Prepare output buffer.
    175   out_buffer_desc.ulVersion = SECBUFFER_VERSION;
    176   out_buffer_desc.cBuffers = 1;
    177   out_buffer_desc.pBuffers = &out_buffer;
    178   out_buffer.BufferType = SECBUFFER_TOKEN;
    179   out_buffer.cbBuffer = max_token_length_;
    180   out_buffer.pvBuffer = malloc(out_buffer.cbBuffer);
    181   if (!out_buffer.pvBuffer)
    182     return ERR_OUT_OF_MEMORY;
    183 
    184   // The service principal name of the destination server.  See
    185   // http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx
    186   std::wstring target(L"HTTP/");
    187   target.append(ASCIIToWide(GetHostAndPort(origin)));
    188   wchar_t* target_name = const_cast<wchar_t*>(target.c_str());
    189 
    190   // This returns a token that is passed to the remote server.
    191   status = InitializeSecurityContext(&cred_,  // phCredential
    192                                      ctxt_ptr,  // phContext
    193                                      target_name,  // pszTargetName
    194                                      0,  // fContextReq
    195                                      0,  // Reserved1 (must be 0)
    196                                      SECURITY_NATIVE_DREP,  // TargetDataRep
    197                                      in_buffer_desc_ptr,  // pInput
    198                                      0,  // Reserved2 (must be 0)
    199                                      &ctxt_,  // phNewContext
    200                                      &out_buffer_desc,  // pOutput
    201                                      &ctxt_attr,  // pfContextAttr
    202                                      &expiry);  // ptsExpiry
    203   // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call
    204   // and SEC_E_OK on the second call.  On failure, the function returns an
    205   // error code.
    206   if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) {
    207     LOG(ERROR) << "InitializeSecurityContext failed: " << status;
    208     ResetSecurityContext();
    209     free(out_buffer.pvBuffer);
    210     return ERR_UNEXPECTED;  // TODO(wtc): map error code.
    211   }
    212   if (!out_buffer.cbBuffer) {
    213     free(out_buffer.pvBuffer);
    214     out_buffer.pvBuffer = NULL;
    215   }
    216   *out_token = out_buffer.pvBuffer;
    217   *out_token_len = out_buffer.cbBuffer;
    218   return OK;
    219 }
    220 
    221 void SplitDomainAndUser(const std::wstring& combined,
    222                         std::wstring* domain,
    223                         std::wstring* user) {
    224   size_t backslash_idx = combined.find(L'\\');
    225   if (backslash_idx == std::wstring::npos) {
    226     domain->clear();
    227     *user = combined;
    228   } else {
    229     *domain = combined.substr(0, backslash_idx);
    230     *user = combined.substr(backslash_idx + 1);
    231   }
    232 }
    233 
    234 int DetermineMaxTokenLength(const std::wstring& package,
    235                             ULONG* max_token_length) {
    236   PSecPkgInfo pkg_info;
    237   SECURITY_STATUS status = QuerySecurityPackageInfo(
    238       const_cast<wchar_t *>(package.c_str()), &pkg_info);
    239   if (status != SEC_E_OK) {
    240     LOG(ERROR) << "Security package " << package << " not found";
    241     return ERR_UNEXPECTED;
    242   }
    243   *max_token_length = pkg_info->cbMaxToken;
    244   FreeContextBuffer(pkg_info);
    245   return OK;
    246 }
    247 
    248 int AcquireCredentials(const SEC_WCHAR* package,
    249                        const std::wstring& domain,
    250                        const std::wstring& user,
    251                        const std::wstring& password,
    252                        CredHandle* cred) {
    253   SEC_WINNT_AUTH_IDENTITY identity;
    254   identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
    255   identity.User =
    256       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str()));
    257   identity.UserLength = user.size();
    258   identity.Domain =
    259       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str()));
    260   identity.DomainLength = domain.size();
    261   identity.Password =
    262       reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str()));
    263   identity.PasswordLength = password.size();
    264 
    265   TimeStamp expiry;
    266 
    267   // Pass the username/password to get the credentials handle.
    268   // Note: If the 5th argument is NULL, it uses the default cached credentials
    269   // for the logged in user, which can be used for single sign-on.
    270   SECURITY_STATUS status = AcquireCredentialsHandle(
    271       NULL,  // pszPrincipal
    272       const_cast<SEC_WCHAR*>(package),  // pszPackage
    273       SECPKG_CRED_OUTBOUND,  // fCredentialUse
    274       NULL,  // pvLogonID
    275       &identity,  // pAuthData
    276       NULL,  // pGetKeyFn (not used)
    277       NULL,  // pvGetKeyArgument (not used)
    278       cred,  // phCredential
    279       &expiry);  // ptsExpiry
    280 
    281   if (status != SEC_E_OK)
    282     return ERR_UNEXPECTED;
    283   return OK;
    284 }
    285 
    286 }  // namespace net
    287