Sockets: Use epoll on Linux

This commit is contained in:
Stenzek 2024-07-21 15:35:56 +10:00
parent 78800870bc
commit ad374ef5e2
No known key found for this signature in database
2 changed files with 123 additions and 18 deletions

View file

@ -42,6 +42,10 @@ using nfds_t = ULONG;
#include <sys/un.h>
#include <unistd.h>
#ifdef __linux__
#include <sys/epoll.h>
#endif
#define ioctlsocket ioctl
#define closesocket close
#define WSAEWOULDBLOCK EAGAIN
@ -227,16 +231,42 @@ SocketMultiplexer::~SocketMultiplexer()
{
CloseAll();
#ifdef __linux__
if (m_epoll_fd >= 0)
close(m_epoll_fd);
#else
if (m_poll_array)
std::free(m_poll_array);
#endif
}
std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
{
if (!PlatformMisc::InitializeSocketSupport(error))
return {};
std::unique_ptr<SocketMultiplexer> ret;
if (PlatformMisc::InitializeSocketSupport(error))
{
ret = std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
if (!ret->Initialize(error))
ret.reset();
}
return ret;
}
bool SocketMultiplexer::Initialize(Error* error)
{
#ifdef __linux__
m_epoll_fd = epoll_create1(0);
if (m_epoll_fd < 0)
{
Error::SetErrno(error, "epoll_create1() failed: ", errno);
return false;
}
return std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
return true;
#else
return true;
#endif
}
std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
@ -325,8 +355,13 @@ std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(con
void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
{
std::unique_lock lock(m_open_sockets_lock);
#ifdef __linux__
struct epoll_event ev = {.events = 0u, .data = {.fd = socket->GetDescriptor()}};
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, socket->GetDescriptor(), &ev) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() to add socket failed: {}", Error::CreateErrno(errno).GetDescription());
#endif
std::unique_lock lock(m_open_sockets_lock);
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
m_open_sockets.emplace(socket->GetDescriptor(), std::move(socket));
}
@ -339,27 +374,29 @@ void SocketMultiplexer::AddClientSocket(std::shared_ptr<BaseSocket> socket)
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
{
std::unique_lock lock(m_open_sockets_lock);
const auto iter = m_open_sockets.find(socket->GetDescriptor());
Assert(iter != m_open_sockets.end());
m_open_sockets.erase(iter);
#ifdef __linux__
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, socket->GetDescriptor(), nullptr) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() to remove socket failed: {}", Error::CreateErrno(errno).GetDescription());
#else
#ifdef _DEBUG
{
std::unique_lock lock(m_poll_array_lock);
for (size_t i = 0; i < m_poll_array_active_size; i++)
{
pollfd& pfd = m_poll_array[i];
DebugAssert(pfd.fd != socket->GetDescriptor());
}
}
#endif
std::unique_lock lock(m_open_sockets_lock);
const auto iter = m_open_sockets.find(socket->GetDescriptor());
Assert(iter != m_open_sockets.end());
m_open_sockets.erase(iter);
// Update size.
size_t new_active_size = 0;
for (size_t i = 0; i < m_poll_array_active_size; i++)
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
m_poll_array_active_size = new_active_size;
#endif
}
void SocketMultiplexer::RemoveClientSocket(BaseSocket* socket)
@ -400,6 +437,11 @@ void SocketMultiplexer::CloseAll()
void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
{
#ifdef __linux__
struct epoll_event ev = {.events = events, .data = {.fd = descriptor}};
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD, descriptor, &ev) != 0) [[unlikely]]
ERROR_LOG("epoll_ctl() for events 0x{:x} failed: {}", events, Error::CreateErrno(errno).GetDescription());
#else
std::unique_lock lock(m_poll_array_lock);
size_t free_slot = m_poll_array_active_size;
for (size_t i = 0; i < m_poll_array_active_size; i++)
@ -440,10 +482,64 @@ void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor
m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
m_poll_array_active_size = free_slot + 1;
#endif
}
bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
{
#ifdef __linux__
constexpr int MAX_EVENTS = 128;
struct epoll_event events[MAX_EVENTS];
const int nevents = epoll_wait(m_epoll_fd, events, MAX_EVENTS, static_cast<int>(milliseconds));
if (nevents <= 0)
return false;
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
PendingSocketPair* triggered_sockets =
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(nevents)));
size_t num_triggered_sockets = 0;
{
std::unique_lock open_lock(m_open_sockets_lock);
for (int i = 0; i < nevents; i++)
{
const epoll_event& ev = events[i];
const auto iter = m_open_sockets.find(ev.data.fd);
if (iter == m_open_sockets.end()) [[unlikely]]
{
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", ev.data.fd);
continue;
}
// we add a reference here in case the read kills it with a write pending, or something like that
new (&triggered_sockets[num_triggered_sockets++]) PendingSocketPair(iter->second->shared_from_this(), ev.events);
}
}
// fire events
for (size_t i = 0; i < num_triggered_sockets; i++)
{
PendingSocketPair& psp = triggered_sockets[i];
// fire events
if (psp.second & (EPOLLRDHUP | EPOLLHUP | EPOLLERR))
{
psp.first->OnHangupEvent();
}
else
{
if (psp.second & EPOLLIN)
psp.first->OnReadEvent();
if (psp.second & EPOLLOUT)
psp.first->OnWriteEvent();
}
psp.first.~shared_ptr();
}
return true;
#else
std::unique_lock lock(m_poll_array_lock);
if (m_poll_array_active_size == 0)
return false;
@ -454,7 +550,8 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
PendingSocketPair* triggered_sockets = reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * res));
PendingSocketPair* triggered_sockets =
reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * static_cast<size_t>(res)));
size_t num_triggered_sockets = 0;
{
std::unique_lock open_lock(m_open_sockets_lock);
@ -467,7 +564,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
const auto iter = m_open_sockets.find(pfd.fd);
if (iter == m_open_sockets.end()) [[unlikely]]
{
ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd);
ERROR_LOG("Attempting to look up unknown socket {}, this should never happen.", pfd.fd);
continue;
}
@ -481,7 +578,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
lock.unlock();
// fire events
for (u32 i = 0; i < num_triggered_sockets; i++)
for (size_t i = 0; i < num_triggered_sockets; i++)
{
PendingSocketPair& psp = triggered_sockets[i];
@ -502,6 +599,7 @@ bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
}
return true;
#endif
}
ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,

View file

@ -135,6 +135,9 @@ private:
// Hide the constructor.
SocketMultiplexer();
// Initialization.
bool Initialize(Error* error);
// Tracking of open sockets.
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
void AddClientSocket(std::shared_ptr<BaseSocket> socket);
@ -148,10 +151,14 @@ private:
// We store the fd in the struct to avoid the cache miss reading the object.
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
#ifdef __linux__
int m_epoll_fd = -1;
#else
std::mutex m_poll_array_lock;
pollfd* m_poll_array = nullptr;
size_t m_poll_array_active_size = 0;
size_t m_poll_array_max_size = 0;
#endif
std::mutex m_open_sockets_lock;
SocketMap m_open_sockets;