MetalDevice: Use TranspileAndCreateShaderFromSource()

This commit is contained in:
Stenzek 2024-09-13 17:31:59 +10:00
parent 68b82ab55b
commit 191957547a
No known key found for this signature in database
4 changed files with 67 additions and 25 deletions

View file

@ -279,6 +279,8 @@ public:
m_size = 0; m_size = 0;
} }
void assign(const std::span<const T> data) { assign(data.data(), data.size()); }
void assign(const T* begin, const T* end) void assign(const T* begin, const T* end)
{ {
const size_t size = reinterpret_cast<const char*>(end) - reinterpret_cast<const char*>(begin); const size_t size = reinterpret_cast<const char*>(end) - reinterpret_cast<const char*>(begin);

View file

@ -1780,6 +1780,13 @@ std::unique_ptr<GPUShader> GPUDevice::TranspileAndCreateShaderFromSource(
GPUShaderStage stage, GPUShaderLanguage source_language, std::string_view source, const char* entry_point, GPUShaderStage stage, GPUShaderLanguage source_language, std::string_view source, const char* entry_point,
GPUShaderLanguage target_language, u32 target_version, DynamicHeapArray<u8>* out_binary, Error* error) GPUShaderLanguage target_language, u32 target_version, DynamicHeapArray<u8>* out_binary, Error* error)
{ {
// Currently, entry points must be "main". TODO: rename the entry point in the SPIR-V.
if (std::strcmp(entry_point, "main") != 0)
{
Error::SetStringView(error, "Entry point must be main.");
return {};
}
// Disable optimization when targeting OpenGL GLSL, otherwise, the name-based linking will fail. // Disable optimization when targeting OpenGL GLSL, otherwise, the name-based linking will fail.
const bool optimization = const bool optimization =
(!m_debug_device && target_language != GPUShaderLanguage::GLSL && target_language != GPUShaderLanguage::GLSLES); (!m_debug_device && target_language != GPUShaderLanguage::GLSL && target_language != GPUShaderLanguage::GLSLES);
@ -1827,7 +1834,14 @@ std::unique_ptr<GPUShader> GPUDevice::TranspileAndCreateShaderFromSource(
if (!TranslateVulkanSpvToLanguage(spv, stage, target_language, target_version, &dest_source, error)) if (!TranslateVulkanSpvToLanguage(spv, stage, target_language, target_version, &dest_source, error))
return {}; return {};
// TODO: MSL needs entry point suffixed. #ifdef __APPLE__
// MSL converter suffixes 0.
if (target_language == GPUShaderLanguage::MSL)
{
return CreateShaderFromSource(stage, target_language, dest_source,
TinyString::from_format("{}0", entry_point).c_str(), out_binary, error);
}
#endif
return CreateShaderFromSource(stage, target_language, dest_source, entry_point, out_binary, error); return CreateShaderFromSource(stage, target_language, dest_source, entry_point, out_binary, error);
} }

View file

@ -317,14 +317,12 @@ private:
void SetFeatures(FeatureMask disabled_features); void SetFeatures(FeatureMask disabled_features);
bool LoadShaders(); bool LoadShaders();
std::unique_ptr<GPUShader> CreateShaderFromMSL(GPUShaderStage stage, std::string_view source,
std::string_view entry_point, Error* error);
id<MTLFunction> GetFunctionFromLibrary(id<MTLLibrary> library, NSString* name); id<MTLFunction> GetFunctionFromLibrary(id<MTLLibrary> library, NSString* name);
id<MTLComputePipelineState> CreateComputePipeline(id<MTLFunction> function, NSString* name); id<MTLComputePipelineState> CreateComputePipeline(id<MTLFunction> function, NSString* name);
ClearPipelineConfig GetCurrentClearPipelineConfig() const; ClearPipelineConfig GetCurrentClearPipelineConfig() const;
id<MTLRenderPipelineState> GetClearDepthPipeline(const ClearPipelineConfig& config); id<MTLRenderPipelineState> GetClearDepthPipeline(const ClearPipelineConfig& config);
std::unique_ptr<GPUShader> CreateShaderFromMSL(GPUShaderStage stage, std::string_view source,
std::string_view entry_point, Error* error);
id<MTLDepthStencilState> GetDepthState(const GPUPipeline::DepthState& ds); id<MTLDepthStencilState> GetDepthState(const GPUPipeline::DepthState& ds);
void CreateCommandBuffer(); void CreateCommandBuffer();

View file

@ -24,6 +24,18 @@ Log_SetChannel(MetalDevice);
// TODO: Disable hazard tracking and issue barriers explicitly. // TODO: Disable hazard tracking and issue barriers explicitly.
// Used for shader "binaries".
namespace {
struct MetalShaderBinaryHeader
{
u32 entry_point_offset;
u32 entry_point_length;
u32 source_offset;
u32 source_length;
};
static_assert(sizeof(MetalShaderBinaryHeader) == 16);
} // namespace
// Looking across a range of GPUs, the optimal copy alignment for Vulkan drivers seems // Looking across a range of GPUs, the optimal copy alignment for Vulkan drivers seems
// to be between 1 (AMD/NV) and 64 (Intel). So, we'll go with 64 here. // to be between 1 (AMD/NV) and 64 (Intel). So, we'll go with 64 here.
static constexpr u32 TEXTURE_UPLOAD_ALIGNMENT = 64; static constexpr u32 TEXTURE_UPLOAD_ALIGNMENT = 64;
@ -648,39 +660,55 @@ std::unique_ptr<GPUShader> MetalDevice::CreateShaderFromMSL(GPUShaderStage stage
std::unique_ptr<GPUShader> MetalDevice::CreateShaderFromBinary(GPUShaderStage stage, std::span<const u8> data, std::unique_ptr<GPUShader> MetalDevice::CreateShaderFromBinary(GPUShaderStage stage, std::span<const u8> data,
Error* error) Error* error)
{ {
const std::string_view str_data(reinterpret_cast<const char*>(data.data()), data.size()); if (data.size() < sizeof(MetalShaderBinaryHeader))
return CreateShaderFromMSL(stage, str_data, "main0", error); {
Error::SetStringView(error, "Invalid header.");
return {};
}
// Need to copy for alignment reasons.
MetalShaderBinaryHeader hdr;
std::memcpy(&hdr, data.data(), sizeof(hdr));
if (static_cast<size_t>(hdr.entry_point_offset) + static_cast<size_t>(hdr.entry_point_length) > data.size() ||
static_cast<size_t>(hdr.source_offset) + static_cast<size_t>(hdr.source_length) > data.size())
{
Error::SetStringView(error, "Out of range fields in header.");
return {};
}
const std::string_view entry_point(reinterpret_cast<const char*>(data.data() + hdr.entry_point_offset),
hdr.entry_point_length);
const std::string source(reinterpret_cast<const char*>(data.data() + hdr.source_offset), hdr.source_length);
return CreateShaderFromMSL(stage, source, entry_point, error);
} }
std::unique_ptr<GPUShader> MetalDevice::CreateShaderFromSource(GPUShaderStage stage, GPUShaderLanguage language, std::unique_ptr<GPUShader> MetalDevice::CreateShaderFromSource(GPUShaderStage stage, GPUShaderLanguage language,
std::string_view source, const char* entry_point, std::string_view source, const char* entry_point,
DynamicHeapArray<u8>* out_binary, Error* error) DynamicHeapArray<u8>* out_binary, Error* error)
{ {
static constexpr bool dump_shaders = false; if (language != GPUShaderLanguage::MSL)
DynamicHeapArray<u8> spv;
if (!CompileGLSLShaderToVulkanSpv(stage, language, source, entry_point, !m_debug_device, false, &spv, error))
return {};
std::string msl;
if (!TranslateVulkanSpvToLanguage(spv.cspan(), stage, GPUShaderLanguage::MSL, 230, &msl, error))
return {};
if constexpr (dump_shaders)
{ {
static unsigned s_next_id = 0; return TranspileAndCreateShaderFromSource(stage, language, source, entry_point, GPUShaderLanguage::MSL,
++s_next_id; m_render_api_version, out_binary, error);
DumpShader(s_next_id, "_input", source);
DumpShader(s_next_id, "_msl", msl);
} }
// Source is the "binary" here, since Metal doesn't allow us to access the bytecode :(
const std::span<const u8> msl(reinterpret_cast<const u8*>(source.data()), source.size());
if (out_binary) if (out_binary)
{ {
out_binary->resize(msl.size()); MetalShaderBinaryHeader hdr;
std::memcpy(out_binary->data(), msl.data(), msl.size()); hdr.entry_point_offset = sizeof(MetalShaderBinaryHeader);
hdr.entry_point_length = static_cast<u32>(std::strlen(entry_point));
hdr.source_offset = hdr.entry_point_offset + hdr.entry_point_length;
hdr.source_length = static_cast<u32>(source.size());
out_binary->resize(sizeof(hdr) + hdr.entry_point_length + hdr.source_length);
std::memcpy(out_binary->data(), &hdr, sizeof(hdr));
std::memcpy(&out_binary->data()[hdr.entry_point_offset], entry_point, hdr.entry_point_length);
std::memcpy(&out_binary->data()[hdr.source_offset], source.data(), hdr.source_length);
} }
return CreateShaderFromMSL(stage, msl, "main0", error); return CreateShaderFromMSL(stage, source, entry_point, error);
} }
MetalPipeline::MetalPipeline(id<MTLRenderPipelineState> pipeline, id<MTLDepthStencilState> depth, MTLCullMode cull_mode, MetalPipeline::MetalPipeline(id<MTLRenderPipelineState> pipeline, id<MTLDepthStencilState> depth, MTLCullMode cull_mode,