Updating the creation of the SBT

This commit is contained in:
mklefrancois 2021-09-08 15:59:03 +02:00
parent f476512440
commit 10637464fc
17 changed files with 477 additions and 444 deletions

View file

@ -1116,41 +1116,71 @@ void HelloVulkan::createRtPipeline()
// The Shader Binding Table (SBT)
// - getting all shader handles and write them in a SBT buffer
// - Besides exception, this could be always done like this
// See how the SBT buffer is used in run()
//
void HelloVulkan::createRtShaderBindingTable()
{
auto groupCount = static_cast<uint32_t>(m_rtShaderGroups.size());
assert(groupCount == 8 && "Update Comment"); // 8 shaders: raygen, 3 miss, 4 chit
uint32_t missCount{3};
uint32_t hitCount{4};
auto handleCount = 1 + missCount + hitCount;
uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;
uint32_t groupHandleSize = m_rtProperties.shaderGroupHandleSize; // Size of a program identifier
// Compute the actual size needed per SBT entry (round-up to alignment needed).
uint32_t groupSizeAligned = nvh::align_up(groupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
// Bytes needed for the SBT.
uint32_t sbtSize = groupCount * groupSizeAligned;
// Fetch all the shader handles used in the pipeline. This is opaque data,
// so we store it in a vector of bytes.
std::vector<uint8_t> shaderHandleStorage(sbtSize);
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, groupCount, sbtSize, shaderHandleStorage.data());
// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);
m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_rgenRegion.size = m_rgenRegion.stride; // The size member of pRayGenShaderBindingTable must be equal to its stride member
m_missRegion.stride = handleSizeAligned;
m_missRegion.size = nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_hitRegion.stride = handleSizeAligned;
m_hitRegion.size = nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
// Get the shader group handles
uint32_t dataSize = handleCount * handleSize;
std::vector<uint8_t> handles(dataSize);
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, handleCount, dataSize, handles.data());
assert(result == VK_SUCCESS);
// Allocate a buffer for storing the SBT. Give it a debug name for NSight.
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
// Allocate a buffer for storing the SBT.
VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size;
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT"));
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT")); // Give it a debug name for NSight.
// Find the SBT addresses of each group
VkBufferDeviceAddressInfo info{VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_rtSBTBuffer.buffer};
VkDeviceAddress sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
m_rgenRegion.deviceAddress = sbtAddress;
m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;
// Helper to retrieve the handle data
auto getHandle = [&](int i) { return handles.data() + i * handleSize; };
// Map the SBT buffer and write in the handles.
void* mapped = m_alloc.map(m_rtSBTBuffer);
auto* pData = reinterpret_cast<uint8_t*>(mapped);
for(uint32_t g = 0; g < groupCount; g++)
auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc.map(m_rtSBTBuffer));
uint8_t* pData{nullptr};
uint32_t handleIdx{0};
// Raygen
pData = pSBTBuffer;
memcpy(pData, getHandle(handleIdx++), handleSize);
// Miss
pData = pSBTBuffer + m_rgenRegion.size;
for(uint32_t c = 0; c < missCount; c++)
{
memcpy(pData, shaderHandleStorage.data() + g * groupHandleSize, groupHandleSize);
pData += groupSizeAligned;
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += m_missRegion.stride;
}
// Hit
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
for(uint32_t c = 0; c < hitCount; c++)
{
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += m_hitRegion.stride;
}
m_alloc.unmap(m_rtSBTBuffer);
m_alloc.finalizeAndReleaseStaging();
}
@ -1323,22 +1353,8 @@ void HelloVulkan::raytrace(const VkCommandBuffer& cmdBuf, const nvmath::vec4f& c
0, sizeof(PushConstantRay), &m_pcRay);
// Size of a program identifier
uint32_t groupSize = nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
uint32_t groupStride = groupSize;
vkCmdTraceRaysKHR(cmdBuf, &m_rgenRegion, &m_missRegion, &m_hitRegion, &m_callRegion, m_size.width, m_size.height, 1);
VkDeviceAddress sbtAddress = nvvk::getBufferDeviceAddress(m_device, m_rtSBTBuffer.buffer);
using Stride = VkStridedDeviceAddressRegionKHR;
std::array<Stride, 4> strideAddresses{Stride{sbtAddress + 0u * groupSize, groupStride, groupSize * 1}, // raygen
Stride{sbtAddress + 1u * groupSize, groupStride, groupSize * 3}, // miss
Stride{sbtAddress + 4u * groupSize, groupStride, groupSize * 4}, // hit
Stride{0u, 0u, 0u}}; // callable
// First pass, illuminate scene with global light.
vkCmdTraceRaysKHR(cmdBuf, &strideAddresses[0], &strideAddresses[1], &strideAddresses[2], &strideAddresses[3],
m_size.width, m_size.height, 1);
// Lantern passes, ensure previous pass completed, then add light contribution from each lantern.
for(int i = 0; i < static_cast<int>(m_lanternCount); ++i)
@ -1372,9 +1388,7 @@ void HelloVulkan::raytrace(const VkCommandBuffer& cmdBuf, const nvmath::vec4f& c
nvvk::getBufferDeviceAddress(m_device, m_lanternIndirectBuffer.buffer) + i * sizeof(LanternIndirectEntry);
// Execute lantern pass.
vkCmdTraceRaysIndirectKHR(cmdBuf, &strideAddresses[0], &strideAddresses[1], //
&strideAddresses[2], &strideAddresses[3], //
indirectDeviceAddress);
vkCmdTraceRaysIndirectKHR(cmdBuf, &m_rgenRegion, &m_missRegion, &m_hitRegion, &m_callRegion, indirectDeviceAddress);
}
m_debug.endLabel(cmdBuf);