Skip to content

Commit

Permalink
Test and handle all tensor dtypes as images (#1840)
Browse files Browse the repository at this point in the history
* Support i64 and u64 tensors

* api_demo: log all image types

* Don't even print out the contents of a tensor

* Handle unfilterable float textures

* fix typo

* py-format

* Simplify is_float_filterable

* Add a helper function pad_and_narrow_and_cast

* Reuse existing image

* Exclude image_tensors demo from default api_demo

* Still run all api demos in e2e test

* pyformat
  • Loading branch information
emilk authored Apr 13, 2023
1 parent 437fe2b commit 1211ef4
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 96 deletions.
20 changes: 19 additions & 1 deletion crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl ArrowDeserialize for TensorId {
/// ),
/// );
/// ```
#[derive(Clone, Debug, PartialEq, ArrowField, ArrowSerialize, ArrowDeserialize)]
#[derive(Clone, PartialEq, ArrowField, ArrowSerialize, ArrowDeserialize)]
#[arrow_field(type = "dense")]
pub enum TensorData {
U8(BinaryBuffer),
Expand Down Expand Up @@ -198,6 +198,24 @@ impl TensorData {
}
}

impl std::fmt::Debug for TensorData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::U8(_) => write!(f, "U8({} bytes)", self.size_in_bytes()),
Self::U16(_) => write!(f, "U16({} bytes)", self.size_in_bytes()),
Self::U32(_) => write!(f, "U32({} bytes)", self.size_in_bytes()),
Self::U64(_) => write!(f, "U64({} bytes)", self.size_in_bytes()),
Self::I8(_) => write!(f, "I8({} bytes)", self.size_in_bytes()),
Self::I16(_) => write!(f, "I16({} bytes)", self.size_in_bytes()),
Self::I32(_) => write!(f, "I32({} bytes)", self.size_in_bytes()),
Self::I64(_) => write!(f, "I64({} bytes)", self.size_in_bytes()),
Self::F32(_) => write!(f, "F32({} bytes)", self.size_in_bytes()),
Self::F64(_) => write!(f, "F64({} bytes)", self.size_in_bytes()),
Self::JPEG(_) => write!(f, "JPEG({} bytes)", self.size_in_bytes()),
}
}
}

/// Flattened `Tensor` data payload
///
/// ## Examples
Expand Down
22 changes: 15 additions & 7 deletions crates/re_renderer/shader/rectangle.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
// Keep in sync with mirror in rectangle.rs

// Which texture to read from?
const SAMPLE_TYPE_FLOAT = 1u;
const SAMPLE_TYPE_SINT = 2u;
const SAMPLE_TYPE_UINT = 3u;
const SAMPLE_TYPE_FLOAT_FILTER = 1u;
const SAMPLE_TYPE_FLOAT_NOFILTER = 2u;
const SAMPLE_TYPE_SINT_NOFILTER = 3u;
const SAMPLE_TYPE_UINT_NOFILTER = 4u;

// How do we do colormapping?
const COLOR_MAPPER_OFF = 1u;
Expand Down Expand Up @@ -67,6 +68,9 @@ var texture_uint: texture_2d<u32>;
@group(1) @binding(5)
var colormap_texture: texture_2d<f32>;

@group(1) @binding(6)
var texture_float_filterable: texture_2d<f32>;


