Skip to content

Commit

Permalink
Add ByteAddressBuffer support (#248)
Browse files Browse the repository at this point in the history
* Add ByteAddressBuffer support

* Fix byteAddressBuffer multi-entrypoint
  • Loading branch information
spencer-lunarg authored Jan 26, 2024
1 parent 42878ed commit 82ac7e5
Show file tree
Hide file tree
Showing 16 changed files with 985 additions and 12 deletions.
12 changes: 12 additions & 0 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,18 @@ void SpvReflectToYaml::WriteDescriptorBinding(std::ostream& os, const SpvReflect
assert(itor != descriptor_binding_to_index_.end());
os << t1 << "uav_counter_binding: *db" << itor->second << " # " << SafeString(db.uav_counter_binding->name) << std::endl;
}

if (db.byte_address_buffer_offset_count > 0) {
os << t1 << "ByteAddressBuffer offsets: [";
for (uint32_t i = 0; i < db.byte_address_buffer_offset_count; i++) {
os << db.byte_address_buffer_offsets[i];
if (i < (db.byte_address_buffer_offset_count - 1)) {
os << ", ";
}
}
os << "]\n";
}

if (verbosity_ >= 1) {
// SpvReflectTypeDescription* type_description;
if (db.type_description == nullptr) {
Expand Down
187 changes: 175 additions & 12 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ typedef struct SpvReflectPrvString {
// OpAtomicIAdd -> OpAccessChain -> OpVariable
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
typedef struct SpvReflectPrvAccessedVariable {
SpvReflectPrvNode* p_node;
uint32_t result_id;
uint32_t variable_ptr;
} SpvReflectPrvAccessedVariable;
Expand Down Expand Up @@ -981,6 +982,15 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) {
case SpvOpFunctionParameter: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
case SpvOpBitcast:
case SpvOpShiftRightLogical:
case SpvOpIAdd:
case SpvOpISub:
case SpvOpIMul:
case SpvOpUDiv:
case SpvOpSDiv: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
}

if (p_node->is_type) {
Expand Down Expand Up @@ -1152,6 +1162,7 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
const uint32_t ptr_index = p_node->word_offset + 3;
SpvReflectPrvAccessedVariable* access_ptr = &p_func->accessed_variables[p_func->accessed_variable_count];

access_ptr->p_node = p_node;
// Need to track Result ID as not sure there has been any memory access through here yet
CHECKED_READU32(p_parser, result_index, access_ptr->result_id);
CHECKED_READU32(p_parser, ptr_index, access_ptr->variable_ptr);
Expand All @@ -1160,11 +1171,12 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
case SpvOpStore: {
const uint32_t result_index = p_node->word_offset + 2;
CHECKED_READU32(p_parser, result_index, p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
p_func->accessed_variables[p_func->accessed_variable_count].p_node = p_node;
(++p_func->accessed_variable_count);
} break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized: {
// There is no result_id is being zero is same as being invalid
// There is no result_id or node, being zero is same as being invalid
CHECKED_READU32(p_parser, p_node->word_offset + 1,
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
Expand Down Expand Up @@ -3221,6 +3233,106 @@ static SpvReflectResult TraverseCallGraph(SpvReflectPrvParser* p_parser, SpvRefl
return SPV_REFLECT_RESULT_SUCCESS;
}

static uint32_t GetUint32Constant(SpvReflectPrvParser* p_parser, uint32_t id) {
uint32_t result = (uint32_t)INVALID_VALUE;
SpvReflectPrvNode* p_node = FindNode(p_parser, id);
if (p_node && p_node->op == SpvOpConstant) {
UNCHECKED_READU32(p_parser, p_node->word_offset + 3, result);
}
return result;
}

static bool HasByteAddressBufferOffset(SpvReflectPrvNode* p_node, SpvReflectDescriptorBinding* p_binding) {
return IsNotNull(p_node) && IsNotNull(p_binding) && p_node->op == SpvOpAccessChain && p_node->word_count == 6 &&
(p_binding->user_type == SPV_REFLECT_USER_TYPE_BYTE_ADDRESS_BUFFER ||
p_binding->user_type == SPV_REFLECT_USER_TYPE_RW_BYTE_ADDRESS_BUFFER);
}

static SpvReflectResult ParseByteAddressBuffer(SpvReflectPrvParser* p_parser, SpvReflectPrvNode* p_node,
SpvReflectDescriptorBinding* p_binding) {
const SpvReflectResult not_found = SPV_REFLECT_RESULT_SUCCESS;
if (!HasByteAddressBufferOffset(p_node, p_binding)) {
return not_found;
}

uint32_t offset = 0; // starting offset

uint32_t base_id = 0;
// expect first index of 2D access is zero
UNCHECKED_READU32(p_parser, p_node->word_offset + 4, base_id);
if (GetUint32Constant(p_parser, base_id) != 0) {
return not_found;
}
UNCHECKED_READU32(p_parser, p_node->word_offset + 5, base_id);
SpvReflectPrvNode* p_next_node = FindNode(p_parser, base_id);
if (IsNull(p_next_node)) {
return not_found;
} else if (p_next_node->op == SpvOpConstant) {
// The access chain might just be a constant right to the offset
offset = GetUint32Constant(p_parser, base_id);
p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
p_binding->byte_address_buffer_offset_count++;
return SPV_REFLECT_RESULT_SUCCESS;
}

// there is usually 2 (sometimes 3) instrucitons that make up the arithmetic logic to calculate the offset
SpvReflectPrvNode* arithmetic_node_stack[8];
uint32_t arithmetic_count = 0;

while (IsNotNull(p_next_node)) {
if (p_next_node->op == SpvOpLoad || p_next_node->op == SpvOpBitcast || p_next_node->op == SpvOpConstant) {
break; // arithmetic starts here
}
arithmetic_node_stack[arithmetic_count++] = p_next_node;
if (arithmetic_count >= 8) {
return not_found;
}

UNCHECKED_READU32(p_parser, p_next_node->word_offset + 3, base_id);
p_next_node = FindNode(p_parser, base_id);
}

const uint32_t count = arithmetic_count;
for (uint32_t i = 0; i < count; i++) {
p_next_node = arithmetic_node_stack[--arithmetic_count];
// All arithmetic ops takes 2 operands, assumption is the 2nd operand has the constant
UNCHECKED_READU32(p_parser, p_next_node->word_offset + 4, base_id);
uint32_t value = GetUint32Constant(p_parser, base_id);
if (value == INVALID_VALUE) {
return not_found;
}

switch (p_next_node->op) {
case SpvOpShiftRightLogical:
offset >>= value;
break;
case SpvOpIAdd:
offset += value;
break;
case SpvOpISub:
offset -= value;
break;
case SpvOpIMul:
offset *= value;
break;
case SpvOpUDiv:
offset /= value;
break;
case SpvOpSDiv:
// OpConstant might be signed, but value should never be negative
assert((int32_t)value > 0);
offset /= value;
break;
default:
return not_found;
}
}

p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
p_binding->byte_address_buffer_offset_count++;
return SPV_REFLECT_RESULT_SUCCESS;
}

static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module,
SpvReflectEntryPoint* p_entry, size_t uniform_count, uint32_t* uniforms,
size_t push_constant_count, uint32_t* push_constants) {
Expand Down Expand Up @@ -3253,6 +3365,7 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
called_function_count = 0;
result = TraverseCallGraph(p_parser, p_func, &called_function_count, p_called_functions, 0);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_called_functions);
return result;
}

Expand Down Expand Up @@ -3296,30 +3409,77 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars

// Do set intersection to find the used uniform and push constants
size_t used_uniform_count = 0;
SpvReflectResult result0 = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count,
&p_entry->used_uniforms, &used_uniform_count);
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count, &p_entry->used_uniforms,
&used_uniform_count);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}

size_t used_push_constant_count = 0;
SpvReflectResult result1 =
IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}

for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[i];
uint32_t byte_address_buffer_offset_count = 0;

for (uint32_t j = 0; j < used_acessed_count; j++) {
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
p_binding->accessed = 1;

if (HasByteAddressBufferOffset(p_used_accesses[j].p_node, p_binding)) {
byte_address_buffer_offset_count++;
}
}
}

