mirror of
https://github.com/RetroDECK/Duckstation.git
synced 2025-01-20 15:25:38 +00:00
Util: Add socket helper classes
This commit is contained in:
parent
9eb18449a7
commit
4e905a63ec
|
@ -2,6 +2,8 @@ add_library(common
|
||||||
align.h
|
align.h
|
||||||
assert.cpp
|
assert.cpp
|
||||||
assert.h
|
assert.h
|
||||||
|
binary_span_reader_writer.cpp
|
||||||
|
binary_span_reader_writer.h
|
||||||
bitfield.h
|
bitfield.h
|
||||||
bitutils.h
|
bitutils.h
|
||||||
build_timestamp.h
|
build_timestamp.h
|
||||||
|
|
110
src/common/binary_span_reader_writer.cpp
Normal file
110
src/common/binary_span_reader_writer.cpp
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
// SPDX-FileCopyrightText: 2024 Connor McLaughlin <stenzek@gmail.com>
|
||||||
|
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
|
||||||
|
|
||||||
|
#include "binary_span_reader_writer.h"
|
||||||
|
#include "small_string.h"
|
||||||
|
|
||||||
|
BinarySpanReader::BinarySpanReader() = default;
|
||||||
|
|
||||||
|
BinarySpanReader::BinarySpanReader(std::span<const u8> buf) : m_buf(buf)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::PeekCString(std::string_view* dst)
|
||||||
|
{
|
||||||
|
size_t pos = m_pos;
|
||||||
|
size_t size = 0;
|
||||||
|
while (pos < m_buf.size())
|
||||||
|
{
|
||||||
|
if (m_buf[pos] == 0)
|
||||||
|
break;
|
||||||
|
|
||||||
|
pos++;
|
||||||
|
size++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos == m_buf.size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
*dst = std::string_view(reinterpret_cast<const char*>(&m_buf[m_pos]), size);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::ReadCString(std::string* dst)
|
||||||
|
{
|
||||||
|
std::string_view sv;
|
||||||
|
if (!PeekCString(&sv))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
dst->assign(sv);
|
||||||
|
m_pos += sv.size() + 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::ReadCString(std::string_view* dst)
|
||||||
|
{
|
||||||
|
if (!PeekCString(dst))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
m_pos += dst->size() + 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::ReadCString(SmallStringBase* dst)
|
||||||
|
{
|
||||||
|
std::string_view sv;
|
||||||
|
if (!PeekCString(&sv))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
dst->assign(sv);
|
||||||
|
m_pos += sv.size() + 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string_view BinarySpanReader::ReadCString()
|
||||||
|
{
|
||||||
|
std::string_view ret;
|
||||||
|
if (PeekCString(&ret))
|
||||||
|
m_pos += ret.size() + 1;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::PeekCString(std::string* dst)
|
||||||
|
{
|
||||||
|
std::string_view sv;
|
||||||
|
if (!PeekCString(&sv))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
dst->assign(sv);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanReader::PeekCString(SmallStringBase* dst)
|
||||||
|
{
|
||||||
|
std::string_view sv;
|
||||||
|
if (!PeekCString(&sv))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
dst->assign(sv);
|
||||||
|
m_pos += sv.size() + 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
BinarySpanWriter::BinarySpanWriter() = default;
|
||||||
|
|
||||||
|
BinarySpanWriter::BinarySpanWriter(std::span<u8> buf) : m_buf(buf)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BinarySpanWriter::WriteCString(std::string_view val)
|
||||||
|
{
|
||||||
|
if ((m_pos + val.size() + 1) > m_buf.size()) [[unlikely]]
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (!val.empty())
|
||||||
|
std::memcpy(&m_buf[m_pos], val.data(), val.size());
|
||||||
|
|
||||||
|
m_buf[m_pos + val.size()] = 0;
|
||||||
|
m_pos += val.size() + 1;
|
||||||
|
return true;
|
||||||
|
}
|
135
src/common/binary_span_reader_writer.h
Normal file
135
src/common/binary_span_reader_writer.h
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
// SPDX-FileCopyrightText: 2024 Connor McLaughlin <stenzek@gmail.com>
|
||||||
|
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <span>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
class SmallStringBase;
|
||||||
|
|
||||||
|
class BinarySpanReader
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BinarySpanReader();
|
||||||
|
BinarySpanReader(std::span<const u8> buf);
|
||||||
|
|
||||||
|
ALWAYS_INLINE const std::span<const u8>& GetSpan() const { return m_buf; }
|
||||||
|
ALWAYS_INLINE bool IsValid() const { return !m_buf.empty(); }
|
||||||
|
ALWAYS_INLINE bool CheckRemaining(size_t size) { return ((m_pos + size) <= m_buf.size()); }
|
||||||
|
ALWAYS_INLINE size_t GetBufferRemaining() const { return (m_buf.size() - m_pos); }
|
||||||
|
ALWAYS_INLINE size_t GetBufferConsumed() const { return m_pos; }
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template<typename T> ALWAYS_INLINE bool ReadT(T* dst) { return Read(dst, sizeof(T)); }
|
||||||
|
ALWAYS_INLINE bool ReadU8(u8* dst) { return ReadT(dst); }
|
||||||
|
ALWAYS_INLINE bool ReadU16(u16* dst) { return ReadT(dst); }
|
||||||
|
ALWAYS_INLINE bool ReadU32(u32* dst) { return ReadT(dst); }
|
||||||
|
ALWAYS_INLINE bool ReadU64(u64* dst) { return ReadT(dst); }
|
||||||
|
ALWAYS_INLINE bool ReadFloat(float* dst) { return ReadT(dst); }
|
||||||
|
bool ReadCString(std::string* dst);
|
||||||
|
bool ReadCString(std::string_view* dst);
|
||||||
|
bool ReadCString(SmallStringBase* dst);
|
||||||
|
|
||||||
|
template<typename T> ALWAYS_INLINE T ReadT() { T ret; if (!Read(&ret, sizeof(ret))) [[unlikely]] { ret = {}; } return ret; }
|
||||||
|
ALWAYS_INLINE u8 ReadU8() { return ReadT<u8>(); }
|
||||||
|
ALWAYS_INLINE u16 ReadU16() { return ReadT<u16>(); }
|
||||||
|
ALWAYS_INLINE u32 ReadU32() { return ReadT<u32>(); }
|
||||||
|
ALWAYS_INLINE u64 ReadU64() { return ReadT<u64>(); }
|
||||||
|
ALWAYS_INLINE float ReadFloat() { return ReadT<float>(); }
|
||||||
|
std::string_view ReadCString();
|
||||||
|
|
||||||
|
template<typename T> ALWAYS_INLINE bool PeekT(T* dst) { return Peek(dst, sizeof(T)); }
|
||||||
|
ALWAYS_INLINE bool PeekU8(u8* dst) { return PeekT(dst); }
|
||||||
|
ALWAYS_INLINE bool PeekU16(u16* dst) { return PeekT(dst); }
|
||||||
|
ALWAYS_INLINE bool PeekU32(u32* dst) { return PeekT(dst); }
|
||||||
|
ALWAYS_INLINE bool PeekU64(u64* dst) { return PeekT(dst); }
|
||||||
|
ALWAYS_INLINE bool PeekFloat(float* dst) { return PeekT(dst); }
|
||||||
|
bool PeekCString(std::string* dst);
|
||||||
|
bool PeekCString(std::string_view* dst);
|
||||||
|
bool PeekCString(SmallStringBase* dst);
|
||||||
|
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(u8& val) { val = ReadT<u8>(); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(u16& val) { val = ReadT<u16>(); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(u32& val) { val = ReadT<u32>(); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(u64& val) { val = ReadT<u64>(); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(float& val) { val = ReadT<float>(); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanReader& operator>>(std::string_view val) { val = ReadCString(); return *this; }
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
private:
|
||||||
|
ALWAYS_INLINE bool Read(void* buf, size_t size)
|
||||||
|
{
|
||||||
|
if ((m_pos + size) < m_buf.size()) [[likely]]
|
||||||
|
{
|
||||||
|
std::memcpy(buf, &m_buf[m_pos], size);
|
||||||
|
m_pos += size;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ALWAYS_INLINE bool Peek(void* buf, size_t size)
|
||||||
|
{
|
||||||
|
if ((m_pos + size) < m_buf.size()) [[likely]]
|
||||||
|
{
|
||||||
|
std::memcpy(buf, &m_buf[m_pos], size);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::span<const u8> m_buf;
|
||||||
|
size_t m_pos = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BinarySpanWriter
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BinarySpanWriter();
|
||||||
|
BinarySpanWriter(std::span<u8> buf);
|
||||||
|
|
||||||
|
ALWAYS_INLINE const std::span<u8>& GetSpan() const { return m_buf; }
|
||||||
|
ALWAYS_INLINE bool IsValid() const { return !m_buf.empty(); }
|
||||||
|
ALWAYS_INLINE size_t GetBufferRemaining() const { return (m_buf.size() - m_pos); }
|
||||||
|
ALWAYS_INLINE size_t GetBufferWritten() const { return m_pos; }
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template<typename T> ALWAYS_INLINE bool WriteT(T dst) { return Write(&dst, sizeof(T)); }
|
||||||
|
ALWAYS_INLINE bool WriteU8(u8 val) { return WriteT(val); }
|
||||||
|
ALWAYS_INLINE bool WriteU16(u16 val) { return WriteT(val); }
|
||||||
|
ALWAYS_INLINE bool WriteU32(u32 val) { return WriteT(val); }
|
||||||
|
ALWAYS_INLINE bool WriteU64(u64 val) { return WriteT(val); }
|
||||||
|
ALWAYS_INLINE bool WriteFloat(float val) { return WriteT(val); }
|
||||||
|
bool WriteCString(std::string_view val);
|
||||||
|
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(u8 val) { WriteU8(val); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(u16 val) { WriteU16(val); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(u32 val) { WriteU32(val); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(u64 val) { WriteU64(val); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(float val) { WriteFloat(val); return *this; }
|
||||||
|
ALWAYS_INLINE BinarySpanWriter& operator<<(std::string_view val) { WriteCString(val); return *this; }
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
private:
|
||||||
|
ALWAYS_INLINE bool Write(void* buf, size_t size)
|
||||||
|
{
|
||||||
|
if ((m_pos + size) < m_buf.size()) [[likely]]
|
||||||
|
{
|
||||||
|
std::memcpy(&m_buf[m_pos], buf, size);
|
||||||
|
m_pos += size;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::span<u8> m_buf;
|
||||||
|
size_t m_pos = 0;
|
||||||
|
};
|
|
@ -34,6 +34,7 @@
|
||||||
<ClInclude Include="sha1_digest.h" />
|
<ClInclude Include="sha1_digest.h" />
|
||||||
<ClInclude Include="small_string.h" />
|
<ClInclude Include="small_string.h" />
|
||||||
<ClInclude Include="heterogeneous_containers.h" />
|
<ClInclude Include="heterogeneous_containers.h" />
|
||||||
|
<ClInclude Include="binary_span_reader_writer.h" />
|
||||||
<ClInclude Include="string_util.h" />
|
<ClInclude Include="string_util.h" />
|
||||||
<ClInclude Include="thirdparty\SmallVector.h" />
|
<ClInclude Include="thirdparty\SmallVector.h" />
|
||||||
<ClInclude Include="thirdparty\StackWalker.h" />
|
<ClInclude Include="thirdparty\StackWalker.h" />
|
||||||
|
@ -60,6 +61,7 @@
|
||||||
<ClCompile Include="progress_callback.cpp" />
|
<ClCompile Include="progress_callback.cpp" />
|
||||||
<ClCompile Include="sha1_digest.cpp" />
|
<ClCompile Include="sha1_digest.cpp" />
|
||||||
<ClCompile Include="small_string.cpp" />
|
<ClCompile Include="small_string.cpp" />
|
||||||
|
<ClCompile Include="binary_span_reader_writer.cpp" />
|
||||||
<ClCompile Include="string_util.cpp" />
|
<ClCompile Include="string_util.cpp" />
|
||||||
<ClCompile Include="thirdparty\SmallVector.cpp" />
|
<ClCompile Include="thirdparty\SmallVector.cpp" />
|
||||||
<ClCompile Include="thirdparty\StackWalker.cpp" />
|
<ClCompile Include="thirdparty\StackWalker.cpp" />
|
||||||
|
|
|
@ -45,6 +45,7 @@
|
||||||
<Filter>thirdparty</Filter>
|
<Filter>thirdparty</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
<ClInclude Include="dynamic_library.h" />
|
<ClInclude Include="dynamic_library.h" />
|
||||||
|
<ClInclude Include="binary_span_reader_writer.h" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ClCompile Include="small_string.cpp" />
|
<ClCompile Include="small_string.cpp" />
|
||||||
|
@ -72,6 +73,7 @@
|
||||||
<Filter>thirdparty</Filter>
|
<Filter>thirdparty</Filter>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="dynamic_library.cpp" />
|
<ClCompile Include="dynamic_library.cpp" />
|
||||||
|
<ClCompile Include="binary_span_reader_writer.cpp" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Natvis Include="bitfield.natvis" />
|
<Natvis Include="bitfield.natvis" />
|
||||||
|
|
|
@ -63,6 +63,8 @@ add_library(util
|
||||||
shadergen.h
|
shadergen.h
|
||||||
shiftjis.cpp
|
shiftjis.cpp
|
||||||
shiftjis.h
|
shiftjis.h
|
||||||
|
sockets.cpp
|
||||||
|
sockets.h
|
||||||
state_wrapper.cpp
|
state_wrapper.cpp
|
||||||
state_wrapper.h
|
state_wrapper.h
|
||||||
wav_writer.cpp
|
wav_writer.cpp
|
||||||
|
|
|
@ -5,7 +5,10 @@
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
|
class Error;
|
||||||
|
|
||||||
namespace PlatformMisc {
|
namespace PlatformMisc {
|
||||||
|
bool InitializeSocketSupport(Error* error);
|
||||||
void SuspendScreensaver();
|
void SuspendScreensaver();
|
||||||
void ResumeScreensaver();
|
void ResumeScreensaver();
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,11 @@ Log_SetChannel(PlatformMisc);
|
||||||
|
|
||||||
static IOPMAssertionID s_prevent_idle_assertion = kIOPMNullAssertionID;
|
static IOPMAssertionID s_prevent_idle_assertion = kIOPMNullAssertionID;
|
||||||
|
|
||||||
|
bool PlatformMisc::InitializeSocketSupport(Error* error)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool SetScreensaverInhibitMacOS(bool inhibit)
|
static bool SetScreensaverInhibitMacOS(bool inhibit)
|
||||||
{
|
{
|
||||||
if (inhibit)
|
if (inhibit)
|
||||||
|
|
|
@ -16,6 +16,11 @@
|
||||||
|
|
||||||
Log_SetChannel(PlatformMisc);
|
Log_SetChannel(PlatformMisc);
|
||||||
|
|
||||||
|
bool PlatformMisc::InitializeSocketSupport(Error* error)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool SetScreensaverInhibitDBus(const bool inhibit_requested, const char* program_name, const char* reason)
|
static bool SetScreensaverInhibitDBus(const bool inhibit_requested, const char* program_name, const char* reason)
|
||||||
{
|
{
|
||||||
static dbus_uint32_t s_cookie;
|
static dbus_uint32_t s_cookie;
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
#include "platform_misc.h"
|
#include "platform_misc.h"
|
||||||
|
|
||||||
|
#include "common/error.h"
|
||||||
#include "common/file_system.h"
|
#include "common/file_system.h"
|
||||||
#include "common/log.h"
|
#include "common/log.h"
|
||||||
#include "common/small_string.h"
|
#include "common/small_string.h"
|
||||||
|
@ -13,10 +14,33 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "common/windows_headers.h"
|
#include "common/windows_headers.h"
|
||||||
|
#include <WinSock2.h>
|
||||||
#include <mmsystem.h>
|
#include <mmsystem.h>
|
||||||
|
|
||||||
Log_SetChannel(PlatformMisc);
|
Log_SetChannel(PlatformMisc);
|
||||||
|
|
||||||
|
static bool s_screensaver_suspended = false;
|
||||||
|
static bool s_winsock_initialized = false;
|
||||||
|
static std::once_flag s_winsock_initializer;
|
||||||
|
|
||||||
|
bool PlatformMisc::InitializeSocketSupport(Error* error)
|
||||||
|
{
|
||||||
|
std::call_once(s_winsock_initializer, [](Error* error) {
|
||||||
|
WSADATA wsa = {};
|
||||||
|
if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "WSAStartup() failed: ", WSAGetLastError());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
s_winsock_initialized = true;
|
||||||
|
std::atexit([]() { WSACleanup(); });
|
||||||
|
return true;
|
||||||
|
}, error);
|
||||||
|
|
||||||
|
return s_winsock_initialized;
|
||||||
|
}
|
||||||
|
|
||||||
static bool SetScreensaverInhibitWin32(bool inhibit)
|
static bool SetScreensaverInhibitWin32(bool inhibit)
|
||||||
{
|
{
|
||||||
if (SetThreadExecutionState(ES_CONTINUOUS | (inhibit ? (ES_DISPLAY_REQUIRED | ES_SYSTEM_REQUIRED) : 0)) == NULL)
|
if (SetThreadExecutionState(ES_CONTINUOUS | (inhibit ? (ES_DISPLAY_REQUIRED | ES_SYSTEM_REQUIRED) : 0)) == NULL)
|
||||||
|
@ -28,8 +52,6 @@ static bool SetScreensaverInhibitWin32(bool inhibit)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool s_screensaver_suspended;
|
|
||||||
|
|
||||||
void PlatformMisc::SuspendScreensaver()
|
void PlatformMisc::SuspendScreensaver()
|
||||||
{
|
{
|
||||||
if (s_screensaver_suspended)
|
if (s_screensaver_suspended)
|
||||||
|
|
940
src/util/sockets.cpp
Normal file
940
src/util/sockets.cpp
Normal file
|
@ -0,0 +1,940 @@
|
||||||
|
// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
|
||||||
|
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
|
||||||
|
|
||||||
|
#include "sockets.h"
|
||||||
|
#include "platform_misc.h"
|
||||||
|
|
||||||
|
#include "common/assert.h"
|
||||||
|
#include "common/log.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#ifndef __APPLE__
|
||||||
|
#include <malloc.h> // alloca
|
||||||
|
#else
|
||||||
|
#include <alloca.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
|
||||||
|
#include "common/windows_headers.h"
|
||||||
|
|
||||||
|
#include <WS2tcpip.h>
|
||||||
|
#include <WinSock2.h>
|
||||||
|
|
||||||
|
#define SIZE_CAST(x) static_cast<int>(x)
|
||||||
|
using ssize_t = int;
|
||||||
|
using nfds_t = ULONG;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <errno.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <poll.h>
|
||||||
|
#include <sys/ioctl.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <sys/uio.h>
|
||||||
|
#include <sys/un.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#define ioctlsocket ioctl
|
||||||
|
#define closesocket close
|
||||||
|
#define WSAEWOULDBLOCK EAGAIN
|
||||||
|
#define WSAGetLastError() errno
|
||||||
|
#define WSAPoll poll
|
||||||
|
#define SIZE_CAST(x) x
|
||||||
|
|
||||||
|
#define SOCKET_ERROR -1
|
||||||
|
#define INVALID_SOCKET -1
|
||||||
|
#define SD_BOTH SHUT_RDWR
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Log_SetChannel(Sockets);
|
||||||
|
|
||||||
|
static bool SetNonBlocking(SocketDescriptor sd, Error* error)
|
||||||
|
{
|
||||||
|
// switch to nonblocking mode
|
||||||
|
unsigned long value = 1;
|
||||||
|
if (ioctlsocket(sd, FIONBIO, &value) < 0)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "ioctlsocket() failed: ", WSAGetLastError());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SocketAddress::SetFromSockaddr(const void* sa, size_t length)
|
||||||
|
{
|
||||||
|
m_length = std::min(static_cast<u32>(length), static_cast<u32>(sizeof(m_data)));
|
||||||
|
std::memcpy(m_data, sa, m_length);
|
||||||
|
if (m_length < sizeof(m_data))
|
||||||
|
std::memset(m_data + m_length, 0, sizeof(m_data) - m_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<SocketAddress> SocketAddress::Parse(Type type, const char* address, u32 port, Error* error)
|
||||||
|
{
|
||||||
|
std::optional<SocketAddress> ret = SocketAddress();
|
||||||
|
|
||||||
|
switch (type)
|
||||||
|
{
|
||||||
|
case Type::IPv4:
|
||||||
|
{
|
||||||
|
sockaddr_in* sain = reinterpret_cast<sockaddr_in*>(ret->m_data);
|
||||||
|
std::memset(sain, 0, sizeof(sockaddr_in));
|
||||||
|
sain->sin_family = AF_INET;
|
||||||
|
sain->sin_port = htons(static_cast<u16>(port));
|
||||||
|
int res = inet_pton(AF_INET, address, &sain->sin_addr);
|
||||||
|
if (res == 1)
|
||||||
|
{
|
||||||
|
ret->m_length = sizeof(sockaddr_in);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError());
|
||||||
|
ret.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case Type::IPv6:
|
||||||
|
{
|
||||||
|
sockaddr_in6* sain6 = reinterpret_cast<sockaddr_in6*>(ret->m_data);
|
||||||
|
std::memset(sain6, 0, sizeof(sockaddr_in6));
|
||||||
|
sain6->sin6_family = AF_INET;
|
||||||
|
sain6->sin6_port = htons(static_cast<u16>(port));
|
||||||
|
int res = inet_pton(AF_INET6, address, &sain6->sin6_addr);
|
||||||
|
if (res == 1)
|
||||||
|
{
|
||||||
|
ret->m_length = sizeof(sockaddr_in6);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "inet_pton() failed: ", WSAGetLastError());
|
||||||
|
ret.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
case Type::Unix:
|
||||||
|
{
|
||||||
|
sockaddr_un* sun = reinterpret_cast<sockaddr_un*>(ret->m_data);
|
||||||
|
std::memset(sun, 0, sizeof(sockaddr_un));
|
||||||
|
sun->sun_family = AF_UNIX;
|
||||||
|
|
||||||
|
const size_t len = std::strlen(address);
|
||||||
|
if ((len + 1) <= std::size(sun->sun_path))
|
||||||
|
{
|
||||||
|
std::memcpy(sun->sun_path, address, len);
|
||||||
|
ret->m_length = sizeof(sockaddr_un);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Error::SetStringFmt(error, "Path length {} exceeds {} bytes.", len, std::size(sun->sun_path));
|
||||||
|
ret.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
default:
|
||||||
|
Error::SetStringView(error, "Unknown address type.");
|
||||||
|
ret.reset();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallString SocketAddress::ToString() const
|
||||||
|
{
|
||||||
|
SmallString ret;
|
||||||
|
|
||||||
|
const sockaddr* sa = reinterpret_cast<const sockaddr*>(m_data);
|
||||||
|
switch (sa->sa_family)
|
||||||
|
{
|
||||||
|
case AF_INET:
|
||||||
|
{
|
||||||
|
ret.clear();
|
||||||
|
ret.reserve(128);
|
||||||
|
const char* res =
|
||||||
|
inet_ntop(AF_INET, &reinterpret_cast<const sockaddr_in*>(m_data)->sin_addr, ret.data(), ret.buffer_size());
|
||||||
|
if (res == nullptr)
|
||||||
|
ret.assign("<unknown>");
|
||||||
|
else
|
||||||
|
ret.update_size();
|
||||||
|
|
||||||
|
ret.append_format(":{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in*>(m_data)->sin_port)));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case AF_INET6:
|
||||||
|
{
|
||||||
|
ret.clear();
|
||||||
|
ret.reserve(128);
|
||||||
|
ret.append('[');
|
||||||
|
const char* res = inet_ntop(AF_INET6, &reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_addr, ret.data() + 1,
|
||||||
|
ret.buffer_size() - 1);
|
||||||
|
if (res == nullptr)
|
||||||
|
ret.assign("<unknown>");
|
||||||
|
else
|
||||||
|
ret.update_size();
|
||||||
|
|
||||||
|
ret.append_format("]:{}", static_cast<u32>(ntohs(reinterpret_cast<const sockaddr_in6*>(m_data)->sin6_port)));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
case AF_UNIX:
|
||||||
|
{
|
||||||
|
ret.assign(reinterpret_cast<const sockaddr_un*>(m_data)->sun_path);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
ret.assign("<unknown>");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseSocket::BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor)
|
||||||
|
: m_multiplexer(multiplexer), m_descriptor(descriptor)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
BaseSocket::~BaseSocket() = default;
|
||||||
|
|
||||||
|
SocketMultiplexer::SocketMultiplexer() = default;
|
||||||
|
|
||||||
|
SocketMultiplexer::~SocketMultiplexer()
|
||||||
|
{
|
||||||
|
CloseAll();
|
||||||
|
|
||||||
|
if (m_poll_array)
|
||||||
|
std::free(m_poll_array);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SocketMultiplexer> SocketMultiplexer::Create(Error* error)
|
||||||
|
{
|
||||||
|
if (!PlatformMisc::InitializeSocketSupport(error))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
return std::unique_ptr<SocketMultiplexer>(new SocketMultiplexer());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ListenSocket> SocketMultiplexer::InternalCreateListenSocket(const SocketAddress& address,
|
||||||
|
CreateStreamSocketCallback callback,
|
||||||
|
Error* error)
|
||||||
|
{
|
||||||
|
// create and bind socket
|
||||||
|
const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData());
|
||||||
|
SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address));
|
||||||
|
if (descriptor == INVALID_SOCKET)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "socket() failed: ", WSAGetLastError());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const int reuseaddr_enable = 1;
|
||||||
|
if (setsockopt(descriptor, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&reuseaddr_enable),
|
||||||
|
sizeof(reuseaddr_enable)) < 0)
|
||||||
|
{
|
||||||
|
WARNING_LOG("Failed to set SO_REUSEADDR: {}", Error::CreateSocket(WSAGetLastError()).GetDescription());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(descriptor, sa, address.GetLength()) < 0)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "bind() failed: ", WSAGetLastError());
|
||||||
|
closesocket(descriptor);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listen(descriptor, 5) < 0)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "listen() failed: ", WSAGetLastError());
|
||||||
|
closesocket(descriptor);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SetNonBlocking(descriptor, error))
|
||||||
|
{
|
||||||
|
closesocket(descriptor);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// create listensocket
|
||||||
|
std::shared_ptr<ListenSocket> ret = std::make_shared<ListenSocket>(*this, descriptor, callback);
|
||||||
|
|
||||||
|
// add to list, register for reads
|
||||||
|
AddOpenSocket(std::static_pointer_cast<BaseSocket>(ret));
|
||||||
|
SetNotificationMask(ret.get(), descriptor, POLLIN);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<StreamSocket> SocketMultiplexer::InternalConnectStreamSocket(const SocketAddress& address,
|
||||||
|
CreateStreamSocketCallback callback,
|
||||||
|
Error* error)
|
||||||
|
{
|
||||||
|
// create and bind socket
|
||||||
|
const sockaddr* sa = reinterpret_cast<const sockaddr*>(address.GetData());
|
||||||
|
SocketDescriptor descriptor = socket(sa->sa_family, SOCK_STREAM, StreamSocket::GetSocketProtocolForAddress(address));
|
||||||
|
if (descriptor == INVALID_SOCKET)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "socket() failed: ", WSAGetLastError());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (connect(descriptor, sa, address.GetLength()) < 0)
|
||||||
|
{
|
||||||
|
Error::SetSocket(error, "connect() failed: ", WSAGetLastError());
|
||||||
|
closesocket(descriptor);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SetNonBlocking(descriptor, error))
|
||||||
|
{
|
||||||
|
closesocket(descriptor);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// create stream socket
|
||||||
|
std::shared_ptr<StreamSocket> csocket = callback(*this, descriptor);
|
||||||
|
csocket->InitialSetup();
|
||||||
|
if (!csocket->IsConnected())
|
||||||
|
csocket.reset();
|
||||||
|
|
||||||
|
return csocket;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SocketMultiplexer::AddOpenSocket(std::shared_ptr<BaseSocket> socket)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
|
|
||||||
|
DebugAssert(m_open_sockets.find(socket->GetDescriptor()) == m_open_sockets.end());
|
||||||
|
m_open_sockets.emplace(socket->GetDescriptor(), socket);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SocketMultiplexer::RemoveOpenSocket(BaseSocket* socket)
|
||||||
|
{
|
||||||
|
#ifdef _DEBUG
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_poll_array_lock);
|
||||||
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
{
|
||||||
|
pollfd& pfd = m_poll_array[i];
|
||||||
|
DebugAssert(pfd.fd != socket->GetDescriptor());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
|
const auto iter = m_open_sockets.find(socket->GetDescriptor());
|
||||||
|
Assert(iter != m_open_sockets.end());
|
||||||
|
m_open_sockets.erase(iter);
|
||||||
|
|
||||||
|
// Update size.
|
||||||
|
size_t new_active_size = 0;
|
||||||
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
new_active_size = (m_poll_array[i].fd != INVALID_SOCKET) ? (i + 1) : new_active_size;
|
||||||
|
m_poll_array_active_size = new_active_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SocketMultiplexer::HasAnyOpenSockets()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
|
return !m_open_sockets.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SocketMultiplexer::CloseAll()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_open_sockets_lock);
|
||||||
|
|
||||||
|
while (!m_open_sockets.empty())
|
||||||
|
{
|
||||||
|
std::shared_ptr<BaseSocket> socket = m_open_sockets.begin()->second;
|
||||||
|
lock.unlock();
|
||||||
|
socket->Close();
|
||||||
|
lock.lock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SocketMultiplexer::SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_poll_array_lock);
|
||||||
|
size_t free_slot = m_poll_array_active_size;
|
||||||
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
{
|
||||||
|
pollfd& pfd = m_poll_array[i];
|
||||||
|
if (pfd.fd != descriptor)
|
||||||
|
{
|
||||||
|
free_slot = (pfd.fd < 0 && free_slot != m_poll_array_active_size) ? i : free_slot;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unbinding?
|
||||||
|
if (events != 0)
|
||||||
|
pfd.events = static_cast<short>(events);
|
||||||
|
else
|
||||||
|
pfd.fd = INVALID_SOCKET;
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't create entries for null masks
|
||||||
|
if (events == 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// need to grow the array?
|
||||||
|
if (free_slot == m_poll_array_max_size)
|
||||||
|
{
|
||||||
|
const size_t new_size = std::max(free_slot + 1, free_slot * 2);
|
||||||
|
pollfd* new_array = static_cast<pollfd*>(std::realloc(m_poll_array, sizeof(pollfd) * new_size));
|
||||||
|
if (!new_array)
|
||||||
|
Panic("Memory allocation failed.");
|
||||||
|
|
||||||
|
for (size_t i = m_poll_array_max_size; i < new_size; i++)
|
||||||
|
new_array[i] = {.fd = INVALID_SOCKET, .events = 0, .revents = 0};
|
||||||
|
m_poll_array = new_array;
|
||||||
|
m_poll_array_max_size = new_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_poll_array[free_slot] = {.fd = descriptor, .events = static_cast<short>(events), .revents = 0};
|
||||||
|
m_poll_array_active_size = free_slot + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SocketMultiplexer::PollEventsWithTimeout(u32 milliseconds)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_poll_array_lock);
|
||||||
|
if (m_poll_array_active_size == 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const int res = WSAPoll(m_poll_array, static_cast<nfds_t>(m_poll_array_active_size), milliseconds);
|
||||||
|
if (res <= 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// find sockets that triggered, we use an array here so we can avoid holding the lock, and if a socket disconnects
|
||||||
|
using PendingSocketPair = std::pair<std::shared_ptr<BaseSocket>, u32>;
|
||||||
|
PendingSocketPair* triggered_sockets = reinterpret_cast<PendingSocketPair*>(alloca(sizeof(PendingSocketPair) * res));
|
||||||
|
size_t num_triggered_sockets = 0;
|
||||||
|
{
|
||||||
|
std::unique_lock open_lock(m_open_sockets_lock);
|
||||||
|
for (size_t i = 0; i < m_poll_array_active_size; i++)
|
||||||
|
{
|
||||||
|
const pollfd& pfd = m_poll_array[i];
|
||||||
|
if (pfd.revents == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const auto iter = m_open_sockets.find(pfd.fd);
|
||||||
|
if (iter == m_open_sockets.end()) [[unlikely]]
|
||||||
|
{
|
||||||
|
ERROR_LOG("Attempting to look up known socket {}, this should never happen.", pfd.fd);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// we add a reference here in case the read kills it with a write pending, or something like that
|
||||||
|
new (&triggered_sockets[num_triggered_sockets++])
|
||||||
|
PendingSocketPair(iter->second->shared_from_this(), pfd.revents);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// release lock so connections etc can acquire it
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
|
// fire events
|
||||||
|
for (u32 i = 0; i < num_triggered_sockets; i++)
|
||||||
|
{
|
||||||
|
PendingSocketPair& psp = triggered_sockets[i];
|
||||||
|
|
||||||
|
// fire events
|
||||||
|
if (psp.second & (POLLIN | POLLHUP | POLLERR))
|
||||||
|
psp.first->OnReadEvent();
|
||||||
|
if (psp.second & POLLOUT)
|
||||||
|
psp.first->OnWriteEvent();
|
||||||
|
|
||||||
|
psp.first.~shared_ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
ListenSocket::ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||||
|
SocketMultiplexer::CreateStreamSocketCallback accept_callback)
|
||||||
|
: BaseSocket(multiplexer, descriptor), m_accept_callback(accept_callback)
|
||||||
|
{
|
||||||
|
// get local address
|
||||||
|
sockaddr_storage sa;
|
||||||
|
socklen_t salen = sizeof(sa);
|
||||||
|
if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
|
||||||
|
m_local_address.SetFromSockaddr(&sa, salen);
|
||||||
|
}
|
||||||
|
|
||||||
|
ListenSocket::~ListenSocket()
|
||||||
|
{
|
||||||
|
DebugAssert(m_descriptor == INVALID_SOCKET);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ListenSocket::Close()
|
||||||
|
{
|
||||||
|
if (m_descriptor < 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
|
||||||
|
m_multiplexer.RemoveOpenSocket(this);
|
||||||
|
closesocket(m_descriptor);
|
||||||
|
m_descriptor = INVALID_SOCKET;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ListenSocket::OnReadEvent()
|
||||||
|
{
|
||||||
|
// connection incoming
|
||||||
|
sockaddr_storage sa;
|
||||||
|
socklen_t salen = sizeof(sa);
|
||||||
|
SocketDescriptor new_descriptor = accept(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen);
|
||||||
|
if (new_descriptor == INVALID_SOCKET)
|
||||||
|
{
|
||||||
|
ERROR_LOG("accept() returned {}", Error::CreateSocket(WSAGetLastError()).GetDescription());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Error error;
|
||||||
|
if (!SetNonBlocking(new_descriptor, &error))
|
||||||
|
{
|
||||||
|
ERROR_LOG("Failed to set just-connected socket to nonblocking: {}", error.GetDescription());
|
||||||
|
closesocket(new_descriptor);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create socket, we release our own reference.
|
||||||
|
std::shared_ptr<StreamSocket> client = m_accept_callback(m_multiplexer, new_descriptor);
|
||||||
|
if (!client)
|
||||||
|
{
|
||||||
|
closesocket(new_descriptor);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_num_connections_accepted++;
|
||||||
|
client->InitialSetup();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ListenSocket::OnWriteEvent()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamSocket::StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor)
|
||||||
|
: BaseSocket(multiplexer, descriptor)
|
||||||
|
{
|
||||||
|
// get local address
|
||||||
|
sockaddr_storage sa;
|
||||||
|
socklen_t salen = sizeof(sa);
|
||||||
|
if (getsockname(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
|
||||||
|
m_local_address.SetFromSockaddr(&sa, salen);
|
||||||
|
|
||||||
|
// get remote address
|
||||||
|
salen = sizeof(sockaddr_storage);
|
||||||
|
if (getpeername(m_descriptor, reinterpret_cast<sockaddr*>(&sa), &salen) == 0)
|
||||||
|
m_remote_address.SetFromSockaddr(&sa, salen);
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamSocket::~StreamSocket()
|
||||||
|
{
|
||||||
|
DebugAssert(m_descriptor == INVALID_SOCKET);
|
||||||
|
}
|
||||||
|
|
||||||
|
u32 StreamSocket::GetSocketProtocolForAddress(const SocketAddress& sa)
|
||||||
|
{
|
||||||
|
const sockaddr* ssa = reinterpret_cast<const sockaddr*>(sa.GetData());
|
||||||
|
return (ssa->sa_family == AF_INET || ssa->sa_family == AF_INET6) ? IPPROTO_TCP : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSocket::InitialSetup()
|
||||||
|
{
|
||||||
|
// register for notifications
|
||||||
|
m_multiplexer.AddOpenSocket(shared_from_this());
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);
|
||||||
|
|
||||||
|
// trigger connected notification
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
OnConnected();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t StreamSocket::Read(void* buffer, size_t buffer_size)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// try a read
|
||||||
|
const ssize_t len = recv(m_descriptor, static_cast<char*>(buffer), SIZE_CAST(buffer_size), 0);
|
||||||
|
if (len <= 0)
|
||||||
|
{
|
||||||
|
// Check for EAGAIN
|
||||||
|
if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
// Not an error. Just means no data is available.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// error
|
||||||
|
CloseWithError();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t StreamSocket::Write(const void* buffer, size_t buffer_size)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// try a write
|
||||||
|
const ssize_t len = send(m_descriptor, static_cast<const char*>(buffer), SIZE_CAST(buffer_size), 0);
|
||||||
|
if (len <= 0)
|
||||||
|
{
|
||||||
|
// Check for EAGAIN
|
||||||
|
if (len < 0 && WSAGetLastError() == WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
// Not an error. Just means no data is available.
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// error
|
||||||
|
CloseWithError();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return len;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t StreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers)
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected || num_buffers == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
|
||||||
|
WSABUF* bufs = static_cast<WSABUF*>(alloca(sizeof(WSABUF) * num_buffers));
|
||||||
|
for (size_t i = 0; i < num_buffers; i++)
|
||||||
|
{
|
||||||
|
bufs[i].buf = (CHAR*)buffers[i];
|
||||||
|
bufs[i].len = (ULONG)buffer_lengths[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
DWORD bytesSent = 0;
|
||||||
|
if (WSASend(m_descriptor, bufs, (DWORD)num_buffers, &bytesSent, 0, nullptr, nullptr) == SOCKET_ERROR)
|
||||||
|
{
|
||||||
|
if (WSAGetLastError() != WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
// Socket error.
|
||||||
|
CloseWithError();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return static_cast<size_t>(bytesSent);
|
||||||
|
|
||||||
|
#else // _WIN32
|
||||||
|
|
||||||
|
iovec* bufs = static_cast<iovec*>(alloca(sizeof(iovec) * num_buffers));
|
||||||
|
for (size_t i = 0; i < num_buffers; i++)
|
||||||
|
{
|
||||||
|
bufs[i].iov_base = (void*)buffers[i];
|
||||||
|
bufs[i].iov_len = buffer_lengths[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t res = writev(m_descriptor, bufs, num_buffers);
|
||||||
|
if (res < 0)
|
||||||
|
{
|
||||||
|
if (errno != EAGAIN)
|
||||||
|
{
|
||||||
|
// Socket error.
|
||||||
|
CloseWithError();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
res = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return static_cast<size_t>(res);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSocket::Close()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected)
|
||||||
|
return;
|
||||||
|
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
|
||||||
|
m_multiplexer.RemoveOpenSocket(this);
|
||||||
|
shutdown(m_descriptor, SD_BOTH);
|
||||||
|
closesocket(m_descriptor);
|
||||||
|
m_descriptor = INVALID_SOCKET;
|
||||||
|
m_connected = false;
|
||||||
|
|
||||||
|
OnDisconnected(Error::CreateString("Connection explicitly closed."));
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSocket::CloseWithError()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
DebugAssert(m_connected);
|
||||||
|
|
||||||
|
Error error;
|
||||||
|
const int error_code = WSAGetLastError();
|
||||||
|
if (error_code == 0)
|
||||||
|
error.SetStringView("Connection closed by peer.");
|
||||||
|
else
|
||||||
|
error.SetSocket(error_code);
|
||||||
|
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, 0);
|
||||||
|
m_multiplexer.RemoveOpenSocket(this);
|
||||||
|
closesocket(m_descriptor);
|
||||||
|
m_descriptor = INVALID_SOCKET;
|
||||||
|
m_connected = false;
|
||||||
|
|
||||||
|
OnDisconnected(error);
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSocket::OnReadEvent()
|
||||||
|
{
|
||||||
|
// forward through
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (m_connected)
|
||||||
|
OnRead();
|
||||||
|
}
|
||||||
|
|
||||||
|
void StreamSocket::OnWriteEvent()
|
||||||
|
{
|
||||||
|
// shouldn't be called
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferedStreamSocket::BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||||
|
size_t receive_buffer_size /* = 16384 */,
|
||||||
|
size_t send_buffer_size /* = 16384 */)
|
||||||
|
: StreamSocket(multiplexer, descriptor), m_receive_buffer(receive_buffer_size), m_send_buffer(send_buffer_size)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferedStreamSocket::~BufferedStreamSocket()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::recursive_mutex> BufferedStreamSocket::GetLock()
|
||||||
|
{
|
||||||
|
return std::unique_lock(m_lock);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::span<const u8> BufferedStreamSocket::AcquireReadBuffer() const
|
||||||
|
{
|
||||||
|
return std::span<const u8>(m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferedStreamSocket::ReleaseReadBuffer(size_t bytes_consumed)
|
||||||
|
{
|
||||||
|
DebugAssert(bytes_consumed <= m_receive_buffer_size);
|
||||||
|
m_receive_buffer_offset += static_cast<u32>(bytes_consumed);
|
||||||
|
m_receive_buffer_size -= static_cast<u32>(bytes_consumed);
|
||||||
|
|
||||||
|
// Anything left? If not, reset offset.
|
||||||
|
m_receive_buffer_offset = (m_receive_buffer_size == 0) ? 0 : m_receive_buffer_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::span<u8> BufferedStreamSocket::AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller /* = false */)
|
||||||
|
{
|
||||||
|
// If to get the desired space, we need to move backwards, do so.
|
||||||
|
if ((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) > m_send_buffer.size())
|
||||||
|
{
|
||||||
|
if ((m_send_buffer_size + wanted_bytes) > m_send_buffer.size() && !allow_smaller)
|
||||||
|
{
|
||||||
|
// Not enough space.
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle buffer backwards.
|
||||||
|
std::memmove(m_send_buffer.data(), m_send_buffer.data() + m_send_buffer_offset, m_send_buffer_size);
|
||||||
|
m_send_buffer_offset = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
DebugAssert((m_send_buffer_offset + m_send_buffer_size + wanted_bytes) <= m_send_buffer.size());
|
||||||
|
return std::span<u8>(m_send_buffer.data() + m_send_buffer_offset + m_send_buffer_size,
|
||||||
|
m_send_buffer.size() - m_send_buffer_offset - m_send_buffer_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferedStreamSocket::ReleaseWriteBuffer(size_t bytes_written, bool commit /* = true */)
|
||||||
|
{
|
||||||
|
DebugAssert((m_send_buffer_offset + m_send_buffer_size + bytes_written) <= m_send_buffer.size());
|
||||||
|
m_send_buffer_size += static_cast<u32>(bytes_written);
|
||||||
|
|
||||||
|
// Send as much as we can.
|
||||||
|
if (commit && m_send_buffer_size > 0)
|
||||||
|
{
|
||||||
|
const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset),
|
||||||
|
SIZE_CAST(m_send_buffer_size), 0);
|
||||||
|
if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
CloseWithError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_send_buffer_offset += static_cast<size_t>(res);
|
||||||
|
m_send_buffer_size -= static_cast<size_t>(res);
|
||||||
|
if (m_send_buffer_size == 0)
|
||||||
|
{
|
||||||
|
m_send_buffer_offset = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Register for writes to finish it off.
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN | POLLOUT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BufferedStreamSocket::Read(void* buffer, size_t buffer_size)
|
||||||
|
{
|
||||||
|
// Read from receive buffer.
|
||||||
|
const std::span<const u8> rdbuf = AcquireReadBuffer();
|
||||||
|
if (rdbuf.empty())
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
const size_t bytes_to_read = std::min(rdbuf.size(), buffer_size);
|
||||||
|
std::memcpy(buffer, rdbuf.data(), bytes_to_read);
|
||||||
|
ReleaseReadBuffer(bytes_to_read);
|
||||||
|
return bytes_to_read;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BufferedStreamSocket::Write(const void* buffer, size_t buffer_size)
|
||||||
|
{
|
||||||
|
// Read from receive buffer.
|
||||||
|
const std::span<u8> wrbuf = AcquireWriteBuffer(buffer_size, true);
|
||||||
|
if (wrbuf.empty())
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
const size_t bytes_to_write = std::min(wrbuf.size(), buffer_size);
|
||||||
|
std::memcpy(wrbuf.data(), buffer, bytes_to_write);
|
||||||
|
ReleaseWriteBuffer(bytes_to_write);
|
||||||
|
return bytes_to_write;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t BufferedStreamSocket::WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers)
|
||||||
|
{
|
||||||
|
if (!m_connected || num_buffers == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
size_t total_size = 0;
|
||||||
|
for (size_t i = 0; i < num_buffers; i++)
|
||||||
|
total_size += buffer_lengths[i];
|
||||||
|
|
||||||
|
const std::span<u8> wrbuf = AcquireWriteBuffer(total_size, true);
|
||||||
|
if (wrbuf.empty())
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
size_t written_bytes = 0;
|
||||||
|
for (size_t i = 0; i < num_buffers; i++)
|
||||||
|
{
|
||||||
|
const size_t bytes_to_write = std::min(wrbuf.size() - written_bytes, buffer_lengths[i]);
|
||||||
|
if (bytes_to_write == 0)
|
||||||
|
break;
|
||||||
|
|
||||||
|
std::memcpy(&wrbuf[written_bytes], buffers[i], bytes_to_write);
|
||||||
|
written_bytes += buffer_lengths[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return written_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferedStreamSocket::OnReadEvent()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Pull as many bytes as possible into the read buffer.
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
const size_t buffer_space = m_receive_buffer.size() - m_receive_buffer_offset - m_receive_buffer_size;
|
||||||
|
if (buffer_space == 0) [[unlikely]]
|
||||||
|
{
|
||||||
|
// If we're here again, it means OnRead() didn't consume the data, and we overflowed.
|
||||||
|
ERROR_LOG("Receive buffer overflow, dropping client {}.", GetRemoteAddress().ToString());
|
||||||
|
CloseWithError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ssize_t res = recv(
|
||||||
|
m_descriptor, reinterpret_cast<char*>(m_receive_buffer.data() + m_receive_buffer_offset + m_receive_buffer_size),
|
||||||
|
SIZE_CAST(buffer_space), 0);
|
||||||
|
if (res <= 0 && WSAGetLastError() != WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
CloseWithError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_receive_buffer_size += static_cast<size_t>(res);
|
||||||
|
OnRead();
|
||||||
|
|
||||||
|
// Are we at the end?
|
||||||
|
if ((m_receive_buffer_offset + m_receive_buffer_size) == m_receive_buffer.size())
|
||||||
|
{
|
||||||
|
// Try to claw back some of the buffer, and try reading again.
|
||||||
|
if (m_receive_buffer_offset > 0)
|
||||||
|
{
|
||||||
|
std::memmove(m_receive_buffer.data(), m_receive_buffer.data() + m_receive_buffer_offset, m_receive_buffer_size);
|
||||||
|
m_receive_buffer_offset = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferedStreamSocket::OnWriteEvent()
|
||||||
|
{
|
||||||
|
std::unique_lock lock(m_lock);
|
||||||
|
if (!m_connected)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Send as much as we can.
|
||||||
|
if (m_send_buffer_size > 0)
|
||||||
|
{
|
||||||
|
const ssize_t res = send(m_descriptor, reinterpret_cast<const char*>(m_send_buffer.data() + m_send_buffer_offset),
|
||||||
|
SIZE_CAST(m_send_buffer_size), 0);
|
||||||
|
if (res < 0 && WSAGetLastError() != WSAEWOULDBLOCK)
|
||||||
|
{
|
||||||
|
CloseWithError();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
m_send_buffer_offset += static_cast<size_t>(res);
|
||||||
|
m_send_buffer_size -= static_cast<size_t>(res);
|
||||||
|
if (m_send_buffer_size == 0)
|
||||||
|
m_send_buffer_offset = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
OnWrite();
|
||||||
|
|
||||||
|
if (m_send_buffer_size == 0)
|
||||||
|
{
|
||||||
|
// Are we done? Switch back to reads only.
|
||||||
|
m_multiplexer.SetNotificationMask(this, m_descriptor, POLLIN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BufferedStreamSocket::OnWrite()
|
||||||
|
{
|
||||||
|
}
|
268
src/util/sockets.h
Normal file
268
src/util/sockets.h
Normal file
|
@ -0,0 +1,268 @@
|
||||||
|
// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
|
||||||
|
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common/error.h"
|
||||||
|
#include "common/heap_array.h"
|
||||||
|
#include "common/small_string.h"
|
||||||
|
#include "common/threading.h"
|
||||||
|
#include "common/types.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <optional>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <span>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
using SocketDescriptor = uintptr_t;
|
||||||
|
#else
|
||||||
|
using SocketDescriptor = int;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct pollfd;
|
||||||
|
|
||||||
|
class BaseSocket;
|
||||||
|
class ListenSocket;
|
||||||
|
class StreamSocket;
|
||||||
|
class BufferedStreamSocket;
|
||||||
|
class SocketMultiplexer;
|
||||||
|
|
||||||
|
struct SocketAddress final
|
||||||
|
{
|
||||||
|
enum class Type
|
||||||
|
{
|
||||||
|
Unknown,
|
||||||
|
IPv4,
|
||||||
|
IPv6,
|
||||||
|
Unix,
|
||||||
|
};
|
||||||
|
|
||||||
|
// accessors
|
||||||
|
const void* GetData() const { return m_data; }
|
||||||
|
u32 GetLength() const { return m_length; }
|
||||||
|
|
||||||
|
// parse interface
|
||||||
|
static std::optional<SocketAddress> Parse(Type type, const char* address, u32 port, Error* error);
|
||||||
|
|
||||||
|
// resolve interface
|
||||||
|
static std::optional<SocketAddress> Resolve(const char* address, u32 port, Error* error);
|
||||||
|
|
||||||
|
// to string interface
|
||||||
|
SmallString ToString() const;
|
||||||
|
|
||||||
|
// initializers
|
||||||
|
void SetFromSockaddr(const void* sa, size_t length);
|
||||||
|
|
||||||
|
private:
|
||||||
|
u8 m_data[128] = {};
|
||||||
|
u32 m_length = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BaseSocket : public std::enable_shared_from_this<BaseSocket>
|
||||||
|
{
|
||||||
|
friend SocketMultiplexer;
|
||||||
|
|
||||||
|
public:
|
||||||
|
BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
|
||||||
|
virtual ~BaseSocket();
|
||||||
|
|
||||||
|
ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; }
|
||||||
|
|
||||||
|
virtual void Close() = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void OnReadEvent() = 0;
|
||||||
|
virtual void OnWriteEvent() = 0;
|
||||||
|
|
||||||
|
SocketMultiplexer& m_multiplexer;
|
||||||
|
SocketDescriptor m_descriptor;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SocketMultiplexer final
|
||||||
|
{
|
||||||
|
// TODO: Re-introduce worker threads.
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef std::shared_ptr<StreamSocket> (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer,
|
||||||
|
SocketDescriptor descriptor);
|
||||||
|
friend BaseSocket;
|
||||||
|
friend ListenSocket;
|
||||||
|
friend StreamSocket;
|
||||||
|
friend BufferedStreamSocket;
|
||||||
|
|
||||||
|
public:
|
||||||
|
~SocketMultiplexer();
|
||||||
|
|
||||||
|
// Factory method.
|
||||||
|
static std::unique_ptr<SocketMultiplexer> Create(Error* error);
|
||||||
|
|
||||||
|
// Public interface
|
||||||
|
template<class T>
|
||||||
|
std::shared_ptr<ListenSocket> CreateListenSocket(const SocketAddress& address, Error* error);
|
||||||
|
template<class T>
|
||||||
|
std::shared_ptr<T> ConnectStreamSocket(const SocketAddress& address, Error* error);
|
||||||
|
|
||||||
|
// Returns true if any sockets are currently registered.
|
||||||
|
bool HasAnyOpenSockets();
|
||||||
|
|
||||||
|
// Close all sockets on this multiplexer.
|
||||||
|
void CloseAll();
|
||||||
|
|
||||||
|
// Poll for events. Returns false if there are no sockets registered.
|
||||||
|
bool PollEventsWithTimeout(u32 milliseconds);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Internal interface
|
||||||
|
std::shared_ptr<ListenSocket> InternalCreateListenSocket(const SocketAddress& address,
|
||||||
|
CreateStreamSocketCallback callback, Error* error);
|
||||||
|
std::shared_ptr<StreamSocket> InternalConnectStreamSocket(const SocketAddress& address,
|
||||||
|
CreateStreamSocketCallback callback, Error* error);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Hide the constructor.
|
||||||
|
SocketMultiplexer();
|
||||||
|
|
||||||
|
// Tracking of open sockets.
|
||||||
|
void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
|
||||||
|
void RemoveOpenSocket(BaseSocket* socket);
|
||||||
|
|
||||||
|
// Register for notifications
|
||||||
|
void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// We store the fd in the struct to avoid the cache miss reading the object.
|
||||||
|
using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;
|
||||||
|
|
||||||
|
std::mutex m_poll_array_lock;
|
||||||
|
pollfd* m_poll_array = nullptr;
|
||||||
|
size_t m_poll_array_active_size = 0;
|
||||||
|
size_t m_poll_array_max_size = 0;
|
||||||
|
|
||||||
|
std::mutex m_open_sockets_lock;
|
||||||
|
SocketMap m_open_sockets;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
std::shared_ptr<ListenSocket> SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error)
|
||||||
|
{
|
||||||
|
const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
|
||||||
|
SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
|
||||||
|
return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
|
||||||
|
};
|
||||||
|
return InternalCreateListenSocket(address, callback, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
std::shared_ptr<T> SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error)
|
||||||
|
{
|
||||||
|
const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
|
||||||
|
SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
|
||||||
|
return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
|
||||||
|
};
|
||||||
|
return std::static_pointer_cast<T>(InternalConnectStreamSocket(address, callback, error));
|
||||||
|
}
|
||||||
|
|
||||||
|
class ListenSocket final : public BaseSocket
|
||||||
|
{
|
||||||
|
friend SocketMultiplexer;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
|
||||||
|
SocketMultiplexer::CreateStreamSocketCallback accept_callback);
|
||||||
|
virtual ~ListenSocket() override;
|
||||||
|
|
||||||
|
const SocketAddress* GetLocalAddress() const { return &m_local_address; }
|
||||||
|
u32 GetConnectionsAccepted() const { return m_num_connections_accepted; }
|
||||||
|
|
||||||
|
void Close() override final;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void OnReadEvent() override final;
|
||||||
|
void OnWriteEvent() override final;
|
||||||
|
|
||||||
|
private:
|
||||||
|
SocketMultiplexer::CreateStreamSocketCallback m_accept_callback;
|
||||||
|
SocketAddress m_local_address = {};
|
||||||
|
u32 m_num_connections_accepted = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class StreamSocket : public BaseSocket
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
|
||||||
|
virtual ~StreamSocket() override;
|
||||||
|
|
||||||
|
static u32 GetSocketProtocolForAddress(const SocketAddress& sa);
|
||||||
|
|
||||||
|
virtual void Close() override final;
|
||||||
|
|
||||||
|
// Accessors
|
||||||
|
const SocketAddress& GetLocalAddress() const { return m_local_address; }
|
||||||
|
const SocketAddress& GetRemoteAddress() const { return m_remote_address; }
|
||||||
|
bool IsConnected() const { return m_connected; }
|
||||||
|
|
||||||
|
// Read/write
|
||||||
|
size_t Read(void* buffer, size_t buffer_size);
|
||||||
|
size_t Write(const void* buffer, size_t buffer_size);
|
||||||
|
size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual void OnConnected() = 0;
|
||||||
|
virtual void OnDisconnected(const Error& error) = 0;
|
||||||
|
virtual void OnRead() = 0;
|
||||||
|
|
||||||
|
virtual void OnReadEvent() override;
|
||||||
|
virtual void OnWriteEvent() override;
|
||||||
|
|
||||||
|
void CloseWithError();
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitialSetup();
|
||||||
|
|
||||||
|
SocketAddress m_local_address = {};
|
||||||
|
SocketAddress m_remote_address = {};
|
||||||
|
std::recursive_mutex m_lock;
|
||||||
|
bool m_connected = true;
|
||||||
|
|
||||||
|
// Ugly, but needed in order to call the events.
|
||||||
|
friend SocketMultiplexer;
|
||||||
|
friend ListenSocket;
|
||||||
|
friend BufferedStreamSocket;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BufferedStreamSocket : public StreamSocket
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384,
|
||||||
|
size_t send_buffer_size = 16384);
|
||||||
|
virtual ~BufferedStreamSocket() override;
|
||||||
|
|
||||||
|
// Must hold the lock when not part of OnRead().
|
||||||
|
std::unique_lock<std::recursive_mutex> GetLock();
|
||||||
|
std::span<const u8> AcquireReadBuffer() const;
|
||||||
|
void ReleaseReadBuffer(size_t bytes_consumed);
|
||||||
|
std::span<u8> AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false);
|
||||||
|
void ReleaseWriteBuffer(size_t bytes_written, bool commit = true);
|
||||||
|
|
||||||
|
// Hide StreamSocket read/write methods.
|
||||||
|
size_t Read(void* buffer, size_t buffer_size);
|
||||||
|
size_t Write(const void* buffer, size_t buffer_size);
|
||||||
|
size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void OnReadEvent() override final;
|
||||||
|
void OnWriteEvent() override final;
|
||||||
|
virtual void OnWrite();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<u8> m_receive_buffer;
|
||||||
|
size_t m_receive_buffer_offset = 0;
|
||||||
|
size_t m_receive_buffer_size = 0;
|
||||||
|
|
||||||
|
std::vector<u8> m_send_buffer;
|
||||||
|
size_t m_send_buffer_offset = 0;
|
||||||
|
size_t m_send_buffer_size = 0;
|
||||||
|
};
|
|
@ -87,6 +87,7 @@
|
||||||
<ClInclude Include="sdl_input_source.h" />
|
<ClInclude Include="sdl_input_source.h" />
|
||||||
<ClInclude Include="shadergen.h" />
|
<ClInclude Include="shadergen.h" />
|
||||||
<ClInclude Include="shiftjis.h" />
|
<ClInclude Include="shiftjis.h" />
|
||||||
|
<ClInclude Include="sockets.h" />
|
||||||
<ClInclude Include="state_wrapper.h" />
|
<ClInclude Include="state_wrapper.h" />
|
||||||
<ClInclude Include="cd_xa.h" />
|
<ClInclude Include="cd_xa.h" />
|
||||||
<ClInclude Include="vulkan_builders.h">
|
<ClInclude Include="vulkan_builders.h">
|
||||||
|
@ -204,6 +205,7 @@
|
||||||
<ClCompile Include="shadergen.cpp" />
|
<ClCompile Include="shadergen.cpp" />
|
||||||
<ClCompile Include="shiftjis.cpp" />
|
<ClCompile Include="shiftjis.cpp" />
|
||||||
<ClCompile Include="page_fault_handler.cpp" />
|
<ClCompile Include="page_fault_handler.cpp" />
|
||||||
|
<ClCompile Include="sockets.cpp" />
|
||||||
<ClCompile Include="state_wrapper.cpp" />
|
<ClCompile Include="state_wrapper.cpp" />
|
||||||
<ClCompile Include="cd_xa.cpp" />
|
<ClCompile Include="cd_xa.cpp" />
|
||||||
<ClCompile Include="vulkan_builders.cpp">
|
<ClCompile Include="vulkan_builders.cpp">
|
||||||
|
|
|
@ -72,6 +72,7 @@
|
||||||
<ClInclude Include="opengl_context_egl_x11.h" />
|
<ClInclude Include="opengl_context_egl_x11.h" />
|
||||||
<ClInclude Include="opengl_context_wgl.h" />
|
<ClInclude Include="opengl_context_wgl.h" />
|
||||||
<ClInclude Include="image.h" />
|
<ClInclude Include="image.h" />
|
||||||
|
<ClInclude Include="sockets.h" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ClCompile Include="jit_code_buffer.cpp" />
|
<ClCompile Include="jit_code_buffer.cpp" />
|
||||||
|
@ -153,6 +154,7 @@
|
||||||
<ClCompile Include="opengl_context_wgl.cpp" />
|
<ClCompile Include="opengl_context_wgl.cpp" />
|
||||||
<ClCompile Include="image.cpp" />
|
<ClCompile Include="image.cpp" />
|
||||||
<ClCompile Include="sdl_audio_stream.cpp" />
|
<ClCompile Include="sdl_audio_stream.cpp" />
|
||||||
|
<ClCompile Include="sockets.cpp" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<None Include="metal_shaders.metal" />
|
<None Include="metal_shaders.metal" />
|
||||||
|
|
Loading…
Reference in a new issue