Skip to content

Commit

Permalink
Bump burn to 0.12
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu committed Feb 1, 2024
1 parent 0f2de6e commit 2dfe08c
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ members = [
]

[workspace.dependencies]
serde = "1.0.192"
serde = "1"
2 changes: 1 addition & 1 deletion burn-rl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ publish = false

[dependencies]
rand = "0.8.5"
burn = { version = "0.11.1", features = ["ndarray", "autodiff"] }
burn = { version = "0.12.0", features = ["ndarray", "autodiff"] }
gym-rs = { git = "https://github.com/MathisWellmann/gym-rs.git" }
ringbuffer = "0.15.0"
serde = { workspace = true }
6 changes: 4 additions & 2 deletions burn-rl/src/agent/ppo/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl<E: Environment, B: AutodiffBackend, M: PPOModel<B> + AutodiffModule<B>> PPO
.map(|x| *x as i32)
.collect::<Vec<_>>()
.as_slice(),
&Default::default(),
);

let state_batch =
Expand Down Expand Up @@ -199,7 +200,8 @@ pub(crate) fn get_gae<B: Backend>(
}

Some(GAEOutput::new(
Tensor::from_floats(returns.as_slice()).reshape([returns.len(), 1]),
Tensor::from_floats(advantages.as_slice()).reshape([advantages.len(), 1]),
Tensor::from_floats(returns.as_slice(), &Default::default()).reshape([returns.len(), 1]),
Tensor::from_floats(advantages.as_slice(), &Default::default())
.reshape([advantages.len(), 1]),
))
}
2 changes: 1 addition & 1 deletion burn-rl/src/agent/sac/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct SACTemperature<B: Backend> {
impl<B: Backend> Default for SACTemperature<B> {
fn default() -> Self {
Self {
temperature: Param::from(Tensor::zeros([1, 1])),
temperature: Param::from(Tensor::zeros([1, 1], &Default::default())),
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions burn-rl/src/base/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ mod tests {
type Data = [ElemType; 2];

fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
Tensor::<B, 1>::from_floats(self.data, &Default::default())
}
fn size() -> usize {
1
Expand Down Expand Up @@ -217,7 +217,7 @@ mod tests {
});
assert_eq!(state_batch.shape(), Shape::new([5, 2]));
let state_sample = state_batch
.select(0, Tensor::from_ints([0, 1]))
.select(0, Tensor::from_ints([0, 1], &Default::default()))
.to_data()
.value;
assert_eq!(state_sample[0] * 2.0, state_sample[1]);
Expand All @@ -226,7 +226,7 @@ mod tests {
get_batch(memory.next_states(), &sample_indices, ref_to_state_tensor);
assert_eq!(next_state_batch.shape(), Shape::new([5, 2]));
let next_state_sample = next_state_batch
.select(0, Tensor::from_ints([0, 1]))
.select(0, Tensor::from_ints([0, 1], &Default::default()))
.to_data()
.value;
assert_eq!(next_state_sample[0] * -3.0, next_state_sample[1]);
Expand Down
2 changes: 1 addition & 1 deletion burn-rl/src/environment/cart_pole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl From<CartPoleObservation> for CartPoleState {
impl State for CartPoleState {
type Data = StateData;
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
Tensor::<B, 1>::from_floats(self.data, &Default::default())
}

fn size() -> usize {
Expand Down
2 changes: 1 addition & 1 deletion burn-rl/src/environment/mountain_car.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl From<MountainCarObservation> for MountainCarState {
impl State for MountainCarState {
type Data = StateData;
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
Tensor::<B, 1>::from_floats(self.data, &Default::default())
}

fn size() -> usize {
Expand Down
6 changes: 3 additions & 3 deletions burn-rl/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ pub(crate) fn convert_tenor_to_action<A: Action, B: Backend>(output: Tensor<B, 2
}

pub(crate) fn to_action_tensor<A: Action, B: Backend>(action: A) -> Tensor<B, 1, Int> {
Tensor::<B, 1, Int>::from_ints([action.into() as i32])
Tensor::<B, 1, Int>::from_ints([action.into() as i32], &Default::default())
}

pub(crate) fn ref_to_action_tensor<A: Action, B: Backend>(action: &A) -> Tensor<B, 1, Int> {
to_action_tensor(*action)
}

pub(crate) fn to_reward_tensor<B: Backend>(reward: impl Into<ElemType> + Clone) -> Tensor<B, 1> {
Tensor::from_floats([reward.into()])
Tensor::from_floats([reward.into()], &Default::default())
}

pub(crate) fn ref_to_reward_tensor<B: Backend>(
Expand All @@ -45,7 +45,7 @@ pub(crate) fn ref_to_reward_tensor<B: Backend>(
to_reward_tensor(reward.clone())
}
pub(crate) fn to_not_done_tensor<B: Backend>(done: bool) -> Tensor<B, 1> {
Tensor::from_floats([if done { 0.0 } else { 1.0 }])
Tensor::from_floats([if done { 0.0 } else { 1.0 }], &Default::default())
}

pub(crate) fn ref_to_not_done_tensor<B: Backend>(done: &bool) -> Tensor<B, 1> {
Expand Down
4 changes: 2 additions & 2 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ edition = "2021"
publish = false

[dependencies]
burn = { version = "0.11.1", features = ["ndarray", "autodiff"] }
burn-autodiff = "0.11.1"
burn = { version = "0.12.0", features = ["ndarray", "autodiff"] }
burn-autodiff = "0.12.0"
serde = { workspace = true }

burn-rl = { path = "../burn-rl" }
Expand Down
6 changes: 3 additions & 3 deletions examples/src/dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ impl<B: Backend> Net<B> {
#[allow(unused)]
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(),
linear_1: LinearConfig::new(dense_size, dense_size).init(),
linear_2: LinearConfig::new(dense_size, output_size).init(),
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions examples/src/ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ impl<B: Backend> Net<B> {
Self {
linear: LinearConfig::new(input_size, dense_size)
.with_initializer(initializer.clone())
.init(),
.init(&Default::default()),
linear_actor: LinearConfig::new(dense_size, output_size)
.with_initializer(initializer.clone())
.init(),
.init(&Default::default()),
linear_critic: LinearConfig::new(dense_size, 1)
.with_initializer(initializer)
.init(),
.init(&Default::default()),
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions examples/src/sac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ pub struct Actor<B: Backend> {
impl<B: Backend> Actor<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(),
linear_1: LinearConfig::new(dense_size, dense_size).init(),
linear_2: LinearConfig::new(dense_size, output_size).init(),
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
Expand Down Expand Up @@ -50,9 +50,9 @@ pub struct Critic<B: Backend> {
impl<B: Backend> Critic<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(),
linear_1: LinearConfig::new(dense_size, dense_size).init(),
linear_2: LinearConfig::new(dense_size, output_size).init(),
linear_0: LinearConfig::new(input_size, dense_size).init(&Default::default()),
linear_1: LinearConfig::new(dense_size, dense_size).init(&Default::default()),
linear_2: LinearConfig::new(dense_size, output_size).init(&Default::default()),
}
}
}
Expand Down

0 comments on commit 2dfe08c

Please sign in to comment.