// only if SPIR-V has ByteAddressBuffer user type
if (byte_address_buffer_offset_count > 0) {
bool multi_entrypoint = p_binding->byte_address_buffer_offset_count > 0;
if (multi_entrypoint) {
// If there is a 2nd entrypoint, we can have multiple entry points, in this case we want to just combine the accessed
// offsets and then de-duplicate it
uint32_t* prev_byte_address_buffer_offsets = p_binding->byte_address_buffer_offsets;
p_binding->byte_address_buffer_offsets =
(uint32_t*)calloc(byte_address_buffer_offset_count + p_binding->byte_address_buffer_offset_count, sizeof(uint32_t));
memcpy(p_binding->byte_address_buffer_offsets, prev_byte_address_buffer_offsets,
sizeof(uint32_t) * p_binding->byte_address_buffer_offset_count);
SafeFree(prev_byte_address_buffer_offsets);
} else {
// possible not all allocated offset slots are used, but this will be a max per binding
p_binding->byte_address_buffer_offsets = (uint32_t*)calloc(byte_address_buffer_offset_count, sizeof(uint32_t));
}

if (IsNull(p_binding->byte_address_buffer_offsets)) {
SafeFree(p_used_accesses);
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

for (uint32_t j = 0; j < used_acessed_count; j++) {
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
result = ParseByteAddressBuffer(p_parser, p_used_accesses[j].p_node, p_binding);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}
}
}

