Skip to content

Commit

Permalink
updating samples to use new push constants API
Browse files Browse the repository at this point in the history
  • Loading branch information
natevm committed Oct 8, 2023
1 parent 29b8973 commit 8bb9e20
Show file tree
Hide file tree
Showing 38 changed files with 466 additions and 401 deletions.
58 changes: 32 additions & 26 deletions gprt/gprt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8531,7 +8531,7 @@ gprtGeomSetParameters(GPRTGeom _geometry, void *parameters, int deviceID) {

void
gprtGeomTypeRasterize(GPRTContext _context, GPRTGeomType _geomType, uint32_t numGeometry, GPRTGeom *_geometry,
uint32_t rasterType, uint32_t *instanceCounts) {
uint32_t rasterType, uint32_t *instanceCounts, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();

Context *context = (Context *) _context;
Expand Down Expand Up @@ -8592,6 +8592,12 @@ gprtGeomTypeRasterize(GPRTContext _context, GPRTGeomType _geomType, uint32_t num

err = vkBeginCommandBuffer(context->graphicsCommandBuffer, &cmdBufInfo);

if (pushConstantsSize > 0) {
if (pushConstantsSize > 128) LOG_ERROR("Push constants size exceeds maximum 128 byte limit!");
vkCmdPushConstants(context->graphicsCommandBuffer, geometryType->raster[rasterType].pipelineLayout,
VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT, 0, pushConstantsSize, pushConstants);
}

// Transition our attachments into optimal attachment formats
geometryType->raster[rasterType].colorAttachment->setImageLayout(
context->graphicsCommandBuffer, geometryType->raster[rasterType].colorAttachment->image,
Expand Down Expand Up @@ -9861,19 +9867,19 @@ gprtBuildShaderBindingTable(GPRTContext _context, GPRTBuildSBTFlags flags) {
}

GPRT_API void
gprtRayGenLaunch1D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x) {
gprtRayGenLaunch1D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
gprtRayGenLaunch2D(_context, _rayGen, dims_x, 1);
gprtRayGenLaunch2D(_context, _rayGen, dims_x, 1, pushConstantsSize, pushConstants);
}

GPRT_API void
gprtRayGenLaunch2D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, uint32_t dims_y) {
gprtRayGenLaunch2D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, uint32_t dims_y, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
gprtRayGenLaunch3D(_context, _rayGen, dims_x, dims_y, 1);
gprtRayGenLaunch3D(_context, _rayGen, dims_x, dims_y, 1, pushConstantsSize, pushConstants);
}

GPRT_API void
gprtRayGenLaunch3D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z) {
gprtRayGenLaunch3D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
assert(_rayGen);

Expand All @@ -9900,14 +9906,14 @@ gprtRayGenLaunch3D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, ui
vkCmdBindPipeline(context->graphicsCommandBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR,
context->raytracingPipeline);

struct PushConstants {
uint64_t pad[16] = {requestedFeatures.numRayTypes, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
} pushConstants;
vkCmdPushConstants(context->graphicsCommandBuffer, context->raytracingPipelineLayout,
VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
VK_SHADER_STAGE_RAYGEN_BIT_KHR,
0, sizeof(PushConstants), &pushConstants);
if (pushConstantsSize > 0) {
if (pushConstantsSize > 128) LOG_ERROR("Push constants size exceeds maximum 128 byte limit!");
vkCmdPushConstants(context->graphicsCommandBuffer, context->raytracingPipelineLayout,
VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_ANY_HIT_BIT_KHR |
VK_SHADER_STAGE_INTERSECTION_BIT_KHR | VK_SHADER_STAGE_MISS_BIT_KHR |
VK_SHADER_STAGE_RAYGEN_BIT_KHR,
0, pushConstantsSize, pushConstants);
}

auto getBufferDeviceAddress = [](VkDevice device, VkBuffer buffer) -> uint64_t {
VkBufferDeviceAddressInfoKHR bufferDeviceAI{};
Expand Down Expand Up @@ -10000,7 +10006,7 @@ gprtRayGenLaunch3D(GPRTContext _context, GPRTRayGen _rayGen, uint32_t dims_x, ui
}

void
gprtComputeLaunch(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z) {
gprtComputeLaunch(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z, size_t pushConstantsSize, void* pushConstants) {
assert(_compute);

Context *context = (Context *) _context;
Expand Down Expand Up @@ -10054,11 +10060,11 @@ gprtComputeLaunch(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, u
std::to_string(context->deviceProperties.limits.maxComputeWorkGroupCount[2]) + ")\n");
}

struct PushConstants {
uint64_t pad[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
} pushConstants;
vkCmdPushConstants(context->graphicsCommandBuffer, compute->pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, 0,
sizeof(PushConstants), &pushConstants);
if (pushConstantsSize > 0) {
if (pushConstantsSize > 128) LOG_ERROR("Push constants size exceeds maximum 128 byte limit!");
vkCmdPushConstants(context->graphicsCommandBuffer, compute->pipelineLayout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, pushConstantsSize, pushConstants);
}

vkCmdDispatch(context->graphicsCommandBuffer, dims_x, dims_y, dims_z);

Expand Down Expand Up @@ -10090,21 +10096,21 @@ gprtComputeLaunch(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, u
}

GPRT_API void
gprtComputeLaunch1D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x) {
gprtComputeLaunch1D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
gprtComputeLaunch(_context, _compute, dims_x, 1, 1);
gprtComputeLaunch(_context, _compute, dims_x, 1, 1, pushConstantsSize, pushConstants);
}

GPRT_API void
gprtComputeLaunch2D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y) {
gprtComputeLaunch2D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
gprtComputeLaunch(_context, _compute, dims_x, dims_y, 1);
gprtComputeLaunch(_context, _compute, dims_x, dims_y, 1, pushConstantsSize, pushConstants);
}

GPRT_API void
gprtComputeLaunch3D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z) {
gprtComputeLaunch3D(GPRTContext _context, GPRTCompute _compute, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z, size_t pushConstantsSize, void* pushConstants) {
LOG_API_CALL();
gprtComputeLaunch(_context, _compute, dims_x, dims_y, dims_z);
gprtComputeLaunch(_context, _compute, dims_x, dims_y, dims_z, pushConstantsSize, pushConstants);
}

GPRT_API void
Expand Down
134 changes: 107 additions & 27 deletions gprt/gprt_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ gprtTrianglesSetIndices(GPRTGeomOf<T1> triangles, GPRTBufferOf<T2> indices, uint
for the given AABB geometry. This _has_ to be set before the accel(s)
that this geom is used in get built. */
GPRT_API void gprtAABBsSetPositions(GPRTGeom aabbs, GPRTBuffer positions, uint32_t count,
uint32_t stride GPRT_IF_CPP(= 2 * sizeof(float3)), uint32_t offset GPRT_IF_CPP(= 0));
uint32_t stride GPRT_IF_CPP(= 2 * sizeof(float3)),
uint32_t offset GPRT_IF_CPP(= 0));

template <typename T1, typename T2>
void
Expand Down Expand Up @@ -1198,16 +1199,41 @@ gprtGeomTypeSetRasterAttachments(GPRTGeomTypeOf<T1> type, int rasterType, GPRTTe
(GPRTTexture) depthAttachment);
}

/**
* @brief Rasterize a list of GPRT geometry. (Currently assuming all geometry are GPRT_TRIANGLES)
*
* @param context The GPRT context used to rasterize the triangles
* @param geomType The geometry type to fetch raster programs from
* @param numGeometry The number of GPRTGeoms to rasterize
* @param geometry A pointer to a list of GPRT geometry, to be rasterized in the order given
* @param rasterType Controls which rasterization programs to use. Analogous to "Ray Type", in
* that it indexes into the shader binding table.
* @param instanceCounts How many instances of each geometry in the "geometry" list to rasterize.
* Useful for rasterizing the same geometry in many different locations. If null pointer, this parameter is ignored.
* Otherwise, we expect a list of length "numGeometry"
* @param pushConstantsSize The size of the push constants structure to upload to the device.
* If 0, no push constants are updated. Currently limited to 128 bytes or less.
* @param pushConstants A pointer to a structure of push constants to upload to the device.
*/
void gprtGeomTypeRasterize(GPRTContext context, GPRTGeomType geomType, uint32_t numGeometry, GPRTGeom *geometry,
uint32_t rasterType = 0, uint32_t *instanceCounts = nullptr);
uint32_t rasterType, uint32_t *instanceCounts,
size_t pushConstantsSize GPRT_IF_CPP(= 0),
void *pushConstants GPRT_IF_CPP(= 0));

template <typename T>
template <typename RecordType>
void
gprtGeomTypeRasterize(GPRTContext context, GPRTGeomTypeOf<T> geomType, uint32_t numGeometry, GPRTGeomOf<T> *geometry,
uint32_t rayType = 0, uint32_t *instanceCounts = nullptr) {
gprtGeomTypeRasterize(GPRTContext context, GPRTGeomTypeOf<RecordType> geomType, uint32_t numGeometry, GPRTGeomOf<RecordType> *geometry,
uint32_t rayType, uint32_t *instanceCounts) {
gprtGeomTypeRasterize(context, (GPRTGeomType) geomType, numGeometry, (GPRTGeom *) geometry, rayType, instanceCounts);
}

template <typename RecordType, typename PushConstantsType>
void
gprtGeomTypeRasterize(GPRTContext context, GPRTGeomTypeOf<RecordType> geomType, uint32_t numGeometry, GPRTGeomOf<RecordType> *geometry,
uint32_t rayType, uint32_t *instanceCounts, PushConstantsType pc) {
gprtGeomTypeRasterize(context, (GPRTGeomType) geomType, numGeometry, (GPRTGeom *) geometry, rayType, instanceCounts, sizeof(PushConstantsType), &pc);
}

/**
* @brief Creates a sampler object for use in sampling textures. Behavior below
* defines how texture.SampleLevel and texture.SampleGrad operate. The "sampled
Expand Down Expand Up @@ -1593,9 +1619,9 @@ GPRT_API void gprtBufferTextureCopy(GPRTContext context, GPRTBuffer buffer, GPRT
template <typename T1, typename T2>
void
gprtBufferTextureCopy(GPRTContext context, GPRTBufferOf<T1> buffer, GPRTTextureOf<T2> texture, uint32_t bufferOffset,
uint32_t bufferRowLength, uint32_t bufferImageHeight, uint32_t imageOffsetX, uint32_t imageOffsetY,
uint32_t imageOffsetZ, uint32_t imageExtentX, uint32_t imageExtentY, uint32_t imageExtentZ,
int srcDeviceID GPRT_IF_CPP(= 0), int dstDeviceID GPRT_IF_CPP(= 0)) {
uint32_t bufferRowLength, uint32_t bufferImageHeight, uint32_t imageOffsetX,
uint32_t imageOffsetY, uint32_t imageOffsetZ, uint32_t imageExtentX, uint32_t imageExtentY,
uint32_t imageExtentZ, int srcDeviceID GPRT_IF_CPP(= 0), int dstDeviceID GPRT_IF_CPP(= 0)) {
gprtBufferTextureCopy(context, (GPRTBuffer) buffer, (GPRTTexture) texture, bufferOffset, bufferRowLength,
bufferImageHeight, imageOffsetX, imageOffsetY, imageOffsetZ, imageExtentX, imageExtentY,
imageExtentZ, srcDeviceID, dstDeviceID);
Expand Down Expand Up @@ -1702,59 +1728,113 @@ gprtBufferSaveImage(GPRTBufferOf<T> buffer, uint32_t width, uint32_t height, con
gprtBufferSaveImage((GPRTBuffer) buffer, width, height, imageName);
}

GPRT_API void gprtRayGenLaunch1D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x);
GPRT_API void gprtRayGenLaunch1D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x,
size_t pushConstantsSize GPRT_IF_CPP(= 0), void *pushConstants GPRT_IF_CPP(= 0));

template <typename T>
template <typename RecordType>
void
gprtRayGenLaunch1D(GPRTContext context, GPRTRayGenOf<T> rayGen, uint32_t dims_x) {
gprtRayGenLaunch1D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x) {
gprtRayGenLaunch1D(context, (GPRTRayGen) rayGen, dims_x);
}

template <typename RecordType, typename PushConstantsType>
void
gprtRayGenLaunch1D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtRayGenLaunch1D(context, (GPRTRayGen) rayGen, dims_x, sizeof(PushConstantsType), &pushConstants);
}

/*! Executes a ray tracing pipeline with the given raygen program.
This call will block until the raygen program returns. */
GPRT_API void gprtRayGenLaunch2D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x, uint32_t dims_y);
GPRT_API void gprtRayGenLaunch2D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x, uint32_t dims_y,
size_t pushConstantsSize GPRT_IF_CPP(= 0), void *pushConstants GPRT_IF_CPP(= 0));

template <typename T>
template <typename RecordType>
void
gprtRayGenLaunch2D(GPRTContext context, GPRTRayGenOf<T> rayGen, uint32_t dims_x, uint32_t dims_y) {
gprtRayGenLaunch2D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x, uint32_t dims_y) {
gprtRayGenLaunch2D(context, (GPRTRayGen) rayGen, dims_x, dims_y);
}

template <typename RecordType, typename PushConstantsType>
void
gprtRayGenLaunch2D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x, uint32_t dims_y, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtRayGenLaunch2D(context, (GPRTRayGen) rayGen, dims_x, dims_y, sizeof(PushConstantsType), &pushConstants);
}

