diff --git a/dep/reshadefx/src/effect_codegen_spirv.cpp b/dep/reshadefx/src/effect_codegen_spirv.cpp index e95c78202..994b1377e 100644 --- a/dep/reshadefx/src/effect_codegen_spirv.cpp +++ b/dep/reshadefx/src/effect_codegen_spirv.cpp @@ -390,6 +390,25 @@ private: return {}; const function &entry_point = *entry_point_it->get(); + const auto write_entry_point = [this](const spirv_instruction& oins, std::basic_string& spirv) { + assert(oins.operands.size() > 2); + spirv_instruction nins(oins.op, oins.type, oins.result); + nins.add(oins.operands[0]); + nins.add(oins.operands[1]); + nins.add_string("main"); + + size_t param_start_index = 2; + while (param_start_index < oins.operands.size() && (oins.operands[param_start_index] & 0xFF000000) != 0) + param_start_index++; + + // skip zero + param_start_index++; + + for (size_t i = param_start_index; i < oins.operands.size(); i++) + nins.add(oins.operands[i]); + nins.write(spirv); + }; + // Build list of IDs to remove std::vector variables_to_remove; #if 1 @@ -414,7 +433,7 @@ private: // Only add the matching entry point if (inst.operands[1] == entry_point.id) { - inst.write(spirv); + write_entry_point(inst, spirv); } else { @@ -482,8 +501,9 @@ private: if (func.definition.instructions.empty()) continue; - assert(func.declaration.instructions[_debug_info ? 1 : 0].op == spv::OpFunction); - const spv::Id definition = func.declaration.instructions[_debug_info ? 1 : 0].result; + const bool has_line = (_debug_info && func.declaration.instructions[0].op == spv::OpLine); + assert(func.declaration.instructions[has_line ? 1 : 0].op == spv::OpFunction); + const spv::Id definition = func.declaration.instructions[has_line ? 1 : 0].result; #if 1 if (std::find(functions_to_remove.begin(), functions_to_remove.end(), definition) != functions_to_remove.end()) diff --git a/src/util/postprocessing_shader_fx.cpp b/src/util/postprocessing_shader_fx.cpp index 3e022d721..219b398e6 100644 --- a/src/util/postprocessing_shader_fx.cpp +++ b/src/util/postprocessing_shader_fx.cpp @@ -82,8 +82,8 @@ static std::unique_ptr CreateRFXCodegen() case RenderAPI::Vulkan: case RenderAPI::Metal: { - return std::unique_ptr(reshadefx::create_codegen_glsl( - 460, false, true, debug_info, uniforms_to_spec_constants, false, (rapi == RenderAPI::Vulkan))); + return std::unique_ptr(reshadefx::create_codegen_spirv( + true, debug_info, uniforms_to_spec_constants, false, (rapi == RenderAPI::Vulkan))); } case RenderAPI::OpenGL: @@ -1303,9 +1303,6 @@ GPUTexture* PostProcessing::ReShadeFXShader::GetTextureByID(TextureID id, GPUTex bool PostProcessing::ReShadeFXShader::CompilePipeline(GPUTexture::Format format, u32 width, u32 height, ProgressCallback* progress) { - const RenderAPI api = g_gpu_device->GetRenderAPI(); - const bool needs_main_defn = (api != RenderAPI::D3D11 && api != RenderAPI::D3D12); - m_valid = false; m_textures.clear(); m_passes.clear(); @@ -1342,20 +1339,16 @@ bool PostProcessing::ReShadeFXShader::CompilePipeline(GPUTexture::Format format, return false; } - // TODO: If using spv, this will be populated. - // const std::string effect_code = cg->finalize_code(); - - auto get_shader = [api, needs_main_defn, &cg](const std::string& name, const std::span samplers, - GPUShaderStage stage) { + const RenderAPI api = g_gpu_device->GetRenderAPI(); + auto get_shader = [api, &cg](const std::string& name, const std::span samplers, GPUShaderStage stage) { const std::string real_code = cg->finalize_code_for_entry_point(name); - -#if 0 - FileSystem::WriteStringToFile(fmt::format("D:\\reshade_{}.txt", Path::SanitizeFileName(name)).c_str(), real_code); -#endif + const GPUShaderLanguage lang = (api == RenderAPI::Vulkan || api == RenderAPI::Metal) ? + GPUShaderLanguage::SPV : + ShaderGen::GetShaderLanguageForAPI(api); + const char* entry_point = (lang == GPUShaderLanguage::HLSL) ? name.c_str() : "main"; Error error; - std::unique_ptr sshader = g_gpu_device->CreateShader( - stage, ShaderGen::GetShaderLanguageForAPI(api), real_code, &error, needs_main_defn ? "main" : name.c_str()); + std::unique_ptr sshader = g_gpu_device->CreateShader(stage, lang, real_code, &error, entry_point); if (!sshader) ERROR_LOG("Failed to compile function '{}': {}", name, error.GetDescription());