Home | History | Annotate | Download | only in chrome_frame
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "chrome_frame/dll_redirector.h"
      6 
      7 #include <aclapi.h>
      8 #include <atlbase.h>
      9 #include <atlsecurity.h>
     10 #include <sddl.h>
     11 
     12 #include "base/file_version_info.h"
     13 #include "base/files/file_path.h"
     14 #include "base/logging.h"
     15 #include "base/memory/shared_memory.h"
     16 #include "base/path_service.h"
     17 #include "base/strings/string_util.h"
     18 #include "base/strings/utf_string_conversions.h"
     19 #include "base/version.h"
     20 #include "base/win/windows_version.h"
     21 #include "chrome_frame/utils.h"
     22 
     23 const wchar_t kSharedMemoryName[] = L"ChromeFrameVersionBeacon_";
     24 const uint32 kSharedMemorySize = 128;
     25 const uint32 kSharedMemoryLockTimeoutMs = 1000;
     26 
     27 // static
     28 DllRedirector::DllRedirector() : first_module_handle_(NULL) {
     29   // TODO(robertshield): Allow for overrides to be taken from the environment.
     30   std::wstring beacon_name(kSharedMemoryName);
     31   beacon_name += GetHostProcessName(false);
     32   shared_memory_.reset(new base::SharedMemory(beacon_name));
     33   shared_memory_name_ = WideToUTF8(beacon_name);
     34 }
     35 
     36 DllRedirector::DllRedirector(const char* shared_memory_name)
     37     : shared_memory_name_(shared_memory_name), first_module_handle_(NULL) {
     38   shared_memory_.reset(new base::SharedMemory(ASCIIToWide(shared_memory_name)));
     39 }
     40 
     41 DllRedirector::~DllRedirector() {
     42   if (first_module_handle_) {
     43     if (first_module_handle_ != reinterpret_cast<HMODULE>(&__ImageBase)) {
     44       FreeLibrary(first_module_handle_);
     45     } else {
     46       NOTREACHED() << "Error, DllRedirector attempting to free self.";
     47     }
     48 
     49     first_module_handle_ = NULL;
     50   }
     51   UnregisterAsFirstCFModule();
     52 }
     53 
     54 // static
     55 DllRedirector* DllRedirector::GetInstance() {
     56   return Singleton<DllRedirector>::get();
     57 }
     58 
     59 bool DllRedirector::BuildSecurityAttributesForLock(
     60     CSecurityAttributes* sec_attr) {
     61   DCHECK(sec_attr);
     62   if (base::win::GetVersion() < base::win::VERSION_VISTA) {
     63     // Don't bother with changing ACLs on pre-vista.
     64     return false;
     65   }
     66 
     67   bool success = false;
     68 
     69   // Fill out the rest of the security descriptor from the process token.
     70   CAccessToken token;
     71   if (token.GetProcessToken(TOKEN_QUERY)) {
     72     CSecurityDesc security_desc;
     73     // Set the SACL from an SDDL string that allows access to low-integrity
     74     // processes. See http://msdn.microsoft.com/en-us/library/bb625958.aspx.
     75     if (security_desc.FromString(L"S:(ML;;NW;;;LW)")) {
     76       CSid sid_owner;
     77       if (token.GetOwner(&sid_owner)) {
     78         security_desc.SetOwner(sid_owner);
     79       } else {
     80         NOTREACHED() << "Could not get owner.";
     81       }
     82       CSid sid_group;
     83       if (token.GetPrimaryGroup(&sid_group)) {
     84         security_desc.SetGroup(sid_group);
     85       } else {
     86         NOTREACHED() << "Could not get group.";
     87       }
     88       CDacl dacl;
     89       if (token.GetDefaultDacl(&dacl)) {
     90         // Add an access control entry mask for the current user.
     91         // This is what grants this user access from lower integrity levels.
     92         CSid sid_user;
     93         if (token.GetUser(&sid_user)) {
     94           success = dacl.AddAllowedAce(sid_user, MUTEX_ALL_ACCESS);
     95           security_desc.SetDacl(dacl);
     96           sec_attr->Set(security_desc);
     97         }
     98       }
     99     }
    100   }
    101 
    102   return success;
    103 }
    104 
    105 bool DllRedirector::SetFileMappingToReadOnly(base::SharedMemoryHandle mapping) {
    106   bool success = false;
    107 
    108   CAccessToken token;
    109   if (token.GetProcessToken(TOKEN_QUERY)) {
    110     CSid sid_user;
    111     if (token.GetUser(&sid_user)) {
    112       CDacl dacl;
    113       dacl.AddAllowedAce(sid_user, STANDARD_RIGHTS_READ | FILE_MAP_READ);
    114       success = AtlSetDacl(mapping, SE_KERNEL_OBJECT, dacl);
    115     }
    116   }
    117 
    118   return success;
    119 }
    120 
    121 
    122 bool DllRedirector::RegisterAsFirstCFModule() {
    123   DCHECK(first_module_handle_ == NULL);
    124 
    125   // Build our own file version outside of the lock:
    126   scoped_ptr<Version> our_version(GetCurrentModuleVersion());
    127 
    128   // We sadly can't use the autolock here since we want to have a timeout.
    129   // Be careful not to return while holding the lock. Also, attempt to do as
    130   // little as possible while under this lock.
    131 
    132   bool lock_acquired = false;
    133   CSecurityAttributes sec_attr;
    134   if (base::win::GetVersion() >= base::win::VERSION_VISTA &&
    135       BuildSecurityAttributesForLock(&sec_attr)) {
    136     // On vista and above, we need to explicitly allow low integrity access
    137     // to our objects. On XP, we don't bother.
    138     lock_acquired = shared_memory_->Lock(kSharedMemoryLockTimeoutMs, &sec_attr);
    139   } else {
    140     lock_acquired = shared_memory_->Lock(kSharedMemoryLockTimeoutMs, NULL);
    141   }
    142 
    143   if (!lock_acquired) {
    144     // We couldn't get the lock in a reasonable amount of time, so fall
    145     // back to loading our current version. We return true to indicate that the
    146     // caller should not attempt to delegate to an already loaded version.
    147     dll_version_.swap(our_version);
    148     return true;
    149   }
    150 
    151   bool created_beacon = true;
    152   bool result = shared_memory_->CreateNamed(shared_memory_name_.c_str(),
    153                                             false,  // open_existing
    154                                             kSharedMemorySize);
    155 
    156   if (result) {
    157     // We created the beacon, now we need to mutate the security attributes
    158     // on the shared memory to allow read-only access and let low-integrity
    159     // processes open it. This will fail on FAT32 file systems.
    160     if (!SetFileMappingToReadOnly(shared_memory_->handle())) {
    161       DLOG(ERROR) << "Failed to set file mapping permissions.";
    162     }
    163   } else {
    164     created_beacon = false;
    165 
    166     // We failed to create the shared memory segment, suggesting it may already
    167     // exist: try to create it read-only.
    168     result = shared_memory_->Open(shared_memory_name_.c_str(),
    169                                   true /* read_only */);
    170   }
    171 
    172   if (result) {
    173     // Map in the whole thing.
    174     result = shared_memory_->Map(0);
    175     DCHECK(shared_memory_->memory());
    176 
    177     if (result) {
    178       // Either write our own version number or read it in if it was already
    179       // present in the shared memory section.
    180       if (created_beacon) {
    181         dll_version_.swap(our_version);
    182 
    183         lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()),
    184                   dll_version_->GetString().c_str(),
    185                   std::min(kSharedMemorySize,
    186                            dll_version_->GetString().length() + 1));
    187       } else {
    188         char buffer[kSharedMemorySize] = {0};
    189         memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1);
    190         dll_version_.reset(new Version(buffer));
    191 
    192         if (!dll_version_->IsValid() ||
    193             dll_version_->Equals(*our_version.get())) {
    194           // If we either couldn't parse a valid version out of the shared
    195           // memory or we did parse a version and it is the same as our own,
    196           // then pretend we're first in to avoid trying to load any other DLLs.
    197           dll_version_.reset(our_version.release());
    198           created_beacon = true;
    199         }
    200       }
    201     } else {
    202       NOTREACHED() << "Failed to map in version beacon.";
    203     }
    204   } else {
    205     NOTREACHED() << "Could not create file mapping for version beacon, gle: "
    206                  << ::GetLastError();
    207   }
    208 
    209   // Matching Unlock.
    210   shared_memory_->Unlock();
    211 
    212   return created_beacon;
    213 }
    214 
    215 void DllRedirector::UnregisterAsFirstCFModule() {
    216   if (base::SharedMemory::IsHandleValid(shared_memory_->handle())) {
    217     bool lock_acquired = shared_memory_->Lock(kSharedMemoryLockTimeoutMs, NULL);
    218     if (lock_acquired) {
    219       // Free our handles. The last closed handle SHOULD result in it being
    220       // deleted.
    221       shared_memory_->Close();
    222       shared_memory_->Unlock();
    223     }
    224   }
    225 }
    226 
    227 LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectPtr() {
    228   HMODULE first_module_handle = GetFirstModule();
    229 
    230   LPFNGETCLASSOBJECT proc_ptr = NULL;
    231   if (first_module_handle) {
    232     proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>(
    233         GetProcAddress(first_module_handle, "DllGetClassObject"));
    234     DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of "
    235                                   "DllGetClassObject from first loaded module.";
    236   }
    237 
    238   return proc_ptr;
    239 }
    240 
    241 Version* DllRedirector::GetCurrentModuleVersion() {
    242   scoped_ptr<FileVersionInfo> file_version_info(
    243       FileVersionInfo::CreateFileVersionInfoForCurrentModule());
    244   DCHECK(file_version_info.get());
    245 
    246   scoped_ptr<Version> current_version;
    247   if (file_version_info.get()) {
    248      current_version.reset(
    249          new Version(WideToASCII(file_version_info->file_version())));
    250     DCHECK(current_version->IsValid());
    251   }
    252 
    253   return current_version.release();
    254 }
    255 
    256 HMODULE DllRedirector::GetFirstModule() {
    257   DCHECK(dll_version_.get())
    258       << "Error: Did you call RegisterAsFirstCFModule() first?";
    259 
    260   if (first_module_handle_ == NULL) {
    261     first_module_handle_ = LoadVersionedModule(dll_version_.get());
    262   }
    263 
    264   if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) {
    265     NOTREACHED() << "Should not be loading own version.";
    266     first_module_handle_ = NULL;
    267   }
    268 
    269   return first_module_handle_;
    270 }
    271 
    272 HMODULE DllRedirector::LoadVersionedModule(Version* version) {
    273   DCHECK(version);
    274 
    275   HMODULE hmodule = NULL;
    276   wchar_t system_buffer[MAX_PATH];
    277   HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase);
    278   system_buffer[0] = 0;
    279   if (GetModuleFileName(this_module, system_buffer,
    280                         arraysize(system_buffer)) != 0) {
    281     base::FilePath module_path(system_buffer);
    282 
    283     // For a module located in
    284     // Foo\XXXXXXXXX\<module>.dll, load
    285     // Foo\<version>\<module>.dll:
    286     base::FilePath module_name = module_path.BaseName();
    287     module_path = module_path.DirName()
    288                              .DirName()
    289                              .Append(ASCIIToWide(version->GetString()))
    290                              .Append(module_name);
    291 
    292     hmodule = LoadLibrary(module_path.value().c_str());
    293     if (hmodule == NULL) {
    294       DPLOG(ERROR) << "Could not load reported module version "
    295                    << version->GetString();
    296     }
    297   } else {
    298     DPLOG(FATAL) << "Failed to get module file name";
    299   }
    300   return hmodule;
    301 }
    302