From cc6f22163cedac2a6d5974b2fb8b4c1029d9564c Mon Sep 17 00:00:00 2001 From: Stenzek Date: Fri, 24 Nov 2023 15:54:43 +1000 Subject: [PATCH] HTTPDownloader: Add support for progress updates/cancelling --- src/core/achievements.cpp | 14 ++----- src/core/game_list.cpp | 4 +- src/core/host.cpp | 7 ++++ src/util/host.h | 3 ++ src/util/http_downloader.cpp | 59 ++++++++++++++++++++++------ src/util/http_downloader.h | 16 +++++--- src/util/http_downloader_curl.cpp | 20 +++++++--- src/util/http_downloader_curl.h | 2 +- src/util/http_downloader_winhttp.cpp | 6 +-- src/util/http_downloader_winhttp.h | 2 +- 10 files changed, 92 insertions(+), 41 deletions(-) diff --git a/src/core/achievements.cpp b/src/core/achievements.cpp index 9eb6e251f..71eab67a6 100644 --- a/src/core/achievements.cpp +++ b/src/core/achievements.cpp @@ -125,7 +125,6 @@ static void ReportRCError(int err, fmt::format_string fmt, T&&... args); static void EnsureCacheDirectoriesExist(); static void ClearGameInfo(); static void ClearGameHash(); -static std::string GetUserAgent(); static std::string GetGameHash(CDImage* image); static void SetHardcoreMode(bool enabled, bool force_display_message); static bool IsLoggedInOrLoggingIn(); @@ -241,11 +240,6 @@ const rc_client_user_game_summary_t& Achievements::GetGameSummary() return s_game_summary; } -std::string Achievements::GetUserAgent() -{ - return fmt::format("DuckStation for {} ({}) {}", TARGET_OS_STR, CPU_ARCH_STR, g_scm_tag_str); -} - void Achievements::ReportError(const std::string_view& sv) { std::string error = fmt::format("Achievements error: {}", sv); @@ -311,7 +305,7 @@ std::string Achievements::GetGameHash(CDImage* image) void Achievements::DownloadImage(std::string url, std::string cache_filename) { - auto callback = [cache_filename](s32 status_code, std::string content_type, HTTPDownloader::Request::Data data) { + auto callback = [cache_filename](s32 status_code, const std::string& content_type, HTTPDownloader::Request::Data data) { if (status_code != HTTPDownloader::HTTP_STATUS_OK) return; @@ -437,7 +431,7 @@ bool Achievements::Initialize() bool Achievements::CreateClient(rc_client_t** client, std::unique_ptr* http) { - *http = HTTPDownloader::Create(GetUserAgent().c_str()); + *http = HTTPDownloader::Create(Host::GetHTTPUserAgent()); if (!*http) { Host::ReportErrorAsync("Achievements Error", "Failed to create HTTPDownloader, cannot use achievements"); @@ -621,7 +615,7 @@ uint32_t Achievements::ClientReadMemory(uint32_t address, uint8_t* buffer, uint3 void Achievements::ClientServerCall(const rc_api_request_t* request, rc_client_server_callback_t callback, void* callback_data, rc_client_t* client) { - HTTPDownloader::Request::Callback hd_callback = [callback, callback_data](s32 status_code, std::string content_type, + HTTPDownloader::Request::Callback hd_callback = [callback, callback_data](s32 status_code, const std::string& content_type, HTTPDownloader::Request::Data data) { rc_api_server_response_t rr; rr.http_status_code = (status_code <= 0) ? (status_code == HTTPDownloader::HTTP_STATUS_CANCELLED ? @@ -2959,7 +2953,7 @@ void Achievements::SwitchToRAIntegration() void Achievements::RAIntegration::InitializeRAIntegration(void* main_window_handle) { RA_InitClient((HWND)main_window_handle, "DuckStation", g_scm_tag_str); - RA_SetUserAgentDetail(Achievements::GetUserAgent().c_str()); + RA_SetUserAgentDetail(Host::GetHTTPUserAgent().c_str()); RA_InstallSharedFunctions(RACallbackIsActive, RACallbackCauseUnpause, RACallbackCausePause, RACallbackRebuildMenu, RACallbackEstimateTitle, RACallbackResetEmulator, RACallbackLoadROM); diff --git a/src/core/game_list.cpp b/src/core/game_list.cpp index d18984eeb..32fff7e04 100644 --- a/src/core/game_list.cpp +++ b/src/core/game_list.cpp @@ -1139,7 +1139,7 @@ bool GameList::DownloadCovers(const std::vector& url_templates, boo return false; } - std::unique_ptr downloader(HTTPDownloader::Create()); + std::unique_ptr downloader(HTTPDownloader::Create(Host::GetHTTPUserAgent())); if (!downloader) { progress->DisplayError("Failed to create HTTP downloader."); @@ -1171,7 +1171,7 @@ bool GameList::DownloadCovers(const std::vector& url_templates, boo std::string filename(HTTPDownloader::URLDecode(url)); downloader->CreateRequest( std::move(url), [use_serial, &save_callback, entry_path = std::move(entry_path), filename = std::move(filename)]( - s32 status_code, std::string content_type, HTTPDownloader::Request::Data data) { + s32 status_code, const std::string& content_type, HTTPDownloader::Request::Data data) { if (status_code != HTTPDownloader::HTTP_STATUS_OK || data.empty()) return; diff --git a/src/core/host.cpp b/src/core/host.cpp index c7bd861c6..0797384de 100644 --- a/src/core/host.cpp +++ b/src/core/host.cpp @@ -8,6 +8,8 @@ #include "shader_cache_version.h" #include "system.h" +#include "scmversion/scmversion.h" + #include "util/gpu_device.h" #include "util/imgui_manager.h" @@ -217,6 +219,11 @@ void Host::Internal::SetInputSettingsLayer(SettingsInterface* sif) s_layered_settings_interface.SetLayer(LayeredSettingsInterface::LAYER_INPUT, sif); } +std::string Host::GetHTTPUserAgent() +{ + return fmt::format("DuckStation for {} ({}) {}", TARGET_OS_STR, CPU_ARCH_STR, g_scm_tag_str); +} + void Host::ReportFormattedDebuggerMessage(const char* format, ...) { std::va_list ap; diff --git a/src/util/host.h b/src/util/host.h index 27bd4d750..707b8b890 100644 --- a/src/util/host.h +++ b/src/util/host.h @@ -33,6 +33,9 @@ void ReportFormattedErrorAsync(const std::string_view& title, const char* format bool ConfirmMessage(const std::string_view& title, const std::string_view& message); bool ConfirmFormattedMessage(const std::string_view& title, const char* format, ...); +/// Returns the user agent to use for HTTP requests. +std::string GetHTTPUserAgent(); + /// Opens a URL, using the default application. void OpenURL(const std::string_view& url); diff --git a/src/util/http_downloader.cpp b/src/util/http_downloader.cpp index 802be317d..d24cfe2b9 100644 --- a/src/util/http_downloader.cpp +++ b/src/util/http_downloader.cpp @@ -5,6 +5,7 @@ #include "common/assert.h" #include "common/log.h" +#include "common/progress_callback.h" #include "common/string_util.h" #include "common/timer.h" @@ -34,13 +35,14 @@ void HTTPDownloader::SetMaxActiveRequests(u32 max_active_requests) m_max_active_requests = max_active_requests; } -void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback) +void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback, ProgressCallback* progress) { Request* req = InternalCreateRequest(); req->parent = this; req->type = Request::Type::Get; req->url = std::move(url); req->callback = std::move(callback); + req->progress = progress; req->start_time = Common::Timer::GetCurrentValue(); std::unique_lock lock(m_pending_http_request_lock); @@ -53,7 +55,8 @@ void HTTPDownloader::CreateRequest(std::string url, Request::Callback callback) LockedAddRequest(req); } -void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback) +void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, Request::Callback callback, + ProgressCallback* progress) { Request* req = InternalCreateRequest(); req->parent = this; @@ -61,6 +64,7 @@ void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, R req->url = std::move(url); req->post_data = std::move(post_data); req->callback = std::move(callback); + req->progress = progress; req->start_time = Common::Timer::GetCurrentValue(); std::unique_lock lock(m_pending_http_request_lock); @@ -73,12 +77,6 @@ void HTTPDownloader::CreatePostRequest(std::string url, std::string post_data, R LockedAddRequest(req); } -bool HTTPDownloader::HasAnyRequests() -{ - std::unique_lock lock(m_pending_http_request_lock); - return !m_pending_http_requests.empty(); -} - void HTTPDownloader::LockedPollRequests(std::unique_lock& lock) { if (m_pending_http_requests.empty()) @@ -100,11 +98,12 @@ void HTTPDownloader::LockedPollRequests(std::unique_lock& lock) continue; } - if (req->state == Request::State::Started && current_time >= req->start_time && + if ((req->state == Request::State::Started || req->state == Request::State::Receiving) && + current_time >= req->start_time && Common::Timer::ConvertValueToSeconds(current_time - req->start_time) >= m_timeout) { // request timed out - Log_ErrorPrintf("Request for '%s' timed out", req->url.c_str()); + Log_ErrorFmt("Request for '{}' timed out", req->url); req->state.store(Request::State::Cancelled); m_pending_http_requests.erase(m_pending_http_requests.begin() + index); @@ -117,22 +116,50 @@ void HTTPDownloader::LockedPollRequests(std::unique_lock& lock) lock.lock(); continue; } + else if ((req->state == Request::State::Started || req->state == Request::State::Receiving) && req->progress && + req->progress->IsCancelled()) + { + // request timed out + Log_ErrorFmt("Request for '{}' cancelled", req->url); + + req->state.store(Request::State::Cancelled); + m_pending_http_requests.erase(m_pending_http_requests.begin() + index); + lock.unlock(); + + req->callback(HTTP_STATUS_CANCELLED, std::string(), Request::Data()); + + CloseRequest(req); + + lock.lock(); + continue; + } if (req->state != Request::State::Complete) { + if (req->progress) + { + const u32 size = static_cast(req->data.size()); + if (size != req->last_progress_update) + { + req->last_progress_update = size; + req->progress->SetProgressRange(req->content_length); + req->progress->SetProgressValue(req->last_progress_update); + } + } + active_requests++; index++; continue; } // request complete - Log_VerbosePrintf("Request for '%s' complete, returned status code %u and %zu bytes", req->url.c_str(), - req->status_code, req->data.size()); + Log_VerboseFmt("Request for '{}' complete, returned status code {} and {} bytes", req->url, req->status_code, + req->data.size()); m_pending_http_requests.erase(m_pending_http_requests.begin() + index); // run callback with lock unheld lock.unlock(); - req->callback(req->status_code, std::move(req->content_type), std::move(req->data)); + req->callback(req->status_code, req->content_type, std::move(req->data)); CloseRequest(req); lock.lock(); } @@ -197,6 +224,12 @@ u32 HTTPDownloader::LockedGetActiveRequestCount() return count; } +bool HTTPDownloader::HasAnyRequests() +{ + std::unique_lock lock(m_pending_http_request_lock); + return !m_pending_http_requests.empty(); +} + std::string HTTPDownloader::URLEncode(const std::string_view& str) { std::string ret; diff --git a/src/util/http_downloader.h b/src/util/http_downloader.h index 13038c393..45e5b6d24 100644 --- a/src/util/http_downloader.h +++ b/src/util/http_downloader.h @@ -13,6 +13,8 @@ #include #include +class ProgressCallback; + class HTTPDownloader { public: @@ -27,7 +29,7 @@ public: struct Request { using Data = std::vector; - using Callback = std::function; + using Callback = std::function; enum class Type { @@ -46,6 +48,7 @@ public: HTTPDownloader* parent; Callback callback; + ProgressCallback* progress; std::string url; std::string post_data; std::string content_type; @@ -53,6 +56,7 @@ public: u64 start_time; s32 status_code = 0; u32 content_length = 0; + u32 last_progress_update = 0; Type type = Type::Get; std::atomic state{State::Pending}; }; @@ -60,7 +64,7 @@ public: HTTPDownloader(); virtual ~HTTPDownloader(); - static std::unique_ptr Create(const char* user_agent = DEFAULT_USER_AGENT); + static std::unique_ptr Create(std::string user_agent = DEFAULT_USER_AGENT); static std::string URLEncode(const std::string_view& str); static std::string URLDecode(const std::string_view& str); static std::string GetExtensionForContentType(const std::string& content_type); @@ -68,12 +72,12 @@ public: void SetTimeout(float timeout); void SetMaxActiveRequests(u32 max_active_requests); - void CreateRequest(std::string url, Request::Callback callback); - void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback); - - bool HasAnyRequests(); + void CreateRequest(std::string url, Request::Callback callback, ProgressCallback* progress = nullptr); + void CreatePostRequest(std::string url, std::string post_data, Request::Callback callback, + ProgressCallback* progress = nullptr); void PollRequests(); void WaitForAllRequests(); + bool HasAnyRequests(); static const char DEFAULT_USER_AGENT[]; diff --git a/src/util/http_downloader_curl.cpp b/src/util/http_downloader_curl.cpp index ecd62f8ff..617a77e74 100644 --- a/src/util/http_downloader_curl.cpp +++ b/src/util/http_downloader_curl.cpp @@ -25,10 +25,10 @@ HTTPDownloaderCurl::~HTTPDownloaderCurl() curl_multi_cleanup(m_multi_handle); } -std::unique_ptr HTTPDownloader::Create(const char* user_agent) +std::unique_ptr HTTPDownloader::Create(std::string user_agent) { std::unique_ptr instance(std::make_unique()); - if (!instance->Initialize(user_agent)) + if (!instance->Initialize(std::move(user_agent))) return {}; return instance; @@ -37,7 +37,7 @@ std::unique_ptr HTTPDownloader::Create(const char* user_agent) static bool s_curl_initialized = false; static std::once_flag s_curl_initialized_once_flag; -bool HTTPDownloaderCurl::Initialize(const char* user_agent) +bool HTTPDownloaderCurl::Initialize(std::string user_agent) { if (!s_curl_initialized) { @@ -65,7 +65,7 @@ bool HTTPDownloaderCurl::Initialize(const char* user_agent) return false; } - m_user_agent = user_agent; + m_user_agent = std::move(user_agent); return true; } @@ -76,7 +76,16 @@ size_t HTTPDownloaderCurl::WriteCallback(char* ptr, size_t size, size_t nmemb, v const size_t transfer_size = size * nmemb; const size_t new_size = current_size + transfer_size; req->data.resize(new_size); + req->start_time = Common::Timer::GetCurrentValue(); std::memcpy(&req->data[current_size], ptr, transfer_size); + + if (req->content_length == 0) + { + curl_off_t length; + if (curl_easy_getinfo(req->handle, CURLINFO_CONTENT_LENGTH_DOWNLOAD_T, &length) == CURLE_OK) + req->content_length = static_cast(length); + } + return nmemb; } @@ -160,8 +169,9 @@ bool HTTPDownloaderCurl::StartRequest(HTTPDownloader::Request* request) curl_easy_setopt(req->handle, CURLOPT_USERAGENT, m_user_agent.c_str()); curl_easy_setopt(req->handle, CURLOPT_WRITEFUNCTION, &HTTPDownloaderCurl::WriteCallback); curl_easy_setopt(req->handle, CURLOPT_WRITEDATA, req); - curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1); + curl_easy_setopt(req->handle, CURLOPT_NOSIGNAL, 1L); curl_easy_setopt(req->handle, CURLOPT_PRIVATE, req); + curl_easy_setopt(req->handle, CURLOPT_FOLLOWLOCATION, 1L); if (request->type == Request::Type::Post) { diff --git a/src/util/http_downloader_curl.h b/src/util/http_downloader_curl.h index bd16e2a19..7b3a80091 100644 --- a/src/util/http_downloader_curl.h +++ b/src/util/http_downloader_curl.h @@ -15,7 +15,7 @@ public: HTTPDownloaderCurl(); ~HTTPDownloaderCurl() override; - bool Initialize(const char* user_agent); + bool Initialize(std::string user_agent); protected: Request* InternalCreateRequest() override; diff --git a/src/util/http_downloader_winhttp.cpp b/src/util/http_downloader_winhttp.cpp index 7079f2779..acb896f88 100644 --- a/src/util/http_downloader_winhttp.cpp +++ b/src/util/http_downloader_winhttp.cpp @@ -25,16 +25,16 @@ HTTPDownloaderWinHttp::~HTTPDownloaderWinHttp() } } -std::unique_ptr HTTPDownloader::Create(const char* user_agent) +std::unique_ptr HTTPDownloader::Create(std::string user_agent) { std::unique_ptr instance(std::make_unique()); - if (!instance->Initialize(user_agent)) + if (!instance->Initialize(std::move(user_agent))) return {}; return instance; } -bool HTTPDownloaderWinHttp::Initialize(const char* user_agent) +bool HTTPDownloaderWinHttp::Initialize(std::string user_agent) { static constexpr DWORD dwAccessType = WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY; diff --git a/src/util/http_downloader_winhttp.h b/src/util/http_downloader_winhttp.h index b1efa3b6b..9f4080947 100644 --- a/src/util/http_downloader_winhttp.h +++ b/src/util/http_downloader_winhttp.h @@ -14,7 +14,7 @@ public: HTTPDownloaderWinHttp(); ~HTTPDownloaderWinHttp() override; - bool Initialize(const char* user_agent); + bool Initialize(std::string user_agent); protected: Request* InternalCreateRequest() override;