Skip to content

Latest commit

 

History

History
396 lines (350 loc) · 14 KB

main.org

File metadata and controls

396 lines (350 loc) · 14 KB

main

#+AUTO-TANGLE: t

Imports

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from importlib import reload
import os

import utils.color as color
import utils.graph as graph
import utils.gft as gft
import utils.morton as morton
import utils.clustering as clustering
import utils.quantization as quantization
import utils.bitcoding as bitcoding
import utils.ply as ply

Main function

def run(method='kmeans', sequence='loot', frame_number=1, qsteps=[16], bsize=16, clusters_count=1500, lambd=0.2, colorspace='lab', qstep_centers=10, ref_iterations=1, ref_method='none', beta=2.0, bitstream_directory='/tmp', export_ply_blocks=False, decoder_match_test=False):
    # initialize results array
    results = np.empty((0, 11)) # 11 is the number of measurements we make
    # check that we don't do refinement for lambda values of 0.0 (makes it easier to start jobs this way because of the zipping of arguments)
    if lambd == 0.0 and ref_iterations > 0:
        print(f"We don't continue if lambda is 0.0 and ref_iterations is bigger than 0")
        return np.array([])
    # process parameters
    <<parameters>>
    <<printParameters>>

    ## Encoder
    ### Load/prepare data
    <<data>>

    ### Partitioning
    <<partitioning>>

    ### Centers refinement
    <<centers_refinement>>

    ### Export PLY for blocks visualization
    <<exportBlocks>>

    ### GFT per block
    <<gft>>

    for qstep in qsteps:
        print(f'Processing for qstep={qstep}...')
        ### Quantize
        <<quantize>>

        ### Sort
        <<sort>>

        ### Bit coding
        <<bitCoding>>

        ## Decoder
        if decoder_match_test:
            print(f'Encoder/decoder check...')
            ### Inverse GFT (necessary only for enc/dec match test)
            <<inverseGFT>>

            _, A_quant_dec = decode(
                V=V, N=N, method=method, qstep=qstep, bsize=bsize, clusters_count=clusters_count, lambd=lambd, colorspace=colorspace, qstep_centers=qstep_centers, bitstream_directory=bitstream_directory
            )
            ### Encoder-decoder match check
            print(f'Encoder-decoder match: { np.sum(np.abs(A_quant - A_quant_dec)) }')

        # append the measurements for the current qstep to the results array
        current_result = np.array([qstep, YPSNR_coeff, bs_total, bs_coeffs, bs_centers, bs_dupes, dupes_count, N, t['partitioning'], t['gft'], t['centers_ref']])
        results = np.concatenate((results, np.array(current_result)[np.newaxis,:]))

    return results

# default call when script is called on its own
if __name__ == "__main__":
    run(method='kmeans')

Decoder function

def decode(
        V, N, method='kmeans', qstep=16, bsize=16, clusters_count=1500, lambd=0.3, colorspace='lab', qstep_centers=10, bitstream_directory='/tmp'
):
    ### Bit decoding
    <<bitDecoding>>

    ### Partitioning
    <<partitioning_dec>>

    ### Create GFT block Matrices and frequencies
    <<GFTBlocksFreqs>>

    ### Inverse sorting
    <<inverseSort>>

    ### Inverse quantization
    <<inverseQuant>>

    ### Inverse GFT
    <<inverseGftDec>>

    return V_ordered_dec, A_quant_dec

Parameters

Process parameters

if sequence == 'loot':
    start_frame = 999
elif sequence == 'longdress':
    start_frame = 1050
elif sequence == 'soldier':
    start_frame = 535
elif sequence == 'redandblack':
    start_frame = 1449

frame = start_frame + frame_number
filename = f'/path/to/8iVFBv2/{sequence}/Ply/{sequence}_vox10_{frame:04d}.ply'

Print parameters

print(f'Processing sequence {sequence}, frame number {frame_number} of 300\n'
      f'using {bitstream_directory} as temporary bitstream directory\n'
      f'using parameters:\n'
      f'quantize step sizes: {*qsteps,}\n'
      f'method: {method}')
if method == 'octree':
    print(f'block size: {bsize}')
elif method == 'kmeans':
    print(f'clusters count: {clusters_count}\n'
          f'lambda parameter: {lambd}\n'
          f'colorspace: {colorspace}\n'
          f'centers quantization step: {qstep_centers}\n'
          f'refinement method: {ref_method}\n'
          f'beta parameter: {beta}\n'
          f'refinement iterations: {ref_iterations}\n')
if decoder_match_test:
    print('Encoder/decoder match will be checked!\n')

Encoder

Load/prepare data

print(f'Loading pointcloud data...', end='', flush=True)
# load vertices and attributes
V, A_rgb, N, bitresolution = ply.read(filename)
print(f'loaded {N} points.')

# color conversion from RGB to YUV
A = color.rgb_to_yuv(A_rgb)

# check Morton order
# if morton.check_morton3D(x=V[:, 2], y=V[:, 1], z=V[:, 0]):
    # print('Morton order correct')

