Home | History | Annotate | Download | only in src
      1 //
      2 // Copyright (C) Microsoft Corporation
      3 // All rights reserved.
      4 // Modified for native C++ WRL support by Gregory Morse
      5 //
      6 // Code in Details namespace is for internal usage within the library code
      7 //
      8 
      9 #ifndef _PLATFORM_AGILE_H_
     10 #define _PLATFORM_AGILE_H_
     11 
     12 #ifdef _MSC_VER
     13 #pragma once
     14 #endif  // _MSC_VER
     15 
     16 #include <algorithm>
     17 #include <wrl\client.h>
     18 
     19 template <typename T, bool TIsNotAgile> class Agile;
     20 
     21 template <typename T>
     22 struct UnwrapAgile
     23 {
     24     static const bool _IsAgile = false;
     25 };
     26 template <typename T>
     27 struct UnwrapAgile<Agile<T, false>>
     28 {
     29     static const bool _IsAgile = true;
     30 };
     31 template <typename T>
     32 struct UnwrapAgile<Agile<T, true>>
     33 {
     34     static const bool _IsAgile = true;
     35 };
     36 
     37 #define IS_AGILE(T) UnwrapAgile<T>::_IsAgile
     38 
     39 #define __is_winrt_agile(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::FtmBase, T>::value || std::is_base_of<IAgileObject, T>::value) //derived from Microsoft::WRL::FtmBase or IAgileObject
     40 
     41 #define __is_win_interface(T) (std::is_base_of<IUnknown, T>::value || std::is_base_of<IInspectable, T>::value) //derived from IUnknown or IInspectable
     42 
     43 #define __is_win_class(T) (std::is_same<T, HSTRING__>::value || std::is_base_of<Microsoft::WRL::Details::RuntimeClassBase, T>::value) //derived from Microsoft::WRL::RuntimeClass or HSTRING
     44 
     45     namespace Details
     46     {
     47         IUnknown* __stdcall GetObjectContext();
     48         HRESULT __stdcall GetProxyImpl(IUnknown*, REFIID, IUnknown*, IUnknown**);
     49         HRESULT __stdcall ReleaseInContextImpl(IUnknown*, IUnknown*);
     50 
     51         template <typename T>
     52 #if _MSC_VER >= 1800
     53         __declspec(no_refcount) inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy)
     54 #else
     55         inline HRESULT GetProxy(T *ObjectIn, IUnknown *ContextCallBack, T **Proxy)
     56 #endif
     57         {
     58 #if _MSC_VER >= 1800
     59             return GetProxyImpl(*reinterpret_cast<IUnknown**>(&ObjectIn), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy));
     60 #else
     61             return GetProxyImpl(*reinterpret_cast<IUnknown**>(&const_cast<T*>(ObjectIn)), __uuidof(T*), ContextCallBack, reinterpret_cast<IUnknown**>(Proxy));
     62 #endif
     63         }
     64 
     65         template <typename T>
     66         inline HRESULT ReleaseInContext(T *ObjectIn, IUnknown *ContextCallBack)
     67         {
     68             return ReleaseInContextImpl(ObjectIn, ContextCallBack);
     69         }
     70 
     71         template <typename T>
     72         class AgileHelper
     73         {
     74             __abi_IUnknown* _p;
     75             bool _release;
     76         public:
     77             AgileHelper(__abi_IUnknown* p, bool release = true) : _p(p), _release(release)
     78             {
     79             }
     80             AgileHelper(AgileHelper&& other) : _p(other._p), _release(other._release)
     81             {
     82                 _other._p = nullptr;
     83                 _other._release = true;
     84             }
     85             AgileHelper operator=(AgileHelper&& other)
     86             {
     87                 _p = other._p;
     88                 _release = other._release;
     89                 _other._p = nullptr;
     90                 _other._release = true;
     91                 return *this;
     92             }
     93 
     94             ~AgileHelper()
     95             {
     96                 if (_release && _p)
     97                 {
     98                     _p->__abi_Release();
     99                 }
    100             }
    101 
    102             __declspec(no_refcount) __declspec(no_release_return)
    103                 T* operator->()
    104             {
    105                     return reinterpret_cast<T*>(_p);
    106             }
    107 
    108             __declspec(no_refcount) __declspec(no_release_return)
    109                 operator T * ()
    110             {
    111                     return reinterpret_cast<T*>(_p);
    112             }
    113         private:
    114             AgileHelper(const AgileHelper&);
    115             AgileHelper operator=(const AgileHelper&);
    116         };
    117         template <typename T>
    118         struct __remove_hat
    119         {
    120             typedef T type;
    121         };
    122         template <typename T>
    123         struct __remove_hat<T*>
    124         {
    125             typedef T type;
    126         };
    127         template <typename T>
    128         struct AgileTypeHelper
    129         {
    130             typename typedef __remove_hat<T>::type type;
    131             typename typedef __remove_hat<T>::type* agileMemberType;
    132         };
    133     } // namespace Details
    134 
    135 #pragma warning(push)
    136 #pragma warning(disable: 4451) // Usage of ref class inside this context can lead to invalid marshaling of object across contexts
    137 
    138     template <
    139         typename T,
    140         bool TIsNotAgile = (__is_win_class(typename Details::AgileTypeHelper<T>::type) && !__is_winrt_agile(typename Details::AgileTypeHelper<T>::type)) ||
    141         __is_win_interface(typename Details::AgileTypeHelper<T>::type)
    142     >
    143     class Agile
    144     {
    145         static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types");
    146         typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT;
    147         TypeT _object;
    148         ::Microsoft::WRL::ComPtr<IUnknown> _contextCallback;
    149         ULONG_PTR _contextToken;
    150 
    151 #if _MSC_VER >= 1800
    152         enum class AgileState
    153         {
    154             NonAgilePointer = 0,
    155             AgilePointer = 1,
    156             Unknown = 2
    157         };
    158         AgileState _agileState;
    159 #endif
    160 
    161         void CaptureContext()
    162         {
    163             _contextCallback = Details::GetObjectContext();
    164             __abi_ThrowIfFailed(CoGetContextToken(&_contextToken));
    165         }
    166 
    167         void SetObject(TypeT object)
    168         {
    169             // Capture context before setting the pointer
    170             // If context capture fails then nothing to cleanup
    171             Release();
    172             if (object != nullptr)
    173             {
    174                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
    175                 HRESULT hr = reinterpret_cast<IUnknown*>(object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
    176                 // Don't Capture context if object is agile
    177                 if (hr != S_OK)
    178                 {
    179 #if _MSC_VER >= 1800
    180                     _agileState = AgileState::NonAgilePointer;
    181 #endif
    182                     CaptureContext();
    183                 }
    184 #if _MSC_VER >= 1800
    185                 else
    186                 {
    187                     _agileState = AgileState::AgilePointer;
    188                 }
    189 #endif
    190             }
    191             _object = object;
    192         }
    193 
    194     public:
    195         Agile() throw() : _object(nullptr), _contextToken(0)
    196 #if _MSC_VER >= 1800
    197             , _agileState(AgileState::Unknown)
    198 #endif
    199         {
    200         }
    201 
    202         Agile(nullptr_t) throw() : _object(nullptr), _contextToken(0)
    203 #if _MSC_VER >= 1800
    204             , _agileState(AgileState::Unknown)
    205 #endif
    206         {
    207         }
    208 
    209         explicit Agile(TypeT object) throw() : _object(nullptr), _contextToken(0)
    210 #if _MSC_VER >= 1800
    211             , _agileState(AgileState::Unknown)
    212 #endif
    213         {
    214             // Assumes that the source object is from the current context
    215             SetObject(object);
    216         }
    217 
    218         Agile(const Agile& object) throw() : _object(nullptr), _contextToken(0)
    219 #if _MSC_VER >= 1800
    220             , _agileState(AgileState::Unknown)
    221 #endif
    222         {
    223             // Get returns pointer valid for current context
    224             SetObject(object.Get());
    225         }
    226 
    227         Agile(Agile&& object) throw() : _object(nullptr), _contextToken(0)
    228 #if _MSC_VER >= 1800
    229             , _agileState(AgileState::Unknown)
    230 #endif
    231         {
    232             // Assumes that the source object is from the current context
    233             Swap(object);
    234         }
    235 
    236         ~Agile() throw()
    237         {
    238             Release();
    239         }
    240 
    241         TypeT Get() const
    242         {
    243             // Agile object, no proxy required
    244 #if _MSC_VER >= 1800
    245             if (_agileState == AgileState::AgilePointer || _object == nullptr)
    246 #else
    247             if (_contextToken == 0 || _contextCallback == nullptr || _object == nullptr)
    248 #endif
    249             {
    250                 return _object;
    251             }
    252 
    253             // Do the check for same context
    254             ULONG_PTR currentContextToken;
    255             __abi_ThrowIfFailed(CoGetContextToken(&currentContextToken));
    256             if (currentContextToken == _contextToken)
    257             {
    258                 return _object;
    259             }
    260 
    261 #if _MSC_VER >= 1800
    262             // Different context and holding on to a non agile object
    263             // Do the costly work of getting a proxy
    264             TypeT localObject;
    265             __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject));
    266 
    267             if (_agileState == AgileState::Unknown)
    268 #else
    269             // Object is agile if it implements IAgileObject
    270             // GetAddressOf captures the context with out knowing the type of object that it will hold
    271             if (_object != nullptr)
    272 #endif
    273             {
    274 #if _MSC_VER >= 1800
    275                 // Object is agile if it implements IAgileObject
    276                 // GetAddressOf captures the context with out knowing the type of object that it will hold
    277                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
    278                 HRESULT hr = reinterpret_cast<IUnknown*>(localObject)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
    279 #else
    280                 ::Microsoft::WRL::ComPtr<IAgileObject> checkIfAgile;
    281                 HRESULT hr = reinterpret_cast<IUnknown*>(_object)->QueryInterface(__uuidof(IAgileObject), &checkIfAgile);
    282 #endif
    283                 if (hr == S_OK)
    284                 {
    285                     auto pThis = const_cast<Agile*>(this);
    286 #if _MSC_VER >= 1800
    287                     pThis->_agileState = AgileState::AgilePointer;
    288 #endif
    289                     pThis->_contextToken = 0;
    290                     pThis->_contextCallback = nullptr;
    291                     return _object;
    292                 }
    293 #if _MSC_VER >= 1800
    294                 else
    295                 {
    296                     auto pThis = const_cast<Agile*>(this);
    297                     pThis->_agileState = AgileState::NonAgilePointer;
    298                 }
    299 #endif
    300             }
    301 
    302 #if _MSC_VER < 1800
    303             // Different context and holding on to a non agile object
    304             // Do the costly work of getting a proxy
    305             TypeT localObject;
    306             __abi_ThrowIfFailed(Details::GetProxy(_object, _contextCallback.Get(), &localObject));
    307 #endif
    308             return localObject;
    309         }
    310 
    311         TypeT* GetAddressOf() throw()
    312         {
    313             Release();
    314             CaptureContext();
    315             return &_object;
    316         }
    317 
    318         TypeT* GetAddressOfForInOut() throw()
    319         {
    320             CaptureContext();
    321             return &_object;
    322         }
    323 
    324         TypeT operator->() const throw()
    325         {
    326             return Get();
    327         }
    328 
    329         Agile& operator=(nullptr_t) throw()
    330         {
    331             Release();
    332             return *this;
    333         }
    334 
    335         Agile& operator=(TypeT object) throw()
    336         {
    337             Agile(object).Swap(*this);
    338             return *this;
    339         }
    340 
    341         Agile& operator=(Agile object) throw()
    342         {
    343             // parameter is by copy which gets pointer valid for current context
    344             object.Swap(*this);
    345             return *this;
    346         }
    347 
    348 #if _MSC_VER < 1800
    349         Agile& operator=(IUnknown* lp) throw()
    350         {
    351             // bump ref count
    352             ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp);
    353 
    354             // put it into Platform Object
    355             Platform::Object object;
    356             *(IUnknown**)(&object) = spObject.Detach();
    357 
    358             SetObject(object);
    359             return *this;
    360         }
    361 #endif
    362 
    363         void Swap(Agile& object)
    364         {
    365             std::swap(_object, object._object);
    366             std::swap(_contextCallback, object._contextCallback);
    367             std::swap(_contextToken, object._contextToken);
    368 #if _MSC_VER >= 1800
    369             std::swap(_agileState, object._agileState);
    370 #endif
    371         }
    372 
    373         // Release the interface and set to NULL
    374         void Release() throw()
    375         {
    376             if (_object)
    377             {
    378                 // Cast to IInspectable (no QI)
    379                 IUnknown* pObject = *(IUnknown**)(&_object);
    380                 // Set * to null without release
    381                 *(IUnknown**)(&_object) = nullptr;
    382 
    383                 ULONG_PTR currentContextToken;
    384                 __abi_ThrowIfFailed(CoGetContextToken(&currentContextToken));
    385                 if (_contextToken == 0 || _contextCallback == nullptr || _contextToken == currentContextToken)
    386                 {
    387                     pObject->Release();
    388                 }
    389                 else
    390                 {
    391                     Details::ReleaseInContext(pObject, _contextCallback.Get());
    392                 }
    393                 _contextCallback = nullptr;
    394                 _contextToken = 0;
    395 #if _MSC_VER >= 1800
    396                 _agileState = AgileState::Unknown;
    397 #endif
    398             }
    399         }
    400 
    401         bool operator==(nullptr_t) const throw()
    402         {
    403             return _object == nullptr;
    404         }
    405 
    406         bool operator==(const Agile& other) const throw()
    407         {
    408             return _object == other._object && _contextToken == other._contextToken;
    409         }
    410 
    411         bool operator<(const Agile& other) const throw()
    412         {
    413             if (reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object))
    414             {
    415                 return true;
    416             }
    417 
    418             return _object == other._object && _contextToken < other._contextToken;
    419         }
    420     };
    421 
    422     template <typename T>
    423     class Agile<T, false>
    424     {
    425         static_assert(__is_win_class(typename Details::AgileTypeHelper<T>::type) || __is_win_interface(typename Details::AgileTypeHelper<T>::type), "Agile can only be used with ref class or interface class types");
    426         typename typedef Details::AgileTypeHelper<T>::agileMemberType TypeT;
    427         TypeT _object;
    428 
    429     public:
    430         Agile() throw() : _object(nullptr)
    431         {
    432         }
    433 
    434         Agile(nullptr_t) throw() : _object(nullptr)
    435         {
    436         }
    437 
    438         explicit Agile(TypeT object) throw() : _object(object)
    439         {
    440         }
    441 
    442         Agile(const Agile& object) throw() : _object(object._object)
    443         {
    444         }
    445 
    446         Agile(Agile&& object) throw() : _object(nullptr)
    447         {
    448             Swap(object);
    449         }
    450 
    451         ~Agile() throw()
    452         {
    453             Release();
    454         }
    455 
    456         TypeT Get() const
    457         {
    458             return _object;
    459         }
    460 
    461         TypeT* GetAddressOf() throw()
    462         {
    463             Release();
    464             return &_object;
    465         }
    466 
    467         TypeT* GetAddressOfForInOut() throw()
    468         {
    469             return &_object;
    470         }
    471 
    472         TypeT operator->() const throw()
    473         {
    474             return Get();
    475         }
    476 
    477         Agile& operator=(nullptr_t) throw()
    478         {
    479             Release();
    480             return *this;
    481         }
    482 
    483         Agile& operator=(TypeT object) throw()
    484         {
    485             if (_object != object)
    486             {
    487                 _object = object;
    488             }
    489             return *this;
    490         }
    491 
    492         Agile& operator=(Agile object) throw()
    493         {
    494             object.Swap(*this);
    495             return *this;
    496         }
    497 
    498 #if _MSC_VER < 1800
    499         Agile& operator=(IUnknown* lp) throw()
    500         {
    501             Release();
    502             // bump ref count
    503             ::Microsoft::WRL::ComPtr<IUnknown> spObject(lp);
    504 
    505             // put it into Platform Object
    506             Platform::Object object;
    507             *(IUnknown**)(&object) = spObject.Detach();
    508 
    509             _object = object;
    510             return *this;
    511         }
    512 #endif
    513 
    514         // Release the interface and set to NULL
    515         void Release() throw()
    516         {
    517             _object = nullptr;
    518         }
    519 
    520         void Swap(Agile& object)
    521         {
    522             std::swap(_object, object._object);
    523         }
    524 
    525         bool operator==(nullptr_t) const throw()
    526         {
    527             return _object == nullptr;
    528         }
    529 
    530         bool operator==(const Agile& other) const throw()
    531         {
    532             return _object == other._object;
    533         }
    534 
    535         bool operator<(const Agile& other) const throw()
    536         {
    537             return reinterpret_cast<void*>(_object) < reinterpret_cast<void*>(other._object);
    538         }
    539     };
    540 
    541 #pragma warning(pop)
    542 
    543     template<class U>
    544     bool operator==(nullptr_t, const Agile<U>& a) throw()
    545     {
    546         return a == nullptr;
    547     }
    548 
    549     template<class U>
    550     bool operator!=(const Agile<U>& a, nullptr_t) throw()
    551     {
    552         return !(a == nullptr);
    553     }
    554 
    555     template<class U>
    556     bool operator!=(nullptr_t, const Agile<U>& a) throw()
    557     {
    558         return !(a == nullptr);
    559     }
    560 
    561     template<class U>
    562     bool operator!=(const Agile<U>& a, const Agile<U>& b) throw()
    563     {
    564         return !(a == b);
    565     }
    566 
    567 
    568 #endif // _PLATFORM_AGILE_H_
    569