// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin // SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0) #pragma once #include "common/error.h" #include "common/heap_array.h" #include "common/small_string.h" #include "common/threading.h" #include "common/types.h" #include #include #include #include #include #include #ifdef _WIN32 using SocketDescriptor = uintptr_t; #else using SocketDescriptor = int; #endif struct pollfd; class BaseSocket; class ListenSocket; class StreamSocket; class BufferedStreamSocket; class SocketMultiplexer; struct SocketAddress final { enum class Type { Unknown, IPv4, IPv6, Unix, }; // accessors const void* GetData() const { return m_data; } u32 GetLength() const { return m_length; } // parse interface static std::optional Parse(Type type, const char* address, u32 port, Error* error); // resolve interface static std::optional Resolve(const char* address, u32 port, Error* error); // to string interface SmallString ToString() const; // initializers void SetFromSockaddr(const void* sa, size_t length); /// Returns true if the address is IP. bool IsIPAddress() const; private: u8 m_data[128] = {}; u32 m_length = 0; }; class BaseSocket : public std::enable_shared_from_this { friend SocketMultiplexer; public: BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); virtual ~BaseSocket(); ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; } virtual void Close() = 0; protected: virtual void OnReadEvent() = 0; virtual void OnWriteEvent() = 0; virtual void OnHangupEvent() = 0; SocketMultiplexer& m_multiplexer; SocketDescriptor m_descriptor; }; class SocketMultiplexer final { // TODO: Re-introduce worker threads. public: typedef std::shared_ptr (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); friend BaseSocket; friend ListenSocket; friend StreamSocket; friend BufferedStreamSocket; public: ~SocketMultiplexer(); // Factory method. static std::unique_ptr Create(Error* error); // Public interface template std::shared_ptr CreateListenSocket(const SocketAddress& address, Error* error); template std::shared_ptr ConnectStreamSocket(const SocketAddress& address, Error* error); // Returns true if any sockets are currently registered. bool HasAnyOpenSockets(); // Returns true if any client sockets are currently connected. bool HasAnyClientSockets(); // Returns the number of current client sockets. size_t GetClientSocketCount(); // Close all sockets on this multiplexer. void CloseAll(); // Poll for events. Returns false if there are no sockets registered. bool PollEventsWithTimeout(u32 milliseconds); protected: // Internal interface std::shared_ptr InternalCreateListenSocket(const SocketAddress& address, CreateStreamSocketCallback callback, Error* error); std::shared_ptr InternalConnectStreamSocket(const SocketAddress& address, CreateStreamSocketCallback callback, Error* error); private: // Hide the constructor. SocketMultiplexer(); // Tracking of open sockets. void AddOpenSocket(std::shared_ptr socket); void AddClientSocket(std::shared_ptr socket); void RemoveOpenSocket(BaseSocket* socket); void RemoveClientSocket(BaseSocket* socket); // Register for notifications void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events); private: // We store the fd in the struct to avoid the cache miss reading the object. using SocketMap = std::unordered_map>; std::mutex m_poll_array_lock; pollfd* m_poll_array = nullptr; size_t m_poll_array_active_size = 0; size_t m_poll_array_max_size = 0; std::mutex m_open_sockets_lock; SocketMap m_open_sockets; std::atomic_size_t m_client_socket_count{0}; }; template std::shared_ptr SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error) { const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, SocketDescriptor descriptor) -> std::shared_ptr { return std::static_pointer_cast(std::make_shared(multiplexer, descriptor)); }; return InternalCreateListenSocket(address, callback, error); } template std::shared_ptr SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error) { const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer, SocketDescriptor descriptor) -> std::shared_ptr { return std::static_pointer_cast(std::make_shared(multiplexer, descriptor)); }; return std::static_pointer_cast(InternalConnectStreamSocket(address, callback, error)); } class ListenSocket final : public BaseSocket { friend SocketMultiplexer; public: ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, SocketMultiplexer::CreateStreamSocketCallback accept_callback); virtual ~ListenSocket() override; const SocketAddress* GetLocalAddress() const { return &m_local_address; } u32 GetConnectionsAccepted() const { return m_num_connections_accepted; } void Close() override final; protected: void OnReadEvent() override final; void OnWriteEvent() override final; void OnHangupEvent() override final; private: SocketMultiplexer::CreateStreamSocketCallback m_accept_callback; SocketAddress m_local_address = {}; u32 m_num_connections_accepted = 0; }; class StreamSocket : public BaseSocket { public: StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor); virtual ~StreamSocket() override; static u32 GetSocketProtocolForAddress(const SocketAddress& sa); virtual void Close() override; // Accessors const SocketAddress& GetLocalAddress() const { return m_local_address; } const SocketAddress& GetRemoteAddress() const { return m_remote_address; } bool IsConnected() const { return m_connected; } // Read/write size_t Read(void* buffer, size_t buffer_size); size_t Write(const void* buffer, size_t buffer_size); size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); /// Disables Nagle's buffering algorithm, i.e. TCP_NODELAY. bool SetNagleBuffering(bool enabled, Error* error = nullptr); protected: virtual void OnConnected() = 0; virtual void OnDisconnected(const Error& error) = 0; virtual void OnRead() = 0; virtual void OnReadEvent() override; virtual void OnWriteEvent() override; virtual void OnHangupEvent() override; void CloseWithError(); private: void InitialSetup(); SocketAddress m_local_address = {}; SocketAddress m_remote_address = {}; std::recursive_mutex m_lock; bool m_connected = true; // Ugly, but needed in order to call the events. friend SocketMultiplexer; friend ListenSocket; friend BufferedStreamSocket; }; class BufferedStreamSocket : public StreamSocket { public: BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384, size_t send_buffer_size = 16384); virtual ~BufferedStreamSocket() override; // Must hold the lock when not part of OnRead(). std::unique_lock GetLock(); std::span AcquireReadBuffer() const; void ReleaseReadBuffer(size_t bytes_consumed); std::span AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false); void ReleaseWriteBuffer(size_t bytes_written, bool commit = true); // Hide StreamSocket read/write methods. size_t Read(void* buffer, size_t buffer_size); size_t Write(const void* buffer, size_t buffer_size); size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers); virtual void Close() override; protected: void OnReadEvent() override final; void OnWriteEvent() override final; virtual void OnWrite(); private: std::vector m_receive_buffer; size_t m_receive_buffer_offset = 0; size_t m_receive_buffer_size = 0; std::vector m_send_buffer; size_t m_send_buffer_offset = 0; size_t m_send_buffer_size = 0; };