Partitioning

t = {}
print(f'Partitioning pointcloud...', end='', flush=True)
t['partitioning'] = time.time()
if (method == 'octree'):
    idx_start, idx_stop, N_block = graph.block_indices(V=V, bsize=bsize)
    V_ordered = V
    A_ordered = A

elif (method == 'kmeans'):
    emulate_dec = True
    centers_encoder, labels_encoder, centers_decoder, labels_decoder, V_ordered, A_ordered, idx_start, idx_stop, N_block = clustering.kmeans_encoder(
        V=V, A=A, A_rgb=A_rgb, partition_count=clusters_count, lam_part=lambd, colorspace=colorspace, qstep_centers=qstep_centers, emulate_dec=emulate_dec
    )
t['partitioning'] = time.time() - t['partitioning']
print(f'done')

Centers refinement

t['centers_ref'] = time.time()

if ref_method == 'none' or ref_iterations < 1:
    print('No centers refinement will be performed.')
    centers_decoder_refined = centers_decoder
else:
    print('Refining centers...', end='', flush=True)
    # make centers 6-dimensional for the VA method
    if ref_method == 'VA':
        centers_ref_VA_init = np.zeros((centers_decoder.shape[0], 6))
        centers_ref_VA_init[:, 0:3] = centers_decoder

        A_lab = color.rgb_to_lab(rgb=A_rgb)
        for cluster_id in range(centers_decoder.shape[0]):
            centers_ref_VA_init[cluster_id, 3:] = lambd * np.mean(A_lab[labels_decoder==cluster_id, :], axis=0)

        centers_decoder = centers_ref_VA_init

    if ref_method in ['VA', 'weight', 'weight1']:
        centers_decoder_refined, labels_refined, V_ordered, A_ordered, idx_start, idx_stop, N_block, dist_refined = clustering.refine_centers(
            V=V, A_yuv=A, A_rgb=A_rgb, centers_init=centers_decoder, labels_init=labels_decoder, N_iter=ref_iterations, method=ref_method, beta=beta, lam_part=lambd, scaling=False
        )
    print('done')

t['centers_ref'] = time.time() - t['centers_ref']

GFT per block

print('Calculate GFT...')
t['gft'] = time.time()
Q = np.ones((N, 1))

Ahat, res, GFT_blocks, Gfreq_blocks = gft.transform_block_gft(
    V=V_ordered, A=A_ordered, Q=Q, idx_start=idx_start, idx_stop=idx_stop)
t['gft'] = time.time() - t['gft']

Quantize GFT coefficients

print('Quantizing GFT coefficients...', end='', flush=True)
Ahat_quant, Ahat_quant_idx = quantization.quantize(x=Ahat, qstep=qstep)
print('done')

YPSNR_coeff = color.YPSNR(Ahat, Ahat_quant, N)
print('YPSNR|coeff={:2.4f} dB'.format(YPSNR_coeff))

Inverse GFT (only necessary for decoder check)

if decoder_match_test:
    print('Perform inverse GFT (for decoder check)...')
    A_quant = gft.itransform_block_gft(
        V=V_ordered, Ahat=Ahat_quant, Q=Q,
        idx_start=idx_start, idx_stop=idx_stop, GFT_blocks=GFT_blocks)

Sort coefficients for coding

sort_method = 'dc_subgraphs' # 'none' 'dc'

print('Sorting coefficients...', end='', flush=True)
Ahat_quant_idx_sorted, mask_lo, mask_hi, num_subgraphs_blocks = gft.sort_block_gft_coeffs(
    Ahat=Ahat_quant_idx, Gfreq_blocks=Gfreq_blocks,
    idx_start=idx_start, idx_stop=idx_stop, N_block= N_block,
    sort_method=sort_method)
print('done')

Bit coding

print('Bit coding coefficients (and centers)...', end='', flush=True)
# encode the number of unused/duplicated centers if we used kmeans
bs_dupes = 0
if method == 'kmeans':
    dupes_count = clusters_count - centers_decoder_refined.shape[0]
    bs_dupes = bitcoding.write_number_to_file(
        x=dupes_count, filename='dupes_count.bin', bitstream_directory=bitstream_directory
    )

# encode quantized coeficients indices
bs_coeffs = bitcoding.code_YUV(Ahat_quant_idx_sorted, N=N, bitstream_directory=bitstream_directory)

# encode centers if using kmeans and if lambda is different from 0
bs_centers = 0
if method == 'kmeans' and lambd != 0.0:
    # differential coding of the centers indices

    _, centers_decoder_refined_idx_int = quantization.quantize(centers_decoder_refined, qstep_centers)  # centers_decoder_refined is unsigned, np.diff() yields signed values
    centers_decoder_refined_idx_diff = np.vstack((
        centers_decoder_refined_idx_int[0, :],  # save first entry
        np.diff(centers_decoder_refined_idx_int, axis=0)  # and then all the differences
    ))

    bs_centers = bitcoding.encode_rlgr(
        data=centers_decoder_refined_idx_diff.flatten('F').tolist(),
        filename=os.path.join(bitstream_directory, 'bitstream_centers.bin'),
        is_signed=1  # differences have a sign
    )

