Skip to content

Commit

Permalink
custom parser
Browse files Browse the repository at this point in the history
  • Loading branch information
dagou committed Sep 7, 2024
1 parent b9cbf77 commit a2d15f5
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gtdb_tree"
version = "0.1.8"
version = "0.1.9"
edition = "2021"
description = "A library for parsing Newick format files, especially GTDB tree files."
homepage = "https://github.com/eric9n/gtdb_tree"
Expand Down
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Add this crate to your `Cargo.toml`:

```toml
[dependencies]
gtdb_tree = "0.1.0"
gtdb_tree = "0.1.9"
```

## Usage
Expand Down Expand Up @@ -48,3 +48,27 @@ result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);")
print(result)
```

## Advanced Usage
### Custom Node Parser
You can provide a custom parser function to handle special node formats:

```python
import gtdb_tree

def custom_parser(node_str):
# Custom parsing logic
name, length = node_str.split(':')
return name, 100.0, float(length) # name, bootstrap, length

result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);", custom_parser=custom_parser)
print(result)
```

## Working with Node Objects
## Each Node object in the result has the following attributes:

* id: Unique identifier for the node
* name: Name of the node
* bootstrap: Bootstrap value (if available)
* length: Branch length
* parent: ID of the parent node
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ features = ["python"]

[project]
name = "gtdb_tree"
version = "0.1.8"
version = "0.1.9"
description = "A Python package for parsing GTDB trees using Rust"
readme = "README.md"
authors = [{ name = "dagou", email = "eric9n@gmail.com" }]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name="gtdb_tree",
version="0.1.8",
version="0.1.9",
rust_extensions=[RustExtension("gtdb_tree.gtdb_tree", binding=Binding.PyO3)],
packages=["gtdb_tree"],
# rust extensions are not zip safe, just like C-extensions.
Expand Down
2 changes: 2 additions & 0 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ pub enum ParseError {
UnexpectedEndOfInput,
#[allow(dead_code)]
InvalidFormat(String),
PythonError(String),
}

impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParseError::UnexpectedEndOfInput => write!(f, "Unexpected end of input"),
ParseError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
ParseError::PythonError(msg) => write!(f, "Python error: {}", msg),
}
}
}
Expand Down
100 changes: 97 additions & 3 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
use crate::node::Node as RustNode;
use crate::tree;
use crate::node::ParseError;
use crate::tree::{self, NodeParser};
use std::convert::From;
use std::sync::Arc;

// 添加一个从 PyErr 到 ParseError 的转换实现
impl From<PyErr> for ParseError {
fn from(err: PyErr) -> Self {
ParseError::PythonError(err.to_string())
}
}

#[cfg(feature = "python")]
use pyo3::prelude::*;
Expand Down Expand Up @@ -39,10 +49,94 @@ impl Node {
}
}

// #[cfg(feature = "python")]
// #[pyfunction]
// pub fn parse_tree(newick_str: &str) -> PyResult<Vec<Node>> {
// tree::parse_tree(newick_str)
// .map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect())
// .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
// }

#[cfg(feature = "python")]
#[pyfunction]
pub fn parse_tree(newick_str: &str) -> PyResult<Vec<Node>> {
tree::parse_tree(newick_str)
#[pyo3(signature = (newick_str, custom_parser = None))]
#[pyo3(text_signature = "(newick_str, custom_parser=None)")]
/// Parse a Newick format string into a list of Node objects.
///
/// This function takes a Newick format string and optionally a custom parser function,
/// and returns a list of Node objects representing the phylogenetic tree.
///
/// Parameters:
/// -----------
/// newick_str : str
/// The Newick format string representing the phylogenetic tree.
/// custom_parser : callable, optional
/// A custom parsing function for node information. If not provided, the default parser will be used.
/// The custom parser should have the following signature:
///
/// def custom_parser(node_str: str) -> Tuple[str, float, float]:
/// '''
/// Parse a node string and return name, bootstrap, and length.
///
/// Parameters:
/// -----------
/// node_str : str
/// The node string to parse.
///
/// Returns:
/// --------
/// Tuple[str, float, float]
/// A tuple containing (name, bootstrap, length) for the node.
/// '''
/// # Your custom parsing logic here
/// return name, bootstrap, length
///
/// Returns:
/// --------
/// List[Node]
/// A list of Node objects representing the parsed phylogenetic tree.
///
/// Raises:
/// -------
/// ValueError
/// If the Newick string is invalid or parsing fails.
///
/// Example:
/// --------
/// >>> newick_str = "(A:0.1,B:0.2,(C:0.3,D:0.4)70:0.5);"
/// >>> nodes = parse_tree(newick_str)
/// >>>
/// >>> # Using a custom parser
/// >>> def my_parser(node_str):
/// ... parts = node_str.split(':')
/// ... name = parts[0]
/// ... length = float(parts[1]) if len(parts) > 1 else 0.0
/// ... return name, 100.0, length # Always set bootstrap to 100.0
/// >>>
/// >>> nodes_custom = parse_tree(newick_str, custom_parser=my_parser)
pub fn parse_tree(
_py: Python,
newick_str: &str,
custom_parser: Option<PyObject>,
) -> PyResult<Vec<Node>> {
let parser = match custom_parser {
Some(py_func) => {
let py_func = Arc::new(py_func);
NodeParser::Custom(Box::new(
move |node_str: &str| -> Result<(String, f64, f64), ParseError> {
Python::with_gil(|py| {
let result = py_func.call1(py, (node_str,))?;
let (name, bootstrap, length): (String, f64, f64) = result.extract(py)?;
Ok((name, bootstrap, length))
})
.map_err(|e: PyErr| ParseError::PythonError(e.to_string()))
},
))
}
None => NodeParser::Default,
};

tree::parse_tree(newick_str, parser)
.map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
}
87 changes: 66 additions & 21 deletions src/tree.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
use crate::node::{Node, ParseError};
use memchr::memchr2;

// 修改 NodeParser 枚举以使用 trait 对象
pub enum NodeParser {
Default,
Custom(Box<dyn Fn(&str) -> Result<(String, f64, f64), ParseError> + Send>),
}

impl Default for NodeParser {
fn default() -> Self {
NodeParser::Default
}
}

/// Parse the label of a node from a Newick tree string.
///
/// This function takes a byte slice representing a node in a Newick tree string,
/// and returns the name and length of the node as a tuple.
///
/// # Arguments
///
/// * `label` - A string slice representing the node in a Newick tree string.
fn parse_label(label: &str) -> Result<(String, f64), ParseError> {
let label = label.trim_end_matches(";").trim_matches('\'').to_string();

Expand Down Expand Up @@ -31,27 +51,41 @@ fn parse_label(label: &str) -> Result<(String, f64), ParseError> {
///
/// # Arguments
///
/// * `node_bytes` - A byte slice representing the node in a Newick tree string.
/// * `node_str` - A string slice representing the node in a Newick tree string.
///
/// # Returns
///
/// Returns a `Result` containing a tuple of the name and length on success,
/// Returns a `Result` containing a tuple of the name, bootstrap, and length on success,
/// or an `Err(ParseError)` on failure.
///
/// # Example
///
/// ```
/// use gtdb_tree::tree::parse_node;
/// use gtdb_tree::tree::parse_node_default;
///
/// let node_bytes = b"A:0.1";
/// let (name, bootstrap, length) = parse_node(node_bytes).unwrap();
/// let node_str = "A:0.1";
/// let (name, bootstrap, length) = parse_node_default(node_str).unwrap();
/// assert_eq!(name, "A");
/// assert_eq!(bootstrap, 0.0);
/// assert_eq!(length, 0.1);
/// ```
pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> {
let node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence");
// gtdb
pub fn parse_node_default(node_str: &str) -> Result<(String, f64, f64), ParseError> {
// 处理 "AD:0.03347[21.0]" 格式
if let Some((name_length, bootstrap_str)) = node_str.rsplit_once('[') {
if let Some((name, length_str)) = name_length.rsplit_once(':') {
let bootstrap = bootstrap_str
.trim_end_matches(']')
.parse::<f64>()
.map_err(|_| {
ParseError::InvalidFormat(format!("Invalid bootstrap value: {}", bootstrap_str))
})?;
let length = length_str.parse::<f64>().map_err(|_| {
ParseError::InvalidFormat(format!("Invalid length value: {}", length_str))
})?;
return Ok((name.to_string(), bootstrap, length));
}
}

// Check if node_str contains single quotes and ensure they are together
if node_str.matches('\'').count() % 2 != 0 {
return Err(ParseError::InvalidFormat(format!(
Expand Down Expand Up @@ -102,12 +136,13 @@ pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> {
///
/// ```
/// use gtdb_tree::tree::parse_tree;
/// use gtdb_tree::tree::NodeParser;
///
/// let newick_str = "((A:0.1,B:0.2):0.3,C:0.4);";
/// let nodes = parse_tree(newick_str).unwrap();
/// let nodes = parse_tree(newick_str, NodeParser::default()).unwrap();
/// assert_eq!(nodes.len(), 5);
/// ```
pub fn parse_tree(newick_str: &str) -> Result<Vec<Node>, ParseError> {
pub fn parse_tree(newick_str: &str, parser: NodeParser) -> Result<Vec<Node>, ParseError> {
let mut nodes: Vec<Node> = Vec::new();
let mut pos = 0;

Expand All @@ -132,7 +167,16 @@ pub fn parse_tree(newick_str: &str) -> Result<Vec<Node>, ParseError> {
let end_pos = memchr2(b',', b')', &bytes[pos..]).unwrap_or(bytes_len - pos);
let node_end_pos = pos + end_pos;
let node_bytes = &bytes[pos..node_end_pos];
let (name, bootstrap, length) = parse_node(node_bytes)?;

let mut node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence");
if node_end_pos == bytes_len {
node_str = node_str.trim_end_matches(';');
}
let (name, bootstrap, length) = match &parser {
NodeParser::Default => parse_node_default(node_str)?,
NodeParser::Custom(func) => func(node_str)?,
};

let node_id = if &bytes[pos - 1] == &b')' {
stack.pop().unwrap_or(0)
} else {
Expand Down Expand Up @@ -161,8 +205,9 @@ mod tests {
use super::*;

#[test]
fn test_parse_tree() {
fn test_parse_tree() -> Result<(), ParseError> {
let test_cases = vec![
"(A:0.1,B:0.2,(C:0.3,D:0.4)AD:0.03347[21.0]);",
"((A:0.1,B:0.2)'56:F;H;':0.3,C:0.4);",
"(,,(,));", // no nodes are named
"(A,B,(C,D));", // leaf nodes are named
Expand All @@ -175,15 +220,15 @@ mod tests {
];

for newick_str in test_cases {
match parse_tree(newick_str) {
Ok(nodes) => println!(
"Parsed nodes for '{}': {:?}, len: {}",
newick_str,
nodes,
nodes.len()
),
Err(e) => println!("Error parsing '{}': {:?}", newick_str, e),
}
let nodes = parse_tree(newick_str, NodeParser::default())?;
println!(
"Parsed nodes for '{}': {:?}, len: {}",
newick_str,
nodes,
nodes.len()
)
}

Ok(())
}
}

0 comments on commit a2d15f5

Please sign in to comment.