Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ByteAddressBuffer support #248

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,18 @@ void SpvReflectToYaml::WriteDescriptorBinding(std::ostream& os, const SpvReflect
assert(itor != descriptor_binding_to_index_.end());
os << t1 << "uav_counter_binding: *db" << itor->second << " # " << SafeString(db.uav_counter_binding->name) << std::endl;
}

if (db.byte_address_buffer_offset_count > 0) {
os << t1 << "ByteAddressBuffer offsets: [";
for (uint32_t i = 0; i < db.byte_address_buffer_offset_count; i++) {
os << db.byte_address_buffer_offsets[i];
if (i < (db.byte_address_buffer_offset_count - 1)) {
os << ", ";
}
}
os << "]\n";
}

if (verbosity_ >= 1) {
// SpvReflectTypeDescription* type_description;
if (db.type_description == nullptr) {
Expand Down
167 changes: 155 additions & 12 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ typedef struct SpvReflectPrvString {
// OpAtomicIAdd -> OpAccessChain -> OpVariable
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
typedef struct SpvReflectPrvAccessedVariable {
SpvReflectPrvNode* p_node;
uint32_t result_id;
uint32_t variable_ptr;
} SpvReflectPrvAccessedVariable;
Expand Down Expand Up @@ -981,6 +982,15 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) {
case SpvOpFunctionParameter: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
case SpvOpBitcast:
case SpvOpShiftRightLogical:
case SpvOpIAdd:
case SpvOpISub:
case SpvOpIMul:
case SpvOpUDiv:
case SpvOpSDiv: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
}

if (p_node->is_type) {
Expand Down Expand Up @@ -1152,6 +1162,7 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
const uint32_t ptr_index = p_node->word_offset + 3;
SpvReflectPrvAccessedVariable* access_ptr = &p_func->accessed_variables[p_func->accessed_variable_count];

access_ptr->p_node = p_node;
// Need to track Result ID as not sure there has been any memory access through here yet
CHECKED_READU32(p_parser, result_index, access_ptr->result_id);
CHECKED_READU32(p_parser, ptr_index, access_ptr->variable_ptr);
Expand All @@ -1160,11 +1171,12 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
case SpvOpStore: {
const uint32_t result_index = p_node->word_offset + 2;
CHECKED_READU32(p_parser, result_index, p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
p_func->accessed_variables[p_func->accessed_variable_count].p_node = p_node;
(++p_func->accessed_variable_count);
} break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized: {
// There is no result_id is being zero is same as being invalid
// There is no result_id or node, being zero is same as being invalid
CHECKED_READU32(p_parser, p_node->word_offset + 1,
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
Expand Down Expand Up @@ -3221,6 +3233,106 @@ static SpvReflectResult TraverseCallGraph(SpvReflectPrvParser* p_parser, SpvRefl
return SPV_REFLECT_RESULT_SUCCESS;
}

static uint32_t GetUint32Constant(SpvReflectPrvParser* p_parser, uint32_t id) {
uint32_t result = (uint32_t)INVALID_VALUE;
SpvReflectPrvNode* p_node = FindNode(p_parser, id);
if (p_node && p_node->op == SpvOpConstant) {
UNCHECKED_READU32(p_parser, p_node->word_offset + 3, result);
}
return result;
}

static bool HasByteAddressBufferOffset(SpvReflectPrvNode* p_node, SpvReflectDescriptorBinding* p_binding) {
return IsNotNull(p_node) && IsNotNull(p_binding) && p_node->op == SpvOpAccessChain && p_node->word_count == 6 &&
(p_binding->user_type == SPV_REFLECT_USER_TYPE_BYTE_ADDRESS_BUFFER ||
p_binding->user_type == SPV_REFLECT_USER_TYPE_RW_BYTE_ADDRESS_BUFFER);
}

static SpvReflectResult ParseByteAddressBuffer(SpvReflectPrvParser* p_parser, SpvReflectPrvNode* p_node,
SpvReflectDescriptorBinding* p_binding) {
const SpvReflectResult not_found = SPV_REFLECT_RESULT_SUCCESS;
if (!HasByteAddressBufferOffset(p_node, p_binding)) {
return not_found;
}

uint32_t offset = 0; // starting offset

uint32_t base_id = 0;
// expect first index of 2D access is zero
UNCHECKED_READU32(p_parser, p_node->word_offset + 4, base_id);
if (GetUint32Constant(p_parser, base_id) != 0) {
return not_found;
}
UNCHECKED_READU32(p_parser, p_node->word_offset + 5, base_id);
SpvReflectPrvNode* p_next_node = FindNode(p_parser, base_id);
if (IsNull(p_next_node)) {
return not_found;
} else if (p_next_node->op == SpvOpConstant) {
// The access chain might just be a constant right to the offset
offset = GetUint32Constant(p_parser, base_id);
p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
p_binding->byte_address_buffer_offset_count++;
return SPV_REFLECT_RESULT_SUCCESS;
}

// there is usually 2 (sometimes 3) instrucitons that make up the arithmetic logic to calculate the offset
SpvReflectPrvNode* arithmetic_node_stack[8];
uint32_t arithmetic_count = 0;

while (IsNotNull(p_next_node)) {
if (p_next_node->op == SpvOpLoad || p_next_node->op == SpvOpBitcast || p_next_node->op == SpvOpConstant) {
break; // arithmetic starts here
}
arithmetic_node_stack[arithmetic_count++] = p_next_node;
if (arithmetic_count >= 8) {
Copy link
Contributor

@danginsburg danginsburg Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check be above this? Otherwise it will return if you have 8 operations when it could have succeeded (if you have exactly 8). I also don't understand why there is such a low cap on this, it seems rather inflexible. Would it be better to scan for how many first then allocate an array with a cap of like the size of the spir-v or something?

The problem is that SPIR-V reflection will completely fail in this case, not just that you won't get the offsets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so before it failed because I returned an error, not I return the same as "could not find" it it hits the cap

I guess there is no reason for a hard limit, it just I have yet to find a case where it is more then 2 here

Would it be better to scan for how many first

So the other issue is I do need "a cap" when doing this search anyways, there is no "correct" way for HLSL/Slang to flatten the ByteAddressBuffer so in the future, if some strange pattern arises, I still need some arbitrary number to stop at incase I get in some strange situation and don't want to crash

return not_found;
}

UNCHECKED_READU32(p_parser, p_next_node->word_offset + 3, base_id);
p_next_node = FindNode(p_parser, base_id);
}

const uint32_t count = arithmetic_count;
for (uint32_t i = 0; i < count; i++) {
p_next_node = arithmetic_node_stack[--arithmetic_count];
// All arithmetic ops takes 2 operands, assumption is the 2nd operand has the constant
UNCHECKED_READU32(p_parser, p_next_node->word_offset + 4, base_id);
uint32_t value = GetUint32Constant(p_parser, base_id);
if (value == INVALID_VALUE) {
return not_found;
}

switch (p_next_node->op) {
case SpvOpShiftRightLogical:
offset >>= value;
break;
case SpvOpIAdd:
offset += value;
break;
case SpvOpISub:
offset -= value;
break;
case SpvOpIMul:
offset *= value;
break;
case SpvOpUDiv:
offset /= value;
break;
case SpvOpSDiv:
// OpConstant might be signed, but value should never be negative
assert((int32_t)value > 0);
offset /= value;
break;
default:
return not_found;
}
}

p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
p_binding->byte_address_buffer_offset_count++;
return SPV_REFLECT_RESULT_SUCCESS;
}

static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module,
SpvReflectEntryPoint* p_entry, size_t uniform_count, uint32_t* uniforms,
size_t push_constant_count, uint32_t* push_constants) {
Expand Down Expand Up @@ -3253,6 +3365,7 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
called_function_count = 0;
result = TraverseCallGraph(p_parser, p_func, &called_function_count, p_called_functions, 0);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_called_functions);
return result;
}

Expand Down Expand Up @@ -3296,30 +3409,57 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars

// Do set intersection to find the used uniform and push constants
size_t used_uniform_count = 0;
SpvReflectResult result0 = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count,
&p_entry->used_uniforms, &used_uniform_count);
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, uniforms, uniform_count, &p_entry->used_uniforms,
&used_uniform_count);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}

size_t used_push_constant_count = 0;
SpvReflectResult result1 =
IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);
result = IntersectSortedAccessedVariable(p_used_accesses, used_acessed_count, push_constants, push_constant_count,
&p_entry->used_push_constants, &used_push_constant_count);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}

