From 249f3e0ee8b87f94b980363ee5e9e5c164501e29 Mon Sep 17 00:00:00 2001 From: Zhao Liang Date: Fri, 13 Jan 2023 17:04:30 +0800 Subject: [PATCH] [example] Update quaternion arithmetics in fractal_3d_ggui (#7139) This PR updates the code in `fractal_3d_ggui.py`. --- .../examples/ggui_examples/fractal3d_ggui.py | 136 ++++++------------ 1 file changed, 41 insertions(+), 95 deletions(-) diff --git a/python/taichi/examples/ggui_examples/fractal3d_ggui.py b/python/taichi/examples/ggui_examples/fractal3d_ggui.py index b4de653818d1a..04d99bfdff904 100644 --- a/python/taichi/examples/ggui_examples/fractal3d_ggui.py +++ b/python/taichi/examples/ggui_examples/fractal3d_ggui.py @@ -1,62 +1,22 @@ import taichi as ti +import taichi.math as tm arch = ti.vulkan if ti._lib.core.with_vulkan() else ti.cuda ti.init(arch=arch) +vec3 = tm.vec3 +vec4 = tm.vec4 + @ti.func def quat_mul(v1, v2): - return ti.Vector([ - v1.x * v2.x - v1.y * v2.y - v1.z * v2.z - v1.w * v2.w, - v1.x * v2.y + v1.y * v2.x + v1.z * v2.w - v1.w * v2.z, - v1.x * v2.z + v1.z * v2.x + v1.w * v2.y - v1.y * v2.w, - v1.x * v2.w + v1.w * v2.x + v1.y * v2.z - v1.z * v2.y - ]) + return vec4(v1.x * v2.x - tm.dot(v1.yzw, v2.yzw), + v1.x * v2.yzw + v2.x * v1.yzw + tm.cross(v1.yzw, v2.yzw)) @ti.func def quat_conj(q): - return ti.Vector([q[0], -q[1], -q[2], -q[3]]) - - -@ti.func -def dot(x, y): - return x.dot(y) - - -@ti.func -def xy(v): - return ti.Vector([v.x, v.y]) - - -@ti.func -def yx(v): - return ti.Vector([v.y, v.x]) - - -@ti.func -def xz(v): - return ti.Vector([v.x, v.z]) - - -@ti.func -def zx(v): - return ti.Vector([v.z, v.x]) - - -@ti.func -def xw(v): - return ti.Vector([v.x, v.w]) - - -@ti.func -def wx(v): - return ti.Vector([v.w, v.x]) - - -@ti.func -def xyz(v): - return ti.Vector([v.x, v.y, v.z]) + return vec4(q.x, -q.yzw) iters = 10 @@ -67,7 +27,7 @@ def xyz(v): def compute_sdf(z, c): md2 = 1.0 - mz2 = dot(z, z) + mz2 = tm.dot(z, z) for _ in range(iters): md2 *= max_norm * mz2 @@ -82,9 +42,9 @@ def compute_sdf(z, c): @ti.func def compute_normal(z, c): - J0 = ti.Vector([1.0, 0.0, 0.0, 0.0]) - J1 = ti.Vector([0.0, 1.0, 0.0, 0.0]) - J2 = ti.Vector([0.0, 0.0, 1.0, 0.0]) + J0 = vec4(1, 0, 0, 0) + J1 = vec4(0, 1, 0, 0) + J2 = vec4(0, 0, 1, 0) z_curr = z @@ -92,31 +52,18 @@ def compute_normal(z, c): while z_curr.norm() < max_norm and iterations < iters: cz = quat_conj(z_curr) - J0 = ti.Vector([ - dot(J0, cz), - dot(xy(J0), yx(z_curr)), - dot(xz(J0), zx(z_curr)), - dot(xw(J0), wx(z_curr)) - ]) - J1 = ti.Vector([ - dot(J1, cz), - dot(xy(J1), yx(z_curr)), - dot(xz(J1), zx(z_curr)), - dot(xw(J1), wx(z_curr)) - ]) - J2 = ti.Vector([ - dot(J2, cz), - dot(xy(J2), yx(z_curr)), - dot(xz(J2), zx(z_curr)), - dot(xw(J2), wx(z_curr)) - ]) + J0 = vec4(tm.dot(J0, cz), tm.dot(J0.xy, z_curr.yx), + tm.dot(J0.xz, z_curr.zx), tm.dot(J0.xw, z_curr.wx)) + J1 = vec4(tm.dot(J1, cz), tm.dot(J1.xy, z_curr.yx), + tm.dot(J1.xz, z_curr.zx), tm.dot(J1.xw, z_curr.wx)) + J2 = vec4(tm.dot(J2, cz), tm.dot(J2.xy, z_curr.yx), + tm.dot(J2.xz, z_curr.zx), tm.dot(J2.xw, z_curr.wx)) z_curr = quat_mul(z_curr, z_curr) + c iterations += 1 - return ti.Vector([dot(J0, z_curr), - dot(J1, z_curr), - dot(J2, z_curr)]).normalized() + return tm.normalize( + tm.vec3(tm.dot(z_curr, J0), tm.dot(z_curr, J1), tm.dot(z_curr, J2))) image_res = (1280, 720) @@ -130,39 +77,40 @@ def __init__(self): @ti.func def shade(self, pos, surface_color, normal, light_pos): _ = self # make pylint happy - light_color = ti.Vector([1, 1, 1]) + light_color = vec3(1) - light_dir = (light_pos - pos).normalized() - return light_color * surface_color * max(0, dot(light_dir, normal)) + light_dir = tm.normalize(light_pos - pos) + return light_color * surface_color * ti.max(0, tm.dot( + light_dir, normal)) @ti.kernel def march(self, time_arg: float): time = time_arg * 0.15 c = 0.45 * ti.cos( - ti.Vector([0.5, 3.9, 1.4, 1.1]) + time * - ti.Vector([1.2, 1.7, 1.3, 2.5])) - ti.Vector([0.3, 0.0, 0.0, 0.0]) + vec4(0.5, 3.9, 1.4, 1.1) + time * vec4(1.2, 1.7, 1.3, 2.5)) - vec4( + 0.3, 0, 0, 0) r = 1.8 - o3 = ti.Vector([ - r * ti.cos(0.3 + 0.37 * time), 0.3 + - 0.8 * r * ti.cos(1.0 + 0.33 * time), r * ti.cos(2.2 + 0.31 * time) - ]).normalized() * r - ta = ti.Vector([0.0, 0.0, 0.0]) + o3 = tm.normalize( + vec3(r * ti.cos(0.3 + 0.37 * time), + 0.3 + 0.8 * r * ti.cos(1.0 + 0.33 * time), + r * ti.cos(2.2 + 0.31 * time))) * r + ta = vec3(0) cr = 0.1 * ti.cos(0.1 * time) for x, y in self.image: - p = (-ti.Vector([image_res[0], image_res[1]]) + - 2.0 * ti.Vector([x, y])) / (image_res[1] * 0.75) + p = (-tm.vec2(image_res) + 2.0 * tm.vec2(x, y)) / (image_res[1] * + 0.75) - cw = (ta - o3).normalized() - cp = ti.Vector([ti.sin(cr), ti.cos(cr), 0.0]) - cu = cw.cross(cp).normalized() - cv = cu.cross(cw).normalized() + cw = tm.normalize(ta - o3) + cp = vec3(ti.sin(cr), ti.cos(cr), 0) + cu = tm.normalize(cw.cross(cp)) + cv = tm.normalize(cu.cross(cw)) - d3 = (p.x * cu + p.y * cv + 2.0 * cw).normalized() + d3 = tm.normalize(p.x * cu + p.y * cv + 2.0 * cw) - o = ti.Vector([o3.x, o3.y, o3.z, 0.0]) - d = ti.Vector([d3.x, d3.y, d3.z, 0.0]) + o = vec4(o3, 0) + d = vec4(d3, 0) max_t = 10 @@ -174,8 +122,8 @@ def march(self, time_arg: float): break if t < max_t: normal = compute_normal(o + t * d, c) - color = abs(xyz(o + t * d)) / 1.3 - pos = xyz(o + t * d) + color = abs((o + t * d).xyz) / 1.3 + pos = (o + t * d).xyz self.image[x, y] = self.shade(pos, color, normal, o3) else: self.image[x, y] = (0, 0, 0) @@ -195,9 +143,7 @@ def main(): while window.running: frame_id += 1 - canvas.set_image(julia.get_image(frame_id / 60)) - window.show()