321 lines
8.4 KiB
C++
321 lines
8.4 KiB
C++
#include "./server.h"
|
|
#include "./handler.h"
|
|
#include "./net.h"
|
|
|
|
#include <event.h>
|
|
#include <functional>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <cassert>
|
|
#include <zconf.h>
|
|
#include <fcntl.h>
|
|
|
|
using namespace ts::dns;
|
|
|
|
DNSServer::~DNSServer() {
|
|
this->stop();
|
|
}
|
|
|
|
bool DNSServer::start(const std::vector<sockaddr_storage> &bindings, std::string &error) {
|
|
size_t successful_binds{0};
|
|
|
|
std::lock_guard lock{this->bind_lock};
|
|
if(this->started) {
|
|
error = "server already bound";
|
|
return false;
|
|
}
|
|
|
|
this->event_base = event_base_new();
|
|
if(!this->event_base) {
|
|
error = "failed to spawn event base";
|
|
goto error_exit;
|
|
}
|
|
|
|
this->event_base_ticker = evtimer_new(this->event_base, [](auto, auto, void* server){
|
|
static_cast<DNSServer*>(server)->event_cb_timer();
|
|
}, this);
|
|
if(!this->event_base_ticker) {
|
|
error = "failed to spawn heartbeat event";
|
|
goto error_exit;
|
|
}
|
|
|
|
this->_bindings.resize(bindings.size());
|
|
for(size_t index = 0; index < bindings.size(); index++) {
|
|
auto binding = this->_bindings[index] = std::make_shared<DNSServerBinding>();
|
|
binding->self = binding;
|
|
binding->server = this;
|
|
memcpy(&binding->address, &bindings[index], sizeof(sockaddr_storage));
|
|
|
|
if(!this->bind(*binding, error))
|
|
binding->error = error;
|
|
else
|
|
successful_binds++;
|
|
}
|
|
|
|
if(!successful_binds) {
|
|
error = "failed to bind to any address";
|
|
goto error_exit;
|
|
}
|
|
|
|
{
|
|
timeval timeout{10, 0};
|
|
event_add(this->event_base_ticker, &timeout);
|
|
}
|
|
|
|
this->event_base_executor = std::thread(std::bind(&DNSServer::event_executor, this));
|
|
this->started = true;
|
|
return true;
|
|
|
|
error_exit:
|
|
if(this->event_base_ticker) {
|
|
event_del_block(this->event_base_ticker);
|
|
event_free(this->event_base_ticker);
|
|
this->event_base_ticker = nullptr;
|
|
}
|
|
|
|
if(this->event_base) {
|
|
event_base_free(this->event_base);
|
|
this->event_base = nullptr;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void DNSServer::stop() {
|
|
std::unique_lock lock{this->bind_lock};
|
|
if(!this->started)
|
|
return;
|
|
|
|
this->started = false;
|
|
assert(this->event_base); //Must be set else the started flag was invalid
|
|
assert(this->event_base_ticker); //Must be set else the started flag was invalid
|
|
|
|
|
|
for(auto& binding : this->bindings())
|
|
this->unbind(*binding);
|
|
|
|
event_base_loopexit(this->event_base, nullptr);
|
|
{
|
|
timeval timeout{0, 0};
|
|
event_add(this->event_base_ticker, &timeout);
|
|
}
|
|
|
|
lock.unlock();
|
|
this->event_base_executor.join();
|
|
lock.lock();
|
|
|
|
event_free(this->event_base_ticker);
|
|
this->event_base_ticker = nullptr;
|
|
}
|
|
|
|
bool DNSServer::bind(DNSServerBinding &binding, std::string &error) {
|
|
binding.socket = socket(binding.address.ss_family, SOCK_DGRAM, 0);
|
|
if(binding.socket < 2) {
|
|
error = "failed to create socket";
|
|
return false;
|
|
}
|
|
int enable{1};
|
|
|
|
if(binding.address.ss_family == AF_INET6) {
|
|
if(setsockopt(binding.socket, IPPROTO_IPV6, IPV6_RECVPKTINFO, &enable, sizeof(enable)) < 0) {
|
|
error = "failed to enable packet info (v6) (" + std::to_string(errno) + "/" + strerror(errno) + ")";
|
|
goto cleanup_exit;
|
|
}
|
|
if(setsockopt(binding.socket, IPPROTO_IPV6, IPV6_V6ONLY, &enable, sizeof(enable)) < 0) {
|
|
error = "failed to enable ip v6 only (" + std::to_string(errno) + "/" + strerror(errno) + ")";
|
|
goto cleanup_exit;
|
|
}
|
|
} else {
|
|
if(setsockopt(binding.socket, IPPROTO_IP, IP_PKTINFO, &enable, sizeof(enable)) < 0) {
|
|
error = "failed to enable packet info (" + std::to_string(errno) + "/" + strerror(errno) + ")";
|
|
goto cleanup_exit;
|
|
}
|
|
}
|
|
|
|
if(::bind(binding.socket, (const sockaddr*) &binding.address, sizeof(binding.address)) < 0) {
|
|
error = "failed to bind: " + std::to_string(errno) + "/" + strerror(errno);
|
|
goto cleanup_exit;
|
|
}
|
|
|
|
if(fcntl(binding.socket, F_SETFL, fcntl(binding.socket, F_GETFL, 0) | O_NONBLOCK) < 0) {
|
|
error = "failed to enable noblock";
|
|
goto cleanup_exit;
|
|
}
|
|
|
|
#ifdef WIN32
|
|
u_long enabled = 0;
|
|
auto non_block_rs = ioctlsocket(binding.socket, FIONBIO, &enabled);
|
|
if (non_block_rs != NO_ERROR) {
|
|
error = "failed to enable nonblock";
|
|
goto cleanup_exit;
|
|
}
|
|
#endif
|
|
|
|
binding.read_event = event_new(this->event_base, binding.socket, EV_READ | EV_PERSIST, &DNSServer::event_cb_read, &binding);
|
|
if(!binding.read_event) {
|
|
error = "failed to create read event";
|
|
goto cleanup_exit;
|
|
}
|
|
|
|
binding.write_event = event_new(this->event_base, binding.socket, EV_WRITE, &DNSServer::event_cb_write, &binding);
|
|
if(!binding.read_event) {
|
|
error = "failed to create write event";
|
|
goto cleanup_exit;
|
|
}
|
|
|
|
event_add(binding.read_event, nullptr);
|
|
return true;
|
|
|
|
cleanup_exit:
|
|
if(binding.read_event) {
|
|
event_del_noblock(binding.read_event);
|
|
event_free(binding.read_event);
|
|
binding.read_event = nullptr;
|
|
}
|
|
|
|
if(binding.write_event) {
|
|
event_del_noblock(binding.write_event);
|
|
event_free(binding.write_event);
|
|
binding.write_event = nullptr;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void DNSServer::unbind(DNSServerBinding &binding) {
|
|
std::lock_guard lock{binding.io_lock};
|
|
|
|
if(binding.read_event) {
|
|
event_del_block(binding.read_event);
|
|
event_free(binding.read_event);
|
|
binding.read_event = nullptr;
|
|
}
|
|
|
|
if(binding.write_event) {
|
|
event_del_block(binding.write_event);
|
|
event_free(binding.write_event);
|
|
binding.write_event = nullptr;
|
|
}
|
|
|
|
if(binding.socket) {
|
|
::shutdown(binding.socket, SHUT_RDWR);
|
|
::close(binding.socket);
|
|
|
|
binding.socket = 0;
|
|
}
|
|
|
|
{
|
|
auto head = binding.write_buffer_head;
|
|
while(head) {
|
|
auto tmp = head;
|
|
head = head->next;
|
|
free(tmp);
|
|
}
|
|
|
|
binding.write_buffer_head = nullptr;
|
|
binding.write_buffer_tail = nullptr;
|
|
}
|
|
|
|
binding.self = nullptr;
|
|
}
|
|
|
|
void DNSServerBinding::send(const sockaddr_storage &address, const void *payload, size_t size) {
|
|
std::lock_guard lock{this->io_lock};
|
|
if(!this->write_event)
|
|
return;
|
|
|
|
auto buffer = (char*) malloc(sizeof(BindingBuffer) + size);
|
|
auto bbuffer = (BindingBuffer*) buffer;
|
|
bbuffer->next = nullptr;
|
|
bbuffer->size = size;
|
|
memcpy(&bbuffer->target, &address, sizeof(address));
|
|
memcpy(buffer + sizeof(BindingBuffer), payload, size);
|
|
|
|
if(this->write_buffer_tail) {
|
|
assert(!this->write_buffer_tail->next);
|
|
this->write_buffer_tail->next = bbuffer;
|
|
this->write_buffer_tail = bbuffer;
|
|
} else {
|
|
assert(!this->write_buffer_head);
|
|
this->write_buffer_head = bbuffer;
|
|
this->write_buffer_tail = bbuffer;
|
|
}
|
|
|
|
event_add(this->write_event, nullptr);
|
|
}
|
|
|
|
void DNSServer::event_executor() {
|
|
do {
|
|
event_base_loop(this->event_base, EVLOOP_NO_EXIT_ON_EMPTY);
|
|
|
|
std::lock_guard lock{this->bind_lock};
|
|
if(!this->started)
|
|
return;
|
|
} while(true);
|
|
}
|
|
|
|
void DNSServer::event_cb_timer() {
|
|
std::lock_guard lock{this->bind_lock};
|
|
if(this->started) {
|
|
timeval timeout{10, 0};
|
|
event_add(this->event_base_ticker, &timeout);
|
|
}
|
|
}
|
|
|
|
void DNSServer::event_cb_read(evutil_socket_t fd, short, void *ptr_binding) {
|
|
auto binding = static_cast<DNSServerBinding*>(ptr_binding);
|
|
auto binding_ref = binding->self;
|
|
if(!binding_ref) return;
|
|
|
|
sockaddr_storage source_address{};
|
|
socklen_t source_address_length{0};
|
|
|
|
ssize_t read_length{-1};
|
|
size_t buffer_length = 1600; /* IPv6 MTU is ~1.5k */
|
|
char buffer[1600];
|
|
|
|
size_t read_count = 0;
|
|
while(true) { //TODO: Some kind of timeout
|
|
source_address_length = sizeof(sockaddr_storage);
|
|
read_length = recvfrom(fd, (char*) buffer, buffer_length, MSG_DONTWAIT, (struct sockaddr*) &source_address, &source_address_length);
|
|
if(read_length <= 0) {
|
|
if(errno == EAGAIN)
|
|
break;
|
|
|
|
std::cerr << "Failed to receive data: " << errno << "/" << strerror(errno) << "\n";
|
|
break; /* this should never happen! */
|
|
}
|
|
|
|
read_count++;
|
|
//buffer, (size_t) read_length
|
|
|
|
auto handler = binding->server->handler;
|
|
if(handler)
|
|
handler->handle_message(binding_ref, source_address, buffer, read_length);
|
|
else
|
|
std::cerr << "Dropping " << read_length << " bytes from " << net::to_string(source_address, true) << " because we've no handler\n";
|
|
}
|
|
}
|
|
|
|
void DNSServer::event_cb_write(evutil_socket_t fd, short, void *ptr_binding) {
|
|
auto binding = static_cast<DNSServerBinding*>(ptr_binding);
|
|
auto binding_ref = binding->self;
|
|
if(!binding_ref) return;
|
|
|
|
ssize_t code;
|
|
DNSServerBinding::BindingBuffer* buffer{nullptr};
|
|
while(true) {
|
|
{
|
|
std::lock_guard lock{binding->io_lock};
|
|
buffer = binding->write_buffer_head;
|
|
if(!buffer)
|
|
break;
|
|
|
|
if(!(binding->write_buffer_head = binding->write_buffer_head->next))
|
|
binding->write_buffer_tail = nullptr;
|
|
}
|
|
|
|
code = sendto(fd, (char*) buffer + sizeof(DNSServerBinding::BindingBuffer), buffer->size, 0, (sockaddr*) &buffer->target, sizeof(buffer->target));
|
|
if(code <= 0)
|
|
std::cerr << "Failed to send DNS response to " << net::to_string(buffer->target, true) << ": " << errno << "/" << strerror(errno);
|
|
free(buffer);
|
|
}
|
|
} |