diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index d2a1fa52f..9f1fda96d 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -2,6 +2,8 @@ add_library(common align.h assert.cpp assert.h + binary_span_reader_writer.cpp + binary_span_reader_writer.h bitfield.h bitutils.h build_timestamp.h diff --git a/src/common/binary_span_reader_writer.cpp b/src/common/binary_span_reader_writer.cpp new file mode 100644 index 000000000..e1011da00 --- /dev/null +++ b/src/common/binary_span_reader_writer.cpp @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2024 Connor McLaughlin +// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) + +#include "binary_span_reader_writer.h" +#include "small_string.h" + +BinarySpanReader::BinarySpanReader() = default; + +BinarySpanReader::BinarySpanReader(std::span buf) : m_buf(buf) +{ +} + +bool BinarySpanReader::PeekCString(std::string_view* dst) +{ + size_t pos = m_pos; + size_t size = 0; + while (pos < m_buf.size()) + { + if (m_buf[pos] == 0) + break; + + pos++; + size++; + } + + if (pos == m_buf.size()) + return false; + + *dst = std::string_view(reinterpret_cast(&m_buf[m_pos]), size); + return true; +} + +bool BinarySpanReader::ReadCString(std::string* dst) +{ + std::string_view sv; + if (!PeekCString(&sv)) + return false; + + dst->assign(sv); + m_pos += sv.size() + 1; + return true; +} + +bool BinarySpanReader::ReadCString(std::string_view* dst) +{ + if (!PeekCString(dst)) + return false; + + m_pos += dst->size() + 1; + return true; +} + +bool BinarySpanReader::ReadCString(SmallStringBase* dst) +{ + std::string_view sv; + if (!PeekCString(&sv)) + return false; + + dst->assign(sv); + m_pos += sv.size() + 1; + return true; +} + +std::string_view BinarySpanReader::ReadCString() +{ + std::string_view ret; + if (PeekCString(&ret)) + m_pos += ret.size() + 1; + return ret; +} + +bool BinarySpanReader::PeekCString(std::string* dst) +{ + std::string_view sv; + if (!PeekCString(&sv)) + return false; + + dst->assign(sv); + return true; +} + +bool BinarySpanReader::PeekCString(SmallStringBase* dst) +{ + std::string_view sv; + if (!PeekCString(&sv)) + return false; + + dst->assign(sv); + m_pos += sv.size() + 1; + return true; +} + +BinarySpanWriter::BinarySpanWriter() = default; + +BinarySpanWriter::BinarySpanWriter(std::span buf) : m_buf(buf) +{ +} + +bool BinarySpanWriter::WriteCString(std::string_view val) +{ + if ((m_pos + val.size() + 1) > m_buf.size()) [[unlikely]] + return false; + + if (!val.empty()) + std::memcpy(&m_buf[m_pos], val.data(), val.size()); + + m_buf[m_pos + val.size()] = 0; + m_pos += val.size() + 1; + return true; +} diff --git a/src/common/binary_span_reader_writer.h b/src/common/binary_span_reader_writer.h new file mode 100644 index 000000000..6a21c5aed --- /dev/null +++ b/src/common/binary_span_reader_writer.h @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: 2024 Connor McLaughlin +// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) + +#include "types.h" + +#include +#include +#include +#include + +class SmallStringBase; + +class BinarySpanReader +{ +public: + BinarySpanReader(); + BinarySpanReader(std::span buf); + + ALWAYS_INLINE const std::span& GetSpan() const { return m_buf; } + ALWAYS_INLINE bool IsValid() const { return !m_buf.empty(); } + ALWAYS_INLINE bool CheckRemaining(size_t size) { return ((m_pos + size) <= m_buf.size()); } + ALWAYS_INLINE size_t GetBufferRemaining() const { return (m_buf.size() - m_pos); } + ALWAYS_INLINE size_t GetBufferConsumed() const { return m_pos; } + + // clang-format off + template ALWAYS_INLINE bool ReadT(T* dst) { return Read(dst, sizeof(T)); } + ALWAYS_INLINE bool ReadU8(u8* dst) { return ReadT(dst); } + ALWAYS_INLINE bool ReadU16(u16* dst) { return ReadT(dst); } + ALWAYS_INLINE bool ReadU32(u32* dst) { return ReadT(dst); } + ALWAYS_INLINE bool ReadU64(u64* dst) { return ReadT(dst); } + ALWAYS_INLINE bool ReadFloat(float* dst) { return ReadT(dst); } + bool ReadCString(std::string* dst); + bool ReadCString(std::string_view* dst); + bool ReadCString(SmallStringBase* dst); + + template ALWAYS_INLINE T ReadT() { T ret; if (!Read(&ret, sizeof(ret))) [[unlikely]] { ret = {}; } return ret; } + ALWAYS_INLINE u8 ReadU8() { return ReadT(); } + ALWAYS_INLINE u16 ReadU16() { return ReadT(); } + ALWAYS_INLINE u32 ReadU32() { return ReadT(); } + ALWAYS_INLINE u64 ReadU64() { return ReadT(); } + ALWAYS_INLINE float ReadFloat() { return ReadT(); } + std::string_view ReadCString(); + + template ALWAYS_INLINE bool PeekT(T* dst) { return Peek(dst, sizeof(T)); } + ALWAYS_INLINE bool PeekU8(u8* dst) { return PeekT(dst); } + ALWAYS_INLINE bool PeekU16(u16* dst) { return PeekT(dst); } + ALWAYS_INLINE bool PeekU32(u32* dst) { return PeekT(dst); } + ALWAYS_INLINE bool PeekU64(u64* dst) { return PeekT(dst); } + ALWAYS_INLINE bool PeekFloat(float* dst) { return PeekT(dst); } + bool PeekCString(std::string* dst); + bool PeekCString(std::string_view* dst); + bool PeekCString(SmallStringBase* dst); + + ALWAYS_INLINE BinarySpanReader& operator>>(u8& val) { val = ReadT(); return *this; } + ALWAYS_INLINE BinarySpanReader& operator>>(u16& val) { val = ReadT(); return *this; } + ALWAYS_INLINE BinarySpanReader& operator>>(u32& val) { val = ReadT(); return *this; } + ALWAYS_INLINE BinarySpanReader& operator>>(u64& val) { val = ReadT(); return *this; } + ALWAYS_INLINE BinarySpanReader& operator>>(float& val) { val = ReadT(); return *this; } + ALWAYS_INLINE BinarySpanReader& operator>>(std::string_view val) { val = ReadCString(); return *this; } + // clang-format on + +private: + ALWAYS_INLINE bool Read(void* buf, size_t size) + { + if ((m_pos + size) < m_buf.size()) [[likely]] + { + std::memcpy(buf, &m_buf[m_pos], size); + m_pos += size; + return true; + } + + return false; + } + + ALWAYS_INLINE bool Peek(void* buf, size_t size) + { + if ((m_pos + size) < m_buf.size()) [[likely]] + { + std::memcpy(buf, &m_buf[m_pos], size); + return true; + } + + return false; + } + +private: + std::span m_buf; + size_t m_pos = 0; +}; + +class BinarySpanWriter +{ +public: + BinarySpanWriter(); + BinarySpanWriter(std::span buf); + + ALWAYS_INLINE const std::span& GetSpan() const { return m_buf; } + ALWAYS_INLINE bool IsValid() const { return !m_buf.empty(); } + ALWAYS_INLINE size_t GetBufferRemaining() const { return (m_buf.size() - m_pos); } + ALWAYS_INLINE size_t GetBufferWritten() const { return m_pos; } + + // clang-format off + template ALWAYS_INLINE bool WriteT(T dst) { return Write(&dst, sizeof(T)); } + ALWAYS_INLINE bool WriteU8(u8 val) { return WriteT(val); } + ALWAYS_INLINE bool WriteU16(u16 val) { return WriteT(val); } + ALWAYS_INLINE bool WriteU32(u32 val) { return WriteT(val); } + ALWAYS_INLINE bool WriteU64(u64 val) { return WriteT(val); } + ALWAYS_INLINE bool WriteFloat(float val) { return WriteT(val); } + bool WriteCString(std::string_view val); + + ALWAYS_INLINE BinarySpanWriter& operator<<(u8 val) { WriteU8(val); return *this; } + ALWAYS_INLINE BinarySpanWriter& operator<<(u16 val) { WriteU16(val); return *this; } + ALWAYS_INLINE BinarySpanWriter& operator<<(u32 val) { WriteU32(val); return *this; } + ALWAYS_INLINE BinarySpanWriter& operator<<(u64 val) { WriteU64(val); return *this; } + ALWAYS_INLINE BinarySpanWriter& operator<<(float val) { WriteFloat(val); return *this; } + ALWAYS_INLINE BinarySpanWriter& operator<<(std::string_view val) { WriteCString(val); return *this; } + // clang-format on + +private: + ALWAYS_INLINE bool Write(void* buf, size_t size) + { + if ((m_pos + size) < m_buf.size()) [[likely]] + { + std::memcpy(&m_buf[m_pos], buf, size); + m_pos += size; + return true; + } + + return false; + } + +private: + std::span m_buf; + size_t m_pos = 0; +}; diff --git a/src/common/common.vcxproj b/src/common/common.vcxproj index 67f9138ee..9d0feb068 100644 --- a/src/common/common.vcxproj +++ b/src/common/common.vcxproj @@ -34,6 +34,7 @@ + @@ -60,6 +61,7 @@ + diff --git a/src/common/common.vcxproj.filters b/src/common/common.vcxproj.filters index bfb354dc2..a36015e9f 100644 --- a/src/common/common.vcxproj.filters +++ b/src/common/common.vcxproj.filters @@ -45,6 +45,7 @@ thirdparty + @@ -72,6 +73,7 @@ thirdparty + diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 275c36e62..720002de3 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -63,6 +63,8 @@ add_library(util shadergen.h shiftjis.cpp shiftjis.h + sockets.cpp + sockets.h state_wrapper.cpp state_wrapper.h wav_writer.cpp diff --git a/src/util/platform_misc.h b/src/util/platform_misc.h index 176e20035..0366df5f5 100644 --- a/src/util/platform_misc.h +++ b/src/util/platform_misc.h @@ -5,7 +5,10 @@ #include +class Error; + namespace PlatformMisc { +bool InitializeSocketSupport(Error* error); void SuspendScreensaver(); void ResumeScreensaver(); diff --git a/src/util/platform_misc_mac.mm b/src/util/platform_misc_mac.mm index 855cf029a..ee4c91dc7 100644 --- a/src/util/platform_misc_mac.mm +++ b/src/util/platform_misc_mac.mm @@ -24,6 +24,11 @@ Log_SetChannel(PlatformMisc); static IOPMAssertionID s_prevent_idle_assertion = kIOPMNullAssertionID; +bool PlatformMisc::InitializeSocketSupport(Error* error) +{ + return true; +} + static bool SetScreensaverInhibitMacOS(bool inhibit) { if (inhibit) diff --git a/src/util/platform_misc_unix.cpp b/src/util/platform_misc_unix.cpp index 77706d114..26be5afc2 100644 --- a/src/util/platform_misc_unix.cpp +++ b/src/util/platform_misc_unix.cpp @@ -16,6 +16,11 @@ Log_SetChannel(PlatformMisc); +bool PlatformMisc::InitializeSocketSupport(Error* error) +{ + return true; +} + static bool SetScreensaverInhibitDBus(const bool inhibit_requested, const char* program_name, const char* reason) { static dbus_uint32_t s_cookie; diff --git a/src/util/platform_misc_win32.cpp b/src/util/platform_misc_win32.cpp index 6990d11f2..2eb8ae4a1 100644 --- a/src/util/platform_misc_win32.cpp +++ b/src/util/platform_misc_win32.cpp @@ -3,6 +3,7 @@ #include "platform_misc.h" +#include "common/error.h" #include "common/file_system.h" #include "common/log.h" #include "common/small_string.h" @@ -13,10 +14,33 @@ #include #include "common/windows_headers.h" +#include #include Log_SetChannel(PlatformMisc); +static bool s_screensaver_suspended = false; +static bool s_winsock_initialized = false; +static std::once_flag s_winsock_initializer; + +bool PlatformMisc::InitializeSocketSupport(Error* error) +{ + std::call_once(s_winsock_initializer, [](Error* error) { + WSADATA wsa = {}; + if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) + { + Error::SetSocket(error, "WSAStartup() failed: ", WSAGetLastError()); + return false; + } + + s_winsock_initialized = true; + std::atexit([]() { WSACleanup(); }); + return true; + }, error); + + return s_winsock_initialized; +} + static bool SetScreensaverInhibitWin32(bool inhibit) { if (SetThreadExecutionState(ES_CONTINUOUS | (inhibit ? (ES_DISPLAY_REQUIRED | ES_SYSTEM_REQUIRED) : 0)) == NULL) @@ -28,8 +52,6 @@ static bool SetScreensaverInhibitWin32(bool inhibit) return true; } -static bool s_screensaver_suspended; - void PlatformMisc::SuspendScreensaver() { if (s_screensaver_suspended) diff --git a/src/util/sockets.cpp b/src/util/sockets.cpp new file mode 100644 index 000000000..96546f08d --- /dev/null +++ b/src/util/sockets.cpp @@ -0,0 +1,940 @@ +// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin +// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) + +#include "sockets.h" +#include "platform_misc.h" + +#include "common/assert.h" +#include "common/log.h" + +#include +#include +#include + +#ifndef __APPLE__ +#include // alloca +#else +#include +#endif + +#ifdef _WIN32 + +#include "common/windows_headers.h" + +#include +#include + +#define SIZE_CAST(x) static_cast(x) +using ssize_t = int; +using nfds_t = ULONG; + +#else + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define ioctlsocket ioctl +#define closesocket close +#define WSAEWOULDBLOCK EAGAIN +#define WSAGetLastError() errno +#define WSAPoll poll +#define SIZE_CAST(x) x + +#define SOCKET_ERROR -1 +#define INVALID_SOCKET -1 +#define SD_BOTH SHUT_RDWR +#endif + +Log_SetChannel(Sockets); + +static bool SetNonBlocking(SocketDescriptor sd, Error* error) +{ + // switch to nonblocking mode + unsigned long value = 1; + if (ioctlsocket(sd, FIONBIO, &value) < 0) + { + Error::SetSocket(error, "ioctlsocket() failed: ", WSAGetLastError()); + return false; + } + + return true; +} + +void SocketAddress::SetFromSockaddr(const void* sa, size_t length) +{ + m_length = std::min(static_cast(length), static_cast(sizeof(m_data))); + std::memcpy(m_data, sa, m_length); + if (m_length < sizeof(m_data)) + std::memset(m_data + m_length, 0, sizeof(m_data) - m_length); +} + +std::optional SocketAddress::Parse(Type type, const char* address, u32 port, Error* error) +{ + std::optional ret = SocketAddress(); + + switch (type) + { + case Type::IPv4: + { + sockaddr_in* sain = reinterpret_cast(ret->m_data); + std::memset(sain, 0, sizeof(sockaddr_in)); + sain->sin_family = AF_INET; + sain->sin_port = htons(static_cast(port)); + int res = inet_pton(AF_INET, address, &sain->sin_addr); + if (res == 1) + { + ret->m_length = sizeof(sockaddr_in); + } + else + { + Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError()); + ret.reset(); + } + } + break; + + case Type::IPv6: + { + sockaddr_in6* sain6 = reinterpret_cast(ret->m_data); + std::memset(sain6, 0, sizeof(sockaddr_in6)); + sain6->sin6_family = AF_INET; + sain6->sin6_port = htons(static_cast(port)); + int res = inet_pton(AF_INET6, address, &sain6->sin6_addr); + if (res == 1) + { + ret->m_length = sizeof(sockaddr_in6); + } + else + { + Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError()); + ret.reset(); + } + } + break; + +#ifndef _WIN32 + case Type::Unix: + { + sockaddr_un* sun = reinterpret_cast(ret->m_data); + std::memset(sun, 0, sizeof(sockaddr_un)); + sun->sun_family = AF_UNIX; + + const size_t len = std::strlen(address); + if ((len + 1) <= std::size(sun->sun_path)) + { + std::memcpy(sun->sun_path, address, len); + ret->m_length = sizeof(sockaddr_un); + } + else + { + Error::SetStringFmt(error, "Path length {} exceeds {} bytes.", len, std::size(sun->sun_path)); + ret.reset(); + } + } + break; +#endif + + default: + Error::SetStringView(error, "Unknown address type."); + ret.reset(); + break; + } + + return ret; +} + +SmallString SocketAddress::ToString() const +{ + SmallString ret; + + const sockaddr* sa = reinterpret_cast(m_data); + switch (sa->sa_family) + { + case AF_INET: + { + ret.clear(); + ret.reserve(128); + const char* res = + inet_ntop(AF_INET, &reinterpret_cast(m_data)->sin_addr, ret.data(), ret.buffer_size()); + if (res == nullptr) + ret.assign(""); + else + ret.update_size(); + + ret.append_format(":{}", static_cast(ntohs(reinterpret_cast(m_data)->sin_port))); + } + break; + + case AF_INET6: + { + ret.clear(); + ret.reserve(128); + ret.append('['); + const char* res = inet_ntop(AF_INET6, &reinterpret_cast(m_data)->sin6_addr, ret.data() + 1, + ret.buffer_size() - 1); + if (res == nullptr) + ret.assign(""); + else + ret.update_size(); + + ret.append_format("]:{}", static_cast(ntohs(reinterpret_cast(m_data)->sin6_port))); + } + break; + +#ifndef _WIN32 + case AF_UNIX: + { + ret.assign(reinterpret_cast(m_data)->sun_path); + } + break; +#endif + + default: + { + ret.assign(""); + break; + } + } + + return ret; +} + +BaseSocket::BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor) + : m_multiplexer(multiplexer), m_descriptor(descriptor) +{ +} + +BaseSocket::~BaseSocket() = default; + +SocketMultiplexer::SocketMultiplexer() = default; + +SocketMultiplexer::~SocketMultiplexer() +{ + CloseAll(); + + if (m_poll_array) + std::free(m_poll_array); +} + +std::unique_ptr SocketMultiplexer::Create(Error* error) +{ + if (!PlatformMisc::InitializeSocketSupport(error)) + return {}; + + return std::unique_ptr(new SocketMultiplexer()); +} + +std::shared_ptr SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address, + CreateStreamSocketCallback callback, + Error* error) +{ + // create and bind socket + const sockaddr* sa = reinterpret_cast(address.GetData()); + SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address)); + if (descriptor == INVALID_SOCKET) + { + Error::SetSocket(error, "socket() failed: ", WSAGetLastError()); + return {}; + } + + const int reuseaddr_enable = 1; + if (setsockopt(descriptor, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&reuseaddr_enable), + sizeof(reuseaddr_enable)) < 0) + { + WARNING_LOG("Failed to set SO_REUSEADDR: {}", Error::CreateSocket(WSAGetLastError()).GetDescription()); + } + + if (bind(descriptor, sa, address.GetLength()) < 0) + { + Error::SetSocket(error, "bind() failed: ", WSAGetLastError()); + closesocket(descriptor); + return {}; + } + + if (listen(descriptor, 5) < 0) + { + Error::SetSocket(error, "listen() failed: ", WSAGetLastError()); + closesocket(descriptor); + return {}; + } + + if (!SetNonBlocking(descriptor, error)) + { + closesocket(descriptor); + return {}; + } + + // create listensocket + std::shared_ptr ret = std::make_shared(*this, descriptor, callback); + + // add to list, register for reads + AddOpenSocket(std::static_pointer_cast(ret)); + SetNotificationMask(ret.get(), descriptor, POLLIN); + return ret; +} + +std::shared_ptr SocketMultiplexer::InternalConnectStreamSocket(const SocketAddress& address, + CreateStreamSocketCallback callback, + Error* error) +{ + // create and bind socket + const sockaddr* sa = reinterpret_cast(address.GetData()); + SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address)); + if (descriptor == INVALID_SOCKET) + { + Error::SetSocket(error, "socket() failed: ", WSAGetLastError()); + return {}; + } + + if (connect(descriptor, sa, address.GetLength()) < 0) + { + Error::SetSocket(error, "connect() failed: ", WSAGetLastError()); + closesocket(descriptor); + return {}; + } + + if (!SetNonBlocking(descriptor, error)) + { + closesocket(descriptor); + return {}; + } + + // create stream socket + std::shared_ptr csocket = callback(*this, descriptor); + csocket->InitialSetup(); + if (!csocket->IsConnected()) + csocket.reset(); + + return csocket; +} + +void SocketMultiplexer::AddOpenSocket(std::shared_ptr socket) +{ + std::unique_lock lock(m_open_sockets_lock); + + DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end()); + m_open_sockets.emplace(socket->GetDescriptor(), socket); +} + +void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket) +{ +#ifdef _DEBUG + { + std::unique_lock lock(m_poll_array_lock); + for (size_t i = 0; i < m_poll_array_active_size; i++) + { + pollfd& pfd = m_poll_array[i]; + DebugAssert(pfd.fd != socket->GetDescriptor()); + } + } +#endif + + std::unique_lock lock(m_open_sockets_lock); + const auto iter = m_open_sockets.find(socket->GetDescriptor()); + Assert(iter != m_open_sockets.end()); + m_open_sockets.erase(iter); + + // Update size. + size_t new_active_size = 0; + for (size_t i = 0; i < m_poll_array_active_size; i++) + new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size; + m_poll_array_active_size = new_active_size; +} + +bool SocketMultiplexer::HasAnyOpenSockets() +{ + std::unique_lock lock(m_open_sockets_lock); + return !m_open_sockets.empty(); +} + +void SocketMultiplexer::CloseAll() +{ + std::unique_lock lock(m_open_sockets_lock); + + while (!m_open_sockets.empty()) + { + std::shared_ptr socket = m_open_sockets.begin()->second; + lock.unlock(); + socket->Close(); + lock.lock(); + } +} + +void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events) +{ + std::unique_lock lock(m_poll_array_lock); + size_t free_slot = m_poll_array_active_size; + for (size_t i = 0; i < m_poll_array_active_size; i++) + { + pollfd& pfd = m_poll_array[i]; + if (pfd.fd != descriptor) + { + free_slot = (pfd.fd < 0 && free_slot != m_poll_array_active_size) ? i : free_slot; + continue; + } + + // unbinding? + if (events != 0) + pfd.events = static_cast(events); + else + pfd.fd = INVALID_SOCKET; + + return; + } + + // don't create entries for null masks + if (events == 0) + return; + + // need to grow the array? + if (free_slot == m_poll_array_max_size) + { + const size_t new_size = std::max(free_slot + 1, free_slot * 2); + pollfd* new_array = static_cast(std::realloc(m_poll_array, sizeof(pollfd) * new_size)); + if (!new_array) + Panic("Memory allocation failed."); + + for (size_t i = m_poll_array_max_size; i < new_size; i++) + new_array[i] = {.fd = INVALID_SOCKET, .events = 0, .revents = 0}; + m_poll_array = new_array; + m_poll_array_max_size = new_size; + } + + m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast(events), .revents = 0}; + m_poll_array_active_size = free_slot + 1; +} + +bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds) +{ + std::unique_lock lock(m_poll_array_lock); + if (m_poll_array_active_size == 0) + return false; + + const int res = WSAPoll(m_poll_array, static_cast(m_poll_array_active_size), milliseconds); + if (res <= 0) + return false; + + // find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects + using PendingSocketPair = std::pair, u32>; + PendingSocketPair* triggered_sockets = reinterpret_cast(alloca(sizeof(PendingSocketPair) * res)); + size_t num_triggered_sockets = 0; + { + std::unique_lock open_lock(m_open_sockets_lock); + for (size_t i = 0; i < m_poll_array_active_size; i++) + { + const pollfd& pfd = m_poll_array[i]; + if (pfd.revents == 0) + continue; + + const auto iter = m_open_sockets.find(pfd.fd); + if (iter == m_open_sockets.end()) [[unlikely]] + { + ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd); + continue; + } + + // we add a reference here in case the read kills it with a write pending, or something like that + new (&triggered_sockets[num_triggered_sockets++]) + PendingSocketPair(iter->second->shared_from_this(), pfd.revents); + } + } + + // release lock so connections etc can acquire it + lock.unlock(); + + // fire events + for (u32 i = 0; i < num_triggered_sockets; i++) + { + PendingSocketPair& psp = triggered_sockets[i]; + + // fire events + if (psp.second & (POLLIN | POLLHUP | POLLERR)) + psp.first->OnReadEvent(); + if (psp.second & POLLOUT) + psp.first->OnWriteEvent(); + + psp.first.~shared_ptr(); + } + + return true; +} + +ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, + SocketMultiplexer::CreateStreamSocketCallback accept_callback) + : BaseSocket(multiplexer, descriptor), m_accept_callback(accept_callback) +{ + // get local address + sockaddr_storage sa; + socklen_t salen = sizeof(sa); + if (getsockname(m_descriptor, reinterpret_cast(&sa), &salen) == 0) + m_local_address.SetFromSockaddr(&sa, salen); +} + +ListenSocket::~ListenSocket() +{ + DebugAssert(m_descriptor == INVALID_SOCKET); +} + +void ListenSocket::Close() +{ + if (m_descriptor < 0) + return; + + m_multiplexer.SetNotificationMask(this, m_descriptor, 0); + m_multiplexer.RemoveOpenSocket(this); + closesocket(m_descriptor); + m_descriptor = INVALID_SOCKET; +} + +void ListenSocket::OnReadEvent() +{ + // connection incoming + sockaddr_storage sa; + socklen_t salen = sizeof(sa); + SocketDescriptor new_descriptor = accept(m_descriptor, reinterpret_cast(&sa), &salen); + if (new_descriptor == INVALID_SOCKET) + { + ERROR_LOG("accept() returned {}", Error::CreateSocket(WSAGetLastError()).GetDescription()); + return; + } + + Error error; + if (!SetNonBlocking(new_descriptor, &error)) + { + ERROR_LOG("Failed to set just-connected socket to nonblocking: {}", error.GetDescription()); + closesocket(new_descriptor); + return; + } + + // create socket, we release our own reference. + std::shared_ptr client = m_accept_callback(m_multiplexer, new_descriptor); + if (!client) + { + closesocket(new_descriptor); + return; + } + + m_num_connections_accepted++; + client->InitialSetup(); +} + +void ListenSocket::OnWriteEvent() +{ +} + +StreamSocket::StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor) + : BaseSocket(multiplexer, descriptor) +{ + // get local address + sockaddr_storage sa; + socklen_t salen = sizeof(sa); + if (getsockname(m_descriptor, reinterpret_cast(&sa), &salen) == 0) + m_local_address.SetFromSockaddr(&sa, salen); + + // get remote address + salen = sizeof(sockaddr_storage); + if (getpeername(m_descriptor, reinterpret_cast(&sa), &salen) == 0) + m_remote_address.SetFromSockaddr(&sa, salen); +} + +StreamSocket::~StreamSocket() +{ + DebugAssert(m_descriptor == INVALID_SOCKET); +} + +u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa) +{ + const sockaddr* ssa = reinterpret_cast(sa.GetData()); + return (ssa->sa_family == AF_INET || ssa->sa_family == AF_INET6) ? IPPROTO_TCP : 0; +} + +void StreamSocket::InitialSetup() +{ + // register for notifications + m_multiplexer.AddOpenSocket(shared_from_this()); + m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); + + // trigger connected notification + std::unique_lock lock(m_lock); + OnConnected(); +} + +size_t StreamSocket::Read(void* buffer, size_t buffer_size) +{ + std::unique_lock lock(m_lock); + if (!m_connected) + return 0; + + // try a read + const ssize_t len = recv(m_descriptor, static_cast(buffer), SIZE_CAST(buffer_size), 0); + if (len <= 0) + { + // Check for EAGAIN + if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK) + { + // Not an error. Just means no data is available. + return 0; + } + + // error + CloseWithError(); + return 0; + } + + return len; +} + +size_t StreamSocket::Write(const void* buffer, size_t buffer_size) +{ + std::unique_lock lock(m_lock); + if (!m_connected) + return 0; + + // try a write + const ssize_t len = send(m_descriptor, static_cast(buffer), SIZE_CAST(buffer_size), 0); + if (len <= 0) + { + // Check for EAGAIN + if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK) + { + // Not an error. Just means no data is available. + return 0; + } + + // error + CloseWithError(); + return 0; + } + + return len; +} + +size_t StreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers) +{ + std::unique_lock lock(m_lock); + if (!m_connected || num_buffers == 0) + return 0; + +#ifdef _WIN32 + + WSABUF* bufs = static_cast(alloca(sizeof(WSABUF) * num_buffers)); + for (size_t i = 0; i < num_buffers; i++) + { + bufs[i].buf = (CHAR*)buffers[i]; + bufs[i].len = (ULONG)buffer_lengths[i]; + } + + DWORD bytesSent = 0; + if (WSASend(m_descriptor, bufs, (DWORD)num_buffers, &bytesSent, 0, nullptr, nullptr) == SOCKET_ERROR) + { + if (WSAGetLastError() != WSAEWOULDBLOCK) + { + // Socket error. + CloseWithError(); + return 0; + } + } + + return static_cast(bytesSent); + +#else // _WIN32 + + iovec* bufs = static_cast(alloca(sizeof(iovec) * num_buffers)); + for (size_t i = 0; i < num_buffers; i++) + { + bufs[i].iov_base = (void*)buffers[i]; + bufs[i].iov_len = buffer_lengths[i]; + } + + ssize_t res = writev(m_descriptor, bufs, num_buffers); + if (res < 0) + { + if (errno != EAGAIN) + { + // Socket error. + CloseWithError(); + return 0; + } + + res = 0; + } + + return static_cast(res); + +#endif +} + +void StreamSocket::Close() +{ + std::unique_lock lock(m_lock); + if (!m_connected) + return; + + m_multiplexer.SetNotificationMask(this, m_descriptor, 0); + m_multiplexer.RemoveOpenSocket(this); + shutdown(m_descriptor, SD_BOTH); + closesocket(m_descriptor); + m_descriptor = INVALID_SOCKET; + m_connected = false; + + OnDisconnected(Error::CreateString("Connection explicitly closed.")); +} + +void StreamSocket::CloseWithError() +{ + std::unique_lock lock(m_lock); + DebugAssert(m_connected); + + Error error; + const int error_code = WSAGetLastError(); + if (error_code == 0) + error.SetStringView("Connection closed by peer."); + else + error.SetSocket(error_code); + + m_multiplexer.SetNotificationMask(this, m_descriptor, 0); + m_multiplexer.RemoveOpenSocket(this); + closesocket(m_descriptor); + m_descriptor = INVALID_SOCKET; + m_connected = false; + + OnDisconnected(error); +} + +void StreamSocket::OnReadEvent() +{ + // forward through + std::unique_lock lock(m_lock); + if (m_connected) + OnRead(); +} + +void StreamSocket::OnWriteEvent() +{ + // shouldn't be called +} + +BufferedStreamSocket::BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, + size_t receive_buffer_size /* = 16384 */, + size_t send_buffer_size /* = 16384 */) + : StreamSocket(multiplexer, descriptor), m_receive_buffer(receive_buffer_size), m_send_buffer(send_buffer_size) +{ +} + +BufferedStreamSocket::~BufferedStreamSocket() +{ +} + +std::unique_lock BufferedStreamSocket::GetLock() +{ + return std::unique_lock(m_lock); +} + +std::span BufferedStreamSocket::AcquireReadBuffer() const +{ + return std::span(m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size); +} + +void BufferedStreamSocket::ReleaseReadBuffer(size_t bytes_consumed) +{ + DebugAssert(bytes_consumed <= m_receive_buffer_size); + m_receive_buffer_offset += static_cast(bytes_consumed); + m_receive_buffer_size -= static_cast(bytes_consumed); + + // Anything left? If not, reset offset. + m_receive_buffer_offset = (m_receive_buffer_size == 0) ? 0 : m_receive_buffer_offset; +} + +std::span BufferedStreamSocket::AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller /* = false */) +{ + // If to get the desired space, we need to move backwards, do so. + if ((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) > m_send_buffer.size()) + { + if ((m_send_buffer_size + wanted_bytes) > m_send_buffer.size() && !allow_smaller) + { + // Not enough space. + return {}; + } + + // Shuffle buffer backwards. + std::memmove(m_send_buffer.data(), m_send_buffer.data() + m_send_buffer_offset, m_send_buffer_size); + m_send_buffer_offset = 0; + } + + DebugAssert((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) <= m_send_buffer.size()); + return std::span(m_send_buffer.data() + m_send_buffer_offset + m_send_buffer_size, + m_send_buffer.size() - m_send_buffer_offset - m_send_buffer_size); +} + +void BufferedStreamSocket::ReleaseWriteBuffer(size_t bytes_written, bool commit /* = true */) +{ + DebugAssert((m_send_buffer_offset + m_send_buffer_size + bytes_written) <= m_send_buffer.size()); + m_send_buffer_size += static_cast(bytes_written); + + // Send as much as we can. + if (commit && m_send_buffer_size > 0) + { + const ssize_t res = send(m_descriptor, reinterpret_cast(m_send_buffer.data() + m_send_buffer_offset), + SIZE_CAST(m_send_buffer_size), 0); + if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK) + { + CloseWithError(); + return; + } + + m_send_buffer_offset += static_cast(res); + m_send_buffer_size -= static_cast(res); + if (m_send_buffer_size == 0) + { + m_send_buffer_offset = 0; + } + else + { + // Register for writes to finish it off. + m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN | POLLOUT); + } + } +} + +size_t BufferedStreamSocket::Read(void* buffer, size_t buffer_size) +{ + // Read from receive buffer. + const std::span rdbuf = AcquireReadBuffer(); + if (rdbuf.empty()) + return 0; + + const size_t bytes_to_read = std::min(rdbuf.size(), buffer_size); + std::memcpy(buffer, rdbuf.data(), bytes_to_read); + ReleaseReadBuffer(bytes_to_read); + return bytes_to_read; +} + +size_t BufferedStreamSocket::Write(const void* buffer, size_t buffer_size) +{ + // Read from receive buffer. + const std::span wrbuf = AcquireWriteBuffer(buffer_size, true); + if (wrbuf.empty()) + return 0; + + const size_t bytes_to_write = std::min(wrbuf.size(), buffer_size); + std::memcpy(wrbuf.data(), buffer, bytes_to_write); + ReleaseWriteBuffer(bytes_to_write); + return bytes_to_write; +} + +size_t BufferedStreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers) +{ + if (!m_connected || num_buffers == 0) + return 0; + + size_t total_size = 0; + for (size_t i = 0; i < num_buffers; i++) + total_size += buffer_lengths[i]; + + const std::span wrbuf = AcquireWriteBuffer(total_size, true); + if (wrbuf.empty()) + return 0; + + size_t written_bytes = 0; + for (size_t i = 0; i < num_buffers; i++) + { + const size_t bytes_to_write = std::min(wrbuf.size() - written_bytes, buffer_lengths[i]); + if (bytes_to_write == 0) + break; + + std::memcpy(&wrbuf[written_bytes], buffers[i], bytes_to_write); + written_bytes += buffer_lengths[i]; + } + + return written_bytes; +} + +void BufferedStreamSocket::OnReadEvent() +{ + std::unique_lock lock(m_lock); + if (!m_connected) + return; + + // Pull as many bytes as possible into the read buffer. + for (;;) + { + const size_t buffer_space = m_receive_buffer.size() - m_receive_buffer_offset - m_receive_buffer_size; + if (buffer_space == 0) [[unlikely]] + { + // If we're here again, it means OnRead() didn't consume the data, and we overflowed. + ERROR_LOG("Receive buffer overflow, dropping client {}.", GetRemoteAddress().ToString()); + CloseWithError(); + return; + } + + const ssize_t res = recv( + m_descriptor, reinterpret_cast(m_receive_buffer.data() + m_receive_buffer_offset + m_receive_buffer_size), + SIZE_CAST(buffer_space), 0); + if (res <= 0 && WSAGetLastError() != WSAEWOULDBLOCK) + { + CloseWithError(); + return; + } + + m_receive_buffer_size += static_cast(res); + OnRead(); + + // Are we at the end? + if ((m_receive_buffer_offset + m_receive_buffer_size) == m_receive_buffer.size()) + { + // Try to claw back some of the buffer, and try reading again. + if (m_receive_buffer_offset > 0) + { + std::memmove(m_receive_buffer.data(), m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size); + m_receive_buffer_offset = 0; + continue; + } + } + + break; + } +} + +void BufferedStreamSocket::OnWriteEvent() +{ + std::unique_lock lock(m_lock); + if (!m_connected) + return; + + // Send as much as we can. + if (m_send_buffer_size > 0) + { + const ssize_t res = send(m_descriptor, reinterpret_cast(m_send_buffer.data() + m_send_buffer_offset), + SIZE_CAST(m_send_buffer_size), 0); + if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK) + { + CloseWithError(); + return; + } + + m_send_buffer_offset += static_cast(res); + m_send_buffer_size -= static_cast(res); + if (m_send_buffer_size == 0) + m_send_buffer_offset = 0; + } + + OnWrite(); + + if (m_send_buffer_size == 0) + { + // Are we done? Switch back to reads only. + m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN); + } +} + +void BufferedStreamSocket::OnWrite() +{ +} diff --git a/src/util/sockets.h b/src/util/sockets.h new file mode 100644 index 000000000..f9b4db71a --- /dev/null +++ b/src/util/sockets.h @@ -0,0 +1,268 @@ +// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin +// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) + +#pragma once + +#include "common/error.h" +#include "common/heap_array.h" +#include "common/small_string.h" +#include "common/threading.h" +#include "common/types.h" + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +using SocketDescriptor = uintptr_t; +#else +using SocketDescriptor = int; +#endif + +struct pollfd; + +class BaseSocket; +class ListenSocket; +class StreamSocket; +class BufferedStreamSocket; +class SocketMultiplexer; + +struct SocketAddress final +{ + enum class Type + { + Unknown, + IPv4, + IPv6, + Unix, + }; + + // accessors + const void* GetData() const { return m_data; } + u32 GetLength() const { return m_length; } + + // parse interface + static std::optional Parse(Type type, const char* address, u32 port, Error* error); + + // resolve interface + static std::optional Resolve(const char* address, u32 port, Error* error); + + // to string interface + SmallString ToString() const; + + // initializers + void SetFromSockaddr(const void* sa, size_t length); + +private: + u8 m_data[128] = {}; + u32 m_length = 0; +}; + +class BaseSocket : public std::enable_shared_from_this +{ + friend SocketMultiplexer; + +public: + BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); + virtual ~BaseSocket(); + + ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; } + + virtual void Close() = 0; + +protected: + virtual void OnReadEvent() = 0; + virtual void OnWriteEvent() = 0; + + SocketMultiplexer& m_multiplexer; + SocketDescriptor m_descriptor; +}; + +class SocketMultiplexer final +{ + // TODO: Re-introduce worker threads. + +public: + typedef std::shared_ptr (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer, + SocketDescriptor descriptor); + friend BaseSocket; + friend ListenSocket; + friend StreamSocket; + friend BufferedStreamSocket; + +public: + ~SocketMultiplexer(); + + // Factory method. + static std::unique_ptr Create(Error* error); + + // Public interface + template + std::shared_ptr CreateListenSocket(const SocketAddress& address, Error* error); + template + std::shared_ptr ConnectStreamSocket(const SocketAddress& address, Error* error); + + // Returns true if any sockets are currently registered. + bool HasAnyOpenSockets(); + + // Close all sockets on this multiplexer. + void CloseAll(); + + // Poll for events. Returns false if there are no sockets registered. + bool PollEventsWithTimeout(u32 milliseconds); + +protected: + // Internal interface + std::shared_ptr InternalCreateListenSocket(const SocketAddress& address, + CreateStreamSocketCallback callback, Error* error); + std::shared_ptr InternalConnectStreamSocket(const SocketAddress& address, + CreateStreamSocketCallback callback, Error* error); + +private: + // Hide the constructor. + SocketMultiplexer(); + + // Tracking of open sockets. + void AddOpenSocket(std::shared_ptr socket); + void RemoveOpenSocket(BaseSocket* socket); + + // Register for notifications + void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events); + +private: + // We store the fd in the struct to avoid the cache miss reading the object. + using SocketMap = std::unordered_map>; + + std::mutex m_poll_array_lock; + pollfd* m_poll_array = nullptr; + size_t m_poll_array_active_size = 0; + size_t m_poll_array_max_size = 0; + + std::mutex m_open_sockets_lock; + SocketMap m_open_sockets; +}; + +template +std::shared_ptr SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error) +{ + const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, + SocketDescriptor descriptor) -> std::shared_ptr { + return std::static_pointer_cast(std::make_shared(multiplexer, descriptor)); + }; + return InternalCreateListenSocket(address, callback, error); +} + +template +std::shared_ptr SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error) +{ + const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, + SocketDescriptor descriptor) -> std::shared_ptr { + return std::static_pointer_cast(std::make_shared(multiplexer, descriptor)); + }; + return std::static_pointer_cast(InternalConnectStreamSocket(address, callback, error)); +} + +class ListenSocket final : public BaseSocket +{ + friend SocketMultiplexer; + +public: + ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, + SocketMultiplexer::CreateStreamSocketCallback accept_callback); + virtual ~ListenSocket() override; + + const SocketAddress* GetLocalAddress() const { return &m_local_address; } + u32 GetConnectionsAccepted() const { return m_num_connections_accepted; } + + void Close() override final; + +protected: + void OnReadEvent() override final; + void OnWriteEvent() override final; + +private: + SocketMultiplexer::CreateStreamSocketCallback m_accept_callback; + SocketAddress m_local_address = {}; + u32 m_num_connections_accepted = 0; +}; + +class StreamSocket : public BaseSocket +{ +public: + StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); + virtual ~StreamSocket() override; + + static u32 GetSocketProtocolForAddress(const SocketAddress& sa); + + virtual void Close() override final; + + // Accessors + const SocketAddress& GetLocalAddress() const { return m_local_address; } + const SocketAddress& GetRemoteAddress() const { return m_remote_address; } + bool IsConnected() const { return m_connected; } + + // Read/write + size_t Read(void* buffer, size_t buffer_size); + size_t Write(const void* buffer, size_t buffer_size); + size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); + +protected: + virtual void OnConnected() = 0; + virtual void OnDisconnected(const Error& error) = 0; + virtual void OnRead() = 0; + + virtual void OnReadEvent() override; + virtual void OnWriteEvent() override; + + void CloseWithError(); + +private: + void InitialSetup(); + + SocketAddress m_local_address = {}; + SocketAddress m_remote_address = {}; + std::recursive_mutex m_lock; + bool m_connected = true; + + // Ugly, but needed in order to call the events. + friend SocketMultiplexer; + friend ListenSocket; + friend BufferedStreamSocket; +}; + +class BufferedStreamSocket : public StreamSocket +{ +public: + BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384, + size_t send_buffer_size = 16384); + virtual ~BufferedStreamSocket() override; + + // Must hold the lock when not part of OnRead(). + std::unique_lock GetLock(); + std::span AcquireReadBuffer() const; + void ReleaseReadBuffer(size_t bytes_consumed); + std::span AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false); + void ReleaseWriteBuffer(size_t bytes_written, bool commit = true); + + // Hide StreamSocket read/write methods. + size_t Read(void* buffer, size_t buffer_size); + size_t Write(const void* buffer, size_t buffer_size); + size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); + +protected: + void OnReadEvent() override final; + void OnWriteEvent() override final; + virtual void OnWrite(); + +private: + std::vector m_receive_buffer; + size_t m_receive_buffer_offset = 0; + size_t m_receive_buffer_size = 0; + + std::vector m_send_buffer; + size_t m_send_buffer_offset = 0; + size_t m_send_buffer_size = 0; +}; diff --git a/src/util/util.vcxproj b/src/util/util.vcxproj index a5b344708..8e1b6a651 100644 --- a/src/util/util.vcxproj +++ b/src/util/util.vcxproj @@ -87,6 +87,7 @@ + @@ -204,6 +205,7 @@ + diff --git a/src/util/util.vcxproj.filters b/src/util/util.vcxproj.filters index 344d85c2c..d94e15a5a 100644 --- a/src/util/util.vcxproj.filters +++ b/src/util/util.vcxproj.filters @@ -72,6 +72,7 @@ + @@ -153,6 +154,7 @@ +