Skip to content

Commit

Permalink
Add ByteAddressBuffer support
Browse files Browse the repository at this point in the history
  • Loading branch information
spencer-lunarg committed Jan 12, 2024
1 parent 42878ed commit 30efe84
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 1 deletion.
12 changes: 12 additions & 0 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,18 @@ void SpvReflectToYaml::WriteDescriptorBinding(std::ostream& os, const SpvReflect
assert(itor != descriptor_binding_to_index_.end());
os << t1 << "uav_counter_binding: *db" << itor->second << " # " << SafeString(db.uav_counter_binding->name) << std::endl;
}

if (db.byte_address_buffer_offset_count > 0) {
os << t1 << "ByteAddressBuffer offsets: [";
for (uint32_t i = 0; i < db.byte_address_buffer_offset_count; i++) {
os << db.byte_address_buffer_offsets[i];
if (i < (db.byte_address_buffer_offset_count - 1)) {
os << ", ";
}
}
os << "]\n";
}

if (verbosity_ >= 1) {
// SpvReflectTypeDescription* type_description;
if (db.type_description == nullptr) {
Expand Down
122 changes: 121 additions & 1 deletion spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ typedef struct SpvReflectPrvString {
// OpAtomicIAdd -> OpAccessChain -> OpVariable
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
typedef struct SpvReflectPrvAccessedVariable {
SpvReflectPrvNode* p_node;
uint32_t result_id;
uint32_t variable_ptr;
} SpvReflectPrvAccessedVariable;
Expand Down Expand Up @@ -981,6 +982,15 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) {
case SpvOpFunctionParameter: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
case SpvOpBitcast:
case SpvOpShiftRightLogical:
case SpvOpIAdd:
case SpvOpISub:
case SpvOpIMul:
case SpvOpUDiv:
case SpvOpSDiv: {
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
} break;
}

if (p_node->is_type) {
Expand Down Expand Up @@ -1152,6 +1162,7 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
const uint32_t ptr_index = p_node->word_offset + 3;
SpvReflectPrvAccessedVariable* access_ptr = &p_func->accessed_variables[p_func->accessed_variable_count];

access_ptr->p_node = p_node;
// Need to track Result ID as not sure there has been any memory access through here yet
CHECKED_READU32(p_parser, result_index, access_ptr->result_id);
CHECKED_READU32(p_parser, ptr_index, access_ptr->variable_ptr);
Expand All @@ -1160,11 +1171,12 @@ static SpvReflectResult ParseFunction(SpvReflectPrvParser* p_parser, SpvReflectP
case SpvOpStore: {
const uint32_t result_index = p_node->word_offset + 2;
CHECKED_READU32(p_parser, result_index, p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
p_func->accessed_variables[p_func->accessed_variable_count].p_node = p_node;
(++p_func->accessed_variable_count);
} break;
case SpvOpCopyMemory:
case SpvOpCopyMemorySized: {
// There is no result_id is being zero is same as being invalid
// There is no result_id or node, being zero is same as being invalid
CHECKED_READU32(p_parser, p_node->word_offset + 1,
p_func->accessed_variables[p_func->accessed_variable_count].variable_ptr);
(++p_func->accessed_variable_count);
Expand Down Expand Up @@ -3221,6 +3233,106 @@ static SpvReflectResult TraverseCallGraph(SpvReflectPrvParser* p_parser, SpvRefl
return SPV_REFLECT_RESULT_SUCCESS;
}

static uint32_t GetUint32Constant(SpvReflectPrvParser* p_parser, uint32_t id) {
uint32_t result = (uint32_t)INVALID_VALUE;
SpvReflectPrvNode* p_node = FindNode(p_parser, id);
if (p_node && p_node->op == SpvOpConstant) {
UNCHECKED_READU32(p_parser, p_node->word_offset + 3, result);
}
return result;
}

static SpvReflectResult ParseByteAddressBuffer(SpvReflectPrvParser* p_parser, SpvReflectPrvNode* p_node,
SpvReflectDescriptorBinding* p_binding) {
const SpvReflectResult not_found = SPV_REFLECT_RESULT_SUCCESS;
if (IsNull(p_node) || p_node->op != SpvOpAccessChain || p_node->word_count != 6) {
return not_found;
} else if (p_binding->user_type != SPV_REFLECT_USER_TYPE_BYTE_ADDRESS_BUFFER &&
p_binding->user_type != SPV_REFLECT_USER_TYPE_RW_BYTE_ADDRESS_BUFFER) {
return not_found;
}

uint32_t base_id = 0;
// expect first index of 2D access is zero
UNCHECKED_READU32(p_parser, p_node->word_offset + 4, base_id);
if (GetUint32Constant(p_parser, base_id) != 0) {
return not_found;
}
UNCHECKED_READU32(p_parser, p_node->word_offset + 5, base_id);
SpvReflectPrvNode* p_next_node = FindNode(p_parser, base_id);
if (IsNull(p_next_node)) {
return not_found;
}

// there is usually 2 (sometimes 3) instrucitons that make up the arithmetic logic to calculate the offset
SpvReflectPrvNode* arithmetic_node_stack[5];
uint32_t arithmetic_count = 0;

while (IsNotNull(p_next_node)) {
if (p_next_node->op == SpvOpLoad || p_next_node->op == SpvOpBitcast) {
break; // arithmetic starts here
}
arithmetic_node_stack[arithmetic_count++] = p_next_node;
if (arithmetic_count > 5) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

UNCHECKED_READU32(p_parser, p_next_node->word_offset + 3, base_id);
p_next_node = FindNode(p_parser, base_id);
}

uint32_t offset = 0; // starting offset
const uint32_t count = arithmetic_count;
for (uint32_t i = 0; i < count; i++) {
p_next_node = arithmetic_node_stack[--arithmetic_count];
// All arithmetic ops takes 2 operands, assumption is the 2nd operand has the constant
UNCHECKED_READU32(p_parser, p_next_node->word_offset + 4, base_id);
uint32_t value = GetUint32Constant(p_parser, base_id);
if (value == INVALID_VALUE) {
return not_found;
}

switch (p_next_node->op) {
case SpvOpShiftRightLogical:
offset >>= value;
break;
case SpvOpIAdd:
offset += value;
break;
case SpvOpISub:
offset -= value;
break;
case SpvOpIMul:
offset *= value;
break;
case SpvOpUDiv:
offset /= value;
break;
case SpvOpSDiv:
// OpConstant might be signed, but value should never be negative
assert((int32_t)value > 0);
offset /= value;
break;
default:
return not_found;
}
}

// It is costly to get the size before, so trade off a larger memory footprint (only if SPIR-V has ByteAddressBuffer user type)
if (IsNull(p_binding->byte_address_buffer_offsets)) {
p_binding->byte_address_buffer_offsets = (uint32_t*)calloc(SPV_REFLECT_MAX_BYTE_ADDRESS_BUFFER, sizeof(uint32_t));
if (IsNull(p_binding->byte_address_buffer_offsets)) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}
} else if (p_binding->byte_address_buffer_offset_count >= SPV_REFLECT_MAX_BYTE_ADDRESS_BUFFER) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

p_binding->byte_address_buffer_offsets[p_binding->byte_address_buffer_offset_count] = offset;
p_binding->byte_address_buffer_offset_count++;
return SPV_REFLECT_RESULT_SUCCESS;
}

static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module,
SpvReflectEntryPoint* p_entry, size_t uniform_count, uint32_t* uniforms,
size_t push_constant_count, uint32_t* push_constants) {
Expand Down Expand Up @@ -3309,6 +3421,11 @@ static SpvReflectResult ParseStaticallyUsedResources(SpvReflectPrvParser* p_pars
for (uint32_t j = 0; j < used_acessed_count; j++) {
if (p_used_accesses[j].variable_ptr == p_binding->spirv_id) {
p_binding->accessed = 1;

SpvReflectResult result2 = ParseByteAddressBuffer(p_parser, p_used_accesses[j].p_node, p_binding);
if (result2 != SPV_REFLECT_RESULT_SUCCESS) {
return result2;
}
}
}
}
Expand Down Expand Up @@ -4112,6 +4229,9 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) {
// Descriptor binding blocks
for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i];
if (IsNotNull(p_descriptor->byte_address_buffer_offsets)) {
SafeFree(p_descriptor->byte_address_buffer_offsets);
}
SafeFreeBlockVariables(&p_descriptor->block);
}
SafeFree(p_module->descriptor_bindings);
Expand Down
3 changes: 3 additions & 0 deletions spirv_reflect.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ typedef enum SpvReflectGenerator {
enum {
SPV_REFLECT_MAX_ARRAY_DIMS = 32,
SPV_REFLECT_MAX_DESCRIPTOR_SETS = 64,
SPV_REFLECT_MAX_BYTE_ADDRESS_BUFFER = 64,
};

enum {
Expand Down Expand Up @@ -481,6 +482,8 @@ typedef struct SpvReflectDescriptorBinding {
uint32_t accessed;
uint32_t uav_counter_id;
struct SpvReflectDescriptorBinding* uav_counter_binding;
uint32_t byte_address_buffer_offset_count;
uint32_t* byte_address_buffer_offsets;

SpvReflectTypeDescription* type_description;

Expand Down
1 change: 1 addition & 0 deletions tests/user_type/byte_address_buffer.spv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ all_descriptor_bindings:
accessed: 1
uav_counter_id: 4294967295
uav_counter_binding:
ByteAddressBuffer offsets: [4, 5, 11, 13]
type_description: *td1
word_offset: { binding: 129, set: 125 }
user_type: ByteAddressBuffer
Expand Down
1 change: 1 addition & 0 deletions tests/user_type/rw_byte_address_buffer.spv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ all_descriptor_bindings:
accessed: 1
uav_counter_id: 4294967295
uav_counter_binding:
ByteAddressBuffer offsets: [4, 5, 11, 13]
type_description: *td1
word_offset: { binding: 130, set: 126 }
user_type: RWByteAddressBuffer
Expand Down

0 comments on commit 30efe84

Please sign in to comment.