// SPDX-FileCopyrightText: 2015-2024 Connor McLaughlin <stenzek@gmail.com>
// SPDX-License-Identifier: (GPL-3.0 OR CC-BY-NC-ND-4.0)

#pragma once

#include "common/error.h"
#include "common/heap_array.h"
#include "common/small_string.h"
#include "common/threading.h"
#include "common/types.h"

#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <unordered_map>
#include <span>

#ifdef _WIN32
using SocketDescriptor = uintptr_t;
#else
using SocketDescriptor = int;
#endif

struct pollfd;

class BaseSocket;
class ListenSocket;
class StreamSocket;
class BufferedStreamSocket;
class SocketMultiplexer;

struct SocketAddress final
{
  enum class Type
  {
    Unknown,
    IPv4,
    IPv6,
    Unix,
  };

  // accessors
  const void* GetData() const { return m_data; }
  u32 GetLength() const { return m_length; }

  // parse interface
  static std::optional<SocketAddress> Parse(Type type, const char* address, u32 port, Error* error);

  // resolve interface
  static std::optional<SocketAddress> Resolve(const char* address, u32 port, Error* error);

  // to string interface
  SmallString ToString() const;

  // initializers
  void SetFromSockaddr(const void* sa, size_t length);

private:
  u8 m_data[128] = {};
  u32 m_length = 0;
};

class BaseSocket : public std::enable_shared_from_this<BaseSocket>
{
  friend SocketMultiplexer;

public:
  BaseSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
  virtual ~BaseSocket();

  ALWAYS_INLINE SocketDescriptor GetDescriptor() const { return m_descriptor; }

  virtual void Close() = 0;

protected:
  virtual void OnReadEvent() = 0;
  virtual void OnWriteEvent() = 0;

  SocketMultiplexer& m_multiplexer;
  SocketDescriptor m_descriptor;
};

class SocketMultiplexer final
{
  // TODO: Re-introduce worker threads.

public:
  typedef std::shared_ptr<StreamSocket> (*CreateStreamSocketCallback)(SocketMultiplexer& multiplexer,
                                                                      SocketDescriptor descriptor);
  friend BaseSocket;
  friend ListenSocket;
  friend StreamSocket;
  friend BufferedStreamSocket;

public:
  ~SocketMultiplexer();

  // Factory method.
  static std::unique_ptr<SocketMultiplexer> Create(Error* error);

  // Public interface
  template<class T>
  std::shared_ptr<ListenSocket> CreateListenSocket(const SocketAddress& address, Error* error);
  template<class T>
  std::shared_ptr<T> ConnectStreamSocket(const SocketAddress& address, Error* error);

  // Returns true if any sockets are currently registered.
  bool HasAnyOpenSockets();

  // Close all sockets on this multiplexer.
  void CloseAll();

  // Poll for events. Returns false if there are no sockets registered.
  bool PollEventsWithTimeout(u32 milliseconds);

protected:
  // Internal interface
  std::shared_ptr<ListenSocket> InternalCreateListenSocket(const SocketAddress& address,
                                                           CreateStreamSocketCallback callback, Error* error);
  std::shared_ptr<StreamSocket> InternalConnectStreamSocket(const SocketAddress& address,
                                                            CreateStreamSocketCallback callback, Error* error);

private:
  // Hide the constructor.
  SocketMultiplexer();

  // Tracking of open sockets.
  void AddOpenSocket(std::shared_ptr<BaseSocket> socket);
  void RemoveOpenSocket(BaseSocket* socket);

  // Register for notifications
  void SetNotificationMask(BaseSocket* socket, SocketDescriptor descriptor, u32 events);

private:
  // We store the fd in the struct to avoid the cache miss reading the object.
  using SocketMap = std::unordered_map<SocketDescriptor, std::shared_ptr<BaseSocket>>;

  std::mutex m_poll_array_lock;
  pollfd* m_poll_array = nullptr;
  size_t m_poll_array_active_size = 0;
  size_t m_poll_array_max_size = 0;

  std::mutex m_open_sockets_lock;
  SocketMap m_open_sockets;
};