for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[i];
uint32_t byte_address_buffer_offset_count = 0;

for (uint32_t j = 0; j < used_acessed_count; j++) {
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
p_binding->accessed = 1;

if (HasByteAddressBufferOffset(p_used_accesses[j].p_node, p_binding)) {
byte_address_buffer_offset_count++;
}
}
}

// only if SPIR-V has ByteAddressBuffer user type
if (byte_address_buffer_offset_count > 0) {
// possible not all allocated offset slots are used, but this will be a max per binding
p_binding->byte_address_buffer_offsets = (uint32_t*)calloc(byte_address_buffer_offset_count, sizeof(uint32_t));
spencer-lunarg marked this conversation as resolved.
Show resolved Hide resolved
if (IsNull(p_binding->byte_address_buffer_offsets)) {
SafeFree(p_used_accesses);
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

for (uint32_t j = 0; j < used_acessed_count; j++) {
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
result = ParseByteAddressBuffer(p_parser, p_used_accesses[j].p_node, p_binding);
if (result != SPV_REFLECT_RESULT_SUCCESS) {
SafeFree(p_used_accesses);
return result;
}
}
}
}
}

SafeFree(p_used_accesses);
if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
return result0;
}
if (result1 != SPV_REFLECT_RESULT_SUCCESS) {
return result1;
}

