Home | History | Annotate | Download | only in chrome_frame
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome_frame/http_negotiate.h"
      6 
      7 #include <atlbase.h>
      8 #include <atlcom.h>
      9 #include <htiframe.h>
     10 
     11 #include "base/logging.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/strings/string_util.h"
     14 #include "base/strings/stringprintf.h"
     15 #include "base/strings/utf_string_conversions.h"
     16 #include "chrome_frame/bho.h"
     17 #include "chrome_frame/exception_barrier.h"
     18 #include "chrome_frame/html_utils.h"
     19 #include "chrome_frame/urlmon_moniker.h"
     20 #include "chrome_frame/urlmon_url_request.h"
     21 #include "chrome_frame/utils.h"
     22 #include "chrome_frame/vtable_patch_manager.h"
     23 #include "net/http/http_response_headers.h"
     24 #include "net/http/http_util.h"
     25 
     26 bool HttpNegotiatePatch::modify_user_agent_ = true;
     27 const char kUACompatibleHttpHeader[] = "x-ua-compatible";
     28 const char kLowerCaseUserAgent[] = "user-agent";
     29 
     30 // From the latest urlmon.h. Symbol name prepended with LOCAL_ to
     31 // avoid conflict (and therefore build errors) for those building with
     32 // a newer Windows SDK.
     33 // TODO(robertshield): Remove this once we update our SDK version.
     34 const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54;
     35 
     36 static const int kHttpNegotiateBeginningTransactionIndex = 3;
     37 
     38 BEGIN_VTABLE_PATCHES(IHttpNegotiate)
     39   VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex,
     40                      HttpNegotiatePatch::BeginningTransaction)
     41 END_VTABLE_PATCHES()
     42 
     43 namespace {
     44 
     45 class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>,
     46                                  public IBindStatusCallback {
     47  public:
     48   BEGIN_COM_MAP(SimpleBindStatusCallback)
     49     COM_INTERFACE_ENTRY(IBindStatusCallback)
     50   END_COM_MAP()
     51 
     52   // IBindStatusCallback implementation
     53   STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) {
     54     return E_NOTIMPL;
     55   }
     56 
     57   STDMETHOD(GetPriority)(LONG* priority) {
     58     return E_NOTIMPL;
     59   }
     60   STDMETHOD(OnLowResource)(DWORD reserved) {
     61     return E_NOTIMPL;
     62   }
     63 
     64   STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress,
     65                         ULONG status_code, LPCWSTR status_text) {
     66     return E_NOTIMPL;
     67   }
     68   STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) {
     69     return E_NOTIMPL;
     70   }
     71 
     72   STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) {
     73     return E_NOTIMPL;
     74   }
     75 
     76   STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc,
     77     STGMEDIUM* storage) {
     78     return E_NOTIMPL;
     79   }
     80   STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) {
     81     return E_NOTIMPL;
     82   }
     83 };
     84 
     85 // Returns the full user agent header from the HTTP header strings passed to
     86 // IHttpNegotiate::BeginningTransaction. Looks first in |additional_headers|
     87 // and if it can't be found there looks in |headers|.
     88 std::string GetUserAgentFromHeaders(LPCWSTR headers,
     89                                     LPCWSTR additional_headers) {
     90   using net::HttpUtil;
     91 
     92   std::string ascii_headers;
     93   if (additional_headers) {
     94     ascii_headers = WideToASCII(additional_headers);
     95   }
     96 
     97   // Extract "User-Agent" from |additional_headers| or |headers|.
     98   HttpUtil::HeadersIterator headers_iterator(ascii_headers.begin(),
     99                                              ascii_headers.end(), "\r\n");
    100   std::string user_agent_value;
    101   if (headers_iterator.AdvanceTo(kLowerCaseUserAgent)) {
    102     user_agent_value = headers_iterator.values();
    103   } else if (headers != NULL) {
    104     // See if there's a user-agent header specified in the original headers.
    105     std::string original_headers(WideToASCII(headers));
    106     HttpUtil::HeadersIterator original_it(original_headers.begin(),
    107         original_headers.end(), "\r\n");
    108     if (original_it.AdvanceTo(kLowerCaseUserAgent))
    109       user_agent_value = original_it.values();
    110   }
    111 
    112   return user_agent_value;
    113 }
    114 
    115 // Removes the named header |field| from a set of headers. |field| must be
    116 // lower-case.
    117 std::string ExcludeFieldFromHeaders(const std::string& old_headers,
    118                                     const char* field) {
    119   using net::HttpUtil;
    120   std::string new_headers;
    121   new_headers.reserve(old_headers.size());
    122   HttpUtil::HeadersIterator headers_iterator(old_headers.begin(),
    123                                              old_headers.end(), "\r\n");
    124   while (headers_iterator.GetNext()) {
    125     if (!LowerCaseEqualsASCII(headers_iterator.name_begin(),
    126                               headers_iterator.name_end(),
    127                               field)) {
    128       new_headers.append(headers_iterator.name_begin(),
    129                          headers_iterator.name_end());
    130       new_headers += ": ";
    131       new_headers.append(headers_iterator.values_begin(),
    132                          headers_iterator.values_end());
    133       new_headers += "\r\n";
    134     }
    135   }
    136 
    137   return new_headers;
    138 }
    139 
    140 std::string MutateCFUserAgentString(LPCWSTR headers,
    141                                     LPCWSTR additional_headers,
    142                                     bool add_user_agent) {
    143   std::string user_agent_value(GetUserAgentFromHeaders(headers,
    144                                                        additional_headers));
    145 
    146   // Use the default "User-Agent" if none was provided.
    147   if (user_agent_value.empty())
    148     user_agent_value = http_utils::GetDefaultUserAgent();
    149 
    150   // Now add chromeframe to it.
    151   user_agent_value = add_user_agent ?
    152       http_utils::AddChromeFrameToUserAgentValue(user_agent_value) :
    153       http_utils::RemoveChromeFrameFromUserAgentValue(user_agent_value);
    154 
    155   // Build a new set of additional headers, skipping the existing user agent
    156   // value if present.
    157   return ReplaceOrAddUserAgent(additional_headers, user_agent_value);
    158 }
    159 
    160 }  // end namespace
    161 
    162 
    163 std::string AppendCFUserAgentString(LPCWSTR headers,
    164                                     LPCWSTR additional_headers) {
    165   return MutateCFUserAgentString(headers, additional_headers, true);
    166 }
    167 
    168 
    169 // Looks for a user agent header found in |headers| or |additional_headers|
    170 // then returns |additional_headers| with a modified user agent header that does
    171 // not include the chromeframe token.
    172 std::string RemoveCFUserAgentString(LPCWSTR headers,
    173                                     LPCWSTR additional_headers) {
    174   return MutateCFUserAgentString(headers, additional_headers, false);
    175 }
    176 
    177 
    178 // Unconditionally adds the specified |user_agent_value| to the given set of
    179 // |headers|, removing any that were already there.
    180 std::string ReplaceOrAddUserAgent(LPCWSTR headers,
    181                                   const std::string& user_agent_value) {
    182   std::string new_headers;
    183   if (headers) {
    184     std::string ascii_headers(WideToASCII(headers));
    185     // Build new headers, skip the existing user agent value from
    186     // existing headers.
    187     new_headers = ExcludeFieldFromHeaders(ascii_headers, kLowerCaseUserAgent);
    188   }
    189   new_headers += "User-Agent: ";
    190   new_headers += user_agent_value;
    191   new_headers += "\r\n";
    192   return new_headers;
    193 }
    194 
    195 HttpNegotiatePatch::HttpNegotiatePatch() {
    196 }
    197 
    198 HttpNegotiatePatch::~HttpNegotiatePatch() {
    199 }
    200 
    201 // static
    202 bool HttpNegotiatePatch::Initialize() {
    203   if (IS_PATCHED(IHttpNegotiate)) {
    204     DLOG(WARNING) << __FUNCTION__ << " called more than once.";
    205     return true;
    206   }
    207   // Use our SimpleBindStatusCallback class as we need a temporary object that
    208   // implements IBindStatusCallback.
    209   CComObjectStackEx<SimpleBindStatusCallback> request;
    210   base::win::ScopedComPtr<IBindCtx> bind_ctx;
    211   HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive());
    212   DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx";
    213   if (bind_ctx) {
    214     base::win::ScopedComPtr<IUnknown> bscb_holder;
    215     bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive());
    216     if (bscb_holder) {
    217       hr = PatchHttpNegotiate(bscb_holder);
    218     } else {
    219       NOTREACHED() << "Failed to get _BSCB_Holder_";
    220       hr = E_UNEXPECTED;
    221     }
    222     bind_ctx.Release();
    223   }
    224 
    225   return SUCCEEDED(hr);
    226 }
    227 
    228 // static
    229 void HttpNegotiatePatch::Uninitialize() {
    230   vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo);
    231 }
    232 
    233 // static
    234 HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) {
    235   DCHECK(to_patch);
    236   DCHECK_IS_NOT_PATCHED(IHttpNegotiate);
    237 
    238   base::win::ScopedComPtr<IHttpNegotiate> http;
    239   HRESULT hr = http.QueryFrom(to_patch);
    240   if (FAILED(hr)) {
    241     hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive());
    242   }
    243 
    244   if (http) {
    245     hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo);
    246     DLOG_IF(ERROR, FAILED(hr))
    247         << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr);
    248   } else {
    249     DLOG(WARNING)
    250         << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr);
    251   }
    252   return hr;
    253 }
    254 
    255 // static
    256 HRESULT HttpNegotiatePatch::BeginningTransaction(
    257     IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
    258     LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) {
    259   DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers;
    260 
    261   HRESULT hr = original(me, url, headers, reserved, additional_headers);
    262 
    263   if (FAILED(hr)) {
    264     DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error";
    265     return hr;
    266   }
    267   if (modify_user_agent_) {
    268     std::string updated_headers;
    269 
    270     if (IsGcfDefaultRenderer() &&
    271         RendererTypeForUrl(url) == RENDERER_TYPE_CHROME_DEFAULT_RENDERER) {
    272       // Replace the user-agent header with Chrome's.
    273       updated_headers = ReplaceOrAddUserAgent(*additional_headers,
    274                                               http_utils::GetChromeUserAgent());
    275     } else if (ShouldRemoveUAForUrl(url)) {
    276       updated_headers = RemoveCFUserAgentString(headers, *additional_headers);
    277     } else {
    278       updated_headers = AppendCFUserAgentString(headers, *additional_headers);
    279     }
    280 
    281     *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc(
    282         *additional_headers,
    283         (updated_headers.length() + 1) * sizeof(wchar_t)));
    284     lstrcpyW(*additional_headers, ASCIIToWide(updated_headers).c_str());
    285   } else {
    286     // TODO(erikwright): Remove the user agent if it is present (i.e., because
    287     // of PostPlatform setting in the registry).
    288   }
    289   return S_OK;
    290 }
    291