Home | History | Annotate | Download | only in win
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef BASE_WIN_SCOPED_COMPTR_H_
      6 #define BASE_WIN_SCOPED_COMPTR_H_
      7 
      8 #include <unknwn.h>
      9 
     10 #include "base/logging.h"
     11 #include "base/memory/ref_counted.h"
     12 
     13 namespace base {
     14 namespace win {
     15 
     16 // A fairly minimalistic smart class for COM interface pointers.
     17 // Uses scoped_refptr for the basic smart pointer functionality
     18 // and adds a few IUnknown specific services.
     19 template <class Interface, const IID* interface_id = &__uuidof(Interface)>
     20 class ScopedComPtr : public scoped_refptr<Interface> {
     21  public:
     22   // Utility template to prevent users of ScopedComPtr from calling AddRef
     23   // and/or Release() without going through the ScopedComPtr class.
     24   class BlockIUnknownMethods : public Interface {
     25    private:
     26     STDMETHOD(QueryInterface)(REFIID iid, void** object) = 0;
     27     STDMETHOD_(ULONG, AddRef)() = 0;
     28     STDMETHOD_(ULONG, Release)() = 0;
     29   };
     30 
     31   typedef scoped_refptr<Interface> ParentClass;
     32 
     33   ScopedComPtr() {
     34   }
     35 
     36   explicit ScopedComPtr(Interface* p) : ParentClass(p) {
     37   }
     38 
     39   ScopedComPtr(const ScopedComPtr<Interface, interface_id>& p)
     40       : ParentClass(p) {
     41   }
     42 
     43   ~ScopedComPtr() {
     44     // We don't want the smart pointer class to be bigger than the pointer
     45     // it wraps.
     46     COMPILE_ASSERT(sizeof(ScopedComPtr<Interface, interface_id>) ==
     47                    sizeof(Interface*), ScopedComPtrSize);
     48   }
     49 
     50   // Explicit Release() of the held object.  Useful for reuse of the
     51   // ScopedComPtr instance.
     52   // Note that this function equates to IUnknown::Release and should not
     53   // be confused with e.g. scoped_ptr::release().
     54   void Release() {
     55     if (ptr_ != NULL) {
     56       ptr_->Release();
     57       ptr_ = NULL;
     58     }
     59   }
     60 
     61   // Sets the internal pointer to NULL and returns the held object without
     62   // releasing the reference.
     63   Interface* Detach() {
     64     Interface* p = ptr_;
     65     ptr_ = NULL;
     66     return p;
     67   }
     68 
     69   // Accepts an interface pointer that has already been addref-ed.
     70   void Attach(Interface* p) {
     71     DCHECK(!ptr_);
     72     ptr_ = p;
     73   }
     74 
     75   // Retrieves the pointer address.
     76   // Used to receive object pointers as out arguments (and take ownership).
     77   // The function DCHECKs on the current value being NULL.
     78   // Usage: Foo(p.Receive());
     79   Interface** Receive() {
     80     DCHECK(!ptr_) << "Object leak. Pointer must be NULL";
     81     return &ptr_;
     82   }
     83 
     84   // A convenience for whenever a void pointer is needed as an out argument.
     85   void** ReceiveVoid() {
     86     return reinterpret_cast<void**>(Receive());
     87   }
     88 
     89   template <class Query>
     90   HRESULT QueryInterface(Query** p) {
     91     DCHECK(p != NULL);
     92     DCHECK(ptr_ != NULL);
     93     // IUnknown already has a template version of QueryInterface
     94     // so the iid parameter is implicit here. The only thing this
     95     // function adds are the DCHECKs.
     96     return ptr_->QueryInterface(p);
     97   }
     98 
     99   // QI for times when the IID is not associated with the type.
    100   HRESULT QueryInterface(const IID& iid, void** obj) {
    101     DCHECK(obj != NULL);
    102     DCHECK(ptr_ != NULL);
    103     return ptr_->QueryInterface(iid, obj);
    104   }
    105 
    106   // Queries |other| for the interface this object wraps and returns the
    107   // error code from the other->QueryInterface operation.
    108   HRESULT QueryFrom(IUnknown* object) {
    109     DCHECK(object != NULL);
    110     return object->QueryInterface(Receive());
    111   }
    112 
    113   // Convenience wrapper around CoCreateInstance
    114   HRESULT CreateInstance(const CLSID& clsid, IUnknown* outer = NULL,
    115                          DWORD context = CLSCTX_ALL) {
    116     DCHECK(!ptr_);
    117     HRESULT hr = ::CoCreateInstance(clsid, outer, context, *interface_id,
    118                                     reinterpret_cast<void**>(&ptr_));
    119     return hr;
    120   }
    121 
    122   // Checks if the identity of |other| and this object is the same.
    123   bool IsSameObject(IUnknown* other) {
    124     if (!other && !ptr_)
    125       return true;
    126 
    127     if (!other || !ptr_)
    128       return false;
    129 
    130     ScopedComPtr<IUnknown> my_identity;
    131     QueryInterface(my_identity.Receive());
    132 
    133     ScopedComPtr<IUnknown> other_identity;
    134     other->QueryInterface(other_identity.Receive());
    135 
    136     return static_cast<IUnknown*>(my_identity) ==
    137            static_cast<IUnknown*>(other_identity);
    138   }
    139 
    140   // Provides direct access to the interface.
    141   // Here we use a well known trick to make sure we block access to
    142   // IUnknown methods so that something bad like this doesn't happen:
    143   //    ScopedComPtr<IUnknown> p(Foo());
    144   //    p->Release();
    145   //    ... later the destructor runs, which will Release() again.
    146   // and to get the benefit of the DCHECKs we add to QueryInterface.
    147   // There's still a way to call these methods if you absolutely must
    148   // by statically casting the ScopedComPtr instance to the wrapped interface
    149   // and then making the call... but generally that shouldn't be necessary.
    150   BlockIUnknownMethods* operator->() const {
    151     DCHECK(ptr_ != NULL);
    152     return reinterpret_cast<BlockIUnknownMethods*>(ptr_);
    153   }
    154 
    155   // Pull in operator=() from the parent class.
    156   using scoped_refptr<Interface>::operator=;
    157 
    158   // static methods
    159 
    160   static const IID& iid() {
    161     return *interface_id;
    162   }
    163 };
    164 
    165 }  // namespace win
    166 }  // namespace base
    167 
    168 #endif  // BASE_WIN_SCOPED_COMPTR_H_
    169