Skip to content

Commit

Permalink
Implement fast padding-less base64 encode/decode for python
Browse files Browse the repository at this point in the history
  • Loading branch information
kovidgoyal committed Jun 29, 2023
1 parent 37680aa commit 5c2ac8a
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 44 deletions.
5 changes: 3 additions & 2 deletions gen-apc-parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def generate(
payload_case = f'''
case PAYLOAD: {{
sz = screen->parser_buf_pos - pos;
const char *err = base64_decode(screen->parser_buf + pos, sz, payload, sizeof(payload), &g.payload_sz);
if (err != NULL) {{ REPORT_ERROR("Failed to parse {command_class} command payload with error: %s", err); return; }}
g.payload_sz = sizeof(payload);
if (!base64_decode32(screen->parser_buf + pos, sz, payload, &g.payload_sz)) {{
REPORT_ERROR("Failed to parse {command_class} command payload with error: %s", "output buffer for base64_decode too small"); return; }}
pos = screen->parser_buf_pos;
}}
break;
Expand Down
4 changes: 2 additions & 2 deletions kittens/transfer/ftc.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand
case reflect.String:
switch field.Tag.Get("encoding") {
case "base64":
b, err := base64.StdEncoding.DecodeString(serialized_val)
b, err := base64.RawStdEncoding.DecodeString(serialized_val)
if err != nil {
return fmt.Errorf("The field %#v has invalid base64 encoded value with error: %w", key, err)
}
Expand All @@ -260,7 +260,7 @@ func NewFileTransmissionCommand(serialized string) (ans *FileTransmissionCommand
case reflect.Slice:
switch val.Type().Elem().Kind() {
case reflect.Uint8:
b, err := base64.StdEncoding.DecodeString(serialized_val)
b, err := base64.RawStdEncoding.DecodeString(serialized_val)
if err != nil {
return fmt.Errorf("The field %#v has invalid base64 encoded value with error: %w", key, err)
}
Expand Down
125 changes: 125 additions & 0 deletions kitty/base64.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (C) 2023 Kovid Goyal <kovid at kovidgoyal.net>
*
* Distributed under terms of the GPL3 license.
*/

#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>

#ifndef B64_INPUT_BITSIZE
#define B64_INPUT_BITSIZE 8
#endif

#if B64_INPUT_BITSIZE == 8
#define INPUT_T uint8_t
#define inner_func base64_decode_inner8
#define decode_func base64_decode8
#define encode_func base64_encode8
#else
#define INPUT_T uint32_t
#define inner_func base64_decode_inner32
#define decode_func base64_decode32
#define encode_func base64_encode32
#endif

bool decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz);
bool encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding);
#ifndef B64_INCLUDED_ONCE
static inline size_t required_buffer_size_for_base64_decode(size_t src_sz) { return (src_sz / 4) * 3 + 4; }
static inline size_t required_buffer_size_for_base64_encode(size_t src_sz) { return (src_sz / 3) * 4 + 5; }
#endif

#ifndef B64_INCLUDED_ONCE
#define B64_INCLUDED_ONCE
#endif

#ifdef INCLUDE_BASE64_DEFINITIONS
#if B64_INPUT_BITSIZE == 8
// standard decoding using + and / with = being the padding character
static uint8_t b64_decoding_table[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, 0, 0, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
};
#endif

static void
inner_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, const size_t dest_sz) {
for (size_t i = 0, j = 0; i < src_sz;) {
uint32_t sextet_a = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_b = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_c = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t sextet_d = src[i] == '=' ? 0 & i++ : b64_decoding_table[src[i++] & 0xff];
uint32_t triple = (sextet_a << 3 * 6) + (sextet_b << 2 * 6) + (sextet_c << 1 * 6) + (sextet_d << 0 * 6);

if (j < dest_sz) dest[j++] = (triple >> 2 * 8) & 0xFF;
if (j < dest_sz) dest[j++] = (triple >> 1 * 8) & 0xFF;
if (j < dest_sz) dest[j++] = (triple >> 0 * 8) & 0xFF;
}
}

bool
decode_func(const INPUT_T *src, size_t src_sz, uint8_t *dest, size_t *dest_sz) {
while (src_sz && src[src_sz-1] == '=') src_sz--; // remove trailing padding
if (!src_sz) { *dest_sz = 0; return true; }
const size_t dest_capacity = *dest_sz;
size_t extra = src_sz % 4;
src_sz -= extra;
*dest_sz = (src_sz / 4) * 3;
if (*dest_sz > dest_capacity) return false;
if (src_sz) inner_func(src, src_sz, dest, *dest_sz);
if (extra > 1) {
INPUT_T buf[4] = {0};
for (size_t i = 0; i < extra; i++) buf[i] = src[src_sz+i];
dest += *dest_sz;
*dest_sz += extra - 1;
if (*dest_sz > dest_capacity) return false;
inner_func(buf, extra, dest, extra-1);
}
if (*dest_sz + 1 > dest_capacity) return false;
dest[*dest_sz] = 0; // ensure zero-terminated
return true;
}

#if B64_INPUT_BITSIZE == 8
static const unsigned char base64_table[65] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
#endif

bool
encode_func(const unsigned char *src, size_t src_len, unsigned char *out, size_t *out_len, bool add_padding) {
size_t required_len = required_buffer_size_for_base64_encode(src_len);
if (*out_len < required_len) return false;

const unsigned char *end = src + src_len, *in = src;
unsigned char *pos = out;
while (end - in >= 3) {
*pos++ = base64_table[in[0] >> 2];
*pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)];
*pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)];
*pos++ = base64_table[in[2] & 0x3f];
in += 3;
}

if (end - in) {
*pos++ = base64_table[in[0] >> 2];
if (end - in == 1) {
*pos++ = base64_table[(in[0] & 0x03) << 4];
if (add_padding) *pos++ = '=';
} else {
*pos++ = base64_table[((in[0] & 0x03) << 4) |
(in[1] >> 4)];
*pos++ = base64_table[(in[1] & 0x0f) << 2];
}
if (add_padding) *pos++ = '=';
}
*pos = '\0';
*out_len = pos - out;
return true;
}
#undef encode_func
#undef decode_func
#undef inner_func
#undef INPUT_T
#undef B64_INPUT_BITSIZE
#endif
33 changes: 5 additions & 28 deletions kitty/charsets.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions kitty/data-types.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#endif

#include "data-types.h"
#include "base64.h"
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
Expand Down Expand Up @@ -75,6 +76,32 @@ redirect_std_streams(PyObject UNUSED *self, PyObject *args) {
Py_RETURN_NONE;
}

static PyObject*
pybase64_encode(PyObject UNUSED *self, PyObject *args) {
int add_padding = 0;
const char *src; Py_ssize_t src_len;
if (!PyArg_ParseTuple(args, "y#|p", &src, &src_len, &add_padding)) return NULL;
size_t sz = required_buffer_size_for_base64_encode(src_len);
PyObject *ans = PyBytes_FromStringAndSize(NULL, sz);
if (!ans) return NULL;
base64_encode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz, add_padding);
if (_PyBytes_Resize(&ans, sz) != 0) return NULL;
return ans;
}

static PyObject*
pybase64_decode(PyObject UNUSED *self, PyObject *args) {
const char *src; Py_ssize_t src_len;
if (!PyArg_ParseTuple(args, "y#", &src, &src_len)) return NULL;
size_t sz = required_buffer_size_for_base64_decode(src_len);
PyObject *ans = PyBytes_FromStringAndSize(NULL, sz);
if (!ans) return NULL;
base64_decode8((const unsigned char*)src, src_len, (unsigned char*)PyBytes_AS_STRING(ans), &sz);
if (_PyBytes_Resize(&ans, sz) != 0) return NULL;
return ans;
}


static PyObject*
pyset_iutf8(PyObject UNUSED *self, PyObject *args) {
int fd, on;
Expand Down Expand Up @@ -306,6 +333,8 @@ static PyMethodDef module_methods[] = {
{"raw_tty", raw_tty, METH_VARARGS, ""},
{"close_tty", close_tty, METH_VARARGS, ""},
{"set_iutf8_fd", (PyCFunction)pyset_iutf8, METH_VARARGS, ""},
{"base64_encode", (PyCFunction)pybase64_encode, METH_VARARGS, ""},
{"base64_decode", (PyCFunction)pybase64_decode, METH_VARARGS, ""},
{"thread_write", (PyCFunction)cm_thread_write, METH_VARARGS, ""},
{"parse_bytes", (PyCFunction)parse_bytes, METH_VARARGS, ""},
{"parse_bytes_dump", (PyCFunction)parse_bytes_dump, METH_VARARGS, ""},
Expand Down
1 change: 0 additions & 1 deletion kitty/data-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ attrs_to_cursor(const CellAttrs attrs, Cursor *c) {


// Global functions
const char* base64_decode(const uint32_t *src, size_t src_sz, uint8_t *dest, size_t dest_capacity, size_t *dest_sz);
Line* alloc_line(void);
Cursor* alloc_cursor(void);
LineBuf* alloc_linebuf(unsigned int, unsigned int);
Expand Down
2 changes: 2 additions & 0 deletions kitty/fast_data_types.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1534,3 +1534,5 @@ def expand_ansi_c_escapes(test: str) -> str: ...
def update_tab_bar_edge_colors(os_window_id: int) -> bool: ...
def mask_kitty_signals_process_wide() -> None: ...
def is_modifier_key(key: int) -> bool: ...
def base64_encode(src: bytes, add_padding: bool = False) -> bytes: ...
def base64_decode(src: bytes) -> bytes: ...
20 changes: 14 additions & 6 deletions kitty/file_transmission.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
import stat
import tempfile
from base64 import standard_b64decode, standard_b64encode
from base64 import standard_b64decode
from collections import defaultdict, deque
from contextlib import suppress
from dataclasses import Field, dataclass, field, fields
Expand All @@ -19,7 +19,7 @@

from kittens.transfer.librsync import LoadSignature, PatchFile, delta_for_file, signature_of_file
from kittens.transfer.utils import IdentityCompressor, ZlibCompressor, abspath, expand_home, home_path
from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, add_timer, get_boss, get_options
from kitty.fast_data_types import FILE_TRANSFER_CODE, OSC, add_timer, base64_encode, get_boss, get_options
from kitty.types import run_once

from .utils import log_error, sanitize_control_codes
Expand Down Expand Up @@ -246,6 +246,14 @@ def serialized_to_field_map() -> Dict[bytes, 'Field[Any]']:
return ans


def b64decode(val: memoryview) -> bytes:
extra = len(val) % 4
if extra != 0:
padding = b'=' * (4 - extra)
val = memoryview(bytes(val) + padding)
return standard_b64decode(val)


@dataclass
class FileTransmissionCommand:

Expand Down Expand Up @@ -307,10 +315,10 @@ def get_serialized_fields(self, prefix_with_osc_code: bool = False) -> Iterator[
if issubclass(k.type, Enum):
yield val.name
elif k.type is bytes:
yield standard_b64encode(val)
yield base64_encode(val)
elif k.type is str:
if k.metadata.get('base64'):
yield standard_b64encode(val.encode('utf-8'))
yield base64_encode(val.encode('utf-8'))
else:
yield safe_string(val)
elif k.type is int:
Expand All @@ -334,12 +342,12 @@ def handle_item(key: memoryview, val: memoryview) -> None:
if issubclass(field.type, Enum):
setattr(ans, field.name, field.type[decode_utf8_buffer(val)])
elif field.type is bytes:
setattr(ans, field.name, standard_b64decode(val))
setattr(ans, field.name, b64decode(val))
elif field.type is int:
setattr(ans, field.name, int(val))
elif field.type is str:
if field.metadata.get('base64'):
sval = standard_b64decode(val).decode('utf-8')
sval = b64decode(val).decode('utf-8')
else:
sval = safe_string(decode_utf8_buffer(val))
setattr(ans, field.name, sanitize_control_codes(sval))
Expand Down
8 changes: 4 additions & 4 deletions kitty/parse-graphics-command.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions kitty/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define _POSIX_C_SOURCE 200809L

#include "data-types.h"
#define B64_INPUT_BITSIZE 32
#include "base64.h"
#include "control-codes.h"
#include "screen.h"
#include "graphics.h"
Expand Down
12 changes: 11 additions & 1 deletion kitty_tests/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from binascii import hexlify
from functools import partial

from kitty.fast_data_types import CURSOR_BLOCK, parse_bytes, parse_bytes_dump
from kitty.fast_data_types import CURSOR_BLOCK, parse_bytes, parse_bytes_dump, base64_decode, base64_encode
from kitty.notify import NotificationCommand, handle_notification_cmd, notification_activated, reset_registry

from . import BaseTest
Expand Down Expand Up @@ -41,6 +41,16 @@ def parse_bytes_dump(self, s, x, *cmds):
q.append(('draw', current))
self.ae(tuple(q), cmds)

def test_base64(self):
for src, expected in {
'bGlnaHQgdw==': 'light w',
'bGlnaHQgd28=': 'light wo',
'bGlnaHQgd29y': 'light wor',
}.items():
self.ae(base64_decode(src.encode()), expected.encode(), f'Decoding of {src} failed')
self.ae(base64_decode(src.replace('=', '').encode()), expected.encode(), f'Decoding of {src} failed')
self.ae(base64_encode(expected.encode()), src.replace('=', '').encode(), f'Encoding of {expected} failed')

def test_simple_parsing(self):
s = self.create_screen()
pb = partial(self.parse_bytes_dump, s)
Expand Down

0 comments on commit 5c2ac8a

Please sign in to comment.