source: nscp/include/socket/client.hpp @ af05fa1

0.4.10.4.2
Last change on this file since af05fa1 was af05fa1, checked in by Michael Medin <michael@…>, 12 months ago
  • Refactored (same as with server) client internals to be more uniform (Again, next up are NSClient Server and NRDP)
  • Property mode set to 100644
File size: 9.4 KB
Line 
1#pragma once
2
3#include <boost/shared_ptr.hpp>
4
5#include <socket/socket_helpers.hpp>
6#include <iostream>
7
8using boost::asio::ip::tcp;
9
10namespace socket_helpers {
11        namespace client {
12
13                static const bool debug_trace = true;
14
15                template<class protocol_type>
16                class connection : public boost::enable_shared_from_this<connection<protocol_type> >, private boost::noncopyable {
17                private:
18                        tcp::socket socket_;
19                        protocol_type protocol_;
20                        boost::asio::io_service &io_service_;
21                        boost::asio::deadline_timer timer_;
22                        boost::posix_time::time_duration timeout_;
23                        boost::optional<boost::system::error_code> timer_result_;
24                        boost::optional<bool> data_result_;
25                        boost::shared_ptr<typename protocol_type::client_handler> handler_;
26
27                public:
28                        connection(boost::asio::io_service &io_service, boost::posix_time::time_duration timeout, boost::shared_ptr<typename protocol_type::client_handler> handler)
29                                : io_service_(io_service)
30                                , socket_(io_service)
31                                , timer_(io_service)
32                                , timeout_(timeout)
33                                , handler_(handler)
34                                , protocol_(handler)
35                        {}
36
37                        virtual ~connection() {
38                                stop_timer();
39                                close();
40                        }
41
42                        typedef boost::asio::basic_socket<tcp,boost::asio::stream_socket_service<tcp> >  basic_socket_type;
43
44                        //////////////////////////////////////////////////////////////////////////
45                        // Time related functions
46                        //
47                        void start_timer() {
48                                timer_result_.reset();
49                                timer_.expires_from_now(timeout_);
50                                timer_.async_wait(boost::bind(&connection::on_timeout, shared_from_this(),  boost::asio::placeholders::error));
51                        }
52                        void stop_timer() {
53                                timer_.cancel();
54                        }
55                        virtual void on_timeout(boost::system::error_code ec) {
56                                trace("on_timeout(" + ec.message() + ")");
57                                if (!ec) {
58                                        timer_result_.reset(ec);
59                                }
60                        }
61
62                        //////////////////////////////////////////////////////////////////////////
63                        // External API functions
64                        //
65                        virtual void connect(std::string host, std::string port) {
66                                tcp::resolver resolver(io_service_);
67                                tcp::resolver::query query(host, port);
68
69                                tcp::resolver::iterator endpoint_iterator = resolver.resolve(query);
70                                tcp::resolver::iterator end;
71
72                                boost::system::error_code error = boost::asio::error::host_not_found;
73                                while (error && endpoint_iterator != end) {
74                                        tcp::resolver::endpoint_type ep = *endpoint_iterator;
75                                        get_socket().close();
76                                        get_socket().lowest_layer().connect(*endpoint_iterator++, error);
77                                }
78                                if (error)
79                                        throw boost::system::system_error(error);
80                                protocol_.on_connect();
81                        }
82
83                        virtual typename protocol_type::response_type process_request(typename protocol_type::request_type &packet) {
84                                start_timer();
85                                data_result_.reset();
86                                protocol_.prepare_request(packet);
87                                do_process();
88                                if (!wait()) {
89                                        stop_timer();
90                                        close();
91                                        return protocol_.get_timeout_response();
92                                }
93                                stop_timer();
94                                return protocol_.get_response();
95                        }
96
97                        virtual void shutdown() {
98                                trace("shutdown()");
99                                boost::system::error_code ignored_ec;
100                                if (get_socket().is_open())
101                                        get_socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ignored_ec);
102                        };
103
104
105                        virtual void close() {
106                                trace("close()");
107                                if (!get_socket().is_open())
108                                        return;
109                                get_socket().shutdown(boost::asio::ip::tcp::socket::shutdown_both);
110                                get_socket().close();
111                        }
112
113                        //////////////////////////////////////////////////////////////////////////
114                        // Internal socket functions
115                        //
116                        virtual void do_process() {
117                                trace("do_process()");
118                                if (protocol_.wants_data()) {
119                                        start_read_request(boost::asio::buffer(protocol_.get_inbound()));
120                                } else if (protocol_.has_data()) {
121                                        start_write_request(boost::asio::buffer(protocol_.get_outbound()));
122                                } else {
123                                        trace("do_process(done)");
124                                        data_result_.reset(true);
125                                }
126                        }
127
128                        virtual void start_read_request(boost::asio::mutable_buffers_1 &buffer) {
129                                trace("start_read_request()");
130                                async_read(socket_, buffer,
131                                        boost::bind(&connection::handle_read_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
132                                );
133                        }
134
135                        virtual void handle_read_request(const boost::system::error_code& e, std::size_t bytes_transferred) {
136                                trace("handle_read_request(" + strEx::s::itos((int)bytes_transferred) + ")");
137                                if (!e) {
138                                        protocol_.on_read(bytes_transferred);
139                                        do_process();
140                                } else {
141                                        handler_->log_error(__FILE__, __LINE__, "Failed to read data: " + e.message());
142                                }
143                        }
144
145                        virtual void start_write_request(boost::asio::mutable_buffers_1 &buffer) {
146                                trace("start_write_request()");
147                                async_write(socket_, buffer,
148                                        boost::bind(&connection::handle_write_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
149                                        );
150                        }
151
152                        virtual void handle_write_request(const boost::system::error_code& e, std::size_t bytes_transferred) {
153                                trace("handle_write_request(" + strEx::s::itos((int)bytes_transferred) + ")");
154                                if (!e) {
155                                        protocol_.on_write(bytes_transferred);
156                                        do_process();
157                                } else {
158                                        handler_->log_error(__FILE__, __LINE__, "Failed to send data: " + e.message());
159                                }
160                        }
161
162                        virtual bool wait() {
163                                trace("wait()");
164                                io_service_.reset();
165                                while (io_service_.run_one()) {
166                                        if (data_result_) {
167                                                trace("data_result()");
168                                                return true;
169                                        }
170                                        else if (timer_result_) {
171                                                trace("timer_result()");
172                                                return false;
173                                        }
174                                }
175                                return false;
176                        }
177                        virtual basic_socket_type& get_socket() {
178                                return socket_;
179                        }
180
181                        //////////////////////////////////////////////////////////////////////////
182                        // Internal helper functions
183                        //
184                        inline void trace(std::string msg) const {
185                                if (debug_trace)
186                                        handler_->log_debug(__FILE__, __LINE__, msg);
187                        }
188                };
189
190#ifdef USE_SSL
191                template<class protocol_type>
192                class ssl_connection : public connection<protocol_type> {
193                private:
194                        typedef connection<protocol_type> connection_type;
195                        boost::asio::ssl::stream<tcp::socket> ssl_socket_;
196
197                public:
198                        ssl_connection(boost::asio::io_service &io_service, boost::asio::ssl::context &context, boost::posix_time::time_duration timeout, boost::shared_ptr<typename protocol_type::client_handler> handler)
199                                : connection_type(io_service, timeout, handler)
200                                , ssl_socket_(io_service, context)
201                        {}
202                        virtual ~ssl_connection() {
203                        }
204
205
206                        virtual void connect(std::string host, std::string port) {
207                                connection_type::connect(host, port);
208                                ssl_socket_.handshake(boost::asio::ssl::stream_base::client);
209                        }
210
211                        virtual void start_read_request(boost::asio::mutable_buffers_1 &buffer) {
212                                trace("ssl::start_read_request()");
213                                async_read(ssl_socket_, buffer,
214                                        boost::bind(&connection::handle_read_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
215                                        );
216                        }
217
218                        virtual void start_write_request(boost::asio::mutable_buffers_1 &buffer) {
219                                trace("ssl::start_write_request()");
220                                async_write(ssl_socket_, buffer,
221                                        boost::bind(&connection::handle_write_request, shared_from_this(), boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)
222                                        );
223                        }
224                        virtual basic_socket_type& get_socket() {
225                                return ssl_socket_.lowest_layer();
226                        }
227                };
228#endif
229
230                template<class protocol_type>
231                class client {
232                        boost::shared_ptr<connection<protocol_type> > connection_;
233                        boost::asio::io_service io_service_;
234                        boost::shared_ptr<typename protocol_type::client_handler> handler_;
235
236                        typedef connection<protocol_type> connection_type;
237#ifdef USE_SSL
238                        boost::asio::ssl::context context_;
239                        typedef ssl_connection<protocol_type> ssl_connection_type;
240#endif
241
242                public:
243                        client(typename boost::shared_ptr<typename protocol_type::client_handler> handler)
244                                : handler_(handler)
245#ifdef USE_SSL
246                                , context_(io_service_, boost::asio::ssl::context::sslv23)
247#endif
248                        {
249                        }
250
251                        void connect() {
252                                connection_.reset(create_connection());
253                                connection_->connect(handler_->get_host(), handler_->get_port());
254                        }
255
256                        typename connection_type* create_connection() {
257#ifdef USE_SSL
258                                if (handler_->use_ssl()) {
259                                        connection_type* ptr = new ssl_connection_type(io_service_, context_, handler_->get_timeout(), handler_);
260                                        handler_->setup_ssl(context_);
261                                        return ptr;
262                                }
263#endif
264                                return new connection_type(io_service_, handler_->get_timeout(), handler_);
265                        }
266
267                        typename protocol_type::response_type process_request(typename protocol_type::request_type &packet) {
268                                return connection_->process_request(packet);
269                        }
270                        void shutdown() {
271                                connection_->shutdown();
272                        };
273
274                };
275
276                struct client_handler : private boost::noncopyable {
277
278                        std::string dh_key_;
279                        std::string host_;
280                        std::string port_;
281                        long timeout_;
282                        bool ssl_;
283
284                        client_handler(std::string host, std::string port, long timeout, bool ssl, std::string dh_key)
285                                : host_(host)
286                                , port_(port)
287                                , timeout_(timeout)
288                                , ssl_(ssl)
289                                , dh_key_(dh_key)
290                        {}
291
292                        bool use_ssl() { return ssl_; }
293                        std::string get_host() { return host_; }
294                        std::string get_port() { return port_; }
295                        boost::posix_time::time_duration get_timeout() { return boost::posix_time::seconds(timeout_); }
296#ifdef USE_SSL
297                        void setup_ssl(boost::asio::ssl::context &context) {
298                                SSL_CTX_set_cipher_list(context.impl(), "ADH");
299                                context.use_tmp_dh_file(dh_key_);
300                                context.set_verify_mode(boost::asio::ssl::context::verify_none);
301                        }
302#endif
303
304                        virtual void log_debug(std::string file, int line, std::string msg) const = 0;
305                        virtual void log_error(std::string file, int line, std::string msg) const = 0;
306
307                };
308       
309        }
310}
Note: See TracBrowser for help on using the repository browser.