1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome_frame/crash_reporting/nt_loader.h" 6 7 #include <tlhelp32.h> 8 #include <winnt.h> 9 10 #include "base/at_exit.h" 11 #include "base/bind.h" 12 #include "base/bind_helpers.h" 13 #include "base/environment.h" 14 #include "base/memory/scoped_ptr.h" 15 #include "base/message_loop/message_loop.h" 16 #include "base/strings/string_util.h" 17 #include "base/strings/utf_string_conversions.h" 18 #include "base/sys_info.h" 19 #include "base/threading/thread.h" 20 #include "base/win/scoped_handle.h" 21 #include "chrome_frame/crash_reporting/crash_dll.h" 22 #include "gtest/gtest.h" 23 24 namespace { 25 void AssertIsCriticalSection(CRITICAL_SECTION* critsec) { 26 // Assert on some of the internals of the debug info if it has one. 27 RTL_CRITICAL_SECTION_DEBUG* debug = critsec->DebugInfo; 28 if (debug) { 29 ASSERT_EQ(RTL_CRITSECT_TYPE, debug->Type); 30 ASSERT_EQ(critsec, debug->CriticalSection); 31 } 32 33 // TODO(siggi): assert on the semaphore handle & object type? 34 } 35 36 class ScopedEnterCriticalSection { 37 public: 38 explicit ScopedEnterCriticalSection(CRITICAL_SECTION* critsec) 39 : critsec_(critsec) { 40 ::EnterCriticalSection(critsec_); 41 } 42 43 ~ScopedEnterCriticalSection() { 44 ::LeaveCriticalSection(critsec_); 45 } 46 47 private: 48 CRITICAL_SECTION* critsec_; 49 }; 50 51 std::wstring FromUnicodeString(const UNICODE_STRING& str) { 52 return std::wstring(str.Buffer, str.Length / sizeof(str.Buffer[0])); 53 } 54 55 } // namespace 56 57 using namespace nt_loader; 58 59 TEST(NtLoader, OwnsCriticalSection) { 60 // Use of Thread requires an atexit manager. 61 base::AtExitManager at_exit; 62 63 CRITICAL_SECTION cs = {}; 64 ::InitializeCriticalSection(&cs); 65 EXPECT_FALSE(OwnsCriticalSection(&cs)); 66 67 // Enter the critsec and assert we own it. 68 { 69 ScopedEnterCriticalSection lock1(&cs); 70 71 EXPECT_TRUE(OwnsCriticalSection(&cs)); 72 73 // Re-enter the critsec and assert we own it. 74 ScopedEnterCriticalSection lock2(&cs); 75 76 EXPECT_TRUE(OwnsCriticalSection(&cs)); 77 } 78 79 // Should no longer own it. 80 EXPECT_FALSE(OwnsCriticalSection(&cs)); 81 82 // Make another thread grab it. 83 base::Thread other("Other threads"); 84 ASSERT_TRUE(other.Start()); 85 other.message_loop()->PostTask( 86 FROM_HERE, base::Bind(::EnterCriticalSection, &cs)); 87 88 base::win::ScopedHandle event(::CreateEvent(NULL, FALSE, FALSE, NULL)); 89 other.message_loop()->PostTask( 90 FROM_HERE, base::Bind(base::IgnoreResult(::SetEvent), event.Get())); 91 92 ASSERT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(event.Get(), INFINITE)); 93 94 // We still shouldn't own it - the other thread does. 95 EXPECT_FALSE(OwnsCriticalSection(&cs)); 96 // And we shouldn't be able to enter it. 97 EXPECT_EQ(0, ::TryEnterCriticalSection(&cs)); 98 99 // Make the other thread release it. 100 other.message_loop()->PostTask( 101 FROM_HERE, base::Bind(::LeaveCriticalSection, &cs)); 102 103 other.Stop(); 104 105 ::DeleteCriticalSection(&cs); 106 } 107 108 TEST(NtLoader, GetLoaderLock) { 109 CRITICAL_SECTION* loader_lock = GetLoaderLock(); 110 111 AssertIsCriticalSection(loader_lock); 112 113 // We should be able to enter and leave the loader's lock without trouble. 114 EnterCriticalSection(loader_lock); 115 LeaveCriticalSection(loader_lock); 116 } 117 118 TEST(NtLoader, OwnsLoaderLock) { 119 CRITICAL_SECTION* loader_lock = GetLoaderLock(); 120 121 EXPECT_FALSE(OwnsLoaderLock()); 122 EnterCriticalSection(loader_lock); 123 EXPECT_TRUE(OwnsLoaderLock()); 124 LeaveCriticalSection(loader_lock); 125 EXPECT_FALSE(OwnsLoaderLock()); 126 } 127 128 TEST(NtLoader, GetLoaderEntry) { 129 // Get all modules in the current process. 130 base::win::ScopedHandle snap( 131 ::CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, ::GetCurrentProcessId())); 132 EXPECT_TRUE(snap.Get() != NULL); 133 134 // Walk them, while checking we get an entry for each, and that it 135 // contains sane information. 136 MODULEENTRY32 module = { sizeof(module) }; 137 ASSERT_TRUE(::Module32First(snap.Get(), &module)); 138 do { 139 ScopedEnterCriticalSection lock(GetLoaderLock()); 140 141 nt_loader::LDR_DATA_TABLE_ENTRY* entry = 142 nt_loader::GetLoaderEntry(module.hModule); 143 ASSERT_TRUE(entry != NULL); 144 EXPECT_EQ(module.hModule, reinterpret_cast<HMODULE>(entry->DllBase)); 145 EXPECT_STREQ(module.szModule, 146 FromUnicodeString(entry->BaseDllName).c_str()); 147 EXPECT_STREQ(module.szExePath, 148 FromUnicodeString(entry->FullDllName).c_str()); 149 150 ULONG flags = entry->Flags; 151 152 // All entries should have this flag set. 153 EXPECT_TRUE(flags & LDRP_ENTRY_PROCESSED); 154 155 if (0 == (flags & LDRP_IMAGE_DLL)) { 156 // TODO(siggi): write a test to assert this holds true for loading 157 // non-DLL, e.g. exe image files. 158 // Dlls have the LDRP_IMAGE_DLL flag set, any module that doesn't 159 // have that flag has to be the main executable. 160 EXPECT_TRUE(module.hModule == ::GetModuleHandle(NULL)); 161 } else { 162 // Since we're not currently loading any modules, all loaded 163 // modules should either have the LDRP_PROCESS_ATTACH_CALLED, 164 // or a NULL entrypoint. 165 if (entry->EntryPoint == NULL) { 166 EXPECT_FALSE(flags & LDRP_PROCESS_ATTACH_CALLED); 167 } else { 168 // Shimeng.dll is an exception to the above, it's loaded 169 // in a special way, see e.g. http://www.alex-ionescu.com/?p=41 170 // for details. 171 bool is_shimeng = LowerCaseEqualsASCII( 172 FromUnicodeString(entry->BaseDllName), "shimeng.dll"); 173 174 EXPECT_TRUE(is_shimeng || (flags & LDRP_PROCESS_ATTACH_CALLED)); 175 } 176 } 177 } while (::Module32Next(snap.Get(), &module)); 178 } 179 180 namespace { 181 182 typedef void (*ExceptionFunction)(EXCEPTION_POINTERS* ex_ptrs); 183 184 class NtLoaderTest: public testing::Test { 185 public: 186 NtLoaderTest() : veh_id_(NULL), exception_function_(NULL) { 187 EXPECT_EQ(NULL, current_); 188 current_ = this; 189 } 190 191 ~NtLoaderTest() { 192 EXPECT_TRUE(this == current_); 193 current_ = NULL; 194 } 195 196 void SetUp() { 197 veh_id_ = ::AddVectoredExceptionHandler(FALSE, &ExceptionHandler); 198 EXPECT_TRUE(veh_id_ != NULL); 199 200 // Clear the crash DLL environment. 201 scoped_ptr<base::Environment> env(base::Environment::Create()); 202 env->UnSetVar(WideToASCII(kCrashOnLoadMode).c_str()); 203 env->UnSetVar(WideToASCII(kCrashOnUnloadMode).c_str()); 204 } 205 206 void TearDown() { 207 if (veh_id_ != NULL) 208 EXPECT_NE(0, ::RemoveVectoredExceptionHandler(veh_id_)); 209 210 // Clear the crash DLL environment. 211 scoped_ptr<base::Environment> env(base::Environment::Create()); 212 env->UnSetVar(WideToASCII(kCrashOnLoadMode).c_str()); 213 env->UnSetVar(WideToASCII(kCrashOnUnloadMode).c_str()); 214 } 215 216 void set_exception_function(ExceptionFunction func) { 217 exception_function_ = func; 218 } 219 220 private: 221 static LONG NTAPI ExceptionHandler(EXCEPTION_POINTERS* ex_ptrs){ 222 // Dispatch the exception to any exception function, 223 // but only on the main thread. 224 if (main_thread_ == ::GetCurrentThreadId() && 225 current_ != NULL && 226 current_->exception_function_ != NULL) 227 current_->exception_function_(ex_ptrs); 228 229 return ExceptionContinueSearch; 230 } 231 232 void* veh_id_; 233 ExceptionFunction exception_function_; 234 235 static NtLoaderTest* current_; 236 static DWORD main_thread_; 237 }; 238 239 NtLoaderTest* NtLoaderTest::current_ = NULL; 240 DWORD NtLoaderTest::main_thread_ = ::GetCurrentThreadId(); 241 242 } // namespace 243 244 static int exceptions_handled = 0; 245 static void OnCrashDuringLoadLibrary(EXCEPTION_POINTERS* ex_ptrs) { 246 ASSERT_EQ(STATUS_ACCESS_VIOLATION, ex_ptrs->ExceptionRecord->ExceptionCode); 247 ASSERT_EQ(2, ex_ptrs->ExceptionRecord->NumberParameters); 248 ASSERT_EQ(EXCEPTION_WRITE_FAULT, 249 ex_ptrs->ExceptionRecord->ExceptionInformation[0]); 250 ASSERT_EQ(kCrashAddress, 251 ex_ptrs->ExceptionRecord->ExceptionInformation[1]); 252 253 // Bump the exceptions count. 254 exceptions_handled++; 255 256 EXPECT_TRUE(OwnsLoaderLock()); 257 258 HMODULE crash_dll = ::GetModuleHandle(kCrashDllName); 259 ASSERT_TRUE(crash_dll != NULL); 260 261 nt_loader::LDR_DATA_TABLE_ENTRY* entry = GetLoaderEntry(crash_dll); 262 ASSERT_TRUE(entry != NULL); 263 ASSERT_EQ(0, entry->Flags & LDRP_PROCESS_ATTACH_CALLED); 264 } 265 266 TEST_F(NtLoaderTest, CrashOnLoadLibrary) { 267 exceptions_handled = 0; 268 set_exception_function(OnCrashDuringLoadLibrary); 269 270 // Setup to crash on load. 271 scoped_ptr<base::Environment> env(base::Environment::Create()); 272 env->SetVar(WideToASCII(kCrashOnLoadMode).c_str(), "1"); 273 274 // And load it. 275 HMODULE module = ::LoadLibrary(kCrashDllName); 276 DWORD err = ::GetLastError(); 277 EXPECT_EQ(NULL, module); 278 EXPECT_EQ(ERROR_NOACCESS, err); 279 EXPECT_EQ(1, exceptions_handled); 280 281 if (module != NULL) 282 ::FreeLibrary(module); 283 } 284 285 static void OnCrashDuringUnloadLibrary(EXCEPTION_POINTERS* ex_ptrs) { 286 ASSERT_EQ(STATUS_ACCESS_VIOLATION, ex_ptrs->ExceptionRecord->ExceptionCode); 287 ASSERT_EQ(2, ex_ptrs->ExceptionRecord->NumberParameters); 288 ASSERT_EQ(EXCEPTION_WRITE_FAULT, 289 ex_ptrs->ExceptionRecord->ExceptionInformation[0]); 290 ASSERT_EQ(kCrashAddress, 291 ex_ptrs->ExceptionRecord->ExceptionInformation[1]); 292 293 // Bump the exceptions count. 294 exceptions_handled++; 295 296 EXPECT_TRUE(OwnsLoaderLock()); 297 298 HMODULE crash_dll = ::GetModuleHandle(kCrashDllName); 299 ASSERT_TRUE(crash_dll == NULL); 300 301 nt_loader::LDR_DATA_TABLE_ENTRY* entry = GetLoaderEntry(crash_dll); 302 ASSERT_TRUE(entry == NULL); 303 } 304 305 TEST_F(NtLoaderTest, CrashOnUnloadLibrary) { 306 // Setup to crash on unload. 307 scoped_ptr<base::Environment> env(base::Environment::Create()); 308 env->SetVar(WideToASCII(kCrashOnUnloadMode).c_str(), "1"); 309 310 // And load it. 311 HMODULE module = ::LoadLibrary(kCrashDllName); 312 EXPECT_TRUE(module != NULL); 313 314 exceptions_handled = 0; 315 set_exception_function(OnCrashDuringUnloadLibrary); 316 317 // We should crash during unload. 318 if (module != NULL) 319 ::FreeLibrary(module); 320 321 EXPECT_EQ(1, exceptions_handled); 322 } 323