Skip to content

Commit

Permalink
Start compute.
Browse files Browse the repository at this point in the history
  • Loading branch information
tychedelia committed Aug 6, 2024
1 parent cbe0da6 commit 91f7d19
Show file tree
Hide file tree
Showing 5 changed files with 546 additions and 16 deletions.
21 changes: 11 additions & 10 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ tokio = { version = "1", features = ["full"]}
[target.'cfg(target_arch = "wasm32")'.dependencies]
tokio = { version = "1", features = ["rt"]}

[features]
egui = ["nannou/egui"]
isf = ["nannou/isf"]
video = ["nannou/video"]

# Audio
[[example]]
Expand All @@ -63,6 +59,11 @@ path = "communication/osc_receiver.rs"
name = "osc_sender"
path = "communication/osc_sender.rs"

# Compute
[[example]]
name = "game_of_life"
path = "compute/game_of_life.rs"

# Draw
[[example]]
name = "draw"
Expand Down Expand Up @@ -131,7 +132,7 @@ path = "draw/draw_transform.rs"
[[example]]
name = "isf_simple"
path = "isf/simple.rs"
required-features = ["isf", "egui"]
required-features = ["nannou/isf", "nannou/egui"]

# Laser
[[example]]
Expand All @@ -143,7 +144,7 @@ path = "laser/laser_frame_stream.rs"
[[example]]
name = "laser_frame_stream_gui"
path = "laser/laser_frame_stream_gui.rs"
required-features = ["egui"]
required-features = ["nannou/egui"]
[[example]]
name = "laser_raw_stream"
path = "laser/laser_raw_stream.rs"
Expand Down Expand Up @@ -208,19 +209,19 @@ path = "templates/template_sketch.rs"
[[example]]
name = "circle_packing"
path = "ui/egui/circle_packing.rs"
required-features = ["egui"]
required-features = ["nannou/egui"]
[[example]]
name = "tune_color"
path = "ui/egui/tune_color.rs"
required-features = ["egui"]
required-features = ["nannou/egui"]
[[example]]
name = "simple_ui"
path = "ui/egui/simple_ui.rs"
required-features = ["egui"]
required-features = ["nannou/egui"]
[[example]]
name = "inspector_ui"
path = "ui/egui/inspector_ui.rs"
required-features = ["egui"]
required-features = ["nannou/egui"]

# Video
[[example]]
Expand Down
74 changes: 74 additions & 0 deletions examples/assets/shaders/game_of_life.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// The shader reads the previous frame's state from the `input` texture, and writes the new state of
// each pixel to the `output` texture. The textures are flipped each step to progress the
// simulation.
// Two textures are needed for the game of life as each pixel of step N depends on the state of its
// neighbors at step N-1.

struct GameOfLife {
index: u32,
}

@group(0) @binding(0) var input: texture_storage_2d<r32float, read>;
@group(0) @binding(1) var output: texture_storage_2d<r32float, write>;

fn hash(value: u32) -> u32 {
var state = value;
state = state ^ 2747636419u;
state = state * 2654435769u;
state = state ^ state >> 16u;
state = state * 2654435769u;
state = state ^ state >> 16u;
state = state * 2654435769u;
return state;
}

fn randomFloat(value: u32) -> f32 {
return f32(hash(value)) / 4294967295.0;
}

@compute @workgroup_size(8, 8, 1)
fn init(@builtin(global_invocation_id) invocation_id: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
let location = vec2<i32>(i32(invocation_id.x), i32(invocation_id.y));

let randomNumber = randomFloat(invocation_id.y << 16u | invocation_id.x);
let alive = randomNumber > 0.9;
let color = vec4<f32>(f32(alive));

textureStore(output, location, color);
}

fn is_alive(location: vec2<i32>, offset_x: i32, offset_y: i32) -> i32 {
let value: vec4<f32> = textureLoad(input, location + vec2<i32>(offset_x, offset_y));
return i32(value.x);
}

fn count_alive(location: vec2<i32>) -> i32 {
return is_alive(location, -1, -1) +
is_alive(location, -1, 0) +
is_alive(location, -1, 1) +
is_alive(location, 0, -1) +
is_alive(location, 0, 1) +
is_alive(location, 1, -1) +
is_alive(location, 1, 0) +
is_alive(location, 1, 1);
}