/*! 3D-launch variant of \see gprtRayGenLaunch2D */
GPRT_API void gprtRayGenLaunch3D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z);
GPRT_API void gprtRayGenLaunch3D(GPRTContext context, GPRTRayGen rayGen, uint32_t dims_x, uint32_t dims_y,
uint32_t dims_z, size_t pushConstantsSize GPRT_IF_CPP(= 0),
void *pushConstants GPRT_IF_CPP(= 0));

template <typename T>
template <typename RecordType>
void
gprtRayGenLaunch3D(GPRTContext context, GPRTRayGenOf<T> rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z) {
gprtRayGenLaunch3D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z) {
gprtRayGenLaunch3D(context, (GPRTRayGen) rayGen, dims_x, dims_y, dims_z);
}

GPRT_API void gprtComputeLaunch1D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups);
template <typename RecordType, typename PushConstantsType>
void
gprtRayGenLaunch3D(GPRTContext context, GPRTRayGenOf<RecordType> rayGen, uint32_t dims_x, uint32_t dims_y, uint32_t dims_z, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtRayGenLaunch3D(context, (GPRTRayGen) rayGen, dims_x, dims_y, dims_z, sizeof(PushConstantsType), &pushConstants);
}

template <typename T>
GPRT_API void gprtComputeLaunch1D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups,
size_t pushConstantsSize GPRT_IF_CPP(= 0),
void *pushConstants GPRT_IF_CPP(= 0));