bs_total = bs_coeffs + bs_centers + bs_dupes
print('done')
print('Sorted: Coded Y,U,V separately: rate={:2.4f} bits/symbol'.format(bs_total/N))

Decoder

Bit decoding

print('Decoding coefficients and centers...', end='', flush=True)
if method == 'kmeans':
    clusters_count_dec = clusters_count - bitcoding.get_number_from_file(
        filename='dupes_count.bin', bitstream_directory=bitstream_directory
    )

# decode quantization indices
Ahat_quant_idx_sorted_dec = bitcoding.decode_YUV(N, bitstream_directory)

# decode cluster_centers
if method == 'kmeans':
    # Decode and reshape
    centers_decoder_refined_idx_diff_dec = bitcoding.decode_rlgr(
        filename=os.path.join(bitstream_directory, 'bitstream_centers.bin'), N=clusters_count_dec*3, is_signed=1
    )
    centers_decoder_refined_idx_diff_dec = centers_decoder_refined_idx_diff_dec.reshape((3, -1)).T

    # Invert np.diff()
    centers_decoder_refined_idx_int_dec = np.cumsum(centers_decoder_refined_idx_diff_dec, axis=0)
    # dequantize
    centers_decoder_refined_int_dec = quantization.dequantize(centers_decoder_refined_idx_int_dec, qstep_centers)
    clusters_dec = centers_decoder_refined_int_dec.astype(np.uint32)
print('done')

Partitioning

print('Partitioning pointcloud...', end='', flush=True)
if method == 'octree':
    idx_start_dec, idx_stop_dec, N_block_dec = graph.block_indices(V=V, bsize=bsize)

elif method == 'kmeans':
    labels_dec = clustering.labels_from_centers(
        V=V, centers=clusters_dec
    )

    idx_start_dec, idx_stop_dec, N_block_dec, V_ordered_dec, _ = clustering.block_indices(
        V=V, A=np.zeros_like(V), labels=labels_dec, clusters_count=clusters_count_dec
    )
print('done')

Create GFT block matrices and frequencies

print('Calculate GFT...')
Q_dec = np.ones((N, 1))
_, _, GFT_blocks_dec, Gfreq_blocks_dec = gft.transform_block_gft(
    V=V_ordered_dec, A=np.zeros_like(V), Q=Q_dec,
    idx_start=idx_start_dec, idx_stop=idx_stop_dec, ret_GFT=True
)

Inverse sort the coefficients

print('Inverse sort the coefficients...', end='', flush=True)
sort_method = 'dc_subgraphs' # 'none' 'dc'
mask_lo_dec, mask_hi_dec, _ = gft.create_sort_masks_subgraphs(
    Gfreq_blocks=Gfreq_blocks_dec,
    idx_start=idx_start_dec,
    idx_stop=idx_stop_dec,
    N_block=N_block_dec,
    sort_method=sort_method
)

Ahat_quant_idx_dec = gft.reverse_sort_block_gft_coeffs(
    Ahat_sort=Ahat_quant_idx_sorted_dec,
    mask_lo=mask_lo_dec,
    mask_hi=mask_hi_dec
)
print('done')

Dequantize the coefficients

print('Dequantize the coefficients...', end='', flush=True)
Ahat_quant_dec = quantization.dequantize(Ahat_quant_idx_dec, qstep)
print('done')

Inverse GFT

print('Inverse GFT...')
A_quant_dec = gft.itransform_block_gft(
    V=V_ordered_dec, Ahat=Ahat_quant_dec, Q=Q_dec,
    idx_start=idx_start_dec, idx_stop=idx_stop_dec, GFT_blocks=GFT_blocks_dec)

Encoder-decoder match check

# YPSNR
YPSNR_dec = color.YPSNR(A_ordered, A_quant_dec, N)  # color domain
YPSNR_coeff_dec = color.YPSNR(Ahat, Ahat_quant_dec, N) # coefficient domain
print('PSNR_Y={:2.4f} dB, PSNR_Y|coeff={:2.4f} dB'.format(YPSNR_dec, YPSNR_coeff_dec))

# match for sorted quantization indices
print(f'sorted quantization indices match={ np.sum(np.abs(Ahat_quant_idx_sorted - Ahat_quant_idx_sorted_dec)) }')
# match for unsorted quantization indices
print(f'unsorted quantization indices match={ np.sum(np.abs(Ahat_quant_idx - Ahat_quant_idx_dec)) }')
# match for quantized coefficients
print(f'quantized coefficients match={ np.sum(np.abs(Ahat_quant - Ahat_quant_dec)) }')
# final match for resulting distorted attributes
print(f'Encoder-decoder match: { np.sum(np.abs(A_quant - A_quant_dec)) }')