Skip to content

Commit

Permalink
Add additional transformations in Rust. (#179)
Browse files Browse the repository at this point in the history
* Add additional rigid transformations.

* Add reflection methods.

* Add parallelized processing.

* Reorganize geometric methods.

* Add ndarray to series vec.

* Update augmentations.

* Remove print statement.

* Fix formatting.

* Simplify yaw to quat.

* Add python bindings.
  • Loading branch information
benjaminrwilson authored May 5, 2023
1 parent f31b13a commit 0e43d0e
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions rust/src/bin/build_accumulated_sweeps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ const MEMORY_MAPPED: bool = false;

static DST_DATASET_NAME: Lazy<String> =
Lazy::new(|| format!("{DATASET_NAME}_{NUM_ACCUMULATED_SWEEPS}_sweep"));
static SRC_PREFIX: Lazy<PathBuf> =
Lazy::new(|| ROOT_DIR.join(DATASET_NAME.clone()).join(DATASET_TYPE));
static SRC_PREFIX: Lazy<PathBuf> = Lazy::new(|| ROOT_DIR.join(DATASET_NAME).join(DATASET_TYPE));
static DST_PREFIX: Lazy<PathBuf> =
Lazy::new(|| ROOT_DIR.join(DST_DATASET_NAME.clone()).join(DATASET_TYPE));

Expand Down
23 changes: 23 additions & 0 deletions rust/src/geometry/augmentations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//! # augmentations
//!
//! Geometric augmentations.
use std::f32::consts::PI;

use ndarray::{Array, ArrayView, Ix2};

use crate::geometry::so3::{quat_to_yaw, yaw_to_quat};

/// Reflect pose across the x-axis.
pub fn reflect_pose_x(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = -yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
}

/// Reflect pose across the y-axis.
pub fn reflect_pose_y(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let yaw_rad = quat_to_yaw(quat_wxyz);
let reflected_yaw_rad = PI - yaw_rad;
yaw_to_quat(&reflected_yaw_rad.view())
}
5 changes: 4 additions & 1 deletion rust/src/geometry/camera/pinhole_camera.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use polars::{
prelude::{DataFrame, IntoLazy},
};

use crate::{geometry::utils::cart_to_hom, io::read_feather_eager, se3::SE3, so3::quat_to_mat3};
use crate::{
geometry::se3::SE3, geometry::so3::quat_to_mat3, geometry::utils::cart_to_hom,
io::read_feather_eager,
};

/// Pinhole camera intrinsics.
#[derive(Clone, Debug)]
Expand Down
7 changes: 7 additions & 0 deletions rust/src/geometry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
//!
//! Geometric operations for data processing.
/// Geometric augmentations.
pub mod augmentations;
/// Camera models.
pub mod camera;
/// Special Euclidean Group 3.
pub mod se3;
/// Special Orthogonal Group 3.
pub mod so3;
/// Geometric utility functions.
pub mod utils;
File renamed without changes.
73 changes: 73 additions & 0 deletions rust/src/geometry/so3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//! # SO(3)
//!
//! Special Orthogonal Group 3 (SO(3)).
use ndarray::{par_azip, Array, Array2, ArrayView, Ix1, Ix2};

/// Convert a quaternion in scalar-first format to a 3x3 rotation matrix.
pub fn quat_to_mat3(quat_wxyz: &ArrayView<f32, Ix1>) -> Array<f32, Ix2> {
let w = quat_wxyz[0];
let x = quat_wxyz[1];
let y = quat_wxyz[2];
let z = quat_wxyz[3];

let e_00 = 1. - 2. * y.powi(2) - 2. * z.powi(2);
let e_01: f32 = 2. * x * y - 2. * z * w;
let e_02: f32 = 2. * x * z + 2. * y * w;

let e_10 = 2. * x * y + 2. * z * w;
let e_11 = 1. - 2. * x.powi(2) - 2. * z.powi(2);
let e_12 = 2. * y * z - 2. * x * w;

let e_20 = 2. * x * z - 2. * y * w;
let e_21 = 2. * y * z + 2. * x * w;
let e_22 = 1. - 2. * x.powi(2) - 2. * y.powi(2);

// Safety: We will always have nine elements.
unsafe {
Array2::from_shape_vec_unchecked(
[3, 3],
vec![e_00, e_01, e_02, e_10, e_11, e_12, e_20, e_21, e_22],
)
}
}

/// Convert a scalar-first quaternion to yaw.
/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis.
/// Parallelized for batch processing.
pub fn quat_to_yaw(quat_wxyz: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let num_quats = quat_wxyz.shape()[0];
let mut yaws_rad = Array::<f32, Ix2>::zeros((num_quats, 1));
par_azip!((mut y in yaws_rad.outer_iter_mut(), q in quat_wxyz.outer_iter()) {
y[0] = _quat_to_yaw(&q);
});
yaws_rad
}

/// Convert a scalar-first quaternion to yaw.
/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis.
pub fn _quat_to_yaw(quat_wxyz: &ArrayView<f32, Ix1>) -> f32 {
let (qw, qx, qy, qz) = (quat_wxyz[0], quat_wxyz[1], quat_wxyz[2], quat_wxyz[3]);
let siny_cosp = 2. * (qw * qz + qx * qy);
let cosy_cosp = 1. - 2. * (qy * qy + qz * qz);
siny_cosp.atan2(cosy_cosp)
}

/// Convert a scalar-first quaternion to yaw.
/// In the Argoverse 2 coordinate system, this is counter-clockwise rotation about the +z axis.
/// Parallelized for batch processing.
pub fn yaw_to_quat(yaw_rad: &ArrayView<f32, Ix2>) -> Array<f32, Ix2> {
let num_yaws = yaw_rad.shape()[0];
let mut quat_wxyz = Array::<f32, Ix2>::zeros((num_yaws, 4));
par_azip!((mut q in quat_wxyz.outer_iter_mut(), y in yaw_rad.outer_iter()) {
q.assign(&_yaw_to_quat(y[0]));
});
quat_wxyz
}

/// Convert rotation about the z-axis to a scalar-first quaternion.
pub fn _yaw_to_quat(yaw_rad: f32) -> Array<f32, Ix1> {
let qw = f32::cos(0.5 * yaw_rad);
let qz = f32::sin(0.5 * yaw_rad);
Array::<f32, Ix1>::from_vec(vec![qw, 0.0, 0.0, qz])
}
4 changes: 2 additions & 2 deletions rust/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ use std::fs::File;
use std::path::PathBuf;

use crate::constants::POSE_COLUMNS;
use crate::se3::SE3;
use crate::geometry::se3::SE3;
use image::io::Reader as ImageReader;

use crate::so3::quat_to_mat3;
use crate::geometry::so3::quat_to_mat3;

/// Read a feather file and load into a `polars` dataframe.
pub fn read_feather_eager(path: &PathBuf, memory_mapped: bool) -> DataFrame {
Expand Down
50 changes: 47 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ pub mod geometry;
pub mod io;
pub mod ops;
pub mod path;
pub mod se3;
pub mod so3;
pub mod share;
pub mod structures;

use data_loader::{DataLoader, Sweep};
use geometry::augmentations::{reflect_pose_x, reflect_pose_y};
use ndarray::{Dim, Ix1, Ix2};
use numpy::PyReadonlyArray;
use numpy::{IntoPyArray, PyArray};
use pyo3::prelude::*;

use geometry::so3::{quat_to_mat3, quat_to_yaw, yaw_to_quat};
use numpy::PyReadonlyArray2;
use so3::quat_to_mat3;

use crate::ops::voxelize;

Expand Down Expand Up @@ -65,12 +65,56 @@ fn py_quat_to_mat3<'py>(
quat_to_mat3(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "quat_to_yaw")]
#[allow(clippy::type_complexity)]
fn py_quat_to_yaw<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
quat_to_yaw(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "yaw_to_quat")]
#[allow(clippy::type_complexity)]
fn py_yaw_to_quat<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
yaw_to_quat(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "reflect_pose_x")]
#[allow(clippy::type_complexity)]
fn py_reflect_pose_x<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
reflect_pose_x(&quat_wxyz.as_array().view()).into_pyarray(py)
}

