Home | History | Annotate | Download | only in wrl
      1 /**
      2  * This file has no copyright assigned and is placed in the Public Domain.
      3  * This file is part of the mingw-w64 runtime package.
      4  * No warranty is given; refer to the file DISCLAIMER.PD within this package.
      5  */
      6 
      7 #ifndef _WRL_CLIENT_H_
      8 #define _WRL_CLIENT_H_
      9 
     10 #include <stddef.h>
     11 #include <unknwn.h>
     12 /* #include <weakreference.h> */
     13 #include <roapi.h>
     14 
     15 /* #include <wrl/def.h> */
     16 #include <wrl/internal.h>
     17 
     18 namespace Microsoft {
     19     namespace WRL {
     20         namespace Details {
     21             template <typename T> class ComPtrRefBase {
     22             protected:
     23                 T* ptr_;
     24 
     25             public:
     26                 typedef typename T::InterfaceType InterfaceType;
     27 
     28 #ifndef __WRL_CLASSIC_COM__
     29                 operator IInspectable**() const throw()  {
     30                     static_assert(__is_base_of(IInspectable, InterfaceType), "Invalid cast");
     31                     return reinterpret_cast<IInspectable**>(ptr_->ReleaseAndGetAddressOf());
     32                 }
     33 #endif
     34 
     35                 operator IUnknown**() const throw() {
     36                     static_assert(__is_base_of(IUnknown, InterfaceType), "Invalid cast");
     37                     return reinterpret_cast<IUnknown**>(ptr_->ReleaseAndGetAddressOf());
     38                 }
     39             };
     40 
     41             template <typename T> class ComPtrRef : public Details::ComPtrRefBase<T> {
     42             public:
     43                 ComPtrRef(T *ptr) throw() {
     44                     ComPtrRefBase<T>::ptr_ = ptr;
     45                 }
     46 
     47                 operator void**() const throw() {
     48                     return reinterpret_cast<void**>(ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf());
     49                 }
     50 
     51                 operator T*() throw() {
     52                     *ComPtrRefBase<T>::ptr_ = nullptr;
     53                     return ComPtrRefBase<T>::ptr_;
     54                 }
     55 
     56                 operator typename ComPtrRefBase<T>::InterfaceType**() throw() {
     57                     return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
     58                 }
     59 
     60                 typename ComPtrRefBase<T>::InterfaceType *operator*() throw() {
     61                     return ComPtrRefBase<T>::ptr_->Get();
     62                 }
     63 
     64                 typename ComPtrRefBase<T>::InterfaceType *const *GetAddressOf() const throw() {
     65                     return ComPtrRefBase<T>::ptr_->GetAddressOf();
     66                 }
     67 
     68                 typename ComPtrRefBase<T>::InterfaceType **ReleaseAndGetAddressOf() throw() {
     69                     return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
     70                 }
     71             };
     72 
     73         }
     74 
     75         template<typename T> class ComPtr {
     76         public:
     77             typedef T InterfaceType;
     78 
     79             ComPtr() throw() : ptr_(nullptr) {}
     80             ComPtr(decltype(nullptr)) throw() : ptr_(nullptr) {}
     81 
     82             template<class U> ComPtr(U *other) throw() : ptr_(other) {
     83                 InternalAddRef();
     84             }
     85 
     86             ComPtr(const ComPtr &other) throw() : ptr_(other.ptr_) {
     87                 InternalAddRef();
     88             }
     89 
     90             template<class U>
     91             ComPtr(const ComPtr<U> &other) throw() : ptr_(other.ptr_) {
     92                 InternalAddRef();
     93             }
     94 
     95             ComPtr(ComPtr &&other) throw() : ptr_(nullptr) {
     96                 if(this != reinterpret_cast<ComPtr*>(&reinterpret_cast<unsigned char&>(other)))
     97                     Swap(other);
     98             }
     99 
    100             template<class U>
    101             ComPtr(ComPtr<U>&& other) throw() : ptr_(other.ptr_) {
    102                 other.ptr_ = nullptr;
    103             }
    104 
    105             ~ComPtr() throw() {
    106                 InternalRelease();
    107             }
    108 
    109             ComPtr &operator=(decltype(nullptr)) throw() {
    110                 InternalRelease();
    111                 return *this;
    112             }
    113 
    114             ComPtr &operator=(InterfaceType *other) throw() {
    115                 if (ptr_ != other) {
    116                     InternalRelease();
    117                     ptr_ = other;
    118                     InternalAddRef();
    119                 }
    120                 return *this;
    121             }
    122 
    123             template<typename U>
    124             ComPtr &operator=(U *other) throw()  {
    125                 if (ptr_ != other) {
    126                     InternalRelease();
    127                     ptr_ = other;
    128                     InternalAddRef();
    129                 }
    130                 return *this;
    131             }
    132 
    133             ComPtr& operator=(const ComPtr &other) throw() {
    134                 if (ptr_ != other.ptr_)
    135                     ComPtr(other).Swap(*this);
    136                 return *this;
    137             }
    138 
    139             template<class U>
    140             ComPtr &operator=(const ComPtr<U> &other) throw() {
    141                 ComPtr(other).Swap(*this);
    142                 return *this;
    143             }
    144 
    145             ComPtr& operator=(ComPtr &&other) throw() {
    146                 ComPtr(other).Swap(*this);
    147                 return *this;
    148             }
    149 
    150             template<class U>
    151             ComPtr& operator=(ComPtr<U> &&other) throw() {
    152                 ComPtr(other).Swap(*this);
    153                 return *this;
    154             }
    155 
    156             void Swap(ComPtr &&r) throw() {
    157                 InterfaceType *tmp = ptr_;
    158                 ptr_ = r.ptr_;
    159                 r.ptr_ = tmp;
    160             }
    161 
    162             void Swap(ComPtr &r) throw() {
    163                 InterfaceType *tmp = ptr_;
    164                 ptr_ = r.ptr_;
    165                 r.ptr_ = tmp;
    166             }
    167 
    168             operator Details::BoolType() const throw() {
    169                 return Get() != nullptr ? &Details::BoolStruct::Member : nullptr;
    170             }
    171 
    172             InterfaceType *Get() const throw()  {
    173                 return ptr_;
    174             }
    175 
    176             InterfaceType *operator->() const throw() {
    177                 return ptr_;
    178             }
    179 
    180             Details::ComPtrRef<ComPtr<T>> operator&() throw()  {
    181                 return Details::ComPtrRef<ComPtr<T>>(this);
    182             }
    183 
    184             const Details::ComPtrRef<const ComPtr<T>> operator&() const throw() {
    185                 return Details::ComPtrRef<const ComPtr<T>>(this);
    186             }
    187 
    188             InterfaceType *const *GetAddressOf() const throw() {
    189                 return &ptr_;
    190             }
    191 
    192             InterfaceType **GetAddressOf() throw() {
    193                 return &ptr_;
    194             }
    195 
    196             InterfaceType **ReleaseAndGetAddressOf() throw() {
    197                 InternalRelease();
    198                 return &ptr_;
    199             }
    200 
    201             InterfaceType *Detach() throw() {
    202                 T* ptr = ptr_;
    203                 ptr_ = nullptr;
    204                 return ptr;
    205             }
    206 
    207             void Attach(InterfaceType *other) throw() {
    208                 if (ptr_ != other) {
    209                     InternalRelease();
    210                     ptr_ = other;
    211                     InternalAddRef();
    212                 }
    213             }
    214 
    215             unsigned long Reset() {
    216                 return InternalRelease();
    217             }
    218 
    219             HRESULT CopyTo(InterfaceType **ptr) const throw() {
    220                 InternalAddRef();
    221                 *ptr = ptr_;
    222                 return S_OK;
    223             }
    224 
    225             HRESULT CopyTo(REFIID riid, void **ptr) const throw() {
    226                 return ptr_->QueryInterface(riid, ptr);
    227             }
    228 
    229             template<typename U>
    230             HRESULT CopyTo(U **ptr) const throw() {
    231                 return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(ptr));
    232             }
    233 
    234             template<typename U>
    235             HRESULT As(Details::ComPtrRef<ComPtr<U>> p) const throw() {
    236                 return ptr_->QueryInterface(__uuidof(U), p);
    237             }
    238 
    239             template<typename U>
    240             HRESULT As(ComPtr<U> *p) const throw() {
    241                 return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
    242             }
    243 
    244             HRESULT AsIID(REFIID riid, ComPtr<IUnknown> *p) const throw() {
    245                 return ptr_->QueryInterface(riid, reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
    246             }
    247 
    248             /*
    249             HRESULT AsWeak(WeakRef *pWeakRef) const throw() {
    250                 return ::Microsoft::WRL::AsWeak(ptr_, pWeakRef);
    251             }
    252             */
    253         protected:
    254             InterfaceType *ptr_;
    255 
    256             void InternalAddRef() const throw() {
    257                 if(ptr_)
    258                     ptr_->AddRef();
    259             }
    260 
    261             unsigned long InternalRelease() throw() {
    262                 InterfaceType *tmp = ptr_;
    263                 if(!tmp)
    264                     return 0;
    265                 ptr_ = nullptr;
    266                 return tmp->Release();
    267             }
    268         };
    269     }
    270 }
    271 
    272 template<typename T>
    273 void **IID_PPV_ARGS_Helper(::Microsoft::WRL::Details::ComPtrRef<T> pp) throw() {
    274     static_assert(__is_base_of(IUnknown, typename T::InterfaceType), "Expected COM interface");
    275     return pp;
    276 }
    277 
    278 namespace Windows {
    279     namespace Foundation {
    280         template<typename T>
    281         inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
    282             return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
    283         }
    284 
    285         template<typename T>
    286         inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
    287             return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
    288         }
    289     }
    290 }
    291 
    292 namespace ABI {
    293     namespace Windows {
    294         namespace Foundation {
    295             template<typename T>
    296             inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
    297                 return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
    298             }
    299 
    300             template<typename T>
    301             inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
    302                 return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
    303             }
    304         }
    305     }
    306 }
    307 
    308 #endif
    309