Skip to content

Commit

Permalink
Add common augmentations. (#180)
Browse files Browse the repository at this point in the history
* Restructure augmentation code.

* Add python bindings.

* Clean up augmentations.

* Remove comment.

* Remove dead code.

* Update var name.

* Fix clippy.

* Update comments.

* Remove print statement.

* Fix duplicate transform.

* Improve augmentations.

* Fix docstrings.
  • Loading branch information
benjaminrwilson authored May 7, 2023
1 parent 0e43d0e commit 2e75c52
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 42 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pyo3 = { version = "0.18.3", features = ["extension-module"] }
pyo3-polars = { git = "https://github.com/benjaminrwilson/pyo3-polars", rev = "993d22a6bf54ceb8f93c5ed9082621330b186a52", features = [
"serde",
] }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
serde = "1.0.160"
strum = "0.24.1"
Expand Down
121 changes: 108 additions & 13 deletions rust/src/geometry/augmentations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,117 @@
//!
//! Geometric augmentations.
use std::f32::consts::PI;
use itertools::Itertools;
use ndarray::{concatenate, Axis};
use polars::{
lazy::dsl::{col, GetOutput},
prelude::{DataFrame, DataType, IntoLazy},
series::Series,
};
use rand_distr::{Bernoulli, Distribution};

use ndarray::{Array, ArrayView, Ix2};
use crate::share::{data_frame_to_ndarray_f32, ndarray_to_expr_vec};

use crate::geometry::so3::{quat_to_yaw, yaw_to_quat};
use super::so3::{
reflect_orientation_x, reflect_orientation_y, reflect_translation_x, reflect_translation_y,
};

/// Reflect pose across the x-axis.
pub fn reflect_pose_x(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = -yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
/// Sample a scene reflection.
/// This reflects both a point cloud and cuboids across the x-axis.
pub fn sample_scene_reflection_x(
lidar: DataFrame,
cuboids: DataFrame,
p: f64,
) -> (DataFrame, DataFrame) {
let distribution = Bernoulli::new(p).unwrap();
let is_augmented = distribution.sample(&mut rand::thread_rng());
if is_augmented {
let augmented_lidar = lidar
.lazy()
.with_column(col("y").map(
move |x| {
Ok(Some(
x.f32()
.unwrap()
.into_no_null_iter()
.map(|y| -y)
.collect::<Series>(),
))
},
GetOutput::from_type(DataType::Float32),
))
.collect()
.unwrap();

let translation_column_names = vec!["tx_m", "ty_m", "tz_m"];
let txyz_m = data_frame_to_ndarray_f32(cuboids.clone(), translation_column_names.clone());
let augmentation_translation = reflect_translation_x(&txyz_m.view());

let orientation_column_names = vec!["qw", "qx", "qy", "qz"];
let quat_wxyz =
data_frame_to_ndarray_f32(cuboids.clone(), orientation_column_names.clone());
let augmented_orientation = reflect_orientation_x(&quat_wxyz.view());
let augmented_poses =
concatenate![Axis(1), augmentation_translation, augmented_orientation];

let column_names = translation_column_names
.into_iter()
.chain(orientation_column_names)
.collect_vec();
let series_vec = ndarray_to_expr_vec(augmented_poses, column_names);
let augmented_cuboids = cuboids.lazy().with_columns(series_vec).collect().unwrap();
(augmented_lidar, augmented_cuboids)
} else {
(lidar, cuboids)
}
}

/// Reflect pose across the y-axis.
pub fn reflect_pose_y(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = PI - yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
/// Sample a scene reflection.
/// This reflects both a point cloud and cuboids across the y-axis.
pub fn sample_scene_reflection_y(
lidar: DataFrame,
cuboids: DataFrame,
p: f64,
) -> (DataFrame, DataFrame) {
let distribution: Bernoulli = Bernoulli::new(p).unwrap();
let is_augmented = distribution.sample(&mut rand::thread_rng());
if is_augmented {
let augmented_lidar = lidar
.lazy()
.with_column(col("x").map(
move |x| {
Ok(Some(
x.f32()
.unwrap()
.into_no_null_iter()
.map(|x| -x)
.collect::<Series>(),
))
},
GetOutput::from_type(DataType::Float32),
))
.collect()
.unwrap();

let translation_column_names = vec!["tx_m", "ty_m", "tz_m"];
let txyz_m = data_frame_to_ndarray_f32(cuboids.clone(), translation_column_names.clone());
let augmentation_translation = reflect_translation_y(&txyz_m.view());

let orientation_column_names = vec!["qw", "qx", "qy", "qz"];
let quat_wxyz =
data_frame_to_ndarray_f32(cuboids.clone(), orientation_column_names.clone());
let augmented_orientation = reflect_orientation_y(&quat_wxyz.view());
let augmented_poses =
concatenate![Axis(1), augmentation_translation, augmented_orientation];

let column_names = translation_column_names
.into_iter()
.chain(orientation_column_names)
.collect_vec();
let series_vec = ndarray_to_expr_vec(augmented_poses, column_names);
let augmented_cuboids = cuboids.lazy().with_columns(series_vec).collect().unwrap();
(augmented_lidar, augmented_cuboids)
} else {
(lidar, cuboids)
}
}
3 changes: 2 additions & 1 deletion rust/src/geometry/se3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//!
//! Special Euclidean Group 3.
use ndarray::{s, Array1, Array2, ArrayView2};
use ndarray::ArrayView2;
use ndarray::{s, Array1, Array2};

/// Special Euclidean Group 3 (SE(3)).
/// Rigid transformation parameterized by a rotation and translation in $R^3$.
Expand Down
38 changes: 37 additions & 1 deletion rust/src/geometry/so3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
//!
//! Special Orthogonal Group 3 (SO(3)).
use ndarray::{par_azip, Array, Array2, ArrayView, Ix1, Ix2};
use std::f32::consts::PI;

use ndarray::{par_azip, s, Array, Array2, ArrayView, Ix1, Ix2};

/// Convert a quaternion in scalar-first format to a 3x3 rotation matrix.
pub fn quat_to_mat3(quat_wxyz: &ArrayView<f32, Ix1>) -> Array<f32, Ix2> {
Expand Down Expand Up @@ -71,3 +73,37 @@ pub fn _yaw_to_quat(yaw_rad: f32) -> Array<f32, Ix1> {
let qz = f32::sin(0.5 * yaw_rad);
Array::<f32, Ix1>::from_vec(vec![qw, 0.0, 0.0, qz])
}

/// Reflect orientation across the x-axis.
/// (N,4) `quat_wxyz` orientation of `N` rigid objects.
pub fn reflect_orientation_x(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = -yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
}

/// Reflect orientation across the y-axis.
/// (N,4) `quat_wxyz` orientation of `N` rigid objects.
pub fn reflect_orientation_y(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = PI - yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
}

/// Reflect translation across the x-axis.
pub fn reflect_translation_x(xyz_m: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let mut augmented_xyz_m = xyz_m.to_owned();
augmented_xyz_m
.slice_mut(s![.., 1])
.par_mapv_inplace(|y| -y);
augmented_xyz_m
}

/// Reflect translation across the y-axis.
pub fn reflect_translation_y(xyz_m: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let mut augmented_xyz_m = xyz_m.to_owned();
augmented_xyz_m
.slice_mut(s![.., 0])
.par_mapv_inplace(|x| -x);
augmented_xyz_m
}
23 changes: 0 additions & 23 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub mod share;
pub mod structures;

use data_loader::{DataLoader, Sweep};
use geometry::augmentations::{reflect_pose_x, reflect_pose_y};
use ndarray::{Dim, Ix1, Ix2};
use numpy::PyReadonlyArray;
use numpy::{IntoPyArray, PyArray};
Expand Down Expand Up @@ -85,35 +84,13 @@ fn py_yaw_to_quat<'py>(
yaw_to_quat(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "reflect_pose_x")]
#[allow(clippy::type_complexity)]
fn py_reflect_pose_x<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
reflect_pose_x(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "reflect_pose_y")]
#[allow(clippy::type_complexity)]
fn py_reflect_pose_y<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
reflect_pose_y(&quat_wxyz.as_array().view()).into_pyarray(py)
}

/// A Python module implemented in Rust.
#[pymodule]
fn _r(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<DataLoader>()?;
m.add_class::<Sweep>()?;
m.add_function(wrap_pyfunction!(py_quat_to_mat3, m)?)?;
m.add_function(wrap_pyfunction!(py_quat_to_yaw, m)?)?;
m.add_function(wrap_pyfunction!(py_reflect_pose_x, m)?)?;
m.add_function(wrap_pyfunction!(py_reflect_pose_y, m)?)?;
m.add_function(wrap_pyfunction!(py_voxelize, m)?)?;
m.add_function(wrap_pyfunction!(py_yaw_to_quat, m)?)?;
Ok(())
Expand Down
26 changes: 22 additions & 4 deletions rust/src/share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
//! Conversion methods between different libraries.
use ndarray::{Array, Ix2};
use polars::{prelude::NamedFrom, series::Series};
use polars::{
lazy::dsl::{cols, lit, Expr},
prelude::{DataFrame, Float32Type, IntoLazy, NamedFrom},
series::Series,
};

/// Convert the columns of an `ndarray::Array` into a vector of `polars::series::Series`.
pub fn ndarray_to_series_vec(arr: Array<f32, Ix2>, column_names: Vec<&str>) -> Vec<Series> {
/// Convert the columns of an `ndarray::Array` into a vector of `polars` expressions.
pub fn ndarray_to_expr_vec(arr: Array<f32, Ix2>, column_names: Vec<&str>) -> Vec<Expr> {
let num_dims = arr.shape()[1];
if num_dims != column_names.len() {
panic!("Number of array columns and column names must match.");
Expand All @@ -18,7 +22,21 @@ pub fn ndarray_to_series_vec(arr: Array<f32, Ix2>, column_names: Vec<&str>) -> V
column_name,
column.as_standard_layout().to_owned().into_raw_vec(),
);
series_vec.push(series);
series_vec.push(lit(series));
}
series_vec
}

/// Convert a data frame to an `ndarray::Array::<f32, Ix2>`.
pub fn data_frame_to_ndarray_f32(
data_frame: DataFrame,
column_names: Vec<&str>,
) -> Array<f32, Ix2> {
data_frame
.lazy()
.select(&[cols(column_names)])
.collect()
.unwrap()
.to_ndarray::<Float32Type>()
.unwrap()
}

0 comments on commit 2e75c52

Please sign in to comment.