Skip to content

Commit

Permalink
std.compress.zstandard: fix error sets for streaming API
Browse files Browse the repository at this point in the history
  • Loading branch information
dweiller committed Feb 21, 2023
1 parent c6ef83e commit 0f2439b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
29 changes: 23 additions & 6 deletions lib/std/compress/zstandard/decode/huffman.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@ pub const Error = error{
MalformedHuffmanTree,
MalformedFseTable,
MalformedAccuracyLog,
EndOfStream,
};

fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize {
fn decodeFseHuffmanTree(
source: anytype,
compressed_size: usize,
buffer: []u8,
weights: *[256]u4,
) !usize {
var stream = std.io.limitedReader(source, compressed_size);
var bit_reader = readers.bitReader(stream.reader());

var entries: [1 << 6]Table.Fse = undefined;
const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
error.EndOfStream => return error.MalformedFseTable,
else => |e| return e,
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);

Expand All @@ -46,15 +51,21 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
};
const accuracy_log = std.math.log2_int_ceil(usize, table_size);

const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse
return error.MalformedHuffmanTree;
var huff_data = src[start_index..compressed_size];
var huff_bits: readers.ReverseBitReader = undefined;
huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;

return assignWeights(&huff_bits, accuracy_log, &entries, weights);
}

fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize {
fn assignWeights(
huff_bits: *readers.ReverseBitReader,
accuracy_log: usize,
entries: *[1 << 6]Table.Fse,
weights: *[256]u4,
) !usize {
var i: usize = 0;
var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
Expand Down Expand Up @@ -173,7 +184,10 @@ fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffm
return tree;
}

pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree {
pub fn decodeHuffmanTree(
source: anytype,
buffer: []u8,
) (@TypeOf(source).Error || Error)!LiteralsSection.HuffmanTree {
const header = try source.readByte();
var weights: [256]u4 = undefined;
const symbol_count = if (header < 128)
Expand All @@ -185,7 +199,10 @@ pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.Huffman
return buildHuffmanTree(&weights, symbol_count);
}

pub fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) Error!LiteralsSection.HuffmanTree {
pub fn decodeHuffmanTreeSlice(
src: []const u8,
consumed_count: *usize,
) Error!LiteralsSection.HuffmanTree {
if (src.len == 0) return error.MalformedHuffmanTree;
const header = src[0];
var bytes_read: usize = 1;
Expand Down
6 changes: 4 additions & 2 deletions lib/std/compress/zstandard/decompress.zig
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
/// - `error.EndOfStream` if `source` contains fewer than 4 bytes
/// - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the
/// reserved bits are set
pub fn decodeFrameHeader(source: anytype) HeaderError!FrameHeader {
pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader {
const magic = try source.readIntLittle(u32);
const frame_type = try frameType(magic);
switch (frame_type) {
Expand Down Expand Up @@ -596,7 +596,9 @@ pub fn frameWindowSize(header: ZstandardHeader) ?u64 {
/// Errors returned:
/// - `error.ReservedBitSet` if any of the reserved bits of the header are set
/// - `error.EndOfStream` if `source` does not contain a complete header
pub fn decodeZstandardHeader(source: anytype) error{ EndOfStream, ReservedBitSet }!ZstandardHeader {
pub fn decodeZstandardHeader(
source: anytype,
) (@TypeOf(source).Error || error{ EndOfStream, ReservedBitSet })!ZstandardHeader {
const descriptor = @bitCast(ZstandardHeader.Descriptor, try source.readByte());

if (descriptor.reserved) return error.ReservedBitSet;
Expand Down

0 comments on commit 0f2439b

Please sign in to comment.