-
Notifications
You must be signed in to change notification settings - Fork 27
Arithmetic Coding
We know that Huffman coding is the best per-symbol encoder. But, there is a fundamental limitation with being per-symbol, which is that the overhead can be 1 bit per symbol over entropy. This can be very significant. One simple case of this is:
p = {A: 0.1, B: 0.9}, H(p) = 0.47
Even though the entropy is half a bit, we know that no per-symbol coder can do better than 1
bit per symbol. Thus, we are effectively going to use twice as many bits to encode the source if we use Huffman coding.
One possible solution to our problem is making a source out of multiple symbols.. for example if we make a new source out of tuples we get:
p2 = {AA: 0.01, AB: 0.09, BA: 0.09, BB: 0.81}
This should perform better as we have a max-overhead of 1
bit for 2
symbols. In this case, Huffman coding tuples gives us the average bits/symbol as 0.65
which is better than 1.0
but still not equal to entropy of 0.47
. We can of course go on and continue and combine three symbols etc. and eventually Huffman coding will converge to the entropy. However, this is computationally in-efficient as the number of symbols increases exponentially.
Arithmetic coding solves this problem, and essentially has an overhead which is theoretically just 2
bits over the optimal codelength for the entire sequence. Even practically, Arithmetic coding and its variants achieve incredible computational and compression performance.
Along with with computationally efficient there are lots of desirable properties which arithmetic coding offers:
-
Adaptability
-> Arithmetic coding can use different distributions for different symbols, and still be optimal -
model/coding separation
-> As Arithmetic coding is optimal in essentially any scenario, it separates the compression problem into two parts.. coming up with a model .. or a distribution for the data, and secondly encoding data using the distribution. Because of optimality of arithmetic coding, for a lot of purposes we can focus on the task of thinking about the model for the data
The core idea of Arithmetic encoding can be explained using the following two steps:
-
STEP I: Represent the entire input sequence as an interval
[low, high)
within the interval[0,1]
|----------------[.......)------------|
0 low high 1
-
Step II: Represent the
[low, high)
range using a single floating point number (thestate
) within the range which has a short binary expansion. For example, forlow = 0.1, high = 0.6
. One possibility for the state isstate = 0.25 ~ 0.01b
. As thestate = 0.01b
the final arithmetic code for the input becomes01
log <= state < high
The decoder then has to perform the reverse operation to infer the input data. We will next try to understand how to decide the interval and then how to represent this interval using bits.
NOTE: For this discussion we are going to assume we have infinite precision, and that we can represent any floating point number exactly
The process the get the [low, high)
range corresponding to the input is also known as the cake cutting method. It will be obvious why so!
We start with low=0, high=1
, and then proceed to recursively shrink the range into a smaller range based on the input symbols. The code block to do this is given below: Lets take a concrete example, as this would be much more clear that way.
from core.prob_dist import ProbabilityDist
# define a sample distribution
prob = ProbabilityDist({'A': 0.2, 'B': 0.4, 'C': 0.4})
# define a sample input
data = DataBlock(['B', 'A', 'C', 'B'])
The code block to recursively compute the low, high
range is given below:
# initalize low, high values
low, high = 0.0, 1.0
# recursively shrink the range
for s in data.data_list:
rng = (high - low)
low = low + prob.cumulative_prob_dict[s]*rng
high = low + prob.probability(s)*rng
As one can see from the python snippet above:
- We split the current range
rng = (high-low)
into slices which are proportional to the probability of the symbols in the distribution. For example, initially whenlow=0.0, high=1.0
, the slices are[0.0, 0.2), [0.2, 0.6), [0.6, 1.0)
, corresponding toA, B, C
respectively. - We continue this process until we are done will all the alphabets. The progression of the
low, high
values for the sample input are shown below
initial range: low 0.0000, high: 1.0000
0: symbol B, low 0.2000, high: 0.6000
1: symbol A, low 0.2000, high: 0.2800
2: symbol C, low 0.2480, high: 0.2800
3: symbol B, low 0.2544, high: 0.2672
Notice that the final range [0.2544, 0.2672)
losslessly represents the entire input sequence ['B', 'A', 'C', 'B']
. Thus, if the decoder knows this range, it can recover the entire sequence.
One way to communicate the range information is to communicate a number which lies inside the range [0.2544, 0.2672)
. One way to achieve this is as follows:
- We know that
(low + high)/2
lies in the interval[low, high)
. Thus we want to communicate this floating point number. Let us call this themid
.
mid = (low + high)/2`
- Floating point
mid
can actually have infinite bits in binary (for example1/3 in binary = b0.010101...
). So, it can be impossible to do this. For example in our example:
from utils.bitarray_utils import float_to_bitarrays
# low ~ 0.2544, high ~ 0.2672
mid = (low + high)/2 #mid = 0.26080000000000003
_,float_bitarray = float_to_bitarrays(mid, max_precision=20)
# mid = b0.01000010110000111100...
- Note that we if we truncate the binary expansion of
mid
, then the resulting floating point value will be close tomid
but will now be feasible to be represented. Thus, the the final step is to truncate the binary expansion ofmid
to sufficient number of bits so that the resulting fraction (lets call it thestate
) will still lie inside[low, high)
. We also want to be mindful not to use too many bits, as after all we want to compress the input :).
If we truncate the binary expansion of mid
to k
bits after the decimal point, then it is clear than the resulting fraction mid_k
follows:
(mid - mid_k) < 2^{-k}
Thus we can calculate k
so that mid_k
lies in the range [low,high)
.
(mid - mid_k) < 2^{-k} <= |high - low|/2
which implies:
k >= -log2(|high - low|) + 1
In our example, we can thus calculate k
and the state
as:
import numpy as np
#low ~ 0.2544, high ~ 0.2672
k = np.ceil(-np.log2(high - low) + 1))
# get the truncated mid-point
_,code = float_to_bitarrays(mid, max_precision=int(k))
# >> code = bitarray('01000010')
# state = b0.01000010
Thus, the Arithmetic encoder has encoded the sequence ['B', 'A', 'C', 'B']
as 01000010
, which is just 8
bits! The decoder operations should be clear, but we will explicitly look into that in the next section. One more question here however if the Arithmetic coder explained above is any good. i.e. how well does it compress data?
NOTE: I found these sequence of lectures on arithmetic coding extremely useful in understanding the intricacies.
- Introduction to Arithmetic coding A good introduction on what are the key benefits of arithmetic coding
- Arithmetic coding examples: Example 1, Example 2: the examples are quite useful to demonstrate what happens "theoretically" in arithmetic coding
- Why arithmetic coding intervals need to be contained: explains how we
- Rescaling operation for AEC: Explains the core intuition on how the rescaling occurs in arithmetic coding