Skip to content

Commit

Permalink
Unset requires_grad for state variables, fix render_usd example
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-heiden committed Oct 22, 2021
1 parent af89116 commit 01123d3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions dflex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,23 +348,23 @@ def state(self) -> State:
# derived state (output only)

if (self.particle_count):
s.particle_f = torch.empty_like(self.particle_qd, requires_grad=True)
s.particle_f = torch.empty_like(self.particle_qd, requires_grad=False)


if (self.link_count):

# joints
s.joint_qdd = torch.zeros_like(self.joint_qd, requires_grad=True)
s.joint_tau = torch.zeros_like(self.joint_qd, requires_grad=True)
s.joint_S_s = torch.empty((self.joint_dof_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.joint_qdd = torch.zeros_like(self.joint_qd, requires_grad=False)
s.joint_tau = torch.zeros_like(self.joint_qd, requires_grad=False)
s.joint_S_s = torch.empty((self.joint_dof_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)

# derived rigid body data (maximal coordinates)
s.body_X_sc = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_X_sm = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_I_s = torch.empty((self.link_count, 6, 6), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_v_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_a_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_f_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=True)
s.body_X_sc = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=False)
s.body_X_sm = torch.empty((self.link_count, 7), dtype=torch.float32, device=self.adapter, requires_grad=False)
s.body_I_s = torch.empty((self.link_count, 6, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)
s.body_v_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)
s.body_a_s = torch.empty((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)
s.body_f_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)
#s.body_ft_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)
#s.body_f_ext_s = torch.zeros((self.link_count, 6), dtype=torch.float32, device=self.adapter, requires_grad=False)

Expand Down
2 changes: 1 addition & 1 deletion examples/render_usd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# fmt: on

settings = load_settings("examples/config/ansys_cylinder_jello.json")
settings.sim_duration = 0.2
settings.sim_duration = 2.2
settings.sim_substeps = 100
settings.sim_dt = 5e-5
settings.initial_y = 0.08
Expand Down

0 comments on commit 01123d3

Please sign in to comment.