1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 // See "SSPI Sample Application" at 6 // http://msdn.microsoft.com/en-us/library/aa918273.aspx 7 8 #include "net/http/http_auth_sspi_win.h" 9 10 #include "base/base64.h" 11 #include "base/logging.h" 12 #include "base/string_util.h" 13 #include "net/base/net_errors.h" 14 #include "net/base/net_util.h" 15 #include "net/http/http_auth.h" 16 17 namespace net { 18 19 HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme, 20 SEC_WCHAR* security_package) 21 : scheme_(scheme), 22 security_package_(security_package), 23 max_token_length_(0) { 24 SecInvalidateHandle(&cred_); 25 SecInvalidateHandle(&ctxt_); 26 } 27 28 HttpAuthSSPI::~HttpAuthSSPI() { 29 ResetSecurityContext(); 30 if (SecIsValidHandle(&cred_)) { 31 FreeCredentialsHandle(&cred_); 32 SecInvalidateHandle(&cred_); 33 } 34 } 35 36 bool HttpAuthSSPI::NeedsIdentity() const { 37 return decoded_server_auth_token_.empty(); 38 } 39 40 bool HttpAuthSSPI::IsFinalRound() const { 41 return !decoded_server_auth_token_.empty(); 42 } 43 44 void HttpAuthSSPI::ResetSecurityContext() { 45 if (SecIsValidHandle(&ctxt_)) { 46 DeleteSecurityContext(&ctxt_); 47 SecInvalidateHandle(&ctxt_); 48 } 49 } 50 51 bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin, 52 std::string::const_iterator challenge_end) { 53 // Verify the challenge's auth-scheme. 54 HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end); 55 if (!challenge_tok.valid() || 56 !LowerCaseEqualsASCII(challenge_tok.scheme(), 57 StringToLowerASCII(scheme_).c_str())) 58 return false; 59 // Extract the auth-data. We can't use challenge_tok.GetNext() because 60 // auth-data is base64-encoded and may contain '=' padding at the end, 61 // which would be mistaken for a name=value pair. 62 challenge_begin += scheme_.length(); // Skip over scheme name. 63 HttpUtil::TrimLWS(&challenge_begin, &challenge_end); 64 std::string encoded_auth_token(challenge_begin, challenge_end); 65 int encoded_length = encoded_auth_token.length(); 66 // Strip off any padding. 67 // (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.) 68 // 69 // Our base64 decoder requires that the length be a multiple of 4. 70 while (encoded_length > 0 && encoded_length % 4 != 0 && 71 encoded_auth_token[encoded_length - 1] == '=') 72 encoded_length--; 73 encoded_auth_token.erase(encoded_length); 74 75 std::string decoded_auth_token; 76 bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token); 77 if (rv) { 78 decoded_server_auth_token_ = decoded_auth_token; 79 } 80 return rv; 81 } 82 83 int HttpAuthSSPI::GenerateCredentials(const std::wstring& username, 84 const std::wstring& password, 85 const GURL& origin, 86 const HttpRequestInfo* request, 87 const ProxyInfo* proxy, 88 std::string* out_credentials) { 89 // |username| may be in the form "DOMAIN\user". Parse it into the two 90 // components. 91 std::wstring domain; 92 std::wstring user; 93 SplitDomainAndUser(username, &domain, &user); 94 95 // Initial challenge. 96 if (!IsFinalRound()) { 97 int rv = OnFirstRound(domain, user, password); 98 if (rv != OK) 99 return rv; 100 } 101 102 void* out_buf; 103 int out_buf_len; 104 int rv = GetNextSecurityToken( 105 origin, 106 static_cast<void *>(const_cast<char *>( 107 decoded_server_auth_token_.c_str())), 108 decoded_server_auth_token_.length(), 109 &out_buf, 110 &out_buf_len); 111 if (rv != OK) 112 return rv; 113 114 // Base64 encode data in output buffer and prepend the scheme. 115 std::string encode_input(static_cast<char*>(out_buf), out_buf_len); 116 std::string encode_output; 117 bool ok = base::Base64Encode(encode_input, &encode_output); 118 // OK, we are done with |out_buf| 119 free(out_buf); 120 if (!ok) 121 return rv; 122 *out_credentials = scheme_ + " " + encode_output; 123 return OK; 124 } 125 126 int HttpAuthSSPI::OnFirstRound(const std::wstring& domain, 127 const std::wstring& user, 128 const std::wstring& password) { 129 int rv = DetermineMaxTokenLength(security_package_, &max_token_length_); 130 if (rv != OK) { 131 return rv; 132 } 133 rv = AcquireCredentials(security_package_, domain, user, password, &cred_); 134 return rv; 135 } 136 137 int HttpAuthSSPI::GetNextSecurityToken( 138 const GURL& origin, 139 const void * in_token, 140 int in_token_len, 141 void** out_token, 142 int* out_token_len) { 143 SECURITY_STATUS status; 144 TimeStamp expiry; 145 146 DWORD ctxt_attr; 147 CtxtHandle* ctxt_ptr; 148 SecBufferDesc in_buffer_desc, out_buffer_desc; 149 SecBufferDesc* in_buffer_desc_ptr; 150 SecBuffer in_buffer, out_buffer; 151 152 if (in_token_len > 0) { 153 // Prepare input buffer. 154 in_buffer_desc.ulVersion = SECBUFFER_VERSION; 155 in_buffer_desc.cBuffers = 1; 156 in_buffer_desc.pBuffers = &in_buffer; 157 in_buffer.BufferType = SECBUFFER_TOKEN; 158 in_buffer.cbBuffer = in_token_len; 159 in_buffer.pvBuffer = const_cast<void*>(in_token); 160 ctxt_ptr = &ctxt_; 161 in_buffer_desc_ptr = &in_buffer_desc; 162 } else { 163 // If there is no input token, then we are starting a new authentication 164 // sequence. If we have already initialized our security context, then 165 // we're incorrectly reusing the auth handler for a new sequence. 166 if (SecIsValidHandle(&ctxt_)) { 167 LOG(ERROR) << "Cannot restart authentication sequence"; 168 return ERR_UNEXPECTED; 169 } 170 ctxt_ptr = NULL; 171 in_buffer_desc_ptr = NULL; 172 } 173 174 // Prepare output buffer. 175 out_buffer_desc.ulVersion = SECBUFFER_VERSION; 176 out_buffer_desc.cBuffers = 1; 177 out_buffer_desc.pBuffers = &out_buffer; 178 out_buffer.BufferType = SECBUFFER_TOKEN; 179 out_buffer.cbBuffer = max_token_length_; 180 out_buffer.pvBuffer = malloc(out_buffer.cbBuffer); 181 if (!out_buffer.pvBuffer) 182 return ERR_OUT_OF_MEMORY; 183 184 // The service principal name of the destination server. See 185 // http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx 186 std::wstring target(L"HTTP/"); 187 target.append(ASCIIToWide(GetHostAndPort(origin))); 188 wchar_t* target_name = const_cast<wchar_t*>(target.c_str()); 189 190 // This returns a token that is passed to the remote server. 191 status = InitializeSecurityContext(&cred_, // phCredential 192 ctxt_ptr, // phContext 193 target_name, // pszTargetName 194 0, // fContextReq 195 0, // Reserved1 (must be 0) 196 SECURITY_NATIVE_DREP, // TargetDataRep 197 in_buffer_desc_ptr, // pInput 198 0, // Reserved2 (must be 0) 199 &ctxt_, // phNewContext 200 &out_buffer_desc, // pOutput 201 &ctxt_attr, // pfContextAttr 202 &expiry); // ptsExpiry 203 // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call 204 // and SEC_E_OK on the second call. On failure, the function returns an 205 // error code. 206 if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) { 207 LOG(ERROR) << "InitializeSecurityContext failed: " << status; 208 ResetSecurityContext(); 209 free(out_buffer.pvBuffer); 210 return ERR_UNEXPECTED; // TODO(wtc): map error code. 211 } 212 if (!out_buffer.cbBuffer) { 213 free(out_buffer.pvBuffer); 214 out_buffer.pvBuffer = NULL; 215 } 216 *out_token = out_buffer.pvBuffer; 217 *out_token_len = out_buffer.cbBuffer; 218 return OK; 219 } 220 221 void SplitDomainAndUser(const std::wstring& combined, 222 std::wstring* domain, 223 std::wstring* user) { 224 size_t backslash_idx = combined.find(L'\\'); 225 if (backslash_idx == std::wstring::npos) { 226 domain->clear(); 227 *user = combined; 228 } else { 229 *domain = combined.substr(0, backslash_idx); 230 *user = combined.substr(backslash_idx + 1); 231 } 232 } 233 234 int DetermineMaxTokenLength(const std::wstring& package, 235 ULONG* max_token_length) { 236 PSecPkgInfo pkg_info; 237 SECURITY_STATUS status = QuerySecurityPackageInfo( 238 const_cast<wchar_t *>(package.c_str()), &pkg_info); 239 if (status != SEC_E_OK) { 240 LOG(ERROR) << "Security package " << package << " not found"; 241 return ERR_UNEXPECTED; 242 } 243 *max_token_length = pkg_info->cbMaxToken; 244 FreeContextBuffer(pkg_info); 245 return OK; 246 } 247 248 int AcquireCredentials(const SEC_WCHAR* package, 249 const std::wstring& domain, 250 const std::wstring& user, 251 const std::wstring& password, 252 CredHandle* cred) { 253 SEC_WINNT_AUTH_IDENTITY identity; 254 identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE; 255 identity.User = 256 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str())); 257 identity.UserLength = user.size(); 258 identity.Domain = 259 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str())); 260 identity.DomainLength = domain.size(); 261 identity.Password = 262 reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str())); 263 identity.PasswordLength = password.size(); 264 265 TimeStamp expiry; 266 267 // Pass the username/password to get the credentials handle. 268 // Note: If the 5th argument is NULL, it uses the default cached credentials 269 // for the logged in user, which can be used for single sign-on. 270 SECURITY_STATUS status = AcquireCredentialsHandle( 271 NULL, // pszPrincipal 272 const_cast<SEC_WCHAR*>(package), // pszPackage 273 SECPKG_CRED_OUTBOUND, // fCredentialUse 274 NULL, // pvLogonID 275 &identity, // pAuthData 276 NULL, // pGetKeyFn (not used) 277 NULL, // pvGetKeyArgument (not used) 278 cred, // phCredential 279 &expiry); // ptsExpiry 280 281 if (status != SEC_E_OK) 282 return ERR_UNEXPECTED; 283 return OK; 284 } 285 286 } // namespace net 287