source: nscp/include/Socket.h @ 0687108

0.4.00.4.10.4.2stable
Last change on this file since 0687108 was 0687108, checked in by Michael Medin <michael@…>, 7 years ago

+ Added support for empty NRPE checking (i.e.. chec_nrpe without a -c argument)

  • Added error message when detected language is missing from counters.defs + Added Swedish locale to counters.defs (yes, I switched to Swedish XP on my computer :)
  • Fixed : (and possibly other problems) in counters when checking from check_nt (via NSCLient protocol) + Added CheckAllExcept? to CheckDrive? to check all except the specified drives.
  • Property mode set to 100644
File size: 13.9 KB
Line 
1#pragma once
2#include <Thread.h>
3#include <Mutex.h>
4#include <WinSock2.h>
5
6
7namespace simpleSocket {
8        class SocketException {
9        private:
10                std::string error_;
11        public:
12                SocketException(std::string error) : error_(error) {}
13                SocketException(std::string error, int errorCode) : error_(error) {
14                        error_ += strEx::itos(errorCode);
15                }
16                std::string getMessage() const {
17                        return error_;
18                }
19
20        };
21        class DataBuffer {
22        private:
23                char *buffer_;
24                unsigned int length_;
25        public:
26                DataBuffer() : buffer_(NULL), length_(0){
27                }
28                DataBuffer(const DataBuffer &other) {
29                        buffer_ = new char[other.getLength()];
30                        memcpy(buffer_, other.getBuffer(), other.getLength());
31                        length_ = other.getLength();
32                }
33                virtual ~DataBuffer() {
34                        delete [] buffer_;
35                        length_ = 0;
36                }
37                void append(const char* buffer, const unsigned int length) {
38                        char *tBuf = new char[length_+length+1];
39                        memcpy(tBuf, buffer_, length_);
40                        memcpy(&tBuf[length_], buffer, length);
41                        delete [] buffer_;
42                        buffer_ = tBuf;
43                        length_ += length;
44                }
45                const char * getBuffer() const {
46                        return buffer_;
47                }
48                unsigned int getLength() const {
49                        return length_;
50                }
51                void copyFrom(const char* buffer, const unsigned int length) {
52                        delete [] buffer_;
53                        buffer_ = new char[length+1];
54                        memcpy(buffer_, buffer, length);
55                        length_ = length;
56                }
57        };
58
59        class Socket {
60        protected:
61                SOCKET socket_;
62                sockaddr_in from_;
63
64        public:
65                Socket() : socket_(NULL) {
66                }
67                Socket(SOCKET socket) : socket_(socket) {
68                }
69                Socket(Socket &other) {
70                        socket_ = other.socket_;
71                        from_ = other.from_;
72                        other.socket_ = NULL;
73                }
74                virtual ~Socket() {
75                        if (socket_)
76                                closesocket(socket_);
77                        socket_ = NULL;
78                }
79                virtual SOCKET detach() {
80                        SOCKET s = socket_;
81                        socket_ = NULL;
82                        return s;
83                }
84                virtual void attach(SOCKET s) {
85                        assert(socket_ == NULL);
86                        socket_ = s;
87                }
88                virtual void shutdown(int how = SD_BOTH) {
89                        if (socket_)
90                                ::shutdown(socket_, how);
91                }
92
93                virtual void close() {
94                        if (socket_)
95                                closesocket(socket_);
96                        socket_ = NULL;
97                }
98                virtual void setNonBlock() {
99                        unsigned long NoBlock = 1;
100                        this->ioctlsocket(FIONBIO, &NoBlock);
101                }
102                static unsigned long inet_addr(std::string addr) {
103                        return ::inet_addr(addr.c_str());
104                }
105                static std::string getHostByName(std::string ip) {
106                        hostent* remoteHost;
107                        remoteHost = gethostbyname(ip.c_str());
108                        if (remoteHost == NULL)
109                                throw SocketException("gethostbyname failed for " + ip + ": ", ::WSAGetLastError());
110                        // @todo investigate it this is "correct" and dont use before!
111                        return inet_ntoa(*reinterpret_cast<in_addr*>(remoteHost->h_addr));
112                }
113                static std::string getHostByAddr(std::string ip) {
114                        hostent* remoteHost;
115                        remoteHost = gethostbyaddr(ip.c_str(), static_cast<int>(ip.length()), AF_INET);
116                        if (remoteHost == NULL)
117                                throw SocketException("gethostbyaddr failed for " + ip + ": ", ::WSAGetLastError());
118                        return remoteHost->h_name;
119                }
120                virtual void readAll(DataBuffer &buffer, unsigned int tmpBufferLength = 1024);
121
122                virtual void socket(int af, int type, int protocol ) {
123                        socket_ = ::socket(af, type, protocol);
124                        assert(socket_ != INVALID_SOCKET);
125                }
126                virtual void bind() {
127                        assert(socket_);
128                        int fromlen=sizeof(from_);
129                        if (::bind(socket_, (sockaddr*)&from_, fromlen) == SOCKET_ERROR)
130                                throw SocketException("bind failed: ", ::WSAGetLastError());
131                }
132                virtual void listen(int backlog = SOMAXCONN) {
133                        assert(socket_);
134                        if (::listen(socket_, backlog) == SOCKET_ERROR)
135                                throw SocketException("listen failed: ", ::WSAGetLastError());
136                }
137                virtual bool accept(Socket &client) {
138                        int fromlen=sizeof(client.from_);
139                        SOCKET s = ::accept(socket_, (sockaddr*)&client.from_, &fromlen);
140                        if(s == INVALID_SOCKET) {
141                                int err = ::WSAGetLastError();
142                                if (err == WSAEWOULDBLOCK)
143                                        return false;
144                                throw SocketException("accept failed: ", ::WSAGetLastError());
145                        }
146                        client.attach(s);
147                        return true;
148                }
149                virtual void setAddr(short family, u_long addr, u_short port) {
150                        from_.sin_family=family;
151                        from_.sin_addr.s_addr=addr;
152                        from_.sin_port=port;
153                }
154                virtual int send(const char * buf, unsigned int len, int flags = 0) {
155                        assert(socket_);
156                        return ::send(socket_, buf, len, flags);
157                }
158                int inline send(DataBuffer &buffer, int flags = 0) {
159                        return send(buffer.getBuffer(), buffer.getLength(), flags);
160                }
161                virtual void ioctlsocket(long cmd, u_long *argp) {
162                        assert(socket_);
163                        if (::ioctlsocket(socket_, cmd, argp) == SOCKET_ERROR)
164                                throw SocketException("ioctlsocket failed: ", ::WSAGetLastError());
165                }
166                virtual std::string getAddrString() {
167                        return inet_ntoa(from_.sin_addr);
168                }
169                virtual void printError(std::string file, int line, std::string error);
170        };
171
172        class ListenerHandler {
173        public:
174                virtual void onAccept(Socket *client) = 0;
175                virtual void onClose() = 0;
176        };
177
178
179        /**
180        * @ingroup NSClient++
181        * Socket responder class.
182        * This is a background thread that listens to the socket and executes incoming commands.
183        *
184        * @version 1.0
185        * first version
186        *
187        * @date 02-12-2005
188        *
189        * @author mickem
190        *
191        * @par license
192        * This code is absolutely free to use and modify. The code is provided "as is" with
193        * no expressed or implied warranty. The author accepts no liability if it causes
194        * any damage to your computer, causes your pet to fall ill, increases baldness
195        * or makes your car start emitting strange noises when you start it up.
196        * This code has no bugs, just undocumented features!
197        *
198        * @todo This is not very well written and should probably be reworked.
199        *
200        * @bug
201        *
202        */
203        template <class TListenerType = simpleSocket::Socket, class TSocketType = TListenerType>
204        class Listener : public TListenerType {
205        public:
206                typedef TListenerType tListener;
207                typedef TSocketType tSocket;
208        private:
209                struct simpleResponderBundle {
210                        bool terminated;
211                        HANDLE hThread;
212                        unsigned dwThreadID;
213                };
214                typedef std::list<simpleResponderBundle> socketResponses;
215                typedef TListenerType tBase;
216                class ListenerThread;
217                typedef Thread<ListenerThread> listenThreadManager;
218
219                u_short bindPort_;
220                u_long bindAddres_;
221                unsigned int listenQue_;
222                listenThreadManager threadManager_;
223                socketResponses responderList_;
224                MutexHandler responderMutex_;
225
226        public:
227                class ListenerThread {
228                private:
229                        typedef TListenerType tParentBase;
230                        typedef TSocketType tSocket;
231
232                        HANDLE hStopEvent_;
233                public:
234                        ListenerThread() : hStopEvent_(NULL) {}
235                        DWORD threadProc(LPVOID lpParameter);
236                        bool hasThread() const {
237                                return hStopEvent_ != NULL;
238                        }
239                        void exitThread(void) {
240                                assert(hStopEvent_ != NULL);
241                                if (!SetEvent(hStopEvent_))
242                                        throw new SocketException("SetEvent failed.");
243                        }
244                };
245        private:
246                ListenerHandler *pHandler_;
247
248        public:
249                Listener() : pHandler_(NULL), bindPort_(0), bindAddres_(INADDR_ANY), listenQue_(0), threadManager_("listenThreadManager") {};
250                virtual ~Listener() {
251                        if (responderList_.size() > 0) {
252                                MutexLock lock(responderMutex_);
253                                if (!lock.hasMutex()) {
254                                        printError(__FILE__, __LINE__, "Failed to get responder mutex (cannot terminate socket threads).");
255                                } else {
256                                        for (socketResponses::iterator it = responderList_.begin(); it != responderList_.end(); ++it) {
257                                                if (WaitForSingleObject( (*it).hThread, 1000) == WAIT_OBJECT_0) {
258                                                } else {
259                                                        if (!TerminateThread((*it).hThread, -1)) {
260                                                                printError(__FILE__, __LINE__, "We failed to terminate check thread.");
261                                                        } else {
262                                                                if (WaitForSingleObject( (*it).hThread, 5000) == WAIT_OBJECT_0) {
263                                                                        CloseHandle((*it).hThread);
264                                                                } else {
265                                                                        printError(__FILE__, __LINE__, "We failed to terminate check thread (wait timed out).");
266                                                                }
267                                                        }
268                                                }
269                                        }
270                                        responderList_.clear();
271                                }
272                        }
273                };
274/*
275                virtual void StartListener(int port) {
276                        bindPort_ = port;
277                        threadManager_.createThread(this);
278                }
279                */
280                bool hasListener() {
281                        try {
282                                if (threadManager_.hasActiveThread()) {
283                                        const ListenerThread *t = threadManager_.getThreadConst();
284                                        if (t!=NULL)
285                                                return t->hasThread();
286                                }
287                        } catch (ThreadException e) {
288                                printError(__FILE__, __LINE__, "Could not access listener thread!");
289                                return false;
290                        }
291                        return false;
292                }
293                virtual void StartListener(std::string host, int port, int queLength) {
294                        bindPort_ = port;
295                        if (!host.empty())
296                                bindAddres_ = TListenerType::inet_addr(host);
297                        if (bindAddres_ == INADDR_NONE)
298                                bindAddres_ = INADDR_ANY;
299                        listenQue_ = queLength;
300                        threadManager_.createThread(this);
301                }
302                virtual void StopListener() {
303                        try {
304                                if (threadManager_.hasActiveThread())
305                                        if (!threadManager_.exitThread()) {
306                                                tBase::close();
307                                                throw new SocketException("Could not terminate thread.");
308                                        }
309                        } catch (ThreadException e) {
310                                tBase::close();
311                                throw new SocketException("Could not terminate thread (got exception in thread).");
312                        }
313                        tBase::close();
314                }
315                void setHandler(ListenerHandler* pHandler) {
316                        pHandler_ = pHandler;
317                }
318                void removeHandler(ListenerHandler* pHandler) {
319                        if (pHandler != pHandler_)
320                                throw SocketException("Not a registered handler!");
321                        pHandler_ = NULL;
322                }
323                static unsigned __stdcall socketResponceProc(void* lpParameter);
324                struct srp_data {
325                        Listener *pCore;
326                        tSocket *client;
327                };
328                void addResponder(tSocket *client) {
329                        MutexLock lock(responderMutex_);
330                        if (!lock.hasMutex()) {
331                                printError(__FILE__, __LINE__, "Failed to get responder mutex.");
332                                return;
333                        }
334                        for (socketResponses::iterator it = responderList_.begin(); it != responderList_.end();) {
335                                if ( (*it).terminated) {
336                                        if (WaitForSingleObject( (*it).hThread, 500) == WAIT_OBJECT_0) {
337                                                CloseHandle((*it).hThread);
338                                                responderList_.erase(it++);
339                                        }
340                                } else
341                                        ++it;
342                        }
343                        simpleResponderBundle data;
344                        srp_data *lpData = new srp_data;
345                        lpData->pCore = this;
346                        lpData->client = client;
347
348                        data.hThread = reinterpret_cast<HANDLE>(::_beginthreadex( NULL, 0, &socketResponceProc, lpData, 0, &data.dwThreadID));
349                        data.terminated = false;
350                        responderList_.push_back(data);
351                }
352                bool removeResponder(DWORD dwThreadID) {
353                        MutexLock lock(responderMutex_);
354                        if (!lock.hasMutex()) {
355                                printError(__FILE__, __LINE__, "Failed to get responder mutex when trying to free thread.");
356                                return false;
357                        }
358                        for (socketResponses::iterator it = responderList_.begin(); it != responderList_.end(); ++it) {
359                                if ( (*it).dwThreadID == dwThreadID) {
360                                        (*it).terminated = true;
361                                        return true;
362                                }
363                        }
364                        return false;
365                }
366
367
368        private:
369                void onAccept(tSocket *client) {
370                        if (pHandler_)
371                                pHandler_->onAccept(client);
372                }
373                void onClose() {
374                        if (pHandler_)
375                                pHandler_->onClose();
376                }
377                virtual bool accept(tSocket &client) {
378                        return tBase::accept(client);
379                }
380        };
381
382        WSADATA WSAStartup(WORD wVersionRequested = 0x202);
383        void WSACleanup();
384
385}
386
387template <class TListenerType, class TSocketType>
388unsigned simpleSocket::Listener<TListenerType, TSocketType>::socketResponceProc(void* lpParameter)
389{
390        // @todo make sure this terminates after X seconds!
391
392        srp_data *data = reinterpret_cast<srp_data*>(lpParameter);
393        Listener *pCore = data->pCore;
394        tSocket *client = data->client;
395        delete data;
396        try {
397                pCore->onAccept(client);
398        } catch (SocketException e) {
399                pCore->printError(__FILE__, __LINE__, e.getMessage() + " killing socket...");
400        }
401        client->close();
402        delete client;
403        if (!pCore->removeResponder(GetCurrentThreadId())) {
404                pCore->printError(__FILE__, __LINE__, "Could not remove thread: " + strEx::itos(GetCurrentThreadId()));
405        }
406        _endthreadex(0);
407        return 0;
408}
409
410
411template <class TListenerType, class TSocketType>
412DWORD simpleSocket::Listener<TListenerType, TSocketType>::ListenerThread::threadProc(LPVOID lpParameter)
413{
414        Listener *core = reinterpret_cast<Listener*>(lpParameter);
415
416        hStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
417        if (!hStopEvent_) {
418                core->printError(__FILE__, __LINE__, "Create StopEvent failed: " + strEx::itos(GetLastError()));
419                return 0;
420        }
421
422        try {
423                core->socket(AF_INET,SOCK_STREAM,0);
424                core->setAddr(AF_INET, core->bindAddres_, htons(core->bindPort_));
425                core->bind();
426                if (core->listenQue_ != 0)
427                        core->listen(core->listenQue_);
428                else
429                        core->listen();
430                core->setNonBlock();
431                while (!(WaitForSingleObject(hStopEvent_, 100) == WAIT_OBJECT_0)) {
432                        try {
433                                tSocket client;
434                                if (core->accept(client)) {
435                                        core->addResponder(new tSocket(client));
436                                }
437                        } catch (SocketException e) {
438                                core->printError(__FILE__, __LINE__, e.getMessage() + ", attempting to resume...");
439                        }
440                }
441        } catch (SocketException e) {
442                core->printError(__FILE__, __LINE__, e.getMessage());
443        }
444        core->shutdown(SD_BOTH);
445        core->close();
446        core->onClose();
447        HANDLE hTmp = hStopEvent_;
448        hStopEvent_ = NULL;
449        if (!CloseHandle(hTmp)) {
450                core->printError(__FILE__, __LINE__, "CloseHandle StopEvent failed: " + strEx::itos(GetLastError()));
451        }
452        return 0;
453}
454
455
456
457namespace socketHelpers {
458        class allowedHosts {
459        public:
460                typedef std::list<std::string> host_list;
461        private:
462                host_list allowedHosts_;
463                bool cachedAddresses_;
464        public:
465                allowedHosts() : cachedAddresses_(true) {}
466                void setAllowedHosts(host_list allowedHosts, bool cachedAddresses) {
467                        cachedAddresses_ = cachedAddresses;
468                        if ((!allowedHosts.empty()) && (allowedHosts.front() == "") )
469                                allowedHosts.pop_front();
470                        allowedHosts_ = allowedHosts;
471                        if (cachedAddresses_) {
472                                for (host_list::iterator it = allowedHosts_.begin();it!=allowedHosts_.end();++it) {
473                                        if (((*it).length() > 0) && (isalpha((*it)[0]))) {
474                                                std::string s = (*it);
475                                                try {
476                                                        *it = simpleSocket::Socket::getHostByName(s);
477                                                } catch (simpleSocket::SocketException e) {
478                                                        e;
479                                                }
480                                        }
481                                }
482                        }
483                }
484                bool inAllowedHosts(std::string s) {
485                        if (allowedHosts_.empty())
486                                return true;
487                        host_list::const_iterator cit;
488                        if (!cachedAddresses_) {
489                                for (host_list::iterator it = allowedHosts_.begin();it!=allowedHosts_.end();++it) {
490                                        if (((*it).length() > 0) && (isalpha((*it)[0]))) {
491                                                std::string s = (*it);
492                                                try {
493                                                        *it = simpleSocket::Socket::getHostByName(s);
494                                                } catch (simpleSocket::SocketException e) {
495                                                        e;
496                                                }
497                                        }
498                                }
499                        }
500                        for (cit = allowedHosts_.begin();cit!=allowedHosts_.end();++cit) {
501                                if ( (*cit) == s)
502                                        return true;
503                        }
504                        return false;
505                }
506        };
507}
Note: See TracBrowser for help on using the repository browser.