Skip to content

Commit

Permalink
Add PPO inference (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Nov 24, 2023
1 parent 0f5a774 commit 9f2eebc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
3 changes: 1 addition & 2 deletions burn-rl/src/agent/ppo/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ impl<E: Environment, B: Backend, M: PPOModel<B>> Agent<E> for PPO<E, B, M> {
self.model
.as_ref()
.unwrap()
.forward(to_state_tensor(*state).unsqueeze())
.policies,
.inference(to_state_tensor(*state).unsqueeze()),
)
}
}
Expand Down
4 changes: 3 additions & 1 deletion burn-rl/src/agent/ppo/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ impl<B: Backend> PPOOutput<B> {
}
}

pub trait PPOModel<B: Backend>: Model<B, Tensor<B, 2>, PPOOutput<B>> {}
pub trait PPOModel<B: Backend>: Model<B, Tensor<B, 2>, PPOOutput<B>> {
fn inference(&self, input: Tensor<B, 2>) -> Tensor<B, 2>;
}
7 changes: 6 additions & 1 deletion examples/src/ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ impl<B: Backend> Model<B, Tensor<B, 2>, PPOOutput<B>> for Net<B> {
}
}

impl<B: Backend> PPOModel<B> for Net<B> {}
impl<B: Backend> PPOModel<B> for Net<B> {
fn inference(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let layer_0_output = relu(self.linear.forward(input));
softmax(self.linear_actor.forward(layer_0_output.clone()), 1)
}
}

#[allow(unused)]
const MEMORY_SIZE: usize = 512;
Expand Down

0 comments on commit 9f2eebc

Please sign in to comment.