Skip to content

Commit

Permalink
Merge pull request #22 from theseus-rs/update-class-loader-interface
Browse files Browse the repository at this point in the history
refactor!: change class loader function names
  • Loading branch information
brianheineman authored Aug 9, 2024
2 parents 5701571 + b243d9b commit 419b8ed
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 39 deletions.
31 changes: 17 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ indoc = "2.0.5"
rayon = "1.10.0"
reqwest = "0.12.5"
tar = "0.4.41"
tempfile = "3.11.0"
thiserror = "1.0.63"
tokio = { version = "1.39.2", default-features = false, features = ["macros", "rt", "sync"] }
tracing = "0.1.40"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Crates for the [JVM Specification](https://docs.oracle.com/javase/specs/jvms/se2

Crates for the [JVM Specification](https://docs.oracle.com/javase/specs/jvms/se22/html/)

Supports reading, writing, verifying and loading classed for any version of Java version up to 24. Verification of
Supports reading, writing, verifying and loading classes for any version of Java version up to 24. Verification of
class files is supported, but is still a work in progress.

# Examples
Expand Down
1 change: 1 addition & 0 deletions ristretto_classloader/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ zip = { workspace = true }

[dev-dependencies]
indoc = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["rt-multi-thread"] }

[features]
Expand Down
4 changes: 2 additions & 2 deletions ristretto_classloader/src/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Debug for Class {
/// Formats the class for debugging.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Class")
.field("class_loader", &self.class_loader.get_name())
.field("class_loader", &self.class_loader.name())
.field("class_file", &self.class_file)
.finish()
}
Expand All @@ -57,7 +57,7 @@ mod tests {
let mut cursor = Cursor::new(bytes);
let class_file = ClassFile::from_bytes(&mut cursor)?;
let class = Class::new(Arc::new(class_loader), Arc::new(class_file));
assert_eq!("bootstrap", class.get_class_loader().get_name());
assert_eq!("bootstrap", class.get_class_loader().name());
assert_eq!("Simple", class.get_class_file().class_name()?);
Ok(())
}
Expand Down
35 changes: 19 additions & 16 deletions ristretto_classloader/src/class_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ impl ClassLoader {

/// Get the name of the class loader.
#[must_use]
pub fn get_name(&self) -> &str {
pub fn name(&self) -> &str {
&self.name
}

/// Get the class path.
#[must_use]
pub fn get_class_path(&self) -> &ClassPath {
pub fn class_path(&self) -> &ClassPath {
&self.class_path
}

/// Get the parent class loader.
pub fn get_parent(&self) -> Option<Arc<ClassLoader>> {
pub fn parent(&self) -> Option<Arc<ClassLoader>> {
self.parent.as_ref().map(Arc::clone)
}

Expand All @@ -60,7 +60,7 @@ impl ClassLoader {
// Convert hierarchy of class loaders to a flat list.
let mut class_loader = Arc::clone(loader);
let mut class_loaders = vec![Arc::clone(&class_loader)];
while let Some(parent) = class_loader.get_parent() {
while let Some(parent) = class_loader.parent() {
class_loader = parent;
class_loaders.push(Arc::clone(&class_loader));
}
Expand Down Expand Up @@ -92,7 +92,7 @@ impl Default for ClassLoader {
impl PartialEq for ClassLoader {
/// Compare class loaders by name.
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.get_parent() == other.get_parent()
self.name == other.name && self.parent() == other.parent()
}
}

Expand All @@ -105,16 +105,16 @@ mod tests {
fn test_new() {
let name = "test";
let class_loader = ClassLoader::new(name, ClassPath::default());
assert_eq!(name, class_loader.get_name());
assert_eq!(&ClassPath::default(), class_loader.get_class_path());
assert!(class_loader.get_parent().is_none());
assert_eq!(name, class_loader.name());
assert_eq!(&ClassPath::default(), class_loader.class_path());
assert!(class_loader.parent().is_none());
}

#[test]
fn test_default() {
let class_loader = ClassLoader::default();
assert_eq!("bootstrap", class_loader.get_name());
assert!(class_loader.get_parent().is_none());
assert_eq!("bootstrap", class_loader.name());
assert!(class_loader.parent().is_none());
}

#[test]
Expand All @@ -136,10 +136,7 @@ mod tests {
let mut class_loader1 = ClassLoader::new("test1", ClassPath::default());
let class_loader2 = ClassLoader::new("test2", ClassPath::default());
class_loader1.set_parent(Some(Arc::new(class_loader2)));
assert_eq!(
"test2",
class_loader1.get_parent().expect("parent").get_name()
);
assert_eq!("test2", class_loader1.parent().expect("parent").name());
}

#[tokio::test]
Expand All @@ -150,9 +147,15 @@ mod tests {

let class_path = class_path_entries.join(":");
let class_loader = Arc::new(ClassLoader::new("test", ClassPath::from(&class_path)));
let class = ClassLoader::load_class(&class_loader, "HelloWorld").await?;
let class_name = "HelloWorld";
let class = ClassLoader::load_class(&class_loader, class_name).await?;
let class_file = class.get_class_file();
assert_eq!("HelloWorld", class_file.class_name()?);
assert_eq!(class_name, class_file.class_name()?);

// Load the same class again to test caching
let class = ClassLoader::load_class(&class_loader, class_name).await?;
let class_file = class.get_class_file();
assert_eq!(class_name, class_file.class_name()?);
Ok(())
}

Expand Down
55 changes: 49 additions & 6 deletions ristretto_classloader/src/class_path_entry/jar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,19 @@ impl Jar {
class_files: &DashMap<String, Arc<ClassFile>>,
) -> Result<()> {
let reader = io::Cursor::new(bytes);
let mut archive =
ZipArchive::new(reader).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut archive = ZipArchive::new(reader)?;

// Decompress all the bytes from the jar and store in a map to be converted into class files
let mut class_bytes = HashMap::new();
for i in 0..archive.len() {
let mut file = archive
.by_index(i)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut file = archive.by_index(i)?;
let file_name = file.name().to_string();
if !file_name.ends_with(".class") {
continue;
}

let mut bytes = Vec::new();
io::copy(&mut file, &mut bytes).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
io::copy(&mut file, &mut bytes)?;
let class_name = file_name.replace('/', ".").replace(".class", "");
class_bytes.insert(class_name, bytes);
}
Expand Down Expand Up @@ -141,6 +138,8 @@ impl PartialEq for Jar {
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use zip::write::SimpleFileOptions;

#[test]
fn test_new() {
Expand Down Expand Up @@ -196,4 +195,48 @@ mod tests {
assert!(matches!(result, Err(ClassNotFound(_))));
Ok(())
}

#[tokio::test]
async fn test_bad_class_file() -> Result<()> {
let temp_dir = tempfile::tempdir()?;

// Create a jar with a bad class file
let jar_path = temp_dir.path().join("invalid.jar");
let mut archive = zip::ZipWriter::new(fs::File::create(&jar_path)?);
archive.start_file("HelloWorld.class", SimpleFileOptions::default())?;
archive.write_all(&[0x00, 0x01, 0x02])?;
archive.finish()?;

// Test reading the class file
let jar = Jar::new(jar_path.to_string_lossy());
let result = jar.read_class("HelloWorld").await;
assert!(matches!(result, Err(ClassNotFound(_))));
Ok(())
}

#[tokio::test]
async fn test_invalid_class_file() -> Result<()> {
let temp_dir = tempfile::tempdir()?;

// Create an invalid class file
let class_file = ClassFile {
this_class: 42,
..Default::default()
};
let mut bytes = Vec::new();
class_file.to_bytes(&mut bytes)?;

// Create a jar with an invalid class file
let jar_path = temp_dir.path().join("invalid.jar");
let mut archive = zip::ZipWriter::new(fs::File::create(&jar_path)?);
archive.start_file("HelloWorld.class", SimpleFileOptions::default())?;
archive.write_all(bytes.as_slice())?;
archive.finish()?;

// Test reading the class file
let jar = Jar::new(jar_path.to_string_lossy());
let result = jar.read_class("HelloWorld").await;
assert!(matches!(result, Err(ClassNotFound(_))));
Ok(())
}
}
3 changes: 3 additions & 0 deletions ristretto_classloader/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ pub enum Error {
#[cfg(feature = "url")]
#[error(transparent)]
RequestError(#[from] reqwest::Error),
/// An error while reading a jar or module
#[error(transparent)]
ZipError(#[from] zip::result::ZipError),
}

0 comments on commit 419b8ed

Please sign in to comment.