1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/secure_channel_factory.h" 6 7 #include "base/bind.h" 8 #include "net/socket/stream_socket.h" 9 #include "remoting/protocol/authenticator.h" 10 #include "remoting/protocol/channel_authenticator.h" 11 12 namespace remoting { 13 namespace protocol { 14 15 SecureChannelFactory::SecureChannelFactory( 16 StreamChannelFactory* channel_factory, 17 Authenticator* authenticator) 18 : channel_factory_(channel_factory), 19 authenticator_(authenticator) { 20 DCHECK_EQ(authenticator_->state(), Authenticator::ACCEPTED); 21 } 22 23 SecureChannelFactory::~SecureChannelFactory() { 24 // CancelChannelCreation() is expected to be called before destruction. 25 DCHECK(channel_authenticators_.empty()); 26 } 27 28 void SecureChannelFactory::CreateChannel( 29 const std::string& name, 30 const ChannelCreatedCallback& callback) { 31 DCHECK(!callback.is_null()); 32 channel_factory_->CreateChannel( 33 name, 34 base::Bind(&SecureChannelFactory::OnBaseChannelCreated, 35 base::Unretained(this), name, callback)); 36 } 37 38 void SecureChannelFactory::CancelChannelCreation( 39 const std::string& name) { 40 AuthenticatorMap::iterator it = channel_authenticators_.find(name); 41 if (it == channel_authenticators_.end()) { 42 channel_factory_->CancelChannelCreation(name); 43 } else { 44 delete it->second; 45 channel_authenticators_.erase(it); 46 } 47 } 48 49 void SecureChannelFactory::OnBaseChannelCreated( 50 const std::string& name, 51 const ChannelCreatedCallback& callback, 52 scoped_ptr<net::StreamSocket> socket) { 53 if (!socket) { 54 callback.Run(scoped_ptr<net::StreamSocket>()); 55 return; 56 } 57 58 ChannelAuthenticator* channel_authenticator = 59 authenticator_->CreateChannelAuthenticator().release(); 60 channel_authenticators_[name] = channel_authenticator; 61 channel_authenticator->SecureAndAuthenticate( 62 socket.Pass(), 63 base::Bind(&SecureChannelFactory::OnSecureChannelCreated, 64 base::Unretained(this), name, callback)); 65 } 66 67 void SecureChannelFactory::OnSecureChannelCreated( 68 const std::string& name, 69 const ChannelCreatedCallback& callback, 70 int error, 71 scoped_ptr<net::StreamSocket> socket) { 72 DCHECK((socket && error == net::OK) || (!socket && error != net::OK)); 73 74 AuthenticatorMap::iterator it = channel_authenticators_.find(name); 75 DCHECK(it != channel_authenticators_.end()); 76 delete it->second; 77 channel_authenticators_.erase(it); 78 79 callback.Run(socket.Pass()); 80 } 81 82 } // namespace protocol 83 } // namespace remoting 84