Home | History | Annotate | Download | only in test
      1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/test/test_process_killer_win.h"
      6 
      7 #include <windows.h>
      8 #include <winternl.h>
      9 
     10 #include <algorithm>
     11 
     12 #include "base/logging.h"
     13 #include "base/process/kill.h"
     14 #include "base/process/process_iterator.h"
     15 #include "base/strings/string_util.h"
     16 #include "base/win/scoped_handle.h"
     17 
     18 namespace {
     19 
     20 typedef LONG WINAPI
     21 NtQueryInformationProcess(
     22   IN HANDLE ProcessHandle,
     23   IN PROCESSINFOCLASS ProcessInformationClass,
     24   OUT PVOID ProcessInformation,
     25   IN ULONG ProcessInformationLength,
     26   OUT PULONG ReturnLength OPTIONAL
     27 );
     28 
     29 // Get the function pointer to NtQueryInformationProcess in NTDLL.DLL
     30 static bool GetQIP(NtQueryInformationProcess** qip_func_ptr) {
     31   static NtQueryInformationProcess* qip_func =
     32       reinterpret_cast<NtQueryInformationProcess*>(
     33           GetProcAddress(GetModuleHandle(L"ntdll.dll"),
     34           "NtQueryInformationProcess"));
     35   DCHECK(qip_func) << "Could not get pointer to NtQueryInformationProcess.";
     36   *qip_func_ptr = qip_func;
     37   return qip_func != NULL;
     38 }
     39 
     40 // Get the command line of a process
     41 bool GetCommandLineForProcess(uint32 process_id, base::string16* cmd_line) {
     42   DCHECK(process_id != 0);
     43   DCHECK(cmd_line);
     44 
     45   // Open the process
     46   base::win::ScopedHandle process_handle(::OpenProcess(
     47       PROCESS_QUERY_INFORMATION | PROCESS_VM_READ,
     48       false,
     49       process_id));
     50   if (!process_handle) {
     51     DLOG(ERROR) << "Failed to open process " << process_id << ", last error = "
     52                 << GetLastError();
     53   }
     54 
     55   // Obtain Process Environment Block
     56   NtQueryInformationProcess* qip_func = NULL;
     57   if (process_handle) {
     58     GetQIP(&qip_func);
     59   }
     60 
     61   // Read the address of the process params from the peb.
     62   DWORD process_params_address = 0;
     63   if (qip_func) {
     64     PROCESS_BASIC_INFORMATION info = { 0 };
     65     // NtQueryInformationProcess returns an NTSTATUS for whom negative values
     66     // are negative. Just check for that instead of pulling in DDK macros.
     67     if ((qip_func(process_handle.Get(),
     68                   ProcessBasicInformation,
     69                   &info,
     70                   sizeof(info),
     71                   NULL)) < 0) {
     72       DLOG(ERROR) << "Failed to invoke NtQueryProcessInformation, last error = "
     73                   << GetLastError();
     74     } else {
     75       BYTE* peb = reinterpret_cast<BYTE*>(info.PebBaseAddress);
     76 
     77       // The process command line parameters are (or were once) located at
     78       // the base address of the PEB + 0x10 for 32 bit processes. 64 bit
     79       // processes have a different PEB struct as per
     80       // http://msdn.microsoft.com/en-us/library/aa813706(VS.85).aspx.
     81       // TODO(robertshield): See about doing something about this.
     82       SIZE_T bytes_read = 0;
     83       if (!::ReadProcessMemory(process_handle.Get(),
     84                                peb + 0x10,
     85                                &process_params_address,
     86                                sizeof(process_params_address),
     87                                &bytes_read)) {
     88         DLOG(ERROR) << "Failed to read process params address, last error = "
     89                     << GetLastError();
     90       }
     91     }
     92   }
     93 
     94   // Copy all the process parameters into a buffer.
     95   bool success = false;
     96   base::string16 buffer;
     97   if (process_params_address) {
     98     SIZE_T bytes_read;
     99     RTL_USER_PROCESS_PARAMETERS params = { 0 };
    100     if (!::ReadProcessMemory(process_handle.Get(),
    101                              reinterpret_cast<void*>(process_params_address),
    102                              &params,
    103                              sizeof(params),
    104                              &bytes_read)) {
    105       DLOG(ERROR) << "Failed to read RTL_USER_PROCESS_PARAMETERS, "
    106                   << "last error = " << GetLastError();
    107     } else {
    108       // Read the command line parameter
    109       const int max_cmd_line_len = std::min(
    110           static_cast<int>(params.CommandLine.MaximumLength),
    111           4096);
    112       buffer.resize(max_cmd_line_len + 1);
    113       if (!::ReadProcessMemory(process_handle.Get(),
    114                                params.CommandLine.Buffer,
    115                                &buffer[0],
    116                                max_cmd_line_len,
    117                                &bytes_read)) {
    118         DLOG(ERROR) << "Failed to copy process command line, "
    119                     << "last error = " << GetLastError();
    120       } else {
    121         *cmd_line = buffer;
    122         success = true;
    123       }
    124     }
    125   }
    126 
    127   return success;
    128 }
    129 
    130 // Used to filter processes by process ID.
    131 class ArgumentFilter : public base::ProcessFilter {
    132  public:
    133   explicit ArgumentFilter(const base::string16& argument)
    134       : argument_to_find_(argument) {}
    135 
    136   // Returns true to indicate set-inclusion and false otherwise.  This method
    137   // should not have side-effects and should be idempotent.
    138   virtual bool Includes(const base::ProcessEntry& entry) const {
    139     bool found = false;
    140     base::string16 command_line;
    141     if (GetCommandLineForProcess(entry.pid(), &command_line)) {
    142       base::string16::const_iterator it =
    143           std::search(command_line.begin(),
    144                       command_line.end(),
    145                       argument_to_find_.begin(),
    146                       argument_to_find_.end(),
    147                       base::CaseInsensitiveCompareASCII<wchar_t>());
    148       found = (it != command_line.end());
    149     }
    150     return found;
    151   }
    152 
    153  protected:
    154   base::string16 argument_to_find_;
    155 };
    156 
    157 }  // namespace
    158 
    159 namespace base {
    160 
    161 bool KillAllNamedProcessesWithArgument(const string16& process_name,
    162                                        const string16& argument) {
    163   ArgumentFilter argument_filter(argument);
    164   return base::KillProcesses(process_name, 0, &argument_filter);
    165 }
    166 
    167 }  // namespace base
    168