template<class T>
std::shared_ptr<ListenSocket> SocketMultiplexer::CreateListenSocket(const SocketAddress& address, Error* error)
{
  const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
                                                 SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
    return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
  };
  return InternalCreateListenSocket(address, callback, error);
}

template<class T>
std::shared_ptr<T> SocketMultiplexer::ConnectStreamSocket(const SocketAddress& address, Error* error)
{
  const CreateStreamSocketCallback callback = [](SocketMultiplexer& multiplexer,
                                                 SocketDescriptor descriptor) -> std::shared_ptr<StreamSocket> {
    return std::static_pointer_cast<StreamSocket>(std::make_shared<T>(multiplexer, descriptor));
  };
  return std::static_pointer_cast<T>(InternalConnectStreamSocket(address, callback, error));
}

class ListenSocket final : public BaseSocket
{
  friend SocketMultiplexer;

public:
  ListenSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor,
               SocketMultiplexer::CreateStreamSocketCallback accept_callback);
  virtual ~ListenSocket() override;

  const SocketAddress* GetLocalAddress() const { return &m_local_address; }
  u32 GetConnectionsAccepted() const { return m_num_connections_accepted; }

  void Close() override final;

protected:
  void OnReadEvent() override final;
  void OnWriteEvent() override final;

private:
  SocketMultiplexer::CreateStreamSocketCallback m_accept_callback;
  SocketAddress m_local_address = {};
  u32 m_num_connections_accepted = 0;
};

class StreamSocket : public BaseSocket
{
public:
  StreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor);
  virtual ~StreamSocket() override;

  static u32 GetSocketProtocolForAddress(const SocketAddress& sa);

  virtual void Close() override;

  // Accessors
  const SocketAddress& GetLocalAddress() const { return m_local_address; }
  const SocketAddress& GetRemoteAddress() const { return m_remote_address; }
  bool IsConnected() const { return m_connected; }

  // Read/write
  size_t Read(void* buffer, size_t buffer_size);
  size_t Write(const void* buffer, size_t buffer_size);
  size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);

protected:
  virtual void OnConnected() = 0;
  virtual void OnDisconnected(const Error& error) = 0;
  virtual void OnRead() = 0;

  virtual void OnReadEvent() override;
  virtual void OnWriteEvent() override;

  void CloseWithError();

private:
  void InitialSetup();

  SocketAddress m_local_address = {};
  SocketAddress m_remote_address = {};
  std::recursive_mutex m_lock;
  bool m_connected = true;

  // Ugly, but needed in order to call the events.
  friend SocketMultiplexer;
  friend ListenSocket;
  friend BufferedStreamSocket;
};

class BufferedStreamSocket : public StreamSocket
{
public:
  BufferedStreamSocket(SocketMultiplexer& multiplexer, SocketDescriptor descriptor, size_t receive_buffer_size = 16384,
                       size_t send_buffer_size = 16384);
  virtual ~BufferedStreamSocket() override;

  // Must hold the lock when not part of OnRead().
  std::unique_lock<std::recursive_mutex> GetLock();
  std::span<const u8> AcquireReadBuffer() const;
  void ReleaseReadBuffer(size_t bytes_consumed);
  std::span<u8> AcquireWriteBuffer(size_t wanted_bytes, bool allow_smaller = false);
  void ReleaseWriteBuffer(size_t bytes_written, bool commit = true);

  // Hide StreamSocket read/write methods.
  size_t Read(void* buffer, size_t buffer_size);
  size_t Write(const void* buffer, size_t buffer_size);
  size_t WriteVector(const void** buffers, const size_t* buffer_lengths, size_t num_buffers);
  virtual void Close() override;

protected:
  void OnReadEvent() override final;
  void OnWriteEvent() override final;
  virtual void OnWrite();

private:
  std::vector<u8> m_receive_buffer;
  size_t m_receive_buffer_offset = 0;
  size_t m_receive_buffer_size = 0;

  std::vector<u8> m_send_buffer;
  size_t m_send_buffer_offset = 0;
  size_t m_send_buffer_size = 0;
};