p_entry->used_uniform_count = (uint32_t)used_uniform_count;
p_entry->used_push_constant_count = (uint32_t)used_push_constant_count;
Expand Down Expand Up @@ -4112,6 +4252,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) {
// Descriptor binding blocks
for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i];
if (IsNotNull(p_descriptor->byte_address_buffer_offsets)) {
SafeFree(p_descriptor->byte_address_buffer_offsets);
}
SafeFreeBlockVariables(&p_descriptor->block);
}
SafeFree(p_module->descriptor_bindings);
Expand Down
2 changes: 2 additions & 0 deletions spirv_reflect.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ typedef struct SpvReflectDescriptorBinding {
uint32_t accessed;
uint32_t uav_counter_id;
struct SpvReflectDescriptorBinding* uav_counter_binding;
uint32_t byte_address_buffer_offset_count;
uint32_t* byte_address_buffer_offsets;

SpvReflectTypeDescription* type_description;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ all_descriptor_bindings:
accessed: 1
uav_counter_id: 4294967295
uav_counter_binding:
ByteAddressBuffer offsets: [4, 5, 11, 13]
type_description: *td1
word_offset: { binding: 129, set: 125 }
user_type: ByteAddressBuffer
Expand Down
25 changes: 25 additions & 0 deletions tests/user_type/byte_address_buffer_1.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// dxc -spirv -fspv-reflect -T cs_6_0 -E csmain -fspv-target-env=vulkan1.2
uint g_global;

struct MaterialData_t {
float4 g_vTest;
float2 g_vTest2;
float3 g_vTest3;
uint g_tTexture1;
uint g_tTexture2;
bool g_bTest1;
bool g_bTest2;
};

static MaterialData_t _g_MaterialData;

ByteAddressBuffer g_MaterialData : register (t4 , space1);
RWStructuredBuffer<uint2> Output : register(u1);

[numthreads(1, 1, 1)]
void csmain(uint3 tid : SV_DispatchThreadID) {
uint2 a = g_MaterialData.Load2( tid.x + g_global );
uint b = g_MaterialData.Load( tid.x + g_global );
uint2 c = g_MaterialData.Load2(4);
Output[tid.x] = a * uint2(b, b) * c;
}
Binary file added tests/user_type/byte_address_buffer_1.spv
Binary file not shown.
Loading