Skip to content

Commit

Permalink
fix Session init race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
rbran committed Feb 4, 2025
1 parent f6db8b0 commit 80f7b95
Show file tree
Hide file tree
Showing 21 changed files with 86 additions and 87 deletions.
60 changes: 40 additions & 20 deletions rust/src/headless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ use crate::{
};
use std::io;
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::SeqCst;
use thiserror::Error;

use crate::enterprise::release_license;
Expand All @@ -35,7 +33,7 @@ use std::time::Duration;
static MAIN_THREAD_HANDLE: Mutex<Option<JoinHandle<()>>> = Mutex::new(None);

/// Used to prevent shutting down Binary Ninja if there are other [`Session`]'s.
static SESSION_COUNT: AtomicUsize = AtomicUsize::new(0);
static SESSION_COUNT: Mutex<usize> = Mutex::new(0);

#[derive(Error, Debug)]
pub enum InitializationError {
Expand All @@ -47,6 +45,8 @@ pub enum InitializationError {
InvalidLicense,
#[error("no license could located, please see `binaryninja::set_license` for details")]
NoLicenseFound,
#[error("unable to apply options to an previously created `binaryninja::headless::Session`")]
SessionAlreadyInitialized,
}

/// Loads plugins, core architecture, platform, etc.
Expand Down Expand Up @@ -284,39 +284,56 @@ pub fn license_location() -> Option<LicenseLocation> {
}

/// Wrapper for [`init`] and [`shutdown`]. Instantiating this at the top of your script will initialize everything correctly and then clean itself up at exit as well.
pub struct Session {}
pub struct Session {
/// lock that don't allow the user to create a session directly
_lock: std::marker::PhantomData<()>,
}

impl Session {
/// Get a registered [`Session`] for use.
///
/// This is required so that we can keep track of the [`SESSION_COUNT`].
fn registered_session() -> Self {
let _previous_count = SESSION_COUNT.fetch_add(1, SeqCst);
Self {}
fn register_session(
options: Option<InitializationOptions>,
) -> Result<Self, InitializationError> {
// if we were able to locate a license, continue with initialization.
if license_location().is_none() {
// otherwise you must call [Self::new_with_license].
return Err(InitializationError::NoLicenseFound);
}

// This is required so that we call init only once
let mut session_count = SESSION_COUNT.lock().unwrap();
match (*session_count, options) {
// no session, just create one
(0, options) => init_with_opts(options.unwrap_or_default())?,
// session already created, can't apply options
(1.., Some(_)) => return Err(InitializationError::SessionAlreadyInitialized),
// NOTE if the existing session was created with options,
// returning the current session may not be exactly what the
// user expects.
(1.., None) => {}
}
*session_count += 1;
Ok(Self {
_lock: std::marker::PhantomData,
})
}

/// Before calling new you must make sure that the license is retrievable, otherwise the core won't be able to initialize.
///
/// If you cannot otherwise provide a license via `BN_LICENSE_FILE` environment variable or the Binary Ninja user directory
/// you can call [`Session::new_with_opts`] instead of this function.
pub fn new() -> Result<Self, InitializationError> {
if license_location().is_some() {
// We were able to locate a license, continue with initialization.
init()?;
Ok(Self::registered_session())
} else {
// There was no license that could be automatically retrieved, you must call [Self::new_with_license].
Err(InitializationError::NoLicenseFound)
}
Self::register_session(None)
}

/// Initialize with options, the same rules apply as [`Session::new`], see [`InitializationOptions::default`] for the regular options passed.
///
/// This differs from [`Session::new`] in that it does not check to see if there is a license that the core
/// can discover by itself, therefor it is expected that you know where your license is when calling this directly.
pub fn new_with_opts(options: InitializationOptions) -> Result<Self, InitializationError> {
init_with_opts(options)?;
Ok(Self::registered_session())
Self::register_session(Some(options))
}

/// ```no_run
Expand Down Expand Up @@ -410,10 +427,13 @@ impl Session {

impl Drop for Session {
fn drop(&mut self) {
let previous_count = SESSION_COUNT.fetch_sub(1, SeqCst);
if previous_count == 1 {
let mut session_count = SESSION_COUNT.lock().unwrap();
match *session_count {
0 => unreachable!(),
// We were the last session, therefor we can safely shut down.
shutdown();
1 => shutdown(),
2.. => {}
}
*session_count -= 1;
}
}
7 changes: 3 additions & 4 deletions rust/tests/background_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ use binaryninja::headless::Session;
use rstest::*;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_background_task_registered(_session: &Session) {
fn test_background_task_registered(_session: Session) {
let task_progress = "test registered";
let task = BackgroundTask::new(task_progress, false);
BackgroundTask::running_tasks()
Expand All @@ -25,7 +24,7 @@ fn test_background_task_registered(_session: &Session) {
}

#[rstest]
fn test_background_task_cancellable(_session: &Session) {
fn test_background_task_cancellable(_session: Session) {
let task_progress = "test cancellable";
let task = BackgroundTask::new(task_progress, false);
BackgroundTask::running_tasks()
Expand All @@ -38,7 +37,7 @@ fn test_background_task_cancellable(_session: &Session) {
}

#[rstest]
fn test_background_task_progress(_session: &Session) {
fn test_background_task_progress(_session: Session) {
let task = BackgroundTask::new("test progress", false);
let first_progress = task.progress_text().to_string();
assert_eq!(first_progress, "test progress");
Expand Down
5 changes: 2 additions & 3 deletions rust/tests/binary_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ use std::io::{Read, Seek, SeekFrom};
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_binary_reader_seek(_session: &Session) {
fn test_binary_reader_seek(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let mut reader = BinaryReader::new(&view);
Expand Down Expand Up @@ -50,7 +49,7 @@ fn test_binary_reader_seek(_session: &Session) {
}

#[rstest]
fn test_binary_reader_read(_session: &Session) {
fn test_binary_reader_read(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let mut reader = BinaryReader::new(&view);
Expand Down
7 changes: 3 additions & 4 deletions rust/tests/binary_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ use rstest::*;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_binary_loading(_session: &Session) {
fn test_binary_loading(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
assert!(view.has_initial_analysis(), "No initial analysis");
Expand All @@ -22,7 +21,7 @@ fn test_binary_loading(_session: &Session) {
}

#[rstest]
fn test_binary_saving(_session: &Session) {
fn test_binary_saving(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
// Verify the contents before we modify.
Expand All @@ -45,7 +44,7 @@ fn test_binary_saving(_session: &Session) {
}

#[rstest]
fn test_binary_saving_database(_session: &Session) {
fn test_binary_saving_database(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
// Update a symbol to verify modification
Expand Down
5 changes: 2 additions & 3 deletions rust/tests/binary_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_binary_writer_seek(_session: &Session) {
fn test_binary_writer_seek(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let mut writer = BinaryWriter::new(&view);
Expand Down Expand Up @@ -51,7 +50,7 @@ fn test_binary_writer_seek(_session: &Session) {
}

#[rstest]
fn test_binary_writer_write(_session: &Session) {
fn test_binary_writer_write(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let mut reader = BinaryReader::new(&view);
Expand Down
7 changes: 3 additions & 4 deletions rust/tests/collaboration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use serial_test::serial;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}
Expand Down Expand Up @@ -58,7 +57,7 @@ fn temp_project_scope<T: Fn(&RemoteProject)>(remote: &Remote, project_name: &str

#[rstest]
#[serial]
fn test_connection(_session: &Session) {
fn test_connection(_session: Session) {
if !has_collaboration_support() {
eprintln!("No collaboration support, skipping test...");
return;
Expand All @@ -74,7 +73,7 @@ fn test_connection(_session: &Session) {

#[rstest]
#[serial]
fn test_project_creation(_session: &Session) {
fn test_project_creation(_session: Session) {
if !has_collaboration_support() {
eprintln!("No collaboration support, skipping test...");
return;
Expand Down Expand Up @@ -155,7 +154,7 @@ fn test_project_creation(_session: &Session) {

#[rstest]
#[serial]
fn test_project_sync(_session: &Session) {
fn test_project_sync(_session: Session) {
if !has_collaboration_support() {
eprintln!("No collaboration support, skipping test...");
return;
Expand Down
3 changes: 1 addition & 2 deletions rust/tests/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use rstest::*;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_component_creation(_session: &Session) {
fn test_component_creation(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let component = ComponentBuilder::new(view.clone()).name("test").finalize();
Expand Down
5 changes: 2 additions & 3 deletions rust/tests/demangler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ use binaryninja::types::{QualifiedName, Type};
use rstest::*;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_demangler_simple(_session: &Session) {
fn test_demangler_simple(_session: Session) {
let placeholder_arch = CoreArchitecture::by_name("x86").expect("x86 exists");
// Example LLVM-style mangled name
let llvm_mangled = "_Z3fooi"; // "foo(int)" in LLVM mangling
Expand Down Expand Up @@ -46,7 +45,7 @@ fn test_demangler_simple(_session: &Session) {
}

#[rstest]
fn test_custom_demangler(_session: &Session) {
fn test_custom_demangler(_session: Session) {
struct TestDemangler;

impl CustomDemangler for TestDemangler {
Expand Down
3 changes: 1 addition & 2 deletions rust/tests/high_level_il.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use rstest::*;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_hlil_info(_session: &Session) {
fn test_hlil_info(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");

Expand Down
5 changes: 2 additions & 3 deletions rust/tests/low_level_il.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ use rstest::*;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_llil_info(_session: &Session) {
fn test_llil_info(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");

Expand Down Expand Up @@ -170,7 +169,7 @@ fn test_llil_info(_session: &Session) {
}

#[rstest]
fn test_llil_visitor(_session: &Session) {
fn test_llil_visitor(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
let platform = view.default_platform().unwrap();
Expand Down
5 changes: 2 additions & 3 deletions rust/tests/main_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@ use rstest::*;
// TODO: Add a test for MainThreadHandler

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_not_main_thread(_session: &Session) {
fn test_not_main_thread(_session: Session) {
// We should never be the main thread.
assert!(!binaryninja::is_main_thread())
}

#[rstest]
fn test_main_thread_different(_session: &Session) {
fn test_main_thread_different(_session: Session) {
let calling_thread = std::thread::current();
binaryninja::main_thread::execute_on_main_thread_and_wait(move || {
let main_thread = std::thread::current();
Expand Down
3 changes: 1 addition & 2 deletions rust/tests/medium_level_il.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ use rstest::*;
use std::path::PathBuf;

#[fixture]
#[once]
fn session() -> Session {
Session::new().expect("Failed to initialize session")
}

#[rstest]
fn test_mlil_info(_session: &Session) {
fn test_mlil_info(_session: Session) {
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");

Expand Down
Loading

0 comments on commit 80f7b95

Please sign in to comment.