template <typename RecordType>
void
gprtComputeLaunch1D(GPRTContext context, GPRTComputeOf<T> compute, uint32_t x_workgroups) {
gprtComputeLaunch1D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups) {
gprtComputeLaunch1D(context, (GPRTCompute) compute, x_workgroups);
}

GPRT_API void gprtComputeLaunch2D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups, uint32_t y_workgroups);
template <typename RecordType, typename PushConstantsType>
void
gprtComputeLaunch1D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtComputeLaunch1D(context, (GPRTCompute) compute, x_workgroups, sizeof(PushConstantsType), &pushConstants);
}

template <typename T>
GPRT_API void gprtComputeLaunch2D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups,
uint32_t y_workgroups,
size_t pushConstantsSize GPRT_IF_CPP(= 0),
void *pushConstants GPRT_IF_CPP(= 0));

template <typename RecordType>
void
gprtComputeLaunch2D(GPRTContext context, GPRTComputeOf<T> compute, uint32_t x_workgroups, uint32_t y_workgroups) {
gprtComputeLaunch2D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups, uint32_t y_workgroups) {
gprtComputeLaunch2D(context, (GPRTCompute) compute, x_workgroups, y_workgroups);
}

GPRT_API void gprtComputeLaunch3D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups, uint32_t y_workgroups,
uint32_t z_workgroups);
template <typename RecordType, typename PushConstantsType>
void
gprtComputeLaunch2D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups, uint32_t y_workgroups, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtComputeLaunch2D(context, (GPRTCompute) compute, x_workgroups, y_workgroups, sizeof(PushConstantsType), &pushConstants);
}

