Skip to content

Commit

Permalink
feat: add OpenCL device compatibility check + fix typo (#169)
Browse files Browse the repository at this point in the history
* add compatibility check against selected OpenCL device

* fix typo in function names
  • Loading branch information
glemercier authored and hashmap committed Jan 31, 2019
1 parent 8f73783 commit 2c5942b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 48 deletions.
166 changes: 120 additions & 46 deletions ocl_cuckaroo/src/trimmer.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use ocl;
use ocl::enums::{ArgVal, ProfilingInfo};
use ocl::flags::CommandQueueProperties;
use ocl::enums::{ArgVal, DeviceInfo, DeviceInfoResult, ProfilingInfo};
use ocl::flags::{CommandQueueProperties, MemFlags};
use ocl::prm::{Uint2, Ulong4};
use ocl::{
Buffer, Context, Device, Event, EventList, Kernel, Platform, Program, Queue, SpatialDims,
};
use std::collections::HashMap;
use std::env;

const DUCK_SIZE_A: usize = 129; // AMD 126 + 3
Expand All @@ -29,6 +30,11 @@ pub struct Trimmer {
is_nvidia: bool,
}

struct ClBufferParams {
size: usize,
flags: MemFlags,
}

macro_rules! clear_buffer (
($buf:expr) => (
$buf.cmd().fill(0, None).enq()?;
Expand All @@ -47,6 +53,15 @@ macro_rules! kernel_enq(
}
));

macro_rules! get_device_info {
($dev:ident, $name:ident) => {{
match $dev.info(DeviceInfo::$name) {
Ok(DeviceInfoResult::$name(value)) => value,
_ => panic!("Failed to retrieve device {}", stringify!($name)),
}
}};
}

#[cfg(feature = "profile")]
fn queue_props() -> Option<CommandQueueProperties> {
Some(CommandQueueProperties::PROFILING_ENABLE)
Expand Down Expand Up @@ -74,10 +89,62 @@ impl Trimmer {
env::set_var("GPU_SINGLE_ALLOC_PERCENT", "100");
env::set_var("GPU_64BIT_ATOMICS", "1");
env::set_var("GPU_MAX_WORKGROUP_SIZE", "1024");
let platform = find_paltform(platform_name)
let platform = find_platform(platform_name)
.ok_or::<ocl::Error>("Can't find OpenCL platform".into())?;
let p_name = platform.name()?;
let device = find_device(&platform, device_id)?;
let mut buffers = HashMap::new();
buffers.insert(
"A1".to_string(),
ClBufferParams {
size: BUFFER_SIZE_A1,
flags: MemFlags::empty(),
},
);
buffers.insert(
"A2".to_string(),
ClBufferParams {
size: BUFFER_SIZE_A2,
flags: MemFlags::empty(),
},
);
buffers.insert(
"B".to_string(),
ClBufferParams {
size: BUFFER_SIZE_B,
flags: MemFlags::empty(),
},
);
buffers.insert(
"I1".to_string(),
ClBufferParams {
size: INDEX_SIZE,
flags: MemFlags::empty(),
},
);
buffers.insert(
"I2".to_string(),
ClBufferParams {
size: INDEX_SIZE,
flags: MemFlags::empty(),
},
);
buffers.insert(
"R".to_string(),
ClBufferParams {
size: 42 * 2,
flags: MemFlags::READ_ONLY,
},
);
buffers.insert(
"NONCES".to_string(),
ClBufferParams {
size: INDEX_SIZE,
flags: MemFlags::empty(),
},
);

check_device_compatibility(&device, &buffers)?;

let context = Context::builder()
.platform(platform)
Expand All @@ -91,48 +158,13 @@ impl Trimmer {
.src(SRC)
.build(&context)?;

let buffer_a1 = Buffer::<u32>::builder()
.queue(q.clone())
.len(BUFFER_SIZE_A1)
.fill_val(0)
.build()?;

let buffer_a2 = Buffer::<u32>::builder()
.queue(q.clone())
.len(BUFFER_SIZE_A2)
.fill_val(0)
.build()?;

let buffer_b = Buffer::<u32>::builder()
.queue(q.clone())
.len(BUFFER_SIZE_B)
.fill_val(0)
.build()?;

let buffer_i1 = Buffer::<u32>::builder()
.queue(q.clone())
.len(INDEX_SIZE)
.fill_val(0)
.build()?;

let buffer_i2 = Buffer::<u32>::builder()
.queue(q.clone())
.len(INDEX_SIZE)
.fill_val(0)
.build()?;

let buffer_r = Buffer::<u32>::builder()
.queue(q.clone())
.len(42 * 2)
.flags(ocl::flags::MemFlags::READ_ONLY)
.fill_val(0)
.build()?;

let buffer_nonces = Buffer::<u32>::builder()
.queue(q.clone())
.len(INDEX_SIZE)
.fill_val(0)
.build()?;
let buffer_a1 = build_buffer(buffers.get("A1"), &q)?;
let buffer_a2 = build_buffer(buffers.get("A2"), &q)?;
let buffer_b = build_buffer(buffers.get("B"), &q)?;
let buffer_i1 = build_buffer(buffers.get("I1"), &q)?;
let buffer_i2 = build_buffer(buffers.get("I2"), &q)?;
let buffer_r = build_buffer(buffers.get("R"), &q)?;
let buffer_nonces = build_buffer(buffers.get("NONCES"), &q)?;

Ok(Trimmer {
q,
Expand Down Expand Up @@ -397,7 +429,7 @@ fn print_event(name: &str, ev: &Event) {
#[cfg(not(feature = "profile"))]
fn print_event(_name: &str, _ev: &Event) {}

fn find_paltform(selector: Option<&str>) -> Option<Platform> {
fn find_platform(selector: Option<&str>) -> Option<Platform> {
match selector {
None => Some(Platform::default()),
Some(sel) => Platform::list().into_iter().find(|p| {
Expand All @@ -417,6 +449,48 @@ fn find_device(platform: &Platform, selector: Option<usize>) -> ocl::Result<Devi
}
}

fn check_device_compatibility(
device: &Device,
buffers: &HashMap<String, ClBufferParams>,
) -> ocl::Result<()> {
let max_alloc_size: u64 = get_device_info!(device, MaxMemAllocSize);
let global_memory_size: u64 = get_device_info!(device, GlobalMemSize);
let mut total_alloc: u64 = 0;

// Check that no buffer is bigger than the max memory allocation size
for (k, v) in buffers {
total_alloc += v.size as u64;
if v.size as u64 > max_alloc_size {
return Err(ocl::Error::from(format!(
"Buffer {} is bigger than maximum alloc size ({})",
k, max_alloc_size
)));
}
}

// Check that total buffer allocation does not exceed global memory size
if total_alloc > global_memory_size {
return Err(ocl::Error::from(format!(
"Total needed memory is bigger than device's capacity ({})",
global_memory_size
)));
}

Ok(())
}

fn build_buffer(params: Option<&ClBufferParams>, q: &Queue) -> ocl::Result<Buffer<u32>> {
match params {
None => Err(ocl::Error::from("Invalid parameters")),
Some(p) => Buffer::<u32>::builder()
.queue(q.clone())
.len(p.size)
.flags(p.flags)
.fill_val(0)
.build(),
}
}

const SRC: &str = r#"
// Cuckaroo Cycle, a memory-hard proof-of-work by John Tromp and team Grin
// Copyright (c) 2018 Jiri Photon Vadura and John Tromp
Expand Down
4 changes: 2 additions & 2 deletions ocl_cuckatoo/src/trimmer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl Trimmer {
device_id: Option<usize>,
edge_bits: u8,
) -> ocl::Result<Trimmer> {
let platform = find_paltform(platform_name)
let platform = find_platform(platform_name)
.ok_or::<ocl::Error>("Can't find OpenCL platform".into())?;
let device = find_device(&platform, device_id)?;

Expand Down Expand Up @@ -147,7 +147,7 @@ impl Trimmer {
}
}

fn find_paltform(selector: Option<&str>) -> Option<Platform> {
fn find_platform(selector: Option<&str>) -> Option<Platform> {
match selector {
None => Some(Platform::default()),
Some(sel) => Platform::list().into_iter().find(|p| {
Expand Down

0 comments on commit 2c5942b

Please sign in to comment.