@compute @workgroup_size(8, 8, 1)
fn update(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let location = vec2<i32>(i32(invocation_id.x), i32(invocation_id.y));

let n_alive = count_alive(location);

var alive: bool;
if (n_alive == 3) {
alive = true;
} else if (n_alive == 2) {
let currently_alive = is_alive(location, 0, 0);
alive = bool(currently_alive);
} else {
alive = false;
}
let color = vec4<f32>(f32(alive));

textureStore(output, location, color);
}
126 changes: 126 additions & 0 deletions examples/compute/game_of_life.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use nannou::prelude::*;

const DISPLAY_FACTOR: u32 = 4;
const SIZE: (u32, u32) = (1280 / DISPLAY_FACTOR, 720 / DISPLAY_FACTOR);
const WORKGROUP_SIZE: u32 = 8;

fn main() {
nannou::app(model)
.update(update)
.compute(compute)
.view(view)
.run()
}

struct Model {
texture_a: Handle<Image>,
texture_b: Handle<Image>,
displayed: Handle<Image>,
}

#[derive(Clone, Default)]
enum State {
#[default]
Init,
Update(usize),
}


#[derive(AsBindGroup, Clone)]
struct ComputeModel {
#[storage_texture(0, image_format = R32Float, access = ReadOnly)]
texture_read: Handle<Image>,
#[storage_texture(1, image_format = R32Float, access = WriteOnly)]
texture_write: Handle<Image>,
}

impl ComputeShader for ComputeModel {
type State = State;

fn compute_shader() -> ShaderRef {
ShaderRef::Path("shaders/game_of_life.wgsl".into())
}

fn shader_entry(state: &Self::State) -> String {
match state {
State::Init => "init".into(),
State::Update(_) => "update".into(),
}
}

fn workgroup_size(state: &Self::State) -> (u32, u32, u32) {
(SIZE.0 / WORKGROUP_SIZE, SIZE.1 / WORKGROUP_SIZE, 1)
}
}

fn model(app: &App) -> Model {
let _window = app.new_window::<Model>().size(SIZE.0 * DISPLAY_FACTOR, SIZE.1 * DISPLAY_FACTOR).build();

let mut image = Image::new_fill(
Extent3d {
width: SIZE.0,
height: SIZE.1,
depth_or_array_layers: 1,
},
TextureDimension::D2,
&[0, 0, 0, 255],
TextureFormat::R32Float,
RenderAssetUsages::RENDER_WORLD,
);
image.texture_descriptor.usage =
TextureUsages::COPY_DST | TextureUsages::STORAGE_BINDING | TextureUsages::TEXTURE_BINDING;

info!("Adding image to assets");
let image0 = app.assets_mut().add(image.clone());
let image1 = app.assets_mut().add(image);
info!("Added image to assets");
Model {
texture_a: image0.clone(),
texture_b: image1,
displayed: image0.clone(),
}
}

fn update(app: &App, model: &mut Model) {
if model.displayed == model.texture_a {
model.displayed = model.texture_b.clone();
} else {
model.displayed = model.texture_a.clone();
}
}

fn compute(app: &App, model: &Model, state: &mut State, view: Entity) -> ComputeModel {
if let State::Init = state {
*state = State::Update(0);
return ComputeModel {
texture_read: model.texture_a.clone(),
texture_write: model.texture_b.clone(),
}
}

if model.displayed == model.texture_a {
*state = State::Update(0);
ComputeModel {
texture_read: model.texture_a.clone(),
texture_write: model.texture_b.clone(),
}
} else {
*state = State::Update(1);
ComputeModel {
texture_read: model.texture_b.clone(),
texture_write: model.texture_a.clone(),
}
}
}

fn view(
app: &App,
model: &Model,
view: Entity,
) {
let draw = app.draw();
let window_rect = app.window_rect();
draw.rect()
.w_h(window_rect.w(), window_rect.h())
.texture(&model.displayed);
}
65 changes: 60 additions & 5 deletions nannou/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ use bevy::reflect::{
ApplyError, DynamicTypePath, GetTypeRegistration, ReflectMut, ReflectOwned, ReflectRef,
TypeInfo,
};
use bevy::render::extract_component::ExtractComponent;
use bevy::render::extract_resource::{extract_resource, ExtractResource};
use bevy::render::render_graph::ViewNodeRunner;
use bevy::render::render_resource::{AsBindGroup, BindGroup};
use bevy::window::{ExitCondition, PrimaryWindow, WindowClosed, WindowFocused, WindowResized};
use bevy::winit::{UpdateMode, WinitEvent, WinitSettings};
#[cfg(feature = "egui")]
Expand All @@ -44,23 +46,29 @@ use bevy_inspector_egui::quick::WorldInspectorPlugin;
#[cfg(feature = "egui")]
use bevy_inspector_egui::DefaultInspectorConfigPlugin;
use find_folder;

