Updating the creation of the SBT
This commit is contained in:
parent
f476512440
commit
10637464fc
17 changed files with 477 additions and 444 deletions
|
|
@ -843,39 +843,70 @@ void HelloVulkan::createRtPipeline()
|
|||
// The Shader Binding Table (SBT)
|
||||
// - getting all shader handles and write them in a SBT buffer
|
||||
// - Besides exception, this could be always done like this
|
||||
// See how the SBT buffer is used in run()
|
||||
//
|
||||
void HelloVulkan::createRtShaderBindingTable()
|
||||
{
|
||||
auto groupCount = static_cast<uint32_t>(m_rtShaderGroups.size()); // 4 shaders: raygen, 2 miss, chit
|
||||
uint32_t groupHandleSize = m_rtProperties.shaderGroupHandleSize; // Size of a program identifier
|
||||
// Compute the actual size needed per SBT entry (round-up to alignment needed).
|
||||
uint32_t groupSizeAligned = nvh::align_up(groupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
|
||||
// Bytes needed for the SBT.
|
||||
uint32_t sbtSize = groupCount * groupSizeAligned;
|
||||
uint32_t missCount{2};
|
||||
uint32_t hitCount{1};
|
||||
auto handleCount = 1 + missCount + hitCount;
|
||||
uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;
|
||||
|
||||
// Fetch all the shader handles used in the pipeline. This is opaque data,/ so we store it in a vector of bytes.
|
||||
// The order of handles follow the stage entry.
|
||||
std::vector<uint8_t> shaderHandleStorage(sbtSize);
|
||||
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, groupCount, sbtSize, shaderHandleStorage.data());
|
||||
// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
|
||||
uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);
|
||||
|
||||
m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
m_rgenRegion.size = m_rgenRegion.stride; // The size member of pRayGenShaderBindingTable must be equal to its stride member
|
||||
m_missRegion.stride = handleSizeAligned;
|
||||
m_missRegion.size = nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
m_hitRegion.stride = handleSizeAligned;
|
||||
m_hitRegion.size = nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
|
||||
// Get the shader group handles
|
||||
uint32_t dataSize = handleCount * handleSize;
|
||||
std::vector<uint8_t> handles(dataSize);
|
||||
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, handleCount, dataSize, handles.data());
|
||||
assert(result == VK_SUCCESS);
|
||||
|
||||
// Allocate a buffer for storing the SBT. Give it a debug name for NSight.
|
||||
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
|
||||
// Allocate a buffer for storing the SBT.
|
||||
VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size;
|
||||
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
|
||||
| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
|
||||
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
||||
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT"));
|
||||
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT")); // Give it a debug name for NSight.
|
||||
|
||||
// Find the SBT addresses of each group
|
||||
VkBufferDeviceAddressInfo info{VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_rtSBTBuffer.buffer};
|
||||
VkDeviceAddress sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
|
||||
m_rgenRegion.deviceAddress = sbtAddress;
|
||||
m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
|
||||
m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;
|
||||
|
||||
// Helper to retrieve the handle data
|
||||
auto getHandle = [&](int i) { return handles.data() + i * handleSize; };
|
||||
|
||||
// Map the SBT buffer and write in the handles.
|
||||
void* mapped = m_alloc.map(m_rtSBTBuffer);
|
||||
auto* pData = reinterpret_cast<uint8_t*>(mapped);
|
||||
for(uint32_t g = 0; g < groupCount; g++)
|
||||
auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc.map(m_rtSBTBuffer));
|
||||
uint8_t* pData{nullptr};
|
||||
uint32_t handleIdx{0};
|
||||
// Raygen
|
||||
pData = pSBTBuffer;
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
// Miss
|
||||
pData = pSBTBuffer + m_rgenRegion.size;
|
||||
for(uint32_t c = 0; c < missCount; c++)
|
||||
{
|
||||
memcpy(pData, shaderHandleStorage.data() + g * groupHandleSize, groupHandleSize);
|
||||
pData += groupSizeAligned;
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
pData += m_missRegion.stride;
|
||||
}
|
||||
// Hit
|
||||
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
|
||||
for(uint32_t c = 0; c < hitCount; c++)
|
||||
{
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
pData += m_hitRegion.stride;
|
||||
}
|
||||
|
||||
m_alloc.unmap(m_rtSBTBuffer);
|
||||
m_alloc.finalizeAndReleaseStaging();
|
||||
}
|
||||
|
|
@ -906,21 +937,7 @@ void HelloVulkan::raytrace(const VkCommandBuffer& cmdBuf, const nvmath::vec4f& c
|
|||
0, sizeof(PushConstantRay), &m_pcRay);
|
||||
|
||||
|
||||
// Size of a program identifier
|
||||
uint32_t groupSize = nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
|
||||
uint32_t groupStride = groupSize;
|
||||
|
||||
VkDeviceAddress sbtAddress = nvvk::getBufferDeviceAddress(m_device, m_rtSBTBuffer.buffer);
|
||||
|
||||
using Stride = VkStridedDeviceAddressRegionKHR;
|
||||
std::array<Stride, 4> strideAddresses{Stride{sbtAddress + 0u * groupSize, groupStride, groupSize * 1}, // raygen
|
||||
Stride{sbtAddress + 1u * groupSize, groupStride, groupSize * 2}, // miss
|
||||
Stride{sbtAddress + 3u * groupSize, groupStride, groupSize * 1}, // hit
|
||||
Stride{0u, 0u, 0u}}; // callable
|
||||
|
||||
|
||||
vkCmdTraceRaysKHR(cmdBuf, &strideAddresses[0], &strideAddresses[1], &strideAddresses[2], &strideAddresses[3],
|
||||
m_size.width, m_size.height, 1);
|
||||
vkCmdTraceRaysKHR(cmdBuf, &m_rgenRegion, &m_missRegion, &m_hitRegion, &m_callRegion, m_size.width, m_size.height, 1);
|
||||
|
||||
|
||||
m_debug.endLabel(cmdBuf);
|
||||
|
|
|
|||
|
|
@ -145,8 +145,14 @@ public:
|
|||
std::vector<VkRayTracingShaderGroupCreateInfoKHR> m_rtShaderGroups;
|
||||
VkPipelineLayout m_rtPipelineLayout;
|
||||
VkPipeline m_rtPipeline;
|
||||
nvvk::Buffer m_rtSBTBuffer;
|
||||
int m_maxFrames{10};
|
||||
|
||||
nvvk::Buffer m_rtSBTBuffer;
|
||||
VkStridedDeviceAddressRegionKHR m_rgenRegion{};
|
||||
VkStridedDeviceAddressRegionKHR m_missRegion{};
|
||||
VkStridedDeviceAddressRegionKHR m_hitRegion{};
|
||||
VkStridedDeviceAddressRegionKHR m_callRegion{};
|
||||
|
||||
int m_maxFrames{10};
|
||||
|
||||
// Push constant for ray tracer
|
||||
PushConstantRay m_pcRay{};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue