From 9f2eebce94dbfcff14a645d542cf3de7cd2d6cc8 Mon Sep 17 00:00:00 2001 From: Yun-Jhong Wu Date: Fri, 24 Nov 2023 10:59:09 -0600 Subject: [PATCH] Add PPO inference (#30) --- burn-rl/src/agent/ppo/agent.rs | 3 +-- burn-rl/src/agent/ppo/model.rs | 4 +++- examples/src/ppo.rs | 7 ++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/burn-rl/src/agent/ppo/agent.rs b/burn-rl/src/agent/ppo/agent.rs index bd982e1..b498a2e 100644 --- a/burn-rl/src/agent/ppo/agent.rs +++ b/burn-rl/src/agent/ppo/agent.rs @@ -25,8 +25,7 @@ impl> Agent for PPO { self.model .as_ref() .unwrap() - .forward(to_state_tensor(*state).unsqueeze()) - .policies, + .inference(to_state_tensor(*state).unsqueeze()), ) } } diff --git a/burn-rl/src/agent/ppo/model.rs b/burn-rl/src/agent/ppo/model.rs index 57cf0c2..00c3f36 100644 --- a/burn-rl/src/agent/ppo/model.rs +++ b/burn-rl/src/agent/ppo/model.rs @@ -13,4 +13,6 @@ impl PPOOutput { } } -pub trait PPOModel: Model, PPOOutput> {} +pub trait PPOModel: Model, PPOOutput> { + fn inference(&self, input: Tensor) -> Tensor; +} diff --git a/examples/src/ppo.rs b/examples/src/ppo.rs index 773d370..3133a48 100644 --- a/examples/src/ppo.rs +++ b/examples/src/ppo.rs @@ -42,7 +42,12 @@ impl Model, PPOOutput> for Net { } } -impl PPOModel for Net {} +impl PPOModel for Net { + fn inference(&self, input: Tensor) -> Tensor { + let layer_0_output = relu(self.linear.forward(input)); + softmax(self.linear_actor.forward(layer_0_output.clone()), 1) + } +} #[allow(unused)] const MEMORY_SIZE: usize = 512;