use bevy_nannou::prelude::render::ExtendedNannouMaterial;
use bevy_nannou::prelude::{draw, DrawHolder};
use wgpu::naga::ShaderStage::Compute;
use bevy_nannou::prelude::render::{ExtendedNannouMaterial, NannouCamera};
use bevy_nannou::prelude::{draw, DrawHolder, ShaderRef};
use bevy_nannou::NannouPlugin;

use crate::frame::{Frame, FramePlugin};
use crate::prelude::bevy_ecs::system::SystemState;
use crate::prelude::bevy_render::extract_component::ExtractComponentPlugin;
use crate::prelude::render::NannouMesh;
use crate::prelude::NannouMaterialPlugin;
use crate::render::{NannouRenderNode, RenderApp, RenderPlugin};
use crate::render::{
ComputeModel, ComputePlugin, ComputeShader, ComputeShaderHandle, ComputeState,
NannouRenderNode, RenderApp, RenderPlugin,
};
use crate::window::WindowUserFunctions;
use crate::{camera, geom, light, window};

/// The user function type for initialising their model.
pub type ModelFn<Model> = fn(&App) -> Model;

/// The user function type for producing the compute model post-update.
pub type ComputeUpdateFn<Model, ComputeModel: ComputeShader> =
fn(&App, &Model, &mut ComputeModel::State, Entity) -> ComputeModel;

/// The user function type for updating their model in accordance with some event.
pub type EventFn<Model, Event> = fn(&App, &mut Model, &Event);

Expand Down Expand Up @@ -154,6 +162,9 @@ struct EventFnRes<M, E>(Option<EventFn<M, E>>);
#[derive(Resource, Deref, DerefMut)]
struct UpdateFnRes<M>(Option<UpdateFn<M>>);

#[derive(Resource, Deref, DerefMut)]
struct ComputeUpdateFnRes<M, CM: ComputeShader>(ComputeUpdateFn<M, CM>);

#[derive(Resource, Deref, DerefMut)]
pub(crate) struct RenderFnRes<M>(Option<RenderFn<M>>);

Expand Down Expand Up @@ -354,6 +365,27 @@ where
self
}

pub fn compute<CM: ComputeShader>(mut self, compute_fn: ComputeUpdateFn<M, CM>) -> Self {
let render_app = self.app.sub_app_mut(bevy::render::RenderApp);
render_app.insert_resource(ComputeShaderHandle(CM::compute_shader()));
self.app
.add_systems(
First,
|mut commands: Commands, views_q: Query<Entity, Added<NannouCamera>>| {
for view in views_q.iter() {
info!("Adding compute state to view {:?}", view);
commands
.entity(view)
.insert(ComputeState(CM::State::default()));
}
},
)
.insert_resource(ComputeUpdateFnRes(compute_fn))
.add_systems(Update, compute::<M, CM>.after(update::<M>))
.add_plugins(ComputePlugin::<CM>::default());
self
}

#[cfg(any(feature = "config_json", feature = "config_toml"))]
pub fn init_config<T>(mut self) -> Self
where
Expand Down Expand Up @@ -1097,6 +1129,29 @@ fn update<M>(
*ticks += 1;
}

fn compute<M, CM>(
world: &mut World,
state: &mut SystemState<(
Commands,
ResMut<ModelHolder<M>>,
Res<ComputeUpdateFnRes<M, CM>>,
Query<(Entity, &mut ComputeState<CM::State>)>,
)>,
)
where
M: 'static + Send + Sync,
CM: ComputeShader
{
let (mut app, (mut commands, mut model, compute, mut views_q)) = get_app_and_state(world, state);
let compute = compute.0;
for (view, mut state) in views_q.iter_mut() {
let compute_model = compute(&app, &model, &mut state.0, view);
info!("Updating compute model for view {:?}", view);
let id = commands.spawn(ComputeModel(compute_model)).id();
info!("Spawned compute model with id {:?}", id);
}
}

fn events<M, E>(
world: &mut World,
state: &mut SystemState<(
Expand Down
Loading

0 comments on commit 91f7d19

Please sign in to comment.