template <typename T>
GPRT_API void gprtComputeLaunch3D(GPRTContext context, GPRTCompute compute, uint32_t x_workgroups,
uint32_t y_workgroups, uint32_t z_workgroups,
size_t pushConstantsSize GPRT_IF_CPP(= 0),
void *pushConstants GPRT_IF_CPP(= 0));

template <typename RecordType>
void
gprtComputeLaunch3D(GPRTContext context, GPRTComputeOf<T> compute, uint32_t x_workgroups, uint32_t y_workgroups,
gprtComputeLaunch3D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups, uint32_t y_workgroups,
uint32_t z_workgroups) {
gprtComputeLaunch3D(context, (GPRTCompute) compute, x_workgroups, y_workgroups, z_workgroups);
}

template <typename RecordType, typename PushConstantsType>
void
gprtComputeLaunch3D(GPRTContext context, GPRTComputeOf<RecordType> compute, uint32_t x_workgroups, uint32_t y_workgroups,
uint32_t z_workgroups, PushConstantsType pushConstants) {
static_assert(sizeof(PushConstantsType) <= 128, "Current GPRT push constant size limited to 128 bytes or less");
gprtComputeLaunch3D(context, (GPRTCompute) compute, x_workgroups, y_workgroups, z_workgroups, sizeof(PushConstantsType), &pushConstants);
}

GPRT_API void gprtBeginProfile(GPRTContext context);

// returned results are in milliseconds
Expand Down
6 changes: 4 additions & 2 deletions samples/s01-singleTriangle/deviceCode.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include "sharedCode.h"

[[vk::push_constant]] PushConstants pc;

struct [raypayload] Payload {
float3 color : read(caller) : write(closesthit, miss);
};
Expand All @@ -40,9 +42,9 @@ GPRT_RAYGEN_PROGRAM(simpleRayGen, (RayGenData, record)) {
float2 screen = (float2(pixelID) + float2(.5f, .5f)) / float2(fbSize);

RayDesc rayDesc;
rayDesc.Origin = record.camera.pos;
rayDesc.Origin = pc.camera.pos;
rayDesc.Direction =
normalize(record.camera.dir_00 + screen.x * record.camera.dir_du + screen.y * record.camera.dir_dv);
normalize(pc.camera.dir_00 + screen.x * pc.camera.dir_du + screen.y * pc.camera.dir_dv);
rayDesc.TMin = 0.001;
rayDesc.TMax = 10000.0;
RaytracingAccelerationStructure world = gprt::getAccelHandle(record.world);
Expand Down
Loading

0 comments on commit 8bb9e20

Please sign in to comment.