Home | History | Annotate | Download | only in wgl
      1 /**************************************************************************
      2  *
      3  * Copyright 2009-2013 VMware, Inc.
      4  * All Rights Reserved.
      5  *
      6  * Permission is hereby granted, free of charge, to any person obtaining a
      7  * copy of this software and associated documentation files (the
      8  * "Software"), to deal in the Software without restriction, including
      9  * without limitation the rights to use, copy, modify, merge, publish,
     10  * distribute, sub license, and/or sell copies of the Software, and to
     11  * permit persons to whom the Software is furnished to do so, subject to
     12  * the following conditions:
     13  *
     14  * The above copyright notice and this permission notice (including the
     15  * next paragraph) shall be included in all copies or substantial portions
     16  * of the Software.
     17  *
     18  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
     19  * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
     20  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
     21  * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
     22  * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
     23  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
     24  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
     25  *
     26  **************************************************************************/
     27 
     28 #include <windows.h>
     29 #include <tlhelp32.h>
     30 
     31 #include "pipe/p_compiler.h"
     32 #include "util/u_debug.h"
     33 #include "stw_tls.h"
     34 
     35 static DWORD tlsIndex = TLS_OUT_OF_INDEXES;
     36 
     37 
     38 /**
     39  * Static mutex to protect the access to g_pendingTlsData global and
     40  * stw_tls_data::next member.
     41  */
     42 static CRITICAL_SECTION g_mutex = {
     43    (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0
     44 };
     45 
     46 /**
     47  * There is no way to invoke TlsSetValue for a different thread, so we
     48  * temporarily put the thread data for non-current threads here.
     49  */
     50 static struct stw_tls_data *g_pendingTlsData = NULL;
     51 
     52 
     53 static struct stw_tls_data *
     54 stw_tls_data_create(DWORD dwThreadId);
     55 
     56 static struct stw_tls_data *
     57 stw_tls_lookup_pending_data(DWORD dwThreadId);
     58 
     59 
     60 boolean
     61 stw_tls_init(void)
     62 {
     63    tlsIndex = TlsAlloc();
     64    if (tlsIndex == TLS_OUT_OF_INDEXES) {
     65       return FALSE;
     66    }
     67 
     68    /*
     69     * DllMain is called with DLL_THREAD_ATTACH only for threads created after
     70     * the DLL is loaded by the process.  So enumerate and add our hook to all
     71     * previously existing threads.
     72     *
     73     * XXX: Except for the current thread since it there is an explicit
     74     * stw_tls_init_thread() call for it later on.
     75     */
     76    if (1) {
     77       DWORD dwCurrentProcessId = GetCurrentProcessId();
     78       DWORD dwCurrentThreadId = GetCurrentThreadId();
     79       HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId);
     80       if (hSnapshot != INVALID_HANDLE_VALUE) {
     81          THREADENTRY32 te;
     82          te.dwSize = sizeof te;
     83          if (Thread32First(hSnapshot, &te)) {
     84             do {
     85                if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
     86                                 sizeof te.th32OwnerProcessID) {
     87                   if (te.th32OwnerProcessID == dwCurrentProcessId) {
     88                      if (te.th32ThreadID != dwCurrentThreadId) {
     89                         struct stw_tls_data *data;
     90                         data = stw_tls_data_create(te.th32ThreadID);
     91                         if (data) {
     92                            EnterCriticalSection(&g_mutex);
     93                            data->next = g_pendingTlsData;
     94                            g_pendingTlsData = data;
     95                            LeaveCriticalSection(&g_mutex);
     96                         }
     97                      }
     98                   }
     99                }
    100                te.dwSize = sizeof te;
    101             } while (Thread32Next(hSnapshot, &te));
    102          }
    103          CloseHandle(hSnapshot);
    104       }
    105    }
    106 
    107    return TRUE;
    108 }
    109 
    110 
    111 /**
    112  * Install windows hook for a given thread (not necessarily the current one).
    113  */
    114 static struct stw_tls_data *
    115 stw_tls_data_create(DWORD dwThreadId)
    116 {
    117    struct stw_tls_data *data;
    118 
    119    if (0) {
    120       debug_printf("%s(0x%04lx)\n", __FUNCTION__, dwThreadId);
    121    }
    122 
    123    data = calloc(1, sizeof *data);
    124    if (!data) {
    125       goto no_data;
    126    }
    127 
    128    data->dwThreadId = dwThreadId;
    129 
    130    data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC,
    131                                              stw_call_window_proc,
    132                                              NULL,
    133                                              dwThreadId);
    134    if (data->hCallWndProcHook == NULL) {
    135       goto no_hook;
    136    }
    137 
    138    return data;
    139 
    140 no_hook:
    141    free(data);
    142 no_data:
    143    return NULL;
    144 }
    145 
    146 /**
    147  * Destroy the per-thread data/hook.
    148  *
    149  * It is important to remove all hooks when unloading our DLL, otherwise our
    150  * hook function might be called after it is no longer there.
    151  */
    152 static void
    153 stw_tls_data_destroy(struct stw_tls_data *data)
    154 {
    155    assert(data);
    156    if (!data) {
    157       return;
    158    }
    159 
    160    if (0) {
    161       debug_printf("%s(0x%04lx)\n", __FUNCTION__, data->dwThreadId);
    162    }
    163 
    164    if (data->hCallWndProcHook) {
    165       UnhookWindowsHookEx(data->hCallWndProcHook);
    166       data->hCallWndProcHook = NULL;
    167    }
    168 
    169    free(data);
    170 }
    171 
    172 boolean
    173 stw_tls_init_thread(void)
    174 {
    175    struct stw_tls_data *data;
    176 
    177    if (tlsIndex == TLS_OUT_OF_INDEXES) {
    178       return FALSE;
    179    }
    180 
    181    data = stw_tls_data_create(GetCurrentThreadId());
    182    if (!data) {
    183       return FALSE;
    184    }
    185 
    186    TlsSetValue(tlsIndex, data);
    187 
    188    return TRUE;
    189 }
    190 
    191 void
    192 stw_tls_cleanup_thread(void)
    193 {
    194    struct stw_tls_data *data;
    195 
    196    if (tlsIndex == TLS_OUT_OF_INDEXES) {
    197       return;
    198    }
    199 
    200    data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
    201    if (data) {
    202       TlsSetValue(tlsIndex, NULL);
    203    } else {
    204       /* See if there this thread's data in on the pending list */
    205       data = stw_tls_lookup_pending_data(GetCurrentThreadId());
    206    }
    207 
    208    if (data) {
    209       stw_tls_data_destroy(data);
    210    }
    211 }
    212 
    213 void
    214 stw_tls_cleanup(void)
    215 {
    216    if (tlsIndex != TLS_OUT_OF_INDEXES) {
    217       /*
    218        * Destroy all items in g_pendingTlsData linked list.
    219        */
    220       EnterCriticalSection(&g_mutex);
    221       while (g_pendingTlsData) {
    222          struct stw_tls_data * data = g_pendingTlsData;
    223          g_pendingTlsData = data->next;
    224          stw_tls_data_destroy(data);
    225       }
    226       LeaveCriticalSection(&g_mutex);
    227 
    228       TlsFree(tlsIndex);
    229       tlsIndex = TLS_OUT_OF_INDEXES;
    230    }
    231 }
    232 
    233 /*
    234  * Search for the current thread in the g_pendingTlsData linked list.
    235  *
    236  * It will remove and return the node on success, or return NULL on failure.
    237  */
    238 static struct stw_tls_data *
    239 stw_tls_lookup_pending_data(DWORD dwThreadId)
    240 {
    241    struct stw_tls_data ** p_data;
    242    struct stw_tls_data *data = NULL;
    243 
    244    EnterCriticalSection(&g_mutex);
    245    for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) {
    246       if ((*p_data)->dwThreadId == dwThreadId) {
    247          data = *p_data;
    248 
    249 	 /*
    250 	  * Unlink the node.
    251 	  */
    252          *p_data = data->next;
    253          data->next = NULL;
    254 
    255 	 break;
    256       }
    257    }
    258    LeaveCriticalSection(&g_mutex);
    259 
    260    return data;
    261 }
    262 
    263 struct stw_tls_data *
    264 stw_tls_get_data(void)
    265 {
    266    struct stw_tls_data *data;
    267 
    268    if (tlsIndex == TLS_OUT_OF_INDEXES) {
    269       return NULL;
    270    }
    271 
    272    data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
    273    if (!data) {
    274       DWORD dwCurrentThreadId = GetCurrentThreadId();
    275 
    276       /*
    277        * Search for the current thread in the g_pendingTlsData linked list.
    278        */
    279       data = stw_tls_lookup_pending_data(dwCurrentThreadId);
    280 
    281       if (!data) {
    282          /*
    283           * This should be impossible now.
    284           */
    285 	 assert(!"Failed to find thread data for thread id");
    286 
    287          /*
    288           * DllMain is called with DLL_THREAD_ATTACH only by threads created
    289           * after the DLL is loaded by the process
    290           */
    291          data = stw_tls_data_create(dwCurrentThreadId);
    292          if (!data) {
    293             return NULL;
    294          }
    295       }
    296 
    297       TlsSetValue(tlsIndex, data);
    298    }
    299 
    300    assert(data);
    301    assert(data->dwThreadId = GetCurrentThreadId());
    302    assert(data->next == NULL);
    303 
    304    return data;
    305 }
    306