Home | History | Annotate | Download | only in bindings
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
      6 #define MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
      7 
      8 #include <string>
      9 #include <utility>
     10 
     11 #include "base/bind.h"
     12 #include "base/callback.h"
     13 #include "base/macros.h"
     14 #include "base/memory/ptr_util.h"
     15 #include "mojo/public/cpp/bindings/binding.h"
     16 #include "mojo/public/cpp/bindings/connection_error_callback.h"
     17 #include "mojo/public/cpp/bindings/interface_ptr.h"
     18 #include "mojo/public/cpp/bindings/interface_request.h"
     19 #include "mojo/public/cpp/bindings/message.h"
     20 
     21 namespace mojo {
     22 
     23 template <typename BindingType>
     24 struct BindingSetTraits;
     25 
     26 template <typename Interface, typename ImplRefTraits>
     27 struct BindingSetTraits<Binding<Interface, ImplRefTraits>> {
     28   using ProxyType = InterfacePtr<Interface>;
     29   using RequestType = InterfaceRequest<Interface>;
     30   using BindingType = Binding<Interface, ImplRefTraits>;
     31   using ImplPointerType = typename BindingType::ImplPointerType;
     32 
     33   static RequestType MakeRequest(ProxyType* proxy) {
     34     return mojo::MakeRequest(proxy);
     35   }
     36 };
     37 
     38 using BindingId = size_t;
     39 
     40 template <typename ContextType>
     41 struct BindingSetContextTraits {
     42   using Type = ContextType;
     43 
     44   static constexpr bool SupportsContext() { return true; }
     45 };
     46 
     47 template <>
     48 struct BindingSetContextTraits<void> {
     49   // NOTE: This choice of Type only matters insofar as it affects the size of
     50   // the |context_| field of a BindingSetBase::Entry with void context. The
     51   // context value is never used in this case.
     52   using Type = bool;
     53 
     54   static constexpr bool SupportsContext() { return false; }
     55 };
     56 
     57 // Generic definition used for BindingSet and AssociatedBindingSet to own a
     58 // collection of bindings which point to the same implementation.
     59 //
     60 // If |ContextType| is non-void, then every added binding must include a context
     61 // value of that type, and |dispatch_context()| will return that value during
     62 // the extent of any message dispatch targeting that specific binding.
     63 template <typename Interface, typename BindingType, typename ContextType>
     64 class BindingSetBase {
     65  public:
     66   using ContextTraits = BindingSetContextTraits<ContextType>;
     67   using Context = typename ContextTraits::Type;
     68   using PreDispatchCallback = base::Callback<void(const Context&)>;
     69   using Traits = BindingSetTraits<BindingType>;
     70   using ProxyType = typename Traits::ProxyType;
     71   using RequestType = typename Traits::RequestType;
     72   using ImplPointerType = typename Traits::ImplPointerType;
     73 
     74   BindingSetBase() : weak_ptr_factory_(this) {}
     75 
     76   void set_connection_error_handler(base::RepeatingClosure error_handler) {
     77     error_handler_ = std::move(error_handler);
     78     error_with_reason_handler_.Reset();
     79   }
     80 
     81   void set_connection_error_with_reason_handler(
     82       RepeatingConnectionErrorWithReasonCallback error_handler) {
     83     error_with_reason_handler_ = std::move(error_handler);
     84     error_handler_.Reset();
     85   }
     86 
     87   // Sets a callback to be invoked immediately before dispatching any message or
     88   // error received by any of the bindings in the set. This may only be used
     89   // with a non-void |ContextType|.
     90   void set_pre_dispatch_handler(const PreDispatchCallback& handler) {
     91     static_assert(ContextTraits::SupportsContext(),
     92                   "Pre-dispatch handler usage requires non-void context type.");
     93     pre_dispatch_handler_ = handler;
     94   }
     95 
     96   // Adds a new binding to the set which binds |request| to |impl| with no
     97   // additional context.
     98   BindingId AddBinding(ImplPointerType impl, RequestType request) {
     99     static_assert(!ContextTraits::SupportsContext(),
    100                   "Context value required for non-void context type.");
    101     return AddBindingImpl(std::move(impl), std::move(request), false);
    102   }
    103 
    104   // Adds a new binding associated with |context|.
    105   BindingId AddBinding(ImplPointerType impl,
    106                        RequestType request,
    107                        Context context) {
    108     static_assert(ContextTraits::SupportsContext(),
    109                   "Context value unsupported for void context type.");
    110     return AddBindingImpl(std::move(impl), std::move(request),
    111                           std::move(context));
    112   }
    113 
    114   // Removes a binding from the set. Note that this is safe to call even if the
    115   // binding corresponding to |id| has already been removed.
    116   //
    117   // Returns |true| if the binding was removed and |false| if it didn't exist.
    118   bool RemoveBinding(BindingId id) {
    119     auto it = bindings_.find(id);
    120     if (it == bindings_.end())
    121       return false;
    122     bindings_.erase(it);
    123     return true;
    124   }
    125 
    126   // Swaps the interface implementation with a different one, to allow tests
    127   // to modify behavior.
    128   //
    129   // Returns the existing interface implementation to the caller.
    130   ImplPointerType SwapImplForTesting(BindingId id, ImplPointerType new_impl) {
    131     auto it = bindings_.find(id);
    132     if (it == bindings_.end())
    133       return nullptr;
    134 
    135     return it->second->SwapImplForTesting(new_impl);
    136   }
    137 
    138   void CloseAllBindings() { bindings_.clear(); }
    139 
    140   bool empty() const { return bindings_.empty(); }
    141 
    142   size_t size() const { return bindings_.size(); }
    143 
    144   // Implementations may call this when processing a dispatched message or
    145   // error. During the extent of message or error dispatch, this will return the
    146   // context associated with the specific binding which received the message or
    147   // error. Use AddBinding() to associated a context with a specific binding.
    148   const Context& dispatch_context() const {
    149     static_assert(ContextTraits::SupportsContext(),
    150                   "dispatch_context() requires non-void context type.");
    151     DCHECK(dispatch_context_);
    152     return *dispatch_context_;
    153   }
    154 
    155   // Implementations may call this when processing a dispatched message or
    156   // error. During the extent of message or error dispatch, this will return the
    157   // BindingId of the specific binding which received the message or error.
    158   BindingId dispatch_binding() const {
    159     DCHECK(dispatch_context_);
    160     return dispatch_binding_;
    161   }
    162 
    163   // Reports the currently dispatching Message as bad and closes the binding the
    164   // message was received from. Note that this is only legal to call from
    165   // directly within the stack frame of a message dispatch. If you need to do
    166   // asynchronous work before you can determine the legitimacy of a message, use
    167   // GetBadMessageCallback() and retain its result until you're ready to invoke
    168   // or discard it.
    169   void ReportBadMessage(const std::string& error) {
    170     GetBadMessageCallback().Run(error);
    171   }
    172 
    173   // Acquires a callback which may be run to report the currently dispatching
    174   // Message as bad and close the binding the message was received from. Note
    175   // that this is only legal to call from directly within the stack frame of a
    176   // message dispatch, but the returned callback may be called exactly once any
    177   // time thereafter as long as the binding set itself hasn't been destroyed yet
    178   // to report the message as bad. This may only be called once per message.
    179   // The returned callback must be called on the BindingSet's own sequence.
    180   ReportBadMessageCallback GetBadMessageCallback() {
    181     DCHECK(dispatch_context_);
    182     return base::BindOnce(
    183         [](ReportBadMessageCallback error_callback,
    184            base::WeakPtr<BindingSetBase> binding_set, BindingId binding_id,
    185            const std::string& error) {
    186           std::move(error_callback).Run(error);
    187           if (binding_set)
    188             binding_set->RemoveBinding(binding_id);
    189         },
    190         mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(),
    191         dispatch_binding());
    192   }
    193 
    194   void FlushForTesting() {
    195     DCHECK(!is_flushing_);
    196     is_flushing_ = true;
    197     for (auto& binding : bindings_)
    198       if (binding.second)
    199         binding.second->FlushForTesting();
    200     is_flushing_ = false;
    201     // Clean up any bindings that were destroyed.
    202     for (auto it = bindings_.begin(); it != bindings_.end();) {
    203       if (!it->second)
    204         it = bindings_.erase(it);
    205       else
    206         ++it;
    207     }
    208   }
    209 
    210  private:
    211   friend class Entry;
    212 
    213   class Entry {
    214    public:
    215     Entry(ImplPointerType impl,
    216           RequestType request,
    217           BindingSetBase* binding_set,
    218           BindingId binding_id,
    219           Context context)
    220         : binding_(std::move(impl), std::move(request)),
    221           binding_set_(binding_set),
    222           binding_id_(binding_id),
    223           context_(std::move(context)) {
    224       binding_.AddFilter(std::make_unique<DispatchFilter>(this));
    225       binding_.set_connection_error_with_reason_handler(
    226           base::BindOnce(&Entry::OnConnectionError, base::Unretained(this)));
    227     }
    228 
    229     void FlushForTesting() { binding_.FlushForTesting(); }
    230 
    231     ImplPointerType SwapImplForTesting(ImplPointerType new_impl) {
    232       return binding_.SwapImplForTesting(new_impl);
    233     }
    234 
    235    private:
    236     class DispatchFilter : public MessageReceiver {
    237      public:
    238       explicit DispatchFilter(Entry* entry) : entry_(entry) {}
    239       ~DispatchFilter() override {}
    240 
    241      private:
    242       // MessageReceiver:
    243       bool Accept(Message* message) override {
    244         entry_->WillDispatch();
    245         return true;
    246       }
    247 
    248       Entry* entry_;
    249 
    250       DISALLOW_COPY_AND_ASSIGN(DispatchFilter);
    251     };
    252 
    253     void WillDispatch() {
    254       binding_set_->SetDispatchContext(&context_, binding_id_);
    255     }
    256 
    257     void OnConnectionError(uint32_t custom_reason,
    258                            const std::string& description) {
    259       WillDispatch();
    260       binding_set_->OnConnectionError(binding_id_, custom_reason, description);
    261     }
    262 
    263     BindingType binding_;
    264     BindingSetBase* const binding_set_;
    265     const BindingId binding_id_;
    266     Context const context_;
    267 
    268     DISALLOW_COPY_AND_ASSIGN(Entry);
    269   };
    270 
    271   void SetDispatchContext(const Context* context, BindingId binding_id) {
    272     dispatch_context_ = context;
    273     dispatch_binding_ = binding_id;
    274     if (!pre_dispatch_handler_.is_null())
    275       pre_dispatch_handler_.Run(*context);
    276   }
    277 
    278   BindingId AddBindingImpl(ImplPointerType impl,
    279                            RequestType request,
    280                            Context context) {
    281     BindingId id = next_binding_id_++;
    282     DCHECK_GE(next_binding_id_, 0u);
    283     auto entry = std::make_unique<Entry>(std::move(impl), std::move(request),
    284                                          this, id, std::move(context));
    285     bindings_.insert(std::make_pair(id, std::move(entry)));
    286     return id;
    287   }
    288 
    289   void OnConnectionError(BindingId id,
    290                          uint32_t custom_reason,
    291                          const std::string& description) {
    292     auto it = bindings_.find(id);
    293     DCHECK(it != bindings_.end());
    294 
    295     // We keep the Entry alive throughout error dispatch.
    296     std::unique_ptr<Entry> entry = std::move(it->second);
    297     if (!is_flushing_)
    298       bindings_.erase(it);
    299 
    300     if (error_handler_) {
    301       error_handler_.Run();
    302     } else if (error_with_reason_handler_) {
    303       error_with_reason_handler_.Run(custom_reason, description);
    304     }
    305   }
    306 
    307   base::RepeatingClosure error_handler_;
    308   RepeatingConnectionErrorWithReasonCallback error_with_reason_handler_;
    309   PreDispatchCallback pre_dispatch_handler_;
    310   BindingId next_binding_id_ = 0;
    311   std::map<BindingId, std::unique_ptr<Entry>> bindings_;
    312   bool is_flushing_ = false;
    313   const Context* dispatch_context_ = nullptr;
    314   BindingId dispatch_binding_;
    315   base::WeakPtrFactory<BindingSetBase> weak_ptr_factory_;
    316 
    317   DISALLOW_COPY_AND_ASSIGN(BindingSetBase);
    318 };
    319 
    320 template <typename Interface, typename ContextType = void>
    321 using BindingSet = BindingSetBase<Interface, Binding<Interface>, ContextType>;
    322 
    323 }  // namespace mojo
    324 
    325 #endif  // MOJO_PUBLIC_CPP_BINDINGS_BINDING_SET_H_
    326