Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix meshlet vertex attribute interpolation #13775

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
unpack_meshlet_vertex,
},
mesh_view_bindings::view,
mesh_functions::mesh_position_local_to_world,
mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT,
mesh_functions::{mesh_position_local_to_world, sign_determinant_model_3x3m},
mesh_types::{Mesh, MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT},
view_transformations::{position_world_to_clip, frag_coord_to_ndc},
}
#import bevy_render::maths::{affine3_to_square, mat2x4_f32_to_mat3x3_unpack}
Expand Down Expand Up @@ -99,6 +99,7 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let cluster_id = packed_ids >> 6u;
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let meshlet = meshlets[meshlet_id];

let triangle_id = extractBits(packed_ids, 0u, 6u);
let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z));
Expand All @@ -108,9 +109,9 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);

let instance_id = meshlet_cluster_instance_ids[cluster_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let world_from_local = affine3_to_square(instance_uniform.world_from_local);
var instance_uniform = meshlet_instance_uniforms[instance_id];

let world_from_local = affine3_to_square(instance_uniform.world_from_local);
let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0));
Expand All @@ -126,43 +127,27 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
);

let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics;
let vertex_normal = mat3x3(vertex_1.normal, vertex_2.normal, vertex_3.normal) * partial_derivatives.barycentrics;
let world_normal = normalize(
mat2x4_f32_to_mat3x3_unpack(
instance_uniform.local_from_world_transpose_a,
instance_uniform.local_from_world_transpose_b,
) * vertex_normal
);
let world_normal = mat3x3(
normal_local_to_world(vertex_1.normal, &instance_uniform),
normal_local_to_world(vertex_2.normal, &instance_uniform),
normal_local_to_world(vertex_3.normal, &instance_uniform),
) * partial_derivatives.barycentrics;
let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics;
let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx;
let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy;
let vertex_tangent = mat3x4(vertex_1.tangent, vertex_2.tangent, vertex_3.tangent) * partial_derivatives.barycentrics;
let world_tangent = vec4(
normalize(
mat3x3(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz
) * vertex_tangent.xyz
),
vertex_tangent.w * (f32(bool(instance_uniform.flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0)
);
let world_tangent = mat3x4(
tangent_local_to_world(vertex_1.tangent, world_from_local, instance_uniform.flags),
tangent_local_to_world(vertex_2.tangent, world_from_local, instance_uniform.flags),
tangent_local_to_world(vertex_3.tangent, world_from_local, instance_uniform.flags),
) * partial_derivatives.barycentrics;

#ifdef PREPASS_FRAGMENT
#ifdef MOTION_VECTOR_PREPASS
let previous_world_from_local = affine3_to_square(instance_uniform.previous_world_from_local);
let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0));
let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0));
let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0));
let previous_clip_position_1 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_1.xyz, 1.0);
let previous_clip_position_2 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_2.xyz, 1.0);
let previous_clip_position_3 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_3.xyz, 1.0);
let previous_partial_derivatives = compute_partial_derivatives(
array(previous_clip_position_1, previous_clip_position_2, previous_clip_position_3),
frag_coord_ndc,
view.viewport.zw,
);
let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * previous_partial_derivatives.barycentrics;
let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * partial_derivatives.barycentrics;
let motion_vector = calculate_motion_vector(world_position, previous_world_position);
#endif
#endif
Expand All @@ -184,4 +169,34 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
#endif
);
}

fn normal_local_to_world(vertex_normal: vec3<f32>, instance_uniform: ptr<function, Mesh>) -> vec3<f32> {
if any(vertex_normal != vec3<f32>(0.0)) {
return normalize(
mat2x4_f32_to_mat3x3_unpack(
(*instance_uniform).local_from_world_transpose_a,
(*instance_uniform).local_from_world_transpose_b,
) * vertex_normal
);
} else {
return vertex_normal;
}
}

fn tangent_local_to_world(vertex_tangent: vec4<f32>, world_from_local: mat4x4<f32>, mesh_flags: u32) -> vec4<f32> {
if any(vertex_tangent != vec4<f32>(0.0)) {
return vec4<f32>(
normalize(
mat3x3<f32>(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz,
) * vertex_tangent.xyz
),
vertex_tangent.w * sign_determinant_model_3x3m(mesh_flags)
);
} else {
return vertex_tangent;
}
}
#endif
8 changes: 4 additions & 4 deletions crates/bevy_pbr/src/render/mesh_functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ fn mesh_normal_local_to_world(vertex_normal: vec3<f32>, instance_index: u32) ->

// Calculates the sign of the determinant of the 3x3 model matrix based on a
// mesh flag
fn sign_determinant_model_3x3m(instance_index: u32) -> f32 {
fn sign_determinant_model_3x3m(mesh_flags: u32) -> f32 {
// bool(u32) is false if 0u else true
// f32(bool) is 1.0 if true else 0.0
// * 2.0 - 1.0 remaps 0.0 or 1.0 to -1.0 or 1.0 respectively
return f32(bool(mesh[instance_index].flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0;
return f32(bool(mesh_flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0;
}

fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: vec4<f32>, instance_index: u32) -> vec4<f32> {
Expand All @@ -76,12 +76,12 @@ fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: ve
mat3x3<f32>(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz
world_from_local[2].xyz,
) * vertex_tangent.xyz
),
// NOTE: Multiplying by the sign of the determinant of the 3x3 model matrix accounts for
// situations such as negative scaling.
vertex_tangent.w * sign_determinant_model_3x3m(instance_index)
vertex_tangent.w * sign_determinant_model_3x3m(mesh[instance_index].flags)
);
} else {
return vertex_tangent;
Expand Down