Home | History | Annotate | Download | only in win
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef BASE_WIN_SCOPED_COMPTR_H_
      6 #define BASE_WIN_SCOPED_COMPTR_H_
      7 #pragma once
      8 
      9 #include <unknwn.h>
     10 
     11 #include "base/logging.h"
     12 #include "base/memory/ref_counted.h"
     13 
     14 namespace base {
     15 namespace win {
     16 
     17 // A fairly minimalistic smart class for COM interface pointers.
     18 // Uses scoped_refptr for the basic smart pointer functionality
     19 // and adds a few IUnknown specific services.
     20 template <class Interface, const IID* interface_id = &__uuidof(Interface)>
     21 class ScopedComPtr : public scoped_refptr<Interface> {
     22  public:
     23   // Utility template to prevent users of ScopedComPtr from calling AddRef
     24   // and/or Release() without going through the ScopedComPtr class.
     25   class BlockIUnknownMethods : public Interface {
     26    private:
     27     STDMETHOD(QueryInterface)(REFIID iid, void** object) = 0;
     28     STDMETHOD_(ULONG, AddRef)() = 0;
     29     STDMETHOD_(ULONG, Release)() = 0;
     30   };
     31 
     32   typedef scoped_refptr<Interface> ParentClass;
     33 
     34   ScopedComPtr() {
     35   }
     36 
     37   explicit ScopedComPtr(Interface* p) : ParentClass(p) {
     38   }
     39 
     40   ScopedComPtr(const ScopedComPtr<Interface, interface_id>& p)
     41       : ParentClass(p) {
     42   }
     43 
     44   ~ScopedComPtr() {
     45     // We don't want the smart pointer class to be bigger than the pointer
     46     // it wraps.
     47     COMPILE_ASSERT(sizeof(ScopedComPtr<Interface, interface_id>) ==
     48                    sizeof(Interface*), ScopedComPtrSize);
     49   }
     50 
     51   // Explicit Release() of the held object.  Useful for reuse of the
     52   // ScopedComPtr instance.
     53   // Note that this function equates to IUnknown::Release and should not
     54   // be confused with e.g. scoped_ptr::release().
     55   void Release() {
     56     if (ptr_ != NULL) {
     57       ptr_->Release();
     58       ptr_ = NULL;
     59     }
     60   }
     61 
     62   // Sets the internal pointer to NULL and returns the held object without
     63   // releasing the reference.
     64   Interface* Detach() {
     65     Interface* p = ptr_;
     66     ptr_ = NULL;
     67     return p;
     68   }
     69 
     70   // Accepts an interface pointer that has already been addref-ed.
     71   void Attach(Interface* p) {
     72     DCHECK(!ptr_);
     73     ptr_ = p;
     74   }
     75 
     76   // Retrieves the pointer address.
     77   // Used to receive object pointers as out arguments (and take ownership).
     78   // The function DCHECKs on the current value being NULL.
     79   // Usage: Foo(p.Receive());
     80   Interface** Receive() {
     81     DCHECK(!ptr_) << "Object leak. Pointer must be NULL";
     82     return &ptr_;
     83   }
     84 
     85   // A convenience for whenever a void pointer is needed as an out argument.
     86   void** ReceiveVoid() {
     87     return reinterpret_cast<void**>(Receive());
     88   }
     89 
     90   template <class Query>
     91   HRESULT QueryInterface(Query** p) {
     92     DCHECK(p != NULL);
     93     DCHECK(ptr_ != NULL);
     94     // IUnknown already has a template version of QueryInterface
     95     // so the iid parameter is implicit here. The only thing this
     96     // function adds are the DCHECKs.
     97     return ptr_->QueryInterface(p);
     98   }
     99 
    100   // QI for times when the IID is not associated with the type.
    101   HRESULT QueryInterface(const IID& iid, void** obj) {
    102     DCHECK(obj != NULL);
    103     DCHECK(ptr_ != NULL);
    104     return ptr_->QueryInterface(iid, obj);
    105   }
    106 
    107   // Queries |other| for the interface this object wraps and returns the
    108   // error code from the other->QueryInterface operation.
    109   HRESULT QueryFrom(IUnknown* object) {
    110     DCHECK(object != NULL);
    111     return object->QueryInterface(Receive());
    112   }
    113 
    114   // Convenience wrapper around CoCreateInstance
    115   HRESULT CreateInstance(const CLSID& clsid, IUnknown* outer = NULL,
    116                          DWORD context = CLSCTX_ALL) {
    117     DCHECK(!ptr_);
    118     HRESULT hr = ::CoCreateInstance(clsid, outer, context, *interface_id,
    119                                     reinterpret_cast<void**>(&ptr_));
    120     return hr;
    121   }
    122 
    123   // Checks if the identity of |other| and this object is the same.
    124   bool IsSameObject(IUnknown* other) {
    125     if (!other && !ptr_)
    126       return true;
    127 
    128     if (!other || !ptr_)
    129       return false;
    130 
    131     ScopedComPtr<IUnknown> my_identity;
    132     QueryInterface(my_identity.Receive());
    133 
    134     ScopedComPtr<IUnknown> other_identity;
    135     other->QueryInterface(other_identity.Receive());
    136 
    137     return static_cast<IUnknown*>(my_identity) ==
    138            static_cast<IUnknown*>(other_identity);
    139   }
    140 
    141   // Provides direct access to the interface.
    142   // Here we use a well known trick to make sure we block access to
    143   // IUknown methods so that something bad like this doesn't happen:
    144   //    ScopedComPtr<IUnknown> p(Foo());
    145   //    p->Release();
    146   //    ... later the destructor runs, which will Release() again.
    147   // and to get the benefit of the DCHECKs we add to QueryInterface.
    148   // There's still a way to call these methods if you absolutely must
    149   // by statically casting the ScopedComPtr instance to the wrapped interface
    150   // and then making the call... but generally that shouldn't be necessary.
    151   BlockIUnknownMethods* operator->() const {
    152     DCHECK(ptr_ != NULL);
    153     return reinterpret_cast<BlockIUnknownMethods*>(ptr_);
    154   }
    155 
    156   // Pull in operator=() from the parent class.
    157   using scoped_refptr<Interface>::operator=;
    158 
    159   // static methods
    160 
    161   static const IID& iid() {
    162     return *interface_id;
    163   }
    164 };
    165 
    166 }  // namespace win
    167 }  // namespace base
    168 
    169 #endif  // BASE_WIN_SCOPED_COMPTR_H_
    170