From ef524d7dead3f167c96c3f97e95211291a97e76f Mon Sep 17 00:00:00 2001 From: Connor McLaughlin <stenzek@gmail.com> Date: Sun, 21 Feb 2021 18:30:58 +1000 Subject: [PATCH] FrontendCommon: Add HTTPDownloader class --- src/frontend-common/CMakeLists.txt | 13 + src/frontend-common/frontend-common.vcxproj | 25 +- .../frontend-common.vcxproj.filters | 4 + src/frontend-common/http_downloader.cpp | 191 ++++++++++++ src/frontend-common/http_downloader.h | 84 ++++++ .../http_downloader_winhttp.cpp | 285 ++++++++++++++++++ src/frontend-common/http_downloader_winhttp.h | 39 +++ 7 files changed, 623 insertions(+), 18 deletions(-) create mode 100644 src/frontend-common/http_downloader.cpp create mode 100644 src/frontend-common/http_downloader.h create mode 100644 src/frontend-common/http_downloader_winhttp.cpp create mode 100644 src/frontend-common/http_downloader_winhttp.h diff --git a/src/frontend-common/CMakeLists.txt b/src/frontend-common/CMakeLists.txt index 7859af4e2..177000a85 100644 --- a/src/frontend-common/CMakeLists.txt +++ b/src/frontend-common/CMakeLists.txt @@ -91,6 +91,19 @@ if(ENABLE_DISCORD_PRESENCE) target_link_libraries(frontend-common PRIVATE discord-rpc) endif() +if(ENABLE_RETROACHIEVEMENTS) + target_sources(frontend-common PRIVATE + http_downloader.cpp + http_downloader.h + ) + if(WIN32) + target_sources(frontend-common PRIVATE + http_downloader_winhttp.cpp + http_downloader_winhttp.h + ) + endif() +endif() + # Copy the provided data directory to the output directory. add_custom_command(TARGET frontend-common POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_SOURCE_DIR}/data" "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}" diff --git a/src/frontend-common/frontend-common.vcxproj b/src/frontend-common/frontend-common.vcxproj index cd93015fa..3a0d7ddd5 100644 --- a/src/frontend-common/frontend-common.vcxproj +++ b/src/frontend-common/frontend-common.vcxproj @@ -92,6 +92,8 @@ <ClCompile Include="fullscreen_ui_progress_callback.cpp" /> <ClCompile Include="game_list.cpp" /> <ClCompile Include="game_settings.cpp" /> + <ClCompile Include="http_downloader.cpp" /> + <ClCompile Include="http_downloader_wininet.cpp" /> <ClCompile Include="icon.cpp" /> <ClCompile Include="imgui_fullscreen.cpp" /> <ClCompile Include="imgui_impl_dx11.cpp" /> @@ -120,6 +122,8 @@ <ClInclude Include="fullscreen_ui_progress_callback.h" /> <ClInclude Include="game_list.h" /> <ClInclude Include="game_settings.h" /> + <ClInclude Include="http_downloader.h" /> + <ClInclude Include="http_downloader_wininet.h" /> <ClInclude Include="icon.h" /> <ClInclude Include="imgui_fullscreen.h" /> <ClInclude Include="imgui_impl_dx11.h" /> @@ -361,7 +365,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib32-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -392,7 +395,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib32-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -420,7 +422,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -448,7 +449,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -479,7 +479,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -510,7 +509,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64-debug;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -543,9 +541,7 @@ <Lib /> <Lib /> <Lib /> - <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> - </Lib> + <Lib /> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='ReleaseLTCG|Win32'"> <ClCompile> @@ -575,7 +571,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib32;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -608,9 +603,7 @@ <Lib /> <Lib /> <Lib /> - <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> - </Lib> + <Lib /> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|ARM64'"> <ClCompile> @@ -641,9 +634,7 @@ <Lib /> <Lib /> <Lib /> - <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> - </Lib> + <Lib /> </ItemDefinitionGroup> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='ReleaseLTCG|x64'"> <ClCompile> @@ -673,7 +664,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> @@ -705,7 +695,6 @@ </Link> <Lib /> <Lib> - <AdditionalDependencies>SDL2.lib;%(AdditionalDependencies)</AdditionalDependencies> <AdditionalLibraryDirectories>$(SolutionDir)\dep\msvc\lib64;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories> </Lib> </ItemDefinitionGroup> diff --git a/src/frontend-common/frontend-common.vcxproj.filters b/src/frontend-common/frontend-common.vcxproj.filters index 2665afae2..22e30be64 100644 --- a/src/frontend-common/frontend-common.vcxproj.filters +++ b/src/frontend-common/frontend-common.vcxproj.filters @@ -27,6 +27,8 @@ <ClCompile Include="imgui_fullscreen.cpp" /> <ClCompile Include="fullscreen_ui.cpp" /> <ClCompile Include="fullscreen_ui_progress_callback.cpp" /> + <ClCompile Include="http_downloader_wininet.cpp" /> + <ClCompile Include="http_downloader.cpp" /> </ItemGroup> <ItemGroup> <ClInclude Include="icon.h" /> @@ -55,6 +57,8 @@ <ClInclude Include="imgui_fullscreen.h" /> <ClInclude Include="fullscreen_ui.h" /> <ClInclude Include="fullscreen_ui_progress_callback.h" /> + <ClInclude Include="http_downloader_wininet.h" /> + <ClInclude Include="http_downloader.h" /> </ItemGroup> <ItemGroup> <None Include="font_roboto_regular.inl" /> diff --git a/src/frontend-common/http_downloader.cpp b/src/frontend-common/http_downloader.cpp new file mode 100644 index 000000000..7e5ff5f26 --- /dev/null +++ b/src/frontend-common/http_downloader.cpp @@ -0,0 +1,191 @@ +#include "http_downloader.h" +#include "common/assert.h" +#include "common/log.h" +#include "common/timer.h" +Log_SetChannel(HTTPDownloader); + +static constexpr char DEFAULT_USER_AGENT[] = + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:85.0) Gecko/20100101 Firefox/85.0"; +static constexpr float DEFAULT_TIMEOUT_IN_SECONDS = 30; +static constexpr u32 DEFAULT_MAX_ACTIVE_REQUESTS = 4; + +namespace FrontendCommon { + +HTTPDownloader::HTTPDownloader() + : m_user_agent(DEFAULT_USER_AGENT), m_timeout(DEFAULT_TIMEOUT_IN_SECONDS), + m_max_active_requests(DEFAULT_MAX_ACTIVE_REQUESTS) +{ +} + +HTTPDownloader::~HTTPDownloader() = default; + +void HTTPDownloader::SetUserAgent(std::string name) +{ + m_user_agent = std::move(name); +} + +void HTTPDownloader::SetTimeout(float timeout) +{ + m_timeout = timeout; +} + +void HTTPDownloader::SetMaxActiveRequests(u32 max_active_requests) +{ + Assert(max_active_requests > 0); + m_max_active_requests = max_active_requests; +} + +void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback) +{ + Request* req = InternalCreateRequest(); + req->parent = this; + req->type = Request::Type::Get; + req->url = std::move(url); + req->callback = std::move(callback); + req->start_time = Common::Timer::GetValue(); + + std::unique_lock<std::mutex> lock(m_pending_http_request_lock); + if (LockedGetActiveRequestCount() < m_max_active_requests) + { + if (!StartRequest(req)) + return; + } + + LockedAddRequest(req); +} + +void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback) +{ + Request* req = InternalCreateRequest(); + req->parent = this; + req->type = Request::Type::Post; + req->url = std::move(url); + req->post_data = std::move(post_data); + req->callback = std::move(callback); + req->start_time = Common::Timer::GetValue(); + + std::unique_lock<std::mutex> lock(m_pending_http_request_lock); + if (LockedGetActiveRequestCount() < m_max_active_requests) + { + if (!StartRequest(req)) + return; + } + + LockedAddRequest(req); +} + +void HTTPDownloader::LockedPollRequests(std::unique_lock<std::mutex>& lock) +{ + if (m_pending_http_requests.empty()) + return; + + InternalPollRequests(); + + const Common::Timer::Value current_time = Common::Timer::GetValue(); + u32 active_requests = 0; + u32 unstarted_requests = 0; + + for (size_t index = 0; index < m_pending_http_requests.size();) + { + Request* req = m_pending_http_requests[index]; + if (req->state == Request::State::Pending) + { + unstarted_requests++; + index++; + continue; + } + + if (req->state == Request::State::Started && current_time >= req->start_time && + Common::Timer::ConvertValueToSeconds(current_time - req->start_time) >= m_timeout) + { + // request timed out + Log_ErrorPrintf("Request for '%s' timed out", req->url.c_str()); + + req->state.store(Request::State::Cancelled); + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + lock.unlock(); + + req->callback(-1, Request::Data()); + + CloseRequest(req); + + lock.lock(); + continue; + } + + if (req->state != Request::State::Complete) + { + active_requests++; + index++; + continue; + } + + // request complete + Log_VerbosePrintf("Request for '%s' complete, returned status code %u and %zu bytes", req->url.c_str(), + req->status_code, req->data.size()); + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + + // run callback with lock unheld + lock.unlock(); + req->callback(req->status_code, req->data); + CloseRequest(req); + lock.lock(); + } + + // start new requests when we finished some + if (unstarted_requests > 0 && active_requests < m_max_active_requests) + { + for (size_t index = 0; index < m_pending_http_requests.size();) + { + Request* req = m_pending_http_requests[index]; + if (req->state != Request::State::Pending) + { + index++; + continue; + } + + if (!StartRequest(req)) + { + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + continue; + } + + active_requests++; + index++; + + if (active_requests >= m_max_active_requests) + break; + } + } +} + +void HTTPDownloader::PollRequests() +{ + std::unique_lock<std::mutex> lock(m_pending_http_request_lock); + LockedPollRequests(lock); +} + +void HTTPDownloader::WaitForAllRequests() +{ + std::unique_lock<std::mutex> lock(m_pending_http_request_lock); + while (!m_pending_http_requests.empty()) + LockedPollRequests(lock); +} + +void HTTPDownloader::LockedAddRequest(Request* request) +{ + m_pending_http_requests.push_back(request); +} + +u32 HTTPDownloader::LockedGetActiveRequestCount() +{ + u32 count = 0; + for (Request* req : m_pending_http_requests) + { + if (req->state == Request::State::Started || req->state == Request::State::Receiving) + count++; + } + return count; +} + +} // namespace FrontendCommon \ No newline at end of file diff --git a/src/frontend-common/http_downloader.h b/src/frontend-common/http_downloader.h new file mode 100644 index 000000000..fafd4a849 --- /dev/null +++ b/src/frontend-common/http_downloader.h @@ -0,0 +1,84 @@ +#pragma once +#include "common/types.h" +#include <atomic> +#include <functional> +#include <mutex> +#include <string> +#include <vector> + +namespace FrontendCommon { + +class HTTPDownloader +{ +public: + enum : s32 + { + HTTP_OK = 200 + }; + + struct Request + { + using Data = std::vector<u8>; + using Callback = std::function<void(s32 status_code, const Data& data)>; + + enum class Type + { + Get, + Post, + }; + + enum class State + { + Pending, + Cancelled, + Started, + Receiving, + Complete, + }; + + HTTPDownloader* parent; + Callback callback; + std::string url; + std::string post_data; + Data data; + u64 start_time; + s32 status_code = 0; + u32 content_length = 0; + Type type = Type::Get; + std::atomic<State> state{State::Pending}; + }; + + HTTPDownloader(); + virtual ~HTTPDownloader(); + + static std::unique_ptr<HTTPDownloader> Create(); + + void SetUserAgent(std::string name); + void SetTimeout(float timeout); + void SetMaxActiveRequests(u32 max_active_requests); + + void CreateRequest(std::string url, Request::Callback callback); + void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback); + void PollRequests(); + void WaitForAllRequests(); + +protected: + virtual Request* InternalCreateRequest() = 0; + virtual void InternalPollRequests() = 0; + + virtual bool StartRequest(Request* request) = 0; + virtual void CloseRequest(Request* request) = 0; + + void LockedAddRequest(Request* request); + u32 LockedGetActiveRequestCount(); + void LockedPollRequests(std::unique_lock<std::mutex>& lock); + + std::string m_user_agent; + float m_timeout; + u32 m_max_active_requests; + + std::mutex m_pending_http_request_lock; + std::vector<Request*> m_pending_http_requests; +}; + +} // namespace FrontendCommon \ No newline at end of file diff --git a/src/frontend-common/http_downloader_winhttp.cpp b/src/frontend-common/http_downloader_winhttp.cpp new file mode 100644 index 000000000..7f49bc26c --- /dev/null +++ b/src/frontend-common/http_downloader_winhttp.cpp @@ -0,0 +1,285 @@ +#include "http_downloader_winhttp.h" +#include "common/assert.h" +#include "common/log.h" +#include "common/string_util.h" +#include "common/timer.h" +#include <algorithm> +Log_SetChannel(HTTPDownloaderWinHttp); + +#pragma comment(lib, "winhttp.lib") + +namespace FrontendCommon { + +HTTPDownloaderWinHttp::HTTPDownloaderWinHttp() : HTTPDownloader() {} + +HTTPDownloaderWinHttp::~HTTPDownloaderWinHttp() +{ + if (m_hSession) + { + WinHttpSetStatusCallback(m_hSession, nullptr, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, NULL); + WinHttpCloseHandle(m_hSession); + } +} + +std::unique_ptr<HTTPDownloader> HTTPDownloader::Create() +{ + std::unique_ptr<HTTPDownloaderWinHttp> instance(std::make_unique<HTTPDownloaderWinHttp>()); + if (!instance->Initialize()) + return {}; + + return instance; +} + +bool HTTPDownloaderWinHttp::Initialize() +{ + m_hSession = WinHttpOpen(StringUtil::UTF8StringToWideString(m_user_agent).c_str(), + WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY, nullptr, nullptr, WINHTTP_FLAG_ASYNC); + if (m_hSession == NULL) + return false; + + const DWORD notification_flags = WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS | WINHTTP_CALLBACK_FLAG_REQUEST_ERROR | + WINHTTP_CALLBACK_FLAG_HANDLES | WINHTTP_CALLBACK_FLAG_SECURE_FAILURE; + WinHttpSetStatusCallback(m_hSession, HTTPStatusCallback, notification_flags, NULL); + return true; +} + +void CALLBACK HTTPDownloaderWinHttp::HTTPStatusCallback(HINTERNET hRequest, DWORD_PTR dwContext, DWORD dwInternetStatus, + LPVOID lpvStatusInformation, DWORD dwStatusInformationLength) +{ + Request* req = reinterpret_cast<Request*>(dwContext); + switch (dwInternetStatus) + { + case WINHTTP_CALLBACK_STATUS_HANDLE_CREATED: + return; + + case WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING: + { + if (!req) + return; + + DebugAssert(hRequest == req->hRequest); + + HTTPDownloaderWinHttp* parent = static_cast<HTTPDownloaderWinHttp*>(req->parent); + std::unique_lock<std::mutex> lock(parent->m_pending_http_request_lock); + Assert(std::none_of(parent->m_pending_http_requests.begin(), parent->m_pending_http_requests.end(), + [req](HTTPDownloader::Request* it) { return it == req; })); + + // we can clean up the connection as well + DebugAssert(req->hConnection != NULL); + WinHttpCloseHandle(req->hConnection); + delete req; + return; + } + + case WINHTTP_CALLBACK_STATUS_REQUEST_ERROR: + { + const WINHTTP_ASYNC_RESULT* res = reinterpret_cast<const WINHTTP_ASYNC_RESULT*>(lpvStatusInformation); + Log_ErrorPrintf("WinHttp async function %p returned error %u", res->dwResult, res->dwError); + req->status_code = -1; + req->state.store(Request::State::Complete); + return; + } + case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE: + { + Log_DevPrintf("SendRequest complete"); + if (!WinHttpReceiveResponse(hRequest, nullptr)) + { + Log_ErrorPrintf("WinHttpReceiveResponse() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE: + { + Log_DevPrintf("Headers available"); + + DWORD buffer_size = sizeof(req->status_code); + if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, &req->status_code, &buffer_size, WINHTTP_NO_HEADER_INDEX)) + { + Log_ErrorPrintf("WinHttpQueryHeaders() for status code failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + return; + } + + buffer_size = sizeof(req->content_length); + if (!WinHttpQueryHeaders(hRequest, WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER, + WINHTTP_HEADER_NAME_BY_INDEX, &req->content_length, &buffer_size, + WINHTTP_NO_HEADER_INDEX)) + { + if (GetLastError() != ERROR_WINHTTP_HEADER_NOT_FOUND) + Log_WarningPrintf("WinHttpQueryHeaders() for content length failed: %u", GetLastError()); + + req->content_length = 0; + } + + Log_DevPrintf("Status code %d, content-length is %u", req->status_code, req->content_length); + req->data.reserve(req->content_length); + req->state = Request::State::Receiving; + + // start reading + if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING) + { + Log_ErrorPrintf("WinHttpQueryDataAvailable() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE: + { + DWORD bytes_available; + std::memcpy(&bytes_available, lpvStatusInformation, sizeof(bytes_available)); + if (bytes_available == 0) + { + // end of request + Log_DevPrintf("End of request '%s', %zu bytes received", req->url.c_str(), req->data.size()); + req->state.store(Request::State::Complete); + return; + } + + // start the transfer + Log_DevPrintf("%u bytes available", bytes_available); + req->io_position = static_cast<u32>(req->data.size()); + req->data.resize(req->io_position + bytes_available); + if (!WinHttpReadData(hRequest, req->data.data() + req->io_position, bytes_available, nullptr) && + GetLastError() != ERROR_IO_PENDING) + { + Log_ErrorPrintf("WinHttpReadData() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + case WINHTTP_CALLBACK_STATUS_READ_COMPLETE: + { + Log_DevPrintf("Read of %u complete", dwStatusInformationLength); + + const u32 new_size = req->io_position + dwStatusInformationLength; + Assert(new_size <= req->data.size()); + req->data.resize(new_size); + req->start_time = Common::Timer::GetValue(); + + if (!WinHttpQueryDataAvailable(hRequest, nullptr) && GetLastError() != ERROR_IO_PENDING) + { + Log_ErrorPrintf("WinHttpQueryDataAvailable() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + return; + } + default: + // unhandled, ignore + return; + } +} + +HTTPDownloader::Request* HTTPDownloaderWinHttp::InternalCreateRequest() +{ + Request* req = new Request(); + return req; +} + +void HTTPDownloaderWinHttp::InternalPollRequests() +{ + // noop - it uses windows's worker threads +} + +bool HTTPDownloaderWinHttp::StartRequest(HTTPDownloader::Request* request) +{ + Request* req = static_cast<Request*>(request); + + std::wstring host_name; + host_name.resize(req->url.size()); + req->object_name.resize(req->url.size()); + + URL_COMPONENTSW uc = {}; + uc.dwStructSize = sizeof(uc); + uc.lpszHostName = host_name.data(); + uc.dwHostNameLength = static_cast<DWORD>(host_name.size()); + uc.lpszUrlPath = req->object_name.data(); + uc.dwUrlPathLength = static_cast<DWORD>(req->object_name.size()); + + const std::wstring url_wide(StringUtil::UTF8StringToWideString(req->url)); + if (!WinHttpCrackUrl(url_wide.c_str(), static_cast<DWORD>(url_wide.size()), 0, &uc)) + { + Log_ErrorPrintf("WinHttpCrackUrl() failed: %u", GetLastError()); + req->callback(-1, req->data); + delete req; + return false; + } + + host_name.resize(uc.dwHostNameLength); + req->object_name.resize(uc.dwUrlPathLength); + + req->hConnection = WinHttpConnect(m_hSession, host_name.c_str(), uc.nPort, 0); + if (!req->hConnection) + { + Log_ErrorPrintf("Failed to start HTTP request for '%s': %u", req->url.c_str(), GetLastError()); + req->callback(-1, req->data); + delete req; + return false; + } + + const DWORD request_flags = uc.nScheme == INTERNET_SCHEME_HTTPS ? WINHTTP_FLAG_SECURE : 0; + req->hRequest = + WinHttpOpenRequest(req->hConnection, (req->type == HTTPDownloader::Request::Type::Post) ? L"POST" : L"GET", + req->object_name.c_str(), NULL, NULL, NULL, request_flags); + if (!req->hRequest) + { + Log_ErrorPrintf("WinHttpOpenRequest() failed: %u", GetLastError()); + WinHttpCloseHandle(req->hConnection); + return false; + } + + BOOL result; + if (req->type == HTTPDownloader::Request::Type::Post) + { + result = WinHttpSendRequest(req->hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, req->post_data.data(), + static_cast<DWORD>(req->post_data.size()), static_cast<DWORD>(req->post_data.size()), + reinterpret_cast<DWORD_PTR>(req)); + } + else + { + result = WinHttpSendRequest(req->hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0, + reinterpret_cast<DWORD_PTR>(req)); + } + + if (!result && GetLastError() != ERROR_IO_PENDING) + { + Log_ErrorPrintf("WinHttpSendRequest() failed: %u", GetLastError()); + req->status_code = -1; + req->state.store(Request::State::Complete); + } + + Log_DevPrintf("Started HTTP request for '%s'", req->url.c_str()); + req->state = Request::State::Started; + req->start_time = Common::Timer::GetValue(); + return true; +} + +void HTTPDownloaderWinHttp::CloseRequest(HTTPDownloader::Request* request) +{ + Request* req = static_cast<Request*>(request); + + if (req->hRequest != NULL) + { + // req will be freed by the callback. + // the callback can fire immediately here if there's nothing running async, so don't touch req afterwards + WinHttpCloseHandle(req->hRequest); + return; + } + + if (req->hConnection != NULL) + WinHttpCloseHandle(req->hConnection); + + delete req; +} + +} // namespace FrontendCommon \ No newline at end of file diff --git a/src/frontend-common/http_downloader_winhttp.h b/src/frontend-common/http_downloader_winhttp.h new file mode 100644 index 000000000..1f5849a13 --- /dev/null +++ b/src/frontend-common/http_downloader_winhttp.h @@ -0,0 +1,39 @@ +#pragma once +#include "http_downloader.h" + +#include "common/windows_headers.h" + +#include <winhttp.h> + +namespace FrontendCommon { + +class HTTPDownloaderWinHttp final : public HTTPDownloader +{ +public: + HTTPDownloaderWinHttp(); + ~HTTPDownloaderWinHttp() override; + + bool Initialize(); + +protected: + Request* InternalCreateRequest() override; + void InternalPollRequests() override; + bool StartRequest(HTTPDownloader::Request* request) override; + void CloseRequest(HTTPDownloader::Request* request) override; + +private: + struct Request : HTTPDownloader::Request + { + std::wstring object_name; + HINTERNET hConnection = NULL; + HINTERNET hRequest = NULL; + u32 io_position = 0; + }; + + static void CALLBACK HTTPStatusCallback(HINTERNET hInternet, DWORD_PTR dwContext, DWORD dwInternetStatus, + LPVOID lpvStatusInformation, DWORD dwStatusInformationLength); + + HINTERNET m_hSession = NULL; +}; + +} // namespace FrontendCommon \ No newline at end of file