if (multi_entrypoint) {
qsort(p_binding->byte_address_buffer_offsets, p_binding->byte_address_buffer_offset_count,
sizeof(*(p_binding->byte_address_buffer_offsets)), SortCompareUint32);
p_binding->byte_address_buffer_offset_count =
(uint32_t)DedupSortedUint32(p_binding->byte_address_buffer_offsets, p_binding->byte_address_buffer_offset_count);
}
}
}

SafeFree(p_used_accesses);
if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
return result0;
}
if (result1 != SPV_REFLECT_RESULT_SUCCESS) {
return result1;
}

p_entry->used_uniform_count = (uint32_t)used_uniform_count;
p_entry->used_push_constant_count = (uint32_t)used_push_constant_count;
Expand Down Expand Up @@ -4112,6 +4272,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) {
// Descriptor binding blocks
for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i];
if (IsNotNull(p_descriptor->byte_address_buffer_offsets)) {
SafeFree(p_descriptor->byte_address_buffer_offsets);
}
SafeFreeBlockVariables(&p_descriptor->block);
}
SafeFree(p_module->descriptor_bindings);
Expand Down
2 changes: 2 additions & 0 deletions spirv_reflect.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ typedef struct SpvReflectDescriptorBinding {
uint32_t accessed;
uint32_t uav_counter_id;
struct SpvReflectDescriptorBinding* uav_counter_binding;
uint32_t byte_address_buffer_offset_count;
uint32_t* byte_address_buffer_offsets;

SpvReflectTypeDescription* type_description;

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ all_descriptor_bindings:
accessed: 1
uav_counter_id: 4294967295
uav_counter_binding:
ByteAddressBuffer offsets: [4, 5, 11, 13]
type_description: *td1
word_offset: { binding: 129, set: 125 }
user_type: ByteAddressBuffer
Expand Down
25 changes: 25 additions & 0 deletions tests/user_type/byte_address_buffer_1.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// dxc -spirv -fspv-reflect -T cs_6_0 -E csmain -fspv-target-env=vulkan1.2
uint g_global;

struct MaterialData_t {
float4 g_vTest;
float2 g_vTest2;
float3 g_vTest3;
uint g_tTexture1;
uint g_tTexture2;
bool g_bTest1;
bool g_bTest2;
};

static MaterialData_t _g_MaterialData;

ByteAddressBuffer g_MaterialData : register (t4 , space1);
RWStructuredBuffer<uint2> Output : register(u1);

[numthreads(1, 1, 1)]
void csmain(uint3 tid : SV_DispatchThreadID) {
uint2 a = g_MaterialData.Load2( tid.x + g_global );
uint b = g_MaterialData.Load( tid.x + g_global );
uint2 c = g_MaterialData.Load2(4);
Output[tid.x] = a * uint2(b, b) * c;
}
Binary file added tests/user_type/byte_address_buffer_1.spv
Binary file not shown.
Loading

0 comments on commit 82ac7e5

Please sign in to comment.