From 438fe00978abc35941a92329f3e3c412c91e3d62 Mon Sep 17 00:00:00 2001 From: dweiller <4678790+dweiller@users.noreplay.github.com> Date: Mon, 23 Jan 2023 16:26:03 +1100 Subject: [PATCH] std.compress.zstandard: clean up api --- lib/std/compress/zstandard.zig | 1 + lib/std/compress/zstandard/decompress.zig | 207 ++++++++++++---------- lib/std/compress/zstandard/types.zig | 6 +- 3 files changed, 113 insertions(+), 101 deletions(-) diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index c1e9cef58c66..d83b3a333660 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -1,6 +1,7 @@ const std = @import("std"); pub const decompress = @import("zstandard/decompress.zig"); +pub usingnamespace @import("zstandard/types.zig"); test "decompression" { const uncompressed = @embedFile("testdata/rfc8478.txt"); diff --git a/lib/std/compress/zstandard/decompress.zig b/lib/std/compress/zstandard/decompress.zig index 22ed22c0de15..59268dad5717 100644 --- a/lib/std/compress/zstandard/decompress.zig +++ b/lib/std/compress/zstandard/decompress.zig @@ -3,10 +3,10 @@ const assert = std.debug.assert; const types = @import("types.zig"); const frame = types.frame; -const Literals = types.compressed_block.Literals; -const Sequences = types.compressed_block.Sequences; +const LiteralsSection = types.compressed_block.LiteralsSection; +const SequencesSection = types.compressed_block.SequencesSection; const Table = types.compressed_block.Table; -const RingBuffer = @import("RingBuffer.zig"); +pub const RingBuffer = @import("RingBuffer.zig"); const readInt = std.mem.readIntLittle; const readIntSlice = std.mem.readIntSliceLittle; @@ -55,7 +55,7 @@ pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWrit }; } -const DecodeState = struct { +pub const DecodeState = struct { repeat_offsets: [3]u32, offset: StateData(8), @@ -70,7 +70,7 @@ const DecodeState = struct { literal_stream_reader: ReverseBitReader, literal_stream_index: usize, - huffman_tree: ?Literals.HuffmanTree, + huffman_tree: ?LiteralsSection.HuffmanTree, literal_written_count: usize, @@ -84,7 +84,55 @@ const DecodeState = struct { }; } - fn readInitialState(self: *DecodeState, bit_reader: anytype) !void { + pub fn prepare( + self: *DecodeState, + src: []const u8, + literals: LiteralsSection, + sequences_header: SequencesSection.Header, + ) !usize { + if (literals.huffman_tree) |tree| { + self.huffman_tree = tree; + } else if (literals.header.block_type == .treeless and self.huffman_tree == null) { + return error.TreelessLiteralsFirst; + } + + switch (literals.header.block_type) { + .raw, .rle => {}, + .compressed, .treeless => { + self.literal_stream_index = 0; + switch (literals.streams) { + .one => |slice| try self.initLiteralStream(slice), + .four => |streams| try self.initLiteralStream(streams[0]), + } + }, + } + + if (sequences_header.sequence_count > 0) { + var bytes_read = try self.updateFseTable( + src, + .literal, + sequences_header.literal_lengths, + ); + + bytes_read += try self.updateFseTable( + src[bytes_read..], + .offset, + sequences_header.offsets, + ); + + bytes_read += try self.updateFseTable( + src[bytes_read..], + .match, + sequences_header.match_lengths, + ); + self.fse_tables_undefined = false; + + return bytes_read; + } + return 0; + } + + pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void { self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log); self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log); self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log); @@ -130,7 +178,7 @@ const DecodeState = struct { self: *DecodeState, src: []const u8, comptime choice: DataType, - mode: Sequences.Header.Mode, + mode: SequencesSection.Header.Mode, ) !usize { const field_name = @tagName(choice); switch (mode) { @@ -213,7 +261,13 @@ const DecodeState = struct { }; } - fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void { + fn executeSequenceSlice( + self: *DecodeState, + dest: []u8, + write_pos: usize, + literals: LiteralsSection, + sequence: Sequence, + ) !void { try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length); // TODO: should we validate offset against max_window_size? @@ -225,7 +279,12 @@ const DecodeState = struct { std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]); } - fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void { + fn executeSequenceRingBuffer( + self: *DecodeState, + dest: *RingBuffer, + literals: LiteralsSection, + sequence: Sequence, + ) !void { try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length); // TODO: check that ring buffer window is full enough for match copies const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length); @@ -234,11 +293,11 @@ const DecodeState = struct { for (copy_slice.second) |b| dest.writeAssumeCapacity(b); } - fn decodeSequenceSlice( + pub fn decodeSequenceSlice( self: *DecodeState, dest: []u8, write_pos: usize, - literals: Literals, + literals: LiteralsSection, bit_reader: anytype, last_sequence: bool, ) !usize { @@ -255,10 +314,10 @@ const DecodeState = struct { return sequence.match_length + sequence.literal_length; } - fn decodeSequenceRingBuffer( + pub fn decodeSequenceRingBuffer( self: *DecodeState, dest: *RingBuffer, - literals: Literals, + literals: LiteralsSection, bit_reader: anytype, last_sequence: bool, ) !usize { @@ -280,7 +339,7 @@ const DecodeState = struct { return sequence.match_length + sequence.literal_length; } - fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void { + fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void { self.literal_stream_index += 1; try self.initLiteralStream(literals.streams.four[self.literal_stream_index]); } @@ -290,7 +349,7 @@ const DecodeState = struct { try self.literal_stream_reader.init(bytes); } - fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void { + pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void { if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; switch (literals.header.block_type) { .raw => { @@ -310,7 +369,7 @@ const DecodeState = struct { // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; const huffman_tree = self.huffman_tree orelse unreachable; const max_bit_count = huffman_tree.max_bit_count; - const starting_bit_count = Literals.HuffmanTree.weightToBitCount( + const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, max_bit_count, ); @@ -345,7 +404,7 @@ const DecodeState = struct { }, .index => |index| { huffman_tree_index = index; - const bit_count = Literals.HuffmanTree.weightToBitCount( + const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( huffman_tree.nodes[index].weight, max_bit_count, ); @@ -359,7 +418,7 @@ const DecodeState = struct { } } - fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void { + pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void { if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength; switch (literals.header.block_type) { .raw => { @@ -378,7 +437,7 @@ const DecodeState = struct { // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4; const huffman_tree = self.huffman_tree orelse unreachable; const max_bit_count = huffman_tree.max_bit_count; - const starting_bit_count = Literals.HuffmanTree.weightToBitCount( + const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount( huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight, max_bit_count, ); @@ -413,7 +472,7 @@ const DecodeState = struct { }, .index => |index| { huffman_tree_index = index; - const bit_count = Literals.HuffmanTree.weightToBitCount( + const bit_count = LiteralsSection.HuffmanTree.weightToBitCount( huffman_tree.nodes[index].weight, max_bit_count, ); @@ -647,54 +706,6 @@ fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, return block_size; } -fn prepareDecodeState( - decode_state: *DecodeState, - src: []const u8, - literals: Literals, - sequences_header: Sequences.Header, -) !usize { - if (literals.huffman_tree) |tree| { - decode_state.huffman_tree = tree; - } else if (literals.header.block_type == .treeless and decode_state.huffman_tree == null) { - return error.TreelessLiteralsFirst; - } - - switch (literals.header.block_type) { - .raw, .rle => {}, - .compressed, .treeless => { - decode_state.literal_stream_index = 0; - switch (literals.streams) { - .one => |slice| try decode_state.initLiteralStream(slice), - .four => |streams| try decode_state.initLiteralStream(streams[0]), - } - }, - } - - if (sequences_header.sequence_count > 0) { - var bytes_read = try decode_state.updateFseTable( - src, - .literal, - sequences_header.literal_lengths, - ); - - bytes_read += try decode_state.updateFseTable( - src[bytes_read..], - .offset, - sequences_header.offsets, - ); - - bytes_read += try decode_state.updateFseTable( - src[bytes_read..], - .match, - sequences_header.match_lengths, - ); - decode_state.fse_tables_undefined = false; - - return bytes_read; - } - return 0; -} - pub fn decodeBlock( dest: []u8, src: []const u8, @@ -715,7 +726,7 @@ pub fn decodeBlock( const literals = try decodeLiteralsSection(src, &bytes_read); const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); - bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header); + bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header); var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { @@ -723,7 +734,7 @@ pub fn decodeBlock( var bit_stream: ReverseBitReader = undefined; try bit_stream.init(bit_stream_bytes); - try decode_state.readInitialState(&bit_stream); + try decode_state.readInitialFseState(&bit_stream); var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { @@ -780,7 +791,7 @@ pub fn decodeBlockRingBuffer( const literals = try decodeLiteralsSection(src, &bytes_read); const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read); - bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header); + bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header); var bytes_written: usize = 0; if (sequences_header.sequence_count > 0) { @@ -788,7 +799,7 @@ pub fn decodeBlockRingBuffer( var bit_stream: ReverseBitReader = undefined; try bit_stream.init(bit_stream_bytes); - try decode_state.readInitialState(&bit_stream); + try decode_state.readInitialFseState(&bit_stream); var i: usize = 0; while (i < sequences_header.sequence_count) : (i += 1) { @@ -928,7 +939,7 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header { }; } -pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals { +pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection { // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) var bytes_read: usize = 0; const header = decodeLiteralsHeader(src, &bytes_read); @@ -936,7 +947,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals .raw => { const stream = src[bytes_read .. bytes_read + header.regenerated_size]; consumed_count.* += header.regenerated_size + bytes_read; - return Literals{ + return LiteralsSection{ .header = header, .huffman_tree = null, .streams = .{ .one = stream }, @@ -945,7 +956,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals .rle => { const stream = src[bytes_read .. bytes_read + 1]; consumed_count.* += 1 + bytes_read; - return Literals{ + return LiteralsSection{ .header = header, .huffman_tree = null, .streams = .{ .one = stream }, @@ -966,7 +977,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals const stream = src[bytes_read .. bytes_read + total_streams_size]; bytes_read += total_streams_size; consumed_count.* += bytes_read; - return Literals{ + return LiteralsSection{ .header = header, .huffman_tree = huffman_tree, .streams = .{ .one = stream }, @@ -988,7 +999,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals consumed_count.* += total_streams_size + bytes_read; - return Literals{ + return LiteralsSection{ .header = header, .huffman_tree = huffman_tree, .streams = .{ .four = .{ @@ -1002,7 +1013,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals } } -fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanTree { +fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree { var bytes_read: usize = 0; bytes_read += 1; const header = src[0]; @@ -1094,7 +1105,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT weights[symbol_count - 1] = @intCast(u4, std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1); log.debug("weights[{d}] = {d}", .{ symbol_count - 1, weights[symbol_count - 1] }); - var weight_sorted_prefixed_symbols: [256]Literals.HuffmanTree.PrefixedSymbol = undefined; + var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined; for (weight_sorted_prefixed_symbols[0..symbol_count]) |_, i| { weight_sorted_prefixed_symbols[i] = .{ .symbol = @intCast(u8, i), @@ -1104,7 +1115,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT } std.sort.sort( - Literals.HuffmanTree.PrefixedSymbol, + LiteralsSection.HuffmanTree.PrefixedSymbol, weight_sorted_prefixed_symbols[0..symbol_count], weights, lessThanByWeight, @@ -1137,7 +1148,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT } } consumed_count.* += bytes_read; - const tree = Literals.HuffmanTree{ + const tree = LiteralsSection.HuffmanTree{ .max_bit_count = max_number_of_bits, .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1), .nodes = weight_sorted_prefixed_symbols, @@ -1148,8 +1159,8 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT fn lessThanByWeight( weights: [256]u4, - lhs: Literals.HuffmanTree.PrefixedSymbol, - rhs: Literals.HuffmanTree.PrefixedSymbol, + lhs: LiteralsSection.HuffmanTree.PrefixedSymbol, + rhs: LiteralsSection.HuffmanTree.PrefixedSymbol, ) bool { // NOTE: this function relies on the use of a stable sorting algorithm, // otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs; @@ -1157,11 +1168,11 @@ fn lessThanByWeight( return weights[lhs.symbol] < weights[rhs.symbol]; } -pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.Header { +pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSection.Header { // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks) const start = consumed_count.*; const byte0 = src[0]; - const block_type = @intToEnum(Literals.BlockType, byte0 & 0b11); + const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11); const size_format = @intCast(u2, (byte0 & 0b1100) >> 2); var regenerated_size: u20 = undefined; var compressed_size: ?u18 = null; @@ -1220,7 +1231,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He compressed_size, }, ); - return Literals.Header{ + return LiteralsSection.Header{ .block_type = block_type, .size_format = size_format, .regenerated_size = regenerated_size, @@ -1228,7 +1239,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He }; } -fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Header { +pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header { var sequence_count: u24 = undefined; var bytes_read: usize = 0; @@ -1237,7 +1248,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea bytes_read += 1; log.debug("decoded sequences header '{}': sequence count = 0", .{std.fmt.fmtSliceHexUpper(src[0..bytes_read])}); consumed_count.* += bytes_read; - return Sequences.Header{ + return SequencesSection.Header{ .sequence_count = 0, .offsets = undefined, .match_lengths = undefined, @@ -1258,9 +1269,9 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea bytes_read += 1; consumed_count.* += bytes_read; - const matches_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00001100) >> 2); - const offsets_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00110000) >> 4); - const literal_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b11000000) >> 6); + const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2); + const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4); + const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6); log.debug("decoded sequences header '{}': (sc={d},o={s},m={s},l={s})", .{ std.fmt.fmtSliceHexUpper(src[0..bytes_read]), sequence_count, @@ -1270,7 +1281,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea }); if (compression_modes & 0b11 != 0) return error.ReservedBitSet; - return Sequences.Header{ + return SequencesSection.Header{ .sequence_count = sequence_count, .offsets = offsets_mode, .match_lengths = matches_mode, @@ -1428,25 +1439,25 @@ const ReversedByteReader = struct { } }; -const ReverseBitReader = struct { +pub const ReverseBitReader = struct { byte_reader: ReversedByteReader, bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader), - fn init(self: *ReverseBitReader, bytes: []const u8) !void { + pub fn init(self: *ReverseBitReader, bytes: []const u8) !void { self.byte_reader = ReversedByteReader.init(bytes); self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader()); while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {} } - fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { + pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U { return self.bit_reader.readBitsNoEof(U, num_bits); } - fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { + pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U { return try self.bit_reader.readBits(U, num_bits, out_bits); } - fn alignToByte(self: *@This()) void { + pub fn alignToByte(self: *@This()) void { self.bit_reader.alignToByte(); } }; @@ -1514,7 +1525,7 @@ fn dumpFseTable(prefix: []const u8, table: []const Table.Fse) void { } } -fn dumpHuffmanTree(tree: Literals.HuffmanTree) void { +fn dumpHuffmanTree(tree: LiteralsSection.HuffmanTree) void { log.debug("Huffman tree: max bit count = {}, symbol count = {}", .{ tree.max_bit_count, tree.symbol_count_minus_one + 1 }); for (tree.nodes[0 .. tree.symbol_count_minus_one + 1]) |node| { log.debug("symbol = {[symbol]d}, prefix = {[prefix]d}, weight = {[weight]d}", node); diff --git a/lib/std/compress/zstandard/types.zig b/lib/std/compress/zstandard/types.zig index edac66f68638..f703dc29ebb3 100644 --- a/lib/std/compress/zstandard/types.zig +++ b/lib/std/compress/zstandard/types.zig @@ -52,7 +52,7 @@ pub const frame = struct { }; pub const compressed_block = struct { - pub const Literals = struct { + pub const LiteralsSection = struct { header: Header, huffman_tree: ?HuffmanTree, streams: Streams, @@ -119,8 +119,8 @@ pub const compressed_block = struct { } }; - pub const Sequences = struct { - header: Sequences.Header, + pub const SequencesSection = struct { + header: SequencesSection.Header, literals_length_table: Table, offset_table: Table, match_length_table: Table,