Home | History | Annotate | Download | only in win
      1 /*
      2  * Copyright (C) 2007 Apple Inc. All rights reserved.
      3  *
      4  * Redistribution and use in source and binary forms, with or without
      5  * modification, are permitted provided that the following conditions
      6  * are met:
      7  * 1. Redistributions of source code must retain the above copyright
      8  *    notice, this list of conditions and the following disclaimer.
      9  * 2. Redistributions in binary form must reproduce the above copyright
     10  *    notice, this list of conditions and the following disclaimer in the
     11  *    documentation and/or other materials provided with the distribution.
     12  *
     13  * THIS SOFTWARE IS PROVIDED BY APPLE COMPUTER, INC. ``AS IS'' AND ANY
     14  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     15  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
     16  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE COMPUTER, INC. OR
     17  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
     18  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     19  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
     20  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
     21  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     24  */
     25 
     26 #ifndef COMPtr_h
     27 #define COMPtr_h
     28 
     29 #ifndef NOMINMAX
     30 #define NOMINMAX
     31 #endif
     32 
     33 #include <guiddef.h>
     34 #include <unknwn.h>
     35 #include <WTF/Assertions.h>
     36 #include <WTF/HashTraits.h>
     37 
     38 typedef long HRESULT;
     39 
     40 // FIXME: Should we put this into the WebCore namespace and use "using" on it
     41 // as we do with things in WTF?
     42 
     43 enum AdoptCOMTag { AdoptCOM };
     44 enum QueryTag { Query };
     45 enum CreateTag { Create };
     46 
     47 template <typename T> class COMPtr {
     48 public:
     49     COMPtr() : m_ptr(0) { }
     50     COMPtr(T* ptr) : m_ptr(ptr) { if (m_ptr) m_ptr->AddRef(); }
     51     COMPtr(AdoptCOMTag, T* ptr) : m_ptr(ptr) { }
     52     COMPtr(const COMPtr& o) : m_ptr(o.m_ptr) { if (T* ptr = m_ptr) ptr->AddRef(); }
     53 
     54     COMPtr(QueryTag, IUnknown* ptr) : m_ptr(copyQueryInterfaceRef(ptr)) { }
     55     template <typename U> COMPtr(QueryTag, const COMPtr<U>& ptr) : m_ptr(copyQueryInterfaceRef(ptr.get())) { }
     56 
     57     COMPtr(CreateTag, const IID& clsid) : m_ptr(createInstance(clsid)) { }
     58 
     59     // Hash table deleted values, which are only constructed and never copied or destroyed.
     60     COMPtr(WTF::HashTableDeletedValueType) : m_ptr(hashTableDeletedValue()) { }
     61     bool isHashTableDeletedValue() const { return m_ptr == hashTableDeletedValue(); }
     62 
     63     ~COMPtr() { if (m_ptr) m_ptr->Release(); }
     64 
     65     T* get() const { return m_ptr; }
     66     T* releaseRef() { T* tmp = m_ptr; m_ptr = 0; return tmp; }
     67 
     68     T& operator*() const { return *m_ptr; }
     69     T* operator->() const { return m_ptr; }
     70 
     71     T** operator&() { ASSERT(!m_ptr); return &m_ptr; }
     72 
     73     bool operator!() const { return !m_ptr; }
     74 
     75     // This conversion operator allows implicit conversion to bool but not to other integer types.
     76     typedef T* (COMPtr::*UnspecifiedBoolType)() const;
     77     operator UnspecifiedBoolType() const { return m_ptr ? &COMPtr::get : 0; }
     78 
     79     COMPtr& operator=(const COMPtr&);
     80     COMPtr& operator=(T*);
     81     template <typename U> COMPtr& operator=(const COMPtr<U>&);
     82 
     83     void query(IUnknown* ptr) { adoptRef(copyQueryInterfaceRef(ptr)); }
     84     template <typename U> void query(const COMPtr<U>& ptr) { query(ptr.get()); }
     85 
     86     void create(const IID& clsid) { adoptRef(createInstance(clsid)); }
     87 
     88     template <typename U> HRESULT copyRefTo(U**);
     89     void adoptRef(T*);
     90 
     91 private:
     92     static T* copyQueryInterfaceRef(IUnknown*);
     93     static T* createInstance(const IID& clsid);
     94     static T* hashTableDeletedValue() { return reinterpret_cast<T*>(-1); }
     95 
     96     T* m_ptr;
     97 };
     98 
     99 template <typename T> inline T* COMPtr<T>::createInstance(const IID& clsid)
    100 {
    101     T* result;
    102     if (FAILED(CoCreateInstance(clsid, 0, CLSCTX_ALL, __uuidof(result), reinterpret_cast<void**>(&result))))
    103         return 0;
    104     return result;
    105 }
    106 
    107 template <typename T> inline T* COMPtr<T>::copyQueryInterfaceRef(IUnknown* ptr)
    108 {
    109     if (!ptr)
    110         return 0;
    111     T* result;
    112     if (FAILED(ptr->QueryInterface(&result)))
    113         return 0;
    114     return result;
    115 }
    116 
    117 template <typename T> template <typename U> inline HRESULT COMPtr<T>::copyRefTo(U** ptr)
    118 {
    119     if (!ptr)
    120         return E_POINTER;
    121     *ptr = m_ptr;
    122     if (m_ptr)
    123         m_ptr->AddRef();
    124     return S_OK;
    125 }
    126 
    127 template <typename T> inline void COMPtr<T>::adoptRef(T *ptr)
    128 {
    129     if (m_ptr)
    130         m_ptr->Release();
    131     m_ptr = ptr;
    132 }
    133 
    134 template <typename T> inline COMPtr<T>& COMPtr<T>::operator=(const COMPtr<T>& o)
    135 {
    136     T* optr = o.get();
    137     if (optr)
    138         optr->AddRef();
    139     T* ptr = m_ptr;
    140     m_ptr = optr;
    141     if (ptr)
    142         ptr->Release();
    143     return *this;
    144 }
    145 
    146 template <typename T> template <typename U> inline COMPtr<T>& COMPtr<T>::operator=(const COMPtr<U>& o)
    147 {
    148     T* optr = o.get();
    149     if (optr)
    150         optr->AddRef();
    151     T* ptr = m_ptr;
    152     m_ptr = optr;
    153     if (ptr)
    154         ptr->Release();
    155     return *this;
    156 }
    157 
    158 template <typename T> inline COMPtr<T>& COMPtr<T>::operator=(T* optr)
    159 {
    160     if (optr)
    161         optr->AddRef();
    162     T* ptr = m_ptr;
    163     m_ptr = optr;
    164     if (ptr)
    165         ptr->Release();
    166     return *this;
    167 }
    168 
    169 template <typename T, typename U> inline bool operator==(const COMPtr<T>& a, const COMPtr<U>& b)
    170 {
    171     return a.get() == b.get();
    172 }
    173 
    174 template <typename T, typename U> inline bool operator==(const COMPtr<T>& a, U* b)
    175 {
    176     return a.get() == b;
    177 }
    178 
    179 template <typename T, typename U> inline bool operator==(T* a, const COMPtr<U>& b)
    180 {
    181     return a == b.get();
    182 }
    183 
    184 template <typename T, typename U> inline bool operator!=(const COMPtr<T>& a, const COMPtr<U>& b)
    185 {
    186     return a.get() != b.get();
    187 }
    188 
    189 template <typename T, typename U> inline bool operator!=(const COMPtr<T>& a, U* b)
    190 {
    191     return a.get() != b;
    192 }
    193 
    194 template <typename T, typename U> inline bool operator!=(T* a, const COMPtr<U>& b)
    195 {
    196     return a != b.get();
    197 }
    198 
    199 namespace WTF {
    200 
    201     template<typename P> struct HashTraits<COMPtr<P> > : GenericHashTraits<COMPtr<P> > {
    202         static const bool emptyValueIsZero = true;
    203         static void constructDeletedValue(COMPtr<P>& slot) { slot.releaseRef(); *&slot = reinterpret_cast<P*>(-1); }
    204         static bool isDeletedValue(const COMPtr<P>& value) { return value == reinterpret_cast<P*>(-1); }
    205     };
    206 
    207     template<typename P> struct PtrHash<COMPtr<P> > : PtrHash<P*> {
    208         using PtrHash<P*>::hash;
    209         static unsigned hash(const COMPtr<P>& key) { return hash(key.get()); }
    210         using PtrHash<P*>::equal;
    211         static bool equal(const COMPtr<P>& a, const COMPtr<P>& b) { return a == b; }
    212         static bool equal(P* a, const COMPtr<P>& b) { return a == b; }
    213         static bool equal(const COMPtr<P>& a, P* b) { return a == b; }
    214     };
    215 
    216     template<typename P> struct DefaultHash<COMPtr<P> > { typedef PtrHash<COMPtr<P> > Hash; };
    217 }
    218 
    219 #endif
    220