Bond's TCP Library  1.0
Bond's TCP Client/Server Library
tcpsocket.cpp
1 #include "tcpsocket.h"
2 #include <algorithm>
3 #include <string.h>
4 #include <fcntl.h>
5 #include <unistd.h>
6 #include <netdb.h>
7 #include <sys/socket.h>
8 #include <sys/ioctl.h>
9 #include <arpa/inet.h>
10 #include <netinet/ip.h>
11 #include <netinet/tcp.h>
12 
13 namespace tcp {
14 
15 ostream logstream(clog.rdbuf());
16 
17 void setLogStream(ostream *os) { if (os) logstream.rdbuf(os->rdbuf()); }
18 void error(string msg) { logstream << "Error: " << msg << endl; }
19 void error(string label, string msg) { logstream << "Error: " << label << ": " << msg << endl; }
20 void warning(string msg) { logstream << "Warning: " << msg << endl; }
21 void warning(string label, string msg) { logstream << "Warning: " << label << ": " << msg << endl; }
22 void log(string msg) { logstream << msg << endl; }
23 void log(string label, string msg) { logstream << label << ": " << msg << endl;}
24 
26 {
27  handle_ = epoll_create1(0);
28  if (handle_ == -1) {
29  error("epoll_create1",strerror(errno));
30  }
31 }
32 
34 {
35  sockets.clear();
36  if (handle_ > 0) {
37  ::close(handle_);
38  }
39 }
40 
41 bool EPoll::add(Socket& socket, int events)
42 {
43  struct epoll_event ev;
44  bool result = false;
45  ev.events = events;
46  ev.data.fd = socket.socket_;
47  if (epoll_ctl(handle_,EPOLL_CTL_ADD,socket.socket_,&ev) != -1) {
48  sockets[socket.socket_] = &socket;
49  result = true;
50  }
51  return result;
52 }
53 
54 bool EPoll::update(Socket& socket, int events)
55 {
56  bool result;
57  struct epoll_event ev;
58  ev.events = events;
59  ev.data.fd = socket.socket_;
60  result = (epoll_ctl(handle_,EPOLL_CTL_MOD,socket.socket_,&ev) != -1);
61  return result;
62 }
63 
64 bool EPoll::remove(Socket& socket)
65 {
66  bool result = false;
67  if (epoll_ctl(handle_,EPOLL_CTL_DEL,socket.socket_,NULL) != -1) {
68  sockets.erase(socket.socket_);
69  result = true;
70  }
71  return result;
72 }
73 
74 void EPoll::poll(int timeout)
75 {
76  int nfds = epoll_wait(handle_,events,MAX_EVENTS,timeout);
77  if (nfds == -1) {
78  if (errno != EINTR)
79  error("epoll_wait",strerror(errno));
80  } else {
81  for (int n = 0; n < nfds; ++n) {
82  handleEvents(events[n].events,events[n].data.fd);
83  }
84  }
85 }
86 
87 void EPoll::handleEvents(uint32_t events, int fd)
88 {
89  Socket* socket = sockets[fd];
90  if (socket != nullptr) {
91  socket->handleEvents(events);
92  }
93 }
94 
95 /* Socket */
96 
97 Socket::Socket(EPoll &epoll, const int domain, const int socket, const bool blocking, const int events) : epoll_(epoll), events_(events), domain_(domain), socket_(socket)
98 {
99  if ((domain != AF_INET) && (domain != AF_INET6)) {
100  error("Socket","Only IPv4 and IPv6 are supported.");
101  return;
102  }
103  if (socket < 0) {
104  error("Socket","Socket parameter is < 0");
105  return;
106  }
107  mtx.lock();
108  if (socket == 0) {
109  socket_ = ::socket(domain,SOCK_STREAM,0);
110  if (socket_ == -1) {
111  error("socket", strerror(errno));
112  }
113  }
114  int flags = fcntl(socket_,F_GETFL,0);
115  if (flags == -1) {
116  error("fcntl (get)",strerror(errno));
117  } else {
118  if (!blocking) {
119  flags |= O_NONBLOCK;
120  } else {
121  flags = flags & ~O_NONBLOCK;
122  }
123  if (fcntl(socket_,F_SETFL,flags) == -1) {
124  error("fcntl (set)",strerror(errno));
125  }
126  }
127 
128  if (!epoll_.add(*this,events)) {
129  error("Unable to add socket to epoll");
130  }
131 
132  mtx.unlock();
133 }
134 
136 {
137  if (socket_ > 0) {
138  mtx.lock();
139  epoll_.remove(*this);
140  if (::close(socket_) == -1) {
141  error("close",strerror(errno));
142  }
143  socket_ = 0;
144  mtx.unlock();
145  }
146 }
147 
148 bool Socket::setEvents(int events)
149 {
150  mtx.lock();
151  bool result = false;
152  if (events != events_) {
153  if (epoll_.update(*this,events)) {
154  events_ = events;
155  result = true;
156  }
157  } else {
158  result = true;
159  }
160  mtx.unlock();
161  return result;
162 }
163 
165  mtx.lock();
166  if (state_ == SocketState::CONNECTED) {
167  ::shutdown(socket_,SHUT_RDWR);
168  }
169  mtx.unlock();
170  disconnected();
171 }
172 
174 {
175  mtx.lock();
176  if (state_ != SocketState::DISCONNECTED) {
177  ::close(socket_);
178  socket_ = 0;
179  state_ = SocketState::DISCONNECTED;
180  log("Disconnected");
181  } else {
182  warning("Already disconnected");
183  }
184  mtx.unlock();
185 }
186 
187 /* DataSocket */
188 
190 {
191  mtx.lock();
192  if (ssl_ && (state_ == SocketState::CONNECTED)) {
193  ssl_->shutdown();
194  delete ssl_;
195  ssl_ = nullptr;
196  printSSLErrors();
197  }
198  mtx.unlock();
200 }
201 
203 {
204  mtx.lock(); // Do I need to use lock here?
205  if (ssl_) {
206  delete ssl_;
207  ssl_ = nullptr;
208  printSSLErrors();
209  }
210  mtx.unlock();
212 }
213 
215 {
216  uint8_t buffer[256];
217  int size;
218  do {
219  size = read_(&buffer[0],256);
220  for (int i=0;i<size;i++) {
221  inputBuffer.push_back(buffer[i]);
222  }
223  } while (size > 0);
224 }
225 
227 {
228  mtx.lock();
229  size_t size = outputBuffer.size();
230  if (size == 0) return;
231  uint8_t *buffer = (uint8_t*)malloc(size);
232  for (size_t i=0;i<size;i++) {
233  buffer[i] = outputBuffer.at(0);
234  outputBuffer.pop_front();
235  }
236  size_t res = write_(buffer,size);
237  canSend(res!=size);
238  if (res != size) {
239  for (size_t i=res;i<size;i++) {
240  outputBuffer.push_front(buffer[i]);
241  }
242  }
243  free(buffer);
244  mtx.unlock();
245 }
246 
247 void DataSocket::canSend(bool value)
248 {
249  int events = EPOLLIN | EPOLLRDHUP;
250  if (value)
251  events |= EPOLLOUT;
252  setEvents(events);
253 }
254 
255 void DataSocket::handleEvents(uint32_t events)
256 {
257  if (state_ == SocketState::CONNECTED) {
258  if (events & EPOLLRDHUP) {
259  disconnected();
260  } else {
261  if (events & EPOLLIN) {
262  mtx.lock();
264  dataAvailable();
265  if (outputBuffer.size() > 0U) {
267  canSend(outputBuffer.size() > 0U);
268  } else {
269  canSend(false);
270  }
271  mtx.unlock();
272  }
273  if (events & EPOLLOUT) {
274  mtx.lock();
276  canSend(outputBuffer.size() > 0U);
277  mtx.unlock();
278  }
279  }
280  }
281 }
282 
283 size_t DataSocket::read_(void *buffer, size_t size)
284 {
285  if (state_ == SocketState::CONNECTED) {
286  size_t result;
287  if (ssl_) {
288  result = ssl_->read(buffer,size);
289  } else {
290  result = ::recv(socket(),buffer,size,0);
291  }
292  return result;
293  } else {
294  return 0;
295  }
296 }
297 
298 size_t DataSocket::write_(const void *buffer, size_t size)
299 {
300  if (state_ == SocketState::CONNECTED) {
301  size_t result;
302  if (ssl_) {
303  result = ssl_->write(buffer,size);
304  } else {
305  result = ::send(socket(),buffer,size,MSG_NOSIGNAL);
306  }
307  return result;
308  } else {
309  return 0;
310  }
311 }
312 
313 size_t DataSocket::read(void *buffer, size_t size)
314 {
315  size_t result = 0;
316  if (size) {
317  mtx.lock();
318  result = max<size_t>(size,inputBuffer.size());
319  if (result > 0) {
320  for (size_t i=0;i<result;++i) {
321  ((uint8_t*)buffer)[i] = inputBuffer.at(0);
322  inputBuffer.pop_front();
323  }
324  }
325  mtx.unlock();
326  }
327  return result;
328 }
329 
330 size_t DataSocket::write(const void *buffer, size_t size)
331 {
332  size_t result = 0U;
333  if (size) {
334  mtx.lock();
335  try {
336  for (size_t i=0;i<size;++i) {
337  outputBuffer.push_back(((uint8_t*)buffer)[i]);
338  ++result;
339  }
340  canSend(true);
341  } catch (const std::bad_alloc&) {
342  mtx.unlock();
343  }
344  }
345  return result;
346 }
347 
349 {
350  if (context) {
351  return new SSL(*this,*context);
352  } else {
353  return nullptr;
354  }
355 }
356 
357 int getDomainFromHostAndPort(const char* host, const char* port, int def_domain)
358 {
359  struct addrinfo hints;
360  struct addrinfo *result;
361  int errorcode;
362  int domain;
363 
364  memset(&hints,0,sizeof(struct addrinfo));
365  hints.ai_family = AF_UNSPEC;
366  hints.ai_flags = AI_NUMERICHOST;
367  errorcode = getaddrinfo(host,nullptr,&hints,&result);
368  if (errorcode == 0) {
369  domain = result->ai_family;
370  } else {
371  domain = AF_UNSPEC;
372  }
373  if (domain == AF_UNSPEC) {
374  hints.ai_flags = AI_CANONNAME;
375  errorcode = getaddrinfo(host,port,&hints,&result);
376  if (errorcode == 0) {
377  domain = result->ai_family;
378  }
379  }
380  if (domain == AF_UNSPEC) {
381  return def_domain;
382  } else {
383  return domain;
384  }
385 }
386 
387 } // namespace tcp
tcp::printSSLErrors
void printSSLErrors()
This method logs openSSL errors to cerr.
Definition: tcpssl.cpp:56
tcp::Socket::handleEvents
virtual void handleEvents(uint32_t events)=0
Called when the socket recieves an epoll event.
tcp::setLogStream
void setLogStream(ostream *os)
Set the output stream used by the library for log, warning and error messages.
Definition: tcpsocket.cpp:17
tcp::DataSocket::createSSL
virtual SSL * createSSL(SSLContext *context)
Factory method for returning an SSL object.
Definition: tcpsocket.cpp:348
tcp::Socket::disconnected
virtual void disconnected()
Called when a connection is disconnected due to a network error.
Definition: tcpsocket.cpp:173
tcp::error
void error(string msg)
Send an error message to the log stream.
Definition: tcpsocket.cpp:18
tcp::Socket::mtx
recursive_mutex mtx
The mutex used to provide exclusive access to the socket.
Definition: tcpsocket.h:141
tcp::getDomainFromHostAndPort
int getDomainFromHostAndPort(const char *host, const char *port, int def_domain)
Tries to determine which address family to use from a host and port string.
Definition: tcpsocket.cpp:357
tcpsocket.h
Shared base classes for tcpclient.h and tcpserver.h.
tcp::SSL::write
size_t write(const void *buffer, size_t size)
Encrypts and writes SSL socket data.
Definition: tcpssl.cpp:613
tcp::warning
void warning(string msg)
Send a warning message to the log stream.
Definition: tcpsocket.cpp:20
tcp::SSL
Encapsulates an SSL connection data structure.
Definition: tcpssl.h:109
tcp::DataSocket::readToInputBuffer
void readToInputBuffer()
Reads all available data from the socket into inputBuffer.
Definition: tcpsocket.cpp:214
tcp::Socket::state_
SocketState state_
Descendant classes can manipulate the socket state directly.
Definition: tcpsocket.h:147
tcp::DataSocket::sendOutputBuffer
void sendOutputBuffer()
Writes all available data from the outputBuffer to the socket.
Definition: tcpsocket.cpp:226
tcp::DataSocket::disconnected
void disconnected() override
Called when a connection is disconnected.
Definition: tcpsocket.cpp:202
tcp::Socket::~Socket
~Socket()
Closes and destroys the socket.
Definition: tcpsocket.cpp:135
tcp::Socket::Socket
Socket(EPoll &epoll, const int domain=AF_INET, const int socket=0, const bool blocking=false, const int events=(EPOLLIN|EPOLLRDHUP))
Construct a blocking or non-blocking socket handle that responds to certain epoll events.
Definition: tcpsocket.cpp:97
tcp::DataSocket::disconnect
void disconnect() override
Shuts down any SSL connection gracefully.
Definition: tcpsocket.cpp:189
tcp::SSL::shutdown
void shutdown()
Closes the SSL connection gracefully.
Definition: tcpssl.cpp:641
tcp::SSLContext
Encapsulates an openSSL SSL_CTX record.
Definition: tcpssl.h:41
tcp::DataSocket::read
size_t read(void *buffer, size_t size)
Reads up to size bytes from inputBuffer into buffer.
Definition: tcpsocket.cpp:313
tcp::Socket::socket
int socket() const
Return the linux socket handle.
Definition: tcpsocket.h:115
tcp::EPoll::EPoll
EPoll()
Constructor.
Definition: tcpsocket.cpp:25
tcp::DataSocket::ssl_
SSL * ssl_
Exposes the underlying SSL record used for openSSL calls to descendant classes.
Definition: tcpsocket.h:213
tcp::EPoll
Encapsulates the EPoll interface.
Definition: tcpsocket.h:72
tcp::SSL::read
size_t read(void *buffer, size_t size)
Reads and decrypts SSL socket data.
Definition: tcpssl.cpp:594
tcp::Socket::disconnect
virtual void disconnect()
Shuts down the socket gracefully.
Definition: tcpsocket.cpp:164
tcp::DataSocket::canSend
void canSend(bool value)
Sets the epoll event flags.
Definition: tcpsocket.cpp:247
tcp::EPoll::poll
void poll(int timeout)
Call poll() regularly to respond to network events.
Definition: tcpsocket.cpp:74
tcp::DataSocket::handleEvents
void handleEvents(uint32_t events) override
Called by the EPoll class when the listening socket recieves an epoll event.
Definition: tcpsocket.cpp:255
tcp::Socket::setEvents
bool setEvents(int events)
Changes which epoll events the socket listens for.
Definition: tcpsocket.cpp:148
tcp::Socket::domain
int domain() const
Return the socket domain (AF_INET or AF_INET6)
Definition: tcpsocket.h:118
tcp::DataSocket::dataAvailable
virtual void dataAvailable()=0
Called whenever new data is appended to the inputBuffer.
tcp::log
void log(string msg)
Send an log message to the log stream.
Definition: tcpsocket.cpp:22
tcp::EPoll::~EPoll
~EPoll()
Destructor.
Definition: tcpsocket.cpp:33
tcp
A tcp client/server library for linux that supports openSSL and EPoll.
Definition: tcpclient.cpp:5
tcp::Socket
Encapsulates a socket handle that is capable of recieving epoll events.
Definition: tcpsocket.h:96
tcp::DataSocket::write
size_t write(const void *buffer, size_t size)
Writes the contents of buffer to the outputBuffer.
Definition: tcpsocket.cpp:330