Updating the creation of the SBT

This commit is contained in:
mklefrancois 2021-09-08 15:59:03 +02:00
parent f476512440
commit 10637464fc
17 changed files with 477 additions and 444 deletions

View file

@ -384,13 +384,16 @@ void HelloVulkan::destroyResources()
// #VKRay
#ifdef USE_SBT_WRAPPER
m_sbtWrapper.destroy();
#else
m_alloc.destroy(m_rtSBTBuffer);
#endif
m_rtBuilder.destroy();
vkDestroyPipeline(m_device, m_rtPipeline, nullptr);
vkDestroyPipelineLayout(m_device, m_rtPipelineLayout, nullptr);
vkDestroyDescriptorPool(m_device, m_rtDescPool, nullptr);
vkDestroyDescriptorSetLayout(m_device, m_rtDescSetLayout, nullptr);
m_alloc.destroy(m_rtSBTBuffer);
m_alloc.deinit();
}
@ -595,7 +598,9 @@ void HelloVulkan::initRayTracing()
m_rtBuilder.setup(m_device, &m_alloc, m_graphicsQueueIndex);
#ifdef USE_SBT_WRAPPER
m_sbtWrapper.setup(m_device, m_graphicsQueueIndex, &m_alloc, m_rtProperties);
#endif
}
//--------------------------------------------------------------------------------------------------
@ -848,11 +853,15 @@ void HelloVulkan::createRtPipeline()
vkCreateRayTracingPipelinesKHR(m_device, {}, {}, 1, &rayPipelineInfo, nullptr, &m_rtPipeline);
#ifdef USE_SBT_WRAPPER
// Find handle indices and add data
m_sbtWrapper.addIndices(rayPipelineInfo);
m_sbtWrapper.addData(SBTWrapper::eHit, 1, m_hitShaderRecord[0]);
m_sbtWrapper.addData(SBTWrapper::eHit, 2, m_hitShaderRecord[1]);
m_sbtWrapper.addData(nvvk::SBTWrapper::eHit, 1, m_hitShaderRecord[0]);
m_sbtWrapper.addData(nvvk::SBTWrapper::eHit, 2, m_hitShaderRecord[1]);
m_sbtWrapper.create(m_rtPipeline);
#else
createRtShaderBindingTable();
#endif
for(auto& s : stages)
vkDestroyShaderModule(m_device, s.module, nullptr);
@ -864,78 +873,84 @@ void HelloVulkan::createRtPipeline()
// - Besides exception, this could be always done like this
// See how the SBT buffer is used in run()
//
#ifndef USE_SBT_WRAPPER
void HelloVulkan::createRtShaderBindingTable()
{
auto groupCount = static_cast<uint32_t>(m_rtShaderGroups.size()); // 4 shaders: raygen, 2 miss, chit
uint32_t groupHandleSize = m_rtProperties.shaderGroupHandleSize; // Size of a program identifier
// Compute the actual size needed per SBT entry (round-up to alignment needed).
uint32_t groupSizeAligned = nvh::align_up(groupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
// Bytes needed for the SBT.
uint32_t sbtSize = groupCount * groupSizeAligned;
// Fetch all the shader handles used in the pipeline. This is opaque data,/ so we store it in a vector of bytes.
// The order of handles follow the stage entry.
std::vector<uint8_t> shaderHandleStorage(sbtSize);
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, groupCount, sbtSize, shaderHandleStorage.data());
uint32_t missCount{2};
uint32_t hitCount{3};
auto handleCount = 1 + missCount + hitCount;
uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;
// Get the shader group handles
uint32_t dataSize = handleCount * handleSize;
std::vector<uint8_t> handles(dataSize);
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, handleCount, dataSize, handles.data());
assert(result == VK_SUCCESS);
// Retrieve the handle pointers
std::vector<uint8_t*> handles(groupCount);
for(uint32_t i = 0; i < groupCount; i++)
// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);
m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_rgenRegion.size = m_rgenRegion.stride; // The size member of pRayGenShaderBindingTable must be equal to its stride member
m_missRegion.stride = handleSizeAligned;
m_missRegion.size = nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_hitRegion.stride = nvh::align_up(handleSize + sizeof(HitRecordBuffer), m_rtProperties.shaderGroupHandleAlignment);
m_hitRegion.size = nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
// Allocate a buffer for storing the SBT.
VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size;
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT")); // Give it a debug name for NSight.
// Find the SBT addresses of each group
VkBufferDeviceAddressInfo info{VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_rtSBTBuffer.buffer};
VkDeviceAddress sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
m_rgenRegion.deviceAddress = sbtAddress;
m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;
// Helper to retrieve the handle data
auto getHandle = [&](int i) { return handles.data() + i * handleSize; };
// Map the SBT buffer and write in the handles.
auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc.map(m_rtSBTBuffer));
uint8_t* pData{nullptr};
uint32_t handleIdx{0};
// Raygen
pData = pSBTBuffer;
memcpy(pData, getHandle(handleIdx++), handleSize);
// Miss
pData = pSBTBuffer + m_rgenRegion.size;
for(uint32_t c = 0; c < missCount; c++)
{
handles[i] = &shaderHandleStorage[i * groupHandleSize];
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += m_missRegion.stride;
}
// Hit
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
memcpy(pData, getHandle(handleIdx++), handleSize);
// Sizes
uint32_t rayGenSize = groupSizeAligned;
uint32_t missSize = groupSizeAligned;
uint32_t hitSize = nvh::align_up(groupHandleSize + static_cast<int>(sizeof(HitRecordBuffer)), groupSizeAligned);
uint32_t newSbtSize = rayGenSize + 2 * missSize + 3 * hitSize;
// hit 1
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size + m_hitRegion.stride;
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += handleSize;
memcpy(pData, &m_hitShaderRecord[0], sizeof(HitRecordBuffer)); // Hit 1 data
std::vector<uint8_t> sbtBuffer(newSbtSize);
{
uint8_t* pBuffer = sbtBuffer.data();
memcpy(pBuffer, handles[0], groupHandleSize); // Raygen
pBuffer += rayGenSize;
memcpy(pBuffer, handles[1], groupHandleSize); // Miss 0
pBuffer += missSize;
memcpy(pBuffer, handles[2], groupHandleSize); // Miss 1
pBuffer += missSize;
uint8_t* pHitBuffer = pBuffer;
memcpy(pHitBuffer, handles[3], groupHandleSize); // Hit 0
// No data
pBuffer += hitSize;
pHitBuffer = pBuffer;
memcpy(pHitBuffer, handles[4], groupHandleSize); // Hit 1
pHitBuffer += groupHandleSize;
memcpy(pHitBuffer, &m_hitShaderRecord[0], sizeof(HitRecordBuffer)); // Hit 1 data
pBuffer += hitSize;
pHitBuffer = pBuffer;
memcpy(pHitBuffer, handles[4], groupHandleSize); // Hit 2
pHitBuffer += groupHandleSize;
memcpy(pHitBuffer, &m_hitShaderRecord[1], sizeof(HitRecordBuffer)); // Hit 2 data
// pBuffer += hitSize;
}
// Write the handles in the SBT
nvvk::CommandPool genCmdBuf(m_device, m_graphicsQueueIndex);
VkCommandBuffer cmdBuf = genCmdBuf.createCommandBuffer();
m_rtSBTBuffer = m_alloc.createBuffer(cmdBuf, sbtBuffer,
VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR);
m_debug.setObjectName(m_rtSBTBuffer.buffer, "SBT");
// hit 2
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size + (2 * m_hitRegion.stride);
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += handleSize;
memcpy(pData, &m_hitShaderRecord[1], sizeof(HitRecordBuffer)); // Hit 2 data
genCmdBuf.submitAndWait(cmdBuf);
m_alloc.unmap(m_rtSBTBuffer);
m_alloc.finalizeAndReleaseStaging();
}
#endif
//--------------------------------------------------------------------------------------------------
// Ray Tracing the scene
@ -959,39 +974,12 @@ void HelloVulkan::raytrace(const VkCommandBuffer& cmdBuf, const nvmath::vec4f& c
0, sizeof(PushConstantRay), &m_pcRay);
//// Size of a program identifier
//uint32_t groupSize =
// nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
//uint32_t groupStride = groupSize;
//vk::DeviceSize hitGroupSize =
// nvh::align_up(m_rtProperties.shaderGroupHandleSize + sizeof(HitRecordBuffer),
// m_rtProperties.shaderGroupBaseAlignment);
//vk::DeviceAddress sbtAddress = m_device.getBufferAddress({m_rtSBTBuffer.buffer});
//using Stride = vk::StridedDeviceAddressRegionKHR;
//std::array<Stride, 4> strideAddresses{
// Stride{sbtAddress + 0u * groupSize, groupStride, groupSize * 1}, // raygen
// Stride{sbtAddress + 1u * groupSize, groupStride, groupSize * 2}, // miss
// Stride{sbtAddress + 3u * groupSize, hitGroupSize, hitGroupSize * 2}, // hit
// Stride{0u, 0u, 0u}}; // callable
//strideAddresses[0] = m_sbtWrapper.getRaygenRegion();
//strideAddresses[1] = m_sbtWrapper.getMissRegion();
//strideAddresses[2] = m_sbtWrapper.getHitRegion();
//cmdBuf.traceRaysKHR(&strideAddresses[0], &strideAddresses[1], &strideAddresses[2],
// &strideAddresses[3], //
// m_size.width, m_size.height, 1); //
//std::array<vk::StridedDeviceAddressRegionKHR, 4> regions;
//regions[0] = m_sbtWrapper.getRegion(SBTWrapper::eRaygen);
//regions[1] = m_sbtWrapper.getRegion(SBTWrapper::eMiss);
//regions[2] = m_sbtWrapper.getRegion(SBTWrapper::eHit);
//regions[3] = m_sbtWrapper.getRegion(SBTWrapper::eCallable);
#ifdef USE_SBT_WRAPPER
auto& regions = m_sbtWrapper.getRegions();
vkCmdTraceRaysKHR(cmdBuf, &regions[0], &regions[1], &regions[2], &regions[3], m_size.width, m_size.height, 1);
#else
vkCmdTraceRaysKHR(cmdBuf, &m_rgenRegion, &m_missRegion, &m_hitRegion, &m_callRegion, m_size.width, m_size.height, 1);
#endif
m_debug.endLabel(cmdBuf);
}