source: nscp/include/socket/connection.hpp @ 9bd40e2

0.4.10.4.2
Last change on this file since 9bd40e2 was 9bd40e2, checked in by Michael Medin <michael@…>, 13 months ago
  • Refactored server internals to be more uniform (This is the first step to adding more protocols like NRDP and unit tests for check_nt)
  • Property mode set to 100644
File size: 7.7 KB
Line 
1#pragma once
2
3#include <boost/asio.hpp>
4#include <boost/array.hpp>
5#include <boost/noncopyable.hpp>
6#include <boost/shared_ptr.hpp>
7#include <boost/enable_shared_from_this.hpp>
8#ifdef USE_SSL
9#include <boost/asio/ssl/context.hpp>
10#endif
11#include "handler.hpp"
12#include "parser.hpp"
13
14namespace socket_helpers {
15        namespace server {
16
17                using boost::asio::ip::tcp;
18                static const bool debug_trace = true;
19
20                //
21                // The socket statemachine:
22                // This is the idea not how it currently looks
23                // SOCKET  | SSL-SOCKET | PROTOCOL   | RETURN
24                //         | connect    |            |
25                // connect | handhake   | on_connect | true = allow, false = disallow
26                // ...
27                //         |            | wants_data | true = yes, read chunk, false = no, start sending
28                // recv    | recv       | on_read    | true = read more, false = done reading
29                //         |            | has_data   | true = yes, send chunk, false = no, stop sending
30                // recv    | recv       | on_read    | true = read more, false = done reading
31                // ...
32                //         |            | is_done    | true = is done, disconnect, false (read/write loop)
33
34
35                template<class protocol_type, std::size_t N>
36                class connection : public boost::enable_shared_from_this<connection<protocol_type, N> >, private boost::noncopyable {
37                public:
38                        connection(boost::asio::io_service& io_service, boost::shared_ptr<protocol_type> protocol)
39                                : strand_(io_service)
40                                , timer_(io_service)
41                                , socket_(io_service)
42                                , protocol_(protocol)
43                        {
44                        }
45                        virtual ~connection() {
46                        }
47
48                        inline void trace(std::string msg) const {
49                                if (debug_trace)
50                                        protocol_->log_debug(__FILE__, __LINE__, msg);
51                        }
52
53                        virtual boost::asio::ip::tcp::socket& socket() {
54                                return socket_;
55                        }
56
57                        //////////////////////////////////////////////////////////////////////////
58                        // High level connection start/stop
59                        virtual void start() {
60                                trace("start()");
61                                if (protocol_->on_connect()) {
62                                        set_timeout(protocol_->get_info().timeout);
63                                        do_process();
64                                } else {
65                                        stop();
66                                }
67                        }
68
69                        virtual void stop() {
70                                trace("stop()");
71                                cancel_timer();
72                                boost::system::error_code ignored_ec;
73                                socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignored_ec);
74                        }
75
76                        //////////////////////////////////////////////////////////////////////////
77                        // Timeout related functions
78                        virtual void set_timeout(int seconds) {
79                                timer_.expires_from_now(boost::posix_time::seconds(seconds));
80                                timer_.async_wait(boost::bind(&connection::timeout, shared_from_this(), boost::asio::placeholders::error)); 
81                        }
82
83                        virtual void cancel_timer() {
84                                timer_.cancel();
85                        }
86
87                        virtual void timeout(const boost::system::error_code& e) {
88                                if (e != boost::asio::error::operation_aborted) {
89                                        trace("timeout()");
90                                        boost::system::error_code ignored_ec;
91                                        socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignored_ec);
92                                }
93                        }
94
95                        //////////////////////////////////////////////////////////////////////////
96                        // Socket state machine (assumed all sockets are simple connect-read-write-disconnect
97                        virtual void start_read_request() {
98                                trace("start_read_request()");
99                                socket_.async_read_some(
100                                        boost::asio::buffer(buffer_), strand_.wrap(
101                                        boost::bind(&connection::handle_read_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
102                                        ));
103                        }
104
105                        void do_process() {
106                                if (protocol_->wants_data()) {
107                                        start_read_request();
108                                } else if (protocol_->has_data()) {
109                                        //std::vector<boost::asio::const_buffer> buffers;
110                                        //buffers.push_back();
111                                        start_write_request(buf(protocol_->get_outbound()));
112                                } else {
113                                        stop();
114                                }
115                        }
116
117                        virtual void handle_read_request(const boost::system::error_code& e, std::size_t bytes_transferred) {
118                                trace("handle_read_request(" + strEx::s::itos((int)bytes_transferred) + ")");
119                                if (!e) {
120                                        if (protocol_->on_read(buffer_.begin(), buffer_.begin() + bytes_transferred)) {
121                                                do_process();
122                                        } else {
123                                                stop();
124                                        }
125                                } else {
126                                        protocol_->log_error(__FILE__, __LINE__, "Failed to read data");
127                                }
128                        }
129
130                        virtual void start_write_request(const boost::asio::const_buffer& response) {
131                                std::size_t s1 = boost::asio::buffer_size(response);
132                                trace("start_write_request(" + strEx::s::itos((int)s1) + ")");
133                                boost::asio::async_write(socket_, boost::asio::const_buffers_1(response), strand_.wrap(
134                                        boost::bind(&connection::handle_write_response, shared_from_this(),boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
135                                        ));
136                        }
137
138                        virtual void handle_write_response(const boost::system::error_code& e, std::size_t bytes_transferred) {
139                                trace("handle_write_response(" + strEx::s::itos((int)bytes_transferred) + ")");
140                                if (!e) {
141                                        protocol_->on_write();
142                                        do_process();
143                                } else {
144                                        protocol_->log_error(__FILE__, __LINE__, "Failed to send data");
145                                }
146                        }
147
148                protected:
149                        //////////////////////////////////////////////////////////////////////////
150                        // Internal functions and data
151
152                        boost::asio::const_buffer buf(const typename protocol_type::outbound_buffer_type s) {
153                                buffers_.push_back(s);
154                                return boost::asio::buffer(buffers_.back());
155                        }
156
157
158                        boost::asio::io_service::strand strand_;
159                        boost::array<char, N> buffer_;
160                        boost::asio::deadline_timer timer_;
161                        std::list<typename protocol_type::outbound_buffer_type> buffers_;
162                        boost::asio::ip::tcp::socket socket_;
163                        boost::shared_ptr<protocol_type> protocol_;
164                        std::string module_;
165                };
166
167#ifdef USE_SSL
168                template<class protocol_type, std::size_t N>
169                class ssl_connection : private boost::noncopyable, public connection<protocol_type, N> {
170                        typedef connection<protocol_type, N> parent_type;
171                        typedef ssl_connection<protocol_type, N> my_type;
172                public:
173                        ssl_connection(boost::asio::io_service& io_service, boost::asio::ssl::context &context, boost::shared_ptr<protocol_type> protocol)
174                                : connection<protocol_type, N>(io_service, protocol)
175                                , ssl_socket_(io_service, context)
176                        {
177                        }
178                        virtual ~ssl_connection() {
179                        }
180
181
182                        virtual boost::asio::ip::tcp::socket& socket() {
183                                return ssl_socket_.next_layer();
184                        }
185
186
187                        virtual void start() {
188                                trace("ssl::start_read_request()");
189                                boost::shared_ptr<my_type> self = boost::shared_dynamic_cast<my_type>(shared_from_this());
190                                //boost::shared_ptr<ssl_connection<protocol_type, N>> self = static_cast<boost::shared_ptr<ssl_connection<protocol_type, N> > >(shared_from_this());
191                                ssl_socket_.async_handshake(boost::asio::ssl::stream_base::server,strand_.wrap(
192                                        boost::bind(&ssl_connection::handle_handshake, self, boost::asio::placeholders::error)
193                                        ));
194                        }
195                       
196                        virtual void handle_handshake(const boost::system::error_code& error) {
197                                if (!error)
198                                        parent_type::start();
199                                else {
200                                        protocol_->log_error(__FILE__, __LINE__, "Failed to establish secure connection: " + error.message());
201                                }
202                        }
203
204                        virtual void start_read_request() {
205                                ssl_socket_.async_read_some(
206                                        boost::asio::buffer(buffer_),
207                                        strand_.wrap(
208                                        boost::bind(&connection::handle_read_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
209                                        )
210                                        );
211                        }
212
213                        virtual void start_write_request(const std::vector<boost::asio::const_buffer>& response) {
214                                trace("ssl::start_write_request(" + strEx::s::itos((int)response.size()) + ")");
215                                boost::asio::async_write(ssl_socket_, response,
216                                        strand_.wrap(
217                                        boost::bind(&connection::handle_write_response, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
218                                        )
219                                        );
220                        }
221
222                protected:
223                        typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket;
224                        ssl_socket ssl_socket_;
225                };
226#endif
227        } // namespace server
228} // namespace socket_helpers
Note: See TracBrowser for help on using the repository browser.