From 01123d38cb8a40f216a34077bfe35877d0fa7b22 Mon Sep 17 00:00:00 2001 From: Eric Heiden Date: Fri, 22 Oct 2021 14:34:18 -0700 Subject: [PATCH] Unset requires_grad for state variables, fix render_usd example --- dflex/model.py | 20 ++++++++++---------- examples/render_usd.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dflex/model.py b/dflex/model.py index 7bbe874..cb49759 100644 --- a/dflex/model.py +++ b/dflex/model.py @@ -348,23 +348,23 @@ def state(self) -> State: # derived state (output only) if (self.particle_count): - s.particle_f = torch.empty_like(self.particle_qd, requires_grad=True) + s.particle_f = torch.empty_like(self.particle_qd, requires_grad=False) if (self.link_count): # joints - s.joint_qdd = torch.zeros_like(self.joint_qd, requires_grad=True) - s.joint_tau = torch.zeros_like(self.joint_qd, requires_grad=True) - s.joint_S_s = torch.empty((self.joint_dof_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True) + s.joint_qdd = torch.zeros_like(self.joint_qd, requires_grad=False) + s.joint_tau = torch.zeros_like(self.joint_qd, requires_grad=False) + s.joint_S_s = torch.empty((self.joint_dof_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) # derived rigid body data (maximal coordinates) - s.body_X_sc = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=True) - s.body_X_sm = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=True) - s.body_I_s = torch.empty((self.link_count, 6, 6), dtype=torch.float32, device=self.adapter, requires_grad=True) - s.body_v_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True) - s.body_a_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True) - s.body_f_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True) + s.body_X_sc = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=False) + s.body_X_sm = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=False) + s.body_I_s = torch.empty((self.link_count, 6, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) + s.body_v_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) + s.body_a_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) + s.body_f_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) #s.body_ft_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) #s.body_f_ext_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False) diff --git a/examples/render_usd.py b/examples/render_usd.py index 0c67396..0d6227c 100644 --- a/examples/render_usd.py +++ b/examples/render_usd.py @@ -27,7 +27,7 @@ # fmt: on settings = load_settings("examples/config/ansys_cylinder_jello.json") -settings.sim_duration = 0.2 +settings.sim_duration = 2.2 settings.sim_substeps = 100 settings.sim_dt = 5e-5 settings.initial_y = 0.08