#[pyfunction]
#[pyo3(name = "reflect_pose_y")]
#[allow(clippy::type_complexity)]
fn py_reflect_pose_y<'py>(
py: Python<'py>,
quat_wxyz: PyReadonlyArray<f32, Ix2>,
) -> &'py PyArray<f32, Ix2> {
reflect_pose_y(&quat_wxyz.as_array().view()).into_pyarray(py)
}

/// A Python module implemented in Rust.
#[pymodule]
fn _r(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<DataLoader>()?;
m.add_class::<Sweep>()?;
m.add_function(wrap_pyfunction!(py_quat_to_mat3, m)?)?;
m.add_function(wrap_pyfunction!(py_quat_to_yaw, m)?)?;
m.add_function(wrap_pyfunction!(py_reflect_pose_x, m)?)?;
m.add_function(wrap_pyfunction!(py_reflect_pose_y, m)?)?;
m.add_function(wrap_pyfunction!(py_voxelize, m)?)?;
m.add_function(wrap_pyfunction!(py_yaw_to_quat, m)?)?;
Ok(())
}
24 changes: 24 additions & 0 deletions rust/src/share.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//! # share
//!
//! Conversion methods between different libraries.
use ndarray::{Array, Ix2};
use polars::{prelude::NamedFrom, series::Series};

/// Convert the columns of an `ndarray::Array` into a vector of `polars::series::Series`.
pub fn ndarray_to_series_vec(arr: Array<f32, Ix2>, column_names: Vec<&str>) -> Vec<Series> {
let num_dims = arr.shape()[1];
if num_dims != column_names.len() {
panic!("Number of array columns and column names must match.");
}

let mut series_vec = vec![];
for (column, column_name) in arr.columns().into_iter().zip(column_names) {
let series = Series::new(
column_name,
column.as_standard_layout().to_owned().into_raw_vec(),
);
series_vec.push(series);
}
series_vec
}
33 changes: 0 additions & 33 deletions rust/src/so3.rs

This file was deleted.

0 comments on commit 0e43d0e

Please sign in to comment.