Updating the creation of the SBT
This commit is contained in:
parent
f476512440
commit
10637464fc
17 changed files with 477 additions and 444 deletions
|
|
@ -384,13 +384,16 @@ void HelloVulkan::destroyResources()
|
|||
|
||||
|
||||
// #VKRay
|
||||
#ifdef USE_SBT_WRAPPER
|
||||
m_sbtWrapper.destroy();
|
||||
#else
|
||||
m_alloc.destroy(m_rtSBTBuffer);
|
||||
#endif
|
||||
m_rtBuilder.destroy();
|
||||
vkDestroyPipeline(m_device, m_rtPipeline, nullptr);
|
||||
vkDestroyPipelineLayout(m_device, m_rtPipelineLayout, nullptr);
|
||||
vkDestroyDescriptorPool(m_device, m_rtDescPool, nullptr);
|
||||
vkDestroyDescriptorSetLayout(m_device, m_rtDescSetLayout, nullptr);
|
||||
m_alloc.destroy(m_rtSBTBuffer);
|
||||
|
||||
m_alloc.deinit();
|
||||
}
|
||||
|
|
@ -595,7 +598,9 @@ void HelloVulkan::initRayTracing()
|
|||
|
||||
m_rtBuilder.setup(m_device, &m_alloc, m_graphicsQueueIndex);
|
||||
|
||||
#ifdef USE_SBT_WRAPPER
|
||||
m_sbtWrapper.setup(m_device, m_graphicsQueueIndex, &m_alloc, m_rtProperties);
|
||||
#endif
|
||||
}
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
|
|
@ -848,11 +853,15 @@ void HelloVulkan::createRtPipeline()
|
|||
vkCreateRayTracingPipelinesKHR(m_device, {}, {}, 1, &rayPipelineInfo, nullptr, &m_rtPipeline);
|
||||
|
||||
|
||||
#ifdef USE_SBT_WRAPPER
|
||||
// Find handle indices and add data
|
||||
m_sbtWrapper.addIndices(rayPipelineInfo);
|
||||
m_sbtWrapper.addData(SBTWrapper::eHit, 1, m_hitShaderRecord[0]);
|
||||
m_sbtWrapper.addData(SBTWrapper::eHit, 2, m_hitShaderRecord[1]);
|
||||
m_sbtWrapper.addData(nvvk::SBTWrapper::eHit, 1, m_hitShaderRecord[0]);
|
||||
m_sbtWrapper.addData(nvvk::SBTWrapper::eHit, 2, m_hitShaderRecord[1]);
|
||||
m_sbtWrapper.create(m_rtPipeline);
|
||||
#else
|
||||
createRtShaderBindingTable();
|
||||
#endif
|
||||
|
||||
for(auto& s : stages)
|
||||
vkDestroyShaderModule(m_device, s.module, nullptr);
|
||||
|
|
@ -864,78 +873,84 @@ void HelloVulkan::createRtPipeline()
|
|||
// - Besides exception, this could be always done like this
|
||||
// See how the SBT buffer is used in run()
|
||||
//
|
||||
#ifndef USE_SBT_WRAPPER
|
||||
void HelloVulkan::createRtShaderBindingTable()
|
||||
{
|
||||
auto groupCount = static_cast<uint32_t>(m_rtShaderGroups.size()); // 4 shaders: raygen, 2 miss, chit
|
||||
uint32_t groupHandleSize = m_rtProperties.shaderGroupHandleSize; // Size of a program identifier
|
||||
// Compute the actual size needed per SBT entry (round-up to alignment needed).
|
||||
uint32_t groupSizeAligned = nvh::align_up(groupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
|
||||
// Bytes needed for the SBT.
|
||||
uint32_t sbtSize = groupCount * groupSizeAligned;
|
||||
|
||||
// Fetch all the shader handles used in the pipeline. This is opaque data,/ so we store it in a vector of bytes.
|
||||
// The order of handles follow the stage entry.
|
||||
std::vector<uint8_t> shaderHandleStorage(sbtSize);
|
||||
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, groupCount, sbtSize, shaderHandleStorage.data());
|
||||
uint32_t missCount{2};
|
||||
uint32_t hitCount{3};
|
||||
auto handleCount = 1 + missCount + hitCount;
|
||||
uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;
|
||||
|
||||
// Get the shader group handles
|
||||
uint32_t dataSize = handleCount * handleSize;
|
||||
std::vector<uint8_t> handles(dataSize);
|
||||
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_rtPipeline, 0, handleCount, dataSize, handles.data());
|
||||
assert(result == VK_SUCCESS);
|
||||
|
||||
// Retrieve the handle pointers
|
||||
std::vector<uint8_t*> handles(groupCount);
|
||||
for(uint32_t i = 0; i < groupCount; i++)
|
||||
// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
|
||||
uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);
|
||||
|
||||
m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
m_rgenRegion.size = m_rgenRegion.stride; // The size member of pRayGenShaderBindingTable must be equal to its stride member
|
||||
m_missRegion.stride = handleSizeAligned;
|
||||
m_missRegion.size = nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
m_hitRegion.stride = nvh::align_up(handleSize + sizeof(HitRecordBuffer), m_rtProperties.shaderGroupHandleAlignment);
|
||||
m_hitRegion.size = nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
|
||||
|
||||
// Allocate a buffer for storing the SBT.
|
||||
VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size;
|
||||
m_rtSBTBuffer = m_alloc.createBuffer(sbtSize,
|
||||
VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
|
||||
| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
|
||||
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
|
||||
m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT")); // Give it a debug name for NSight.
|
||||
|
||||
// Find the SBT addresses of each group
|
||||
VkBufferDeviceAddressInfo info{VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_rtSBTBuffer.buffer};
|
||||
VkDeviceAddress sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
|
||||
m_rgenRegion.deviceAddress = sbtAddress;
|
||||
m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
|
||||
m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;
|
||||
|
||||
// Helper to retrieve the handle data
|
||||
auto getHandle = [&](int i) { return handles.data() + i * handleSize; };
|
||||
|
||||
// Map the SBT buffer and write in the handles.
|
||||
auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc.map(m_rtSBTBuffer));
|
||||
uint8_t* pData{nullptr};
|
||||
uint32_t handleIdx{0};
|
||||
// Raygen
|
||||
pData = pSBTBuffer;
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
// Miss
|
||||
pData = pSBTBuffer + m_rgenRegion.size;
|
||||
for(uint32_t c = 0; c < missCount; c++)
|
||||
{
|
||||
handles[i] = &shaderHandleStorage[i * groupHandleSize];
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
pData += m_missRegion.stride;
|
||||
}
|
||||
// Hit
|
||||
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
|
||||
// Sizes
|
||||
uint32_t rayGenSize = groupSizeAligned;
|
||||
uint32_t missSize = groupSizeAligned;
|
||||
uint32_t hitSize = nvh::align_up(groupHandleSize + static_cast<int>(sizeof(HitRecordBuffer)), groupSizeAligned);
|
||||
uint32_t newSbtSize = rayGenSize + 2 * missSize + 3 * hitSize;
|
||||
// hit 1
|
||||
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size + m_hitRegion.stride;
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
pData += handleSize;
|
||||
memcpy(pData, &m_hitShaderRecord[0], sizeof(HitRecordBuffer)); // Hit 1 data
|
||||
|
||||
std::vector<uint8_t> sbtBuffer(newSbtSize);
|
||||
{
|
||||
uint8_t* pBuffer = sbtBuffer.data();
|
||||
|
||||
memcpy(pBuffer, handles[0], groupHandleSize); // Raygen
|
||||
pBuffer += rayGenSize;
|
||||
memcpy(pBuffer, handles[1], groupHandleSize); // Miss 0
|
||||
pBuffer += missSize;
|
||||
memcpy(pBuffer, handles[2], groupHandleSize); // Miss 1
|
||||
pBuffer += missSize;
|
||||
|
||||
uint8_t* pHitBuffer = pBuffer;
|
||||
memcpy(pHitBuffer, handles[3], groupHandleSize); // Hit 0
|
||||
// No data
|
||||
pBuffer += hitSize;
|
||||
|
||||
pHitBuffer = pBuffer;
|
||||
memcpy(pHitBuffer, handles[4], groupHandleSize); // Hit 1
|
||||
pHitBuffer += groupHandleSize;
|
||||
memcpy(pHitBuffer, &m_hitShaderRecord[0], sizeof(HitRecordBuffer)); // Hit 1 data
|
||||
pBuffer += hitSize;
|
||||
|
||||
pHitBuffer = pBuffer;
|
||||
memcpy(pHitBuffer, handles[4], groupHandleSize); // Hit 2
|
||||
pHitBuffer += groupHandleSize;
|
||||
memcpy(pHitBuffer, &m_hitShaderRecord[1], sizeof(HitRecordBuffer)); // Hit 2 data
|
||||
// pBuffer += hitSize;
|
||||
}
|
||||
|
||||
// Write the handles in the SBT
|
||||
nvvk::CommandPool genCmdBuf(m_device, m_graphicsQueueIndex);
|
||||
VkCommandBuffer cmdBuf = genCmdBuf.createCommandBuffer();
|
||||
|
||||
m_rtSBTBuffer = m_alloc.createBuffer(cmdBuf, sbtBuffer,
|
||||
VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR);
|
||||
|
||||
m_debug.setObjectName(m_rtSBTBuffer.buffer, "SBT");
|
||||
// hit 2
|
||||
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size + (2 * m_hitRegion.stride);
|
||||
memcpy(pData, getHandle(handleIdx++), handleSize);
|
||||
pData += handleSize;
|
||||
memcpy(pData, &m_hitShaderRecord[1], sizeof(HitRecordBuffer)); // Hit 2 data
|
||||
|
||||
|
||||
genCmdBuf.submitAndWait(cmdBuf);
|
||||
|
||||
m_alloc.unmap(m_rtSBTBuffer);
|
||||
m_alloc.finalizeAndReleaseStaging();
|
||||
}
|
||||
#endif
|
||||
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// Ray Tracing the scene
|
||||
|
|
@ -959,39 +974,12 @@ void HelloVulkan::raytrace(const VkCommandBuffer& cmdBuf, const nvmath::vec4f& c
|
|||
0, sizeof(PushConstantRay), &m_pcRay);
|
||||
|
||||
|
||||
//// Size of a program identifier
|
||||
//uint32_t groupSize =
|
||||
// nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
|
||||
//uint32_t groupStride = groupSize;
|
||||
//vk::DeviceSize hitGroupSize =
|
||||
// nvh::align_up(m_rtProperties.shaderGroupHandleSize + sizeof(HitRecordBuffer),
|
||||
// m_rtProperties.shaderGroupBaseAlignment);
|
||||
//vk::DeviceAddress sbtAddress = m_device.getBufferAddress({m_rtSBTBuffer.buffer});
|
||||
|
||||
//using Stride = vk::StridedDeviceAddressRegionKHR;
|
||||
//std::array<Stride, 4> strideAddresses{
|
||||
// Stride{sbtAddress + 0u * groupSize, groupStride, groupSize * 1}, // raygen
|
||||
// Stride{sbtAddress + 1u * groupSize, groupStride, groupSize * 2}, // miss
|
||||
// Stride{sbtAddress + 3u * groupSize, hitGroupSize, hitGroupSize * 2}, // hit
|
||||
// Stride{0u, 0u, 0u}}; // callable
|
||||
|
||||
//strideAddresses[0] = m_sbtWrapper.getRaygenRegion();
|
||||
//strideAddresses[1] = m_sbtWrapper.getMissRegion();
|
||||
//strideAddresses[2] = m_sbtWrapper.getHitRegion();
|
||||
|
||||
//cmdBuf.traceRaysKHR(&strideAddresses[0], &strideAddresses[1], &strideAddresses[2],
|
||||
// &strideAddresses[3], //
|
||||
// m_size.width, m_size.height, 1); //
|
||||
|
||||
//std::array<vk::StridedDeviceAddressRegionKHR, 4> regions;
|
||||
//regions[0] = m_sbtWrapper.getRegion(SBTWrapper::eRaygen);
|
||||
//regions[1] = m_sbtWrapper.getRegion(SBTWrapper::eMiss);
|
||||
//regions[2] = m_sbtWrapper.getRegion(SBTWrapper::eHit);
|
||||
//regions[3] = m_sbtWrapper.getRegion(SBTWrapper::eCallable);
|
||||
|
||||
#ifdef USE_SBT_WRAPPER
|
||||
auto& regions = m_sbtWrapper.getRegions();
|
||||
vkCmdTraceRaysKHR(cmdBuf, ®ions[0], ®ions[1], ®ions[2], ®ions[3], m_size.width, m_size.height, 1);
|
||||
|
||||
#else
|
||||
vkCmdTraceRaysKHR(cmdBuf, &m_rgenRegion, &m_missRegion, &m_hitRegion, &m_callRegion, m_size.width, m_size.height, 1);
|
||||
#endif
|
||||
|
||||
m_debug.endLabel(cmdBuf);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue