Skip to content

Commit

Permalink
models.py transfer of grasp weights to place
Browse files Browse the repository at this point in the history
  • Loading branch information
ahundt committed Sep 6, 2019
1 parent 04cc3bd commit 9ffcfda
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ def transfer_grasp_to_place(self):
if self.network == 'densenet' or efficientnet_pytorch is None:
# placenet tests block stacking
if self.place:
self.place_color_trunk.load_state_dict(grasp_color_trunk.state_dict())
self.place_depth_trunk.load_state_dict(grasp_depth_trunk.state_dict())
self.place_color_trunk.load_state_dict(self.grasp_color_trunk.state_dict())
self.place_depth_trunk.load_state_dict(self.grasp_depth_trunk.state_dict())
fc_channels = 2048
second_fc_channels = 64
# The push and place efficientnet model is shared, so we don't need to load that.
if place:
self.placenet.load_state_dict(self.graspnet.state_dict())
# The push and place efficientnet model is shared, so we don't need to transfer that.
if self.place:
# we rename the dictionary names of the grasp weights to place, then load them into the placenet
self.placenet.load_state_dict(dict(map(lambda t: (t[0].replace('grasp', 'place'), t[1]), self.graspnet.state_dict().items())))

0 comments on commit 9ffcfda

Please sign in to comment.