struct VertexOut {
@builtin(position) position: Vec4,
Expand All @@ -90,12 +94,16 @@ fn vs_main(@builtin(vertex_index) v_idx: u32) -> VertexOut {
fn fs_main(in: VertexOut) -> @location(0) Vec4 {
// Sample the main texture:
var sampled_value: Vec4;
if rect_info.sample_type == SAMPLE_TYPE_FLOAT {
sampled_value = textureSampleLevel(texture_float, texture_sampler, in.texcoord, 0.0); // TODO(emilk): support mipmaps
} else if rect_info.sample_type == SAMPLE_TYPE_SINT {
if rect_info.sample_type == SAMPLE_TYPE_FLOAT_FILTER {
// TODO(emilk): support mipmaps
sampled_value = textureSampleLevel(texture_float_filterable, texture_sampler, in.texcoord, 0.0);
} else if rect_info.sample_type == SAMPLE_TYPE_FLOAT_NOFILTER {
let icoords = IVec2(in.texcoord * Vec2(textureDimensions(texture_float).xy));
sampled_value = Vec4(textureLoad(texture_float, icoords, 0));
} else if rect_info.sample_type == SAMPLE_TYPE_SINT_NOFILTER {
let icoords = IVec2(in.texcoord * Vec2(textureDimensions(texture_sint).xy));
sampled_value = Vec4(textureLoad(texture_sint, icoords, 0));
} else if rect_info.sample_type == SAMPLE_TYPE_UINT {
} else if rect_info.sample_type == SAMPLE_TYPE_UINT_NOFILTER {
let icoords = IVec2(in.texcoord * Vec2(textureDimensions(texture_uint).xy));
sampled_value = Vec4(textureLoad(texture_uint, icoords, 0));
} else {
Expand Down
56 changes: 44 additions & 12 deletions crates/re_renderer/src/renderer/rectangles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ mod gpu_data {
// Keep in sync with mirror in rectangle.wgsl

// Which texture to read from?
const SAMPLE_TYPE_FLOAT: u32 = 1;
const SAMPLE_TYPE_SINT: u32 = 2;
const SAMPLE_TYPE_UINT: u32 = 3;
const SAMPLE_TYPE_FLOAT_FILTER: u32 = 1;
const SAMPLE_TYPE_FLOAT_NOFILTER: u32 = 2;
const SAMPLE_TYPE_SINT_NOFILTER: u32 = 3;
const SAMPLE_TYPE_UINT_NOFILTER: u32 = 4;

// How do we do colormapping?
const COLOR_MAPPER_OFF: u32 = 1;
Expand Down Expand Up @@ -232,12 +233,18 @@ mod gpu_data {
} = &rectangle.colormapped_texture;

let sample_type = match texture_info.sample_type {
wgpu::TextureSampleType::Float { .. } => SAMPLE_TYPE_FLOAT,
wgpu::TextureSampleType::Float { .. } => {
if super::is_float_filterable(texture_format) {
SAMPLE_TYPE_FLOAT_FILTER
} else {
SAMPLE_TYPE_FLOAT_NOFILTER
}
}
wgpu::TextureSampleType::Depth => {
return Err(RectangleError::DepthTexturesNotSupported);
}
wgpu::TextureSampleType::Sint => SAMPLE_TYPE_SINT,
wgpu::TextureSampleType::Uint => SAMPLE_TYPE_UINT,
wgpu::TextureSampleType::Sint => SAMPLE_TYPE_SINT_NOFILTER,
wgpu::TextureSampleType::Uint => SAMPLE_TYPE_UINT_NOFILTER,
};

let mut colormap_function = 0;
Expand Down Expand Up @@ -378,14 +385,19 @@ impl RectangleDrawData {
));
}

// We set up three texture sources, then instruct the shader to read from at most one of them.
let mut texture_float = ctx.texture_manager_2d.zeroed_texture_float().handle;
// We set up several texture sources, then instruct the shader to read from at most one of them.
let mut texture_float_filterable = ctx.texture_manager_2d.zeroed_texture_float().handle;
let mut texture_float_nofilter = ctx.texture_manager_2d.zeroed_texture_float().handle;
let mut texture_sint = ctx.texture_manager_2d.zeroed_texture_sint().handle;
let mut texture_uint = ctx.texture_manager_2d.zeroed_texture_uint().handle;

match texture_description.sample_type {
wgpu::TextureSampleType::Float { .. } => {
texture_float = texture.handle;
if is_float_filterable(&texture_format) {
texture_float_filterable = texture.handle;
} else {
texture_float_nofilter = texture.handle;
}
}
wgpu::TextureSampleType::Depth => {
return Err(RectangleError::DepthTexturesNotSupported);
Expand Down Expand Up @@ -422,10 +434,11 @@ impl RectangleDrawData {
entries: smallvec![
uniform_buffer,
BindGroupEntry::Sampler(sampler),
BindGroupEntry::DefaultTextureView(texture_float),
BindGroupEntry::DefaultTextureView(texture_float_nofilter),
BindGroupEntry::DefaultTextureView(texture_sint),
BindGroupEntry::DefaultTextureView(texture_uint),
BindGroupEntry::DefaultTextureView(colormap_texture),
BindGroupEntry::DefaultTextureView(texture_float_filterable),
],
layout: rectangle_renderer.bind_group_layout,
},
Expand Down Expand Up @@ -483,12 +496,12 @@ impl Renderer for RectangleRenderer {
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
// float texture:
// float textures without filtering (e.g. R32Float):
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
sample_type: wgpu::TextureSampleType::Float { filterable: false },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
Expand Down Expand Up @@ -527,6 +540,17 @@ impl Renderer for RectangleRenderer {
},
count: None,
},
// float textures with filtering (e.g. Rgba8UnormSrgb):
wgpu::BindGroupLayoutEntry {
binding: 6,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
],
},
);
Expand Down Expand Up @@ -653,3 +677,11 @@ impl Renderer for RectangleRenderer {
]
}
}

fn is_float_filterable(format: &wgpu::TextureFormat) -> bool {
format
.describe()
.guaranteed_format_features
.flags
.contains(wgpu::TextureFormatFeatureFlags::FILTERABLE)
}
8 changes: 6 additions & 2 deletions crates/re_viewer/src/misc/caches/tensor_image_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,12 @@ fn color_tensor_as_color_image(tensor: &Tensor) -> anyhow::Result<ColorImage> {
Ok(ColorImage { size, pixels })
}

(_depth, dtype) => {
anyhow::bail!("Don't know how to turn a tensor of shape={:?} and dtype={dtype:?} into a color image", tensor.shape)
(_depth, _dtype) => {
anyhow::bail!(
"Don't know how to turn a tensor of shape={:?} and dtype={:?} into a color image",
tensor.shape,
tensor.dtype()
)
}
}
}
Expand Down
Loading

1 comment on commit 1211ef4

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rust Benchmark

Benchmark suite Current: 1211ef4 Previous: 437fe2b Ratio
datastore/num_rows=1000/num_instances=1000/packed=false/insert/default 2983729 ns/iter (± 236803) 2856891 ns/iter (± 36142) 1.04
datastore/num_rows=1000/num_instances=1000/packed=false/latest_at/default 373 ns/iter (± 2) 371 ns/iter (± 1) 1.01
datastore/num_rows=1000/num_instances=1000/packed=false/latest_at_missing/primary/default 260 ns/iter (± 2) 262 ns/iter (± 0) 0.99
datastore/num_rows=1000/num_instances=1000/packed=false/latest_at_missing/secondaries/default 421 ns/iter (± 5) 422 ns/iter (± 0) 1.00
datastore/num_rows=1000/num_instances=1000/packed=false/range/default 2906501 ns/iter (± 190324) 3004714 ns/iter (± 39140) 0.97
datastore/num_rows=1000/num_instances=1000/gc/default 2424905 ns/iter (± 30247) 2373232 ns/iter (± 2831) 1.02
mono_points_arrow/generate_message_bundles 25537096 ns/iter (± 1841190) 27253563 ns/iter (± 691392) 0.94
mono_points_arrow/generate_messages 112833858 ns/iter (± 2059820) 113629539 ns/iter (± 886175) 0.99
mono_points_arrow/encode_log_msg 144913631 ns/iter (± 2403506) 144366861 ns/iter (± 785098) 1.00
mono_points_arrow/encode_total 285462508 ns/iter (± 2631675) 282627296 ns/iter (± 1104580) 1.01
mono_points_arrow/decode_log_msg 177851559 ns/iter (± 1902386) 177347684 ns/iter (± 665930) 1.00
mono_points_arrow/decode_message_bundles 57587572 ns/iter (± 2094906) 58626705 ns/iter (± 923501) 0.98
mono_points_arrow/decode_total 234335805 ns/iter (± 2962071) 234974619 ns/iter (± 1075048) 1.00
mono_points_arrow_batched/generate_message_bundles 20310483 ns/iter (± 1874770) 19349150 ns/iter (± 634073) 1.05
mono_points_arrow_batched/generate_messages 4107772 ns/iter (± 394221) 4025677 ns/iter (± 60601) 1.02
mono_points_arrow_batched/encode_log_msg 1374643 ns/iter (± 12370) 1390614 ns/iter (± 5633) 0.99
mono_points_arrow_batched/encode_total 27449206 ns/iter (± 2329110) 26255579 ns/iter (± 930136) 1.05
mono_points_arrow_batched/decode_log_msg 779914 ns/iter (± 7853) 775674 ns/iter (± 1299) 1.01
mono_points_arrow_batched/decode_message_bundles 7658651 ns/iter (± 443227) 7629545 ns/iter (± 96788) 1.00
mono_points_arrow_batched/decode_total 8709255 ns/iter (± 732431) 8521246 ns/iter (± 166257) 1.02
batch_points_arrow/generate_message_bundles 194632 ns/iter (± 1078) 238921 ns/iter (± 477) 0.81
batch_points_arrow/generate_messages 5102 ns/iter (± 56) 5076 ns/iter (± 11) 1.01
batch_points_arrow/encode_log_msg 260024 ns/iter (± 3017) 259069 ns/iter (± 829) 1.00
batch_points_arrow/encode_total 488796 ns/iter (± 8339) 532449 ns/iter (± 1772) 0.92
batch_points_arrow/decode_log_msg 213170 ns/iter (± 2039) 210106 ns/iter (± 440) 1.01
batch_points_arrow/decode_message_bundles 1924 ns/iter (± 19) 1853 ns/iter (± 5) 1.04
batch_points_arrow/decode_total 223283 ns/iter (± 2848) 218044 ns/iter (± 1698) 1.02
arrow_mono_points/insert 2336885135 ns/iter (± 4319013) 2283918228 ns/iter (± 5900823) 1.02
arrow_mono_points/query 1189206 ns/iter (± 27197) 1182176 ns/iter (± 10820) 1.01
arrow_batch_points/insert 1155818 ns/iter (± 10281) 1154299 ns/iter (± 1732) 1.00
arrow_batch_points/query 14331 ns/iter (± 84) 14749 ns/iter (± 124) 0.97
arrow_batch_vecs/insert 26414 ns/iter (± 363) 26396 ns/iter (± 46) 1.00
arrow_batch_vecs/query 325474 ns/iter (± 1206) 326363 ns/iter (± 785) 1.00
tuid/Tuid::random 34 ns/iter (± 0) 34 ns/iter (± 0) 1

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.