Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate UTF-8 at C++ -> Lean boundary #3963

Merged
merged 4 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Init/Data/String/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ where
decreasing_by exact Nat.sub_lt_sub_left ‹_› (Nat.lt_add_of_pos_right c.utf8Size_pos)

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. -/
@[extern "lean_string_from_utf8"]
@[extern "lean_string_from_utf8_unchecked"]
def fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String :=
loop 0 ""
where
Expand Down
6 changes: 5 additions & 1 deletion src/Lean/Compiler/IR/EmitC.lean
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,11 @@ def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit := do
emitLhs z;
match v with
| LitVal.num v => emitNumLit t v; emitLn ";"
| LitVal.str v => emit "lean_mk_string_from_bytes("; emit (quoteString v); emit ", "; emit v.utf8ByteSize; emitLn ");"
| LitVal.str v =>
emit "lean_mk_string_unchecked(";
emit (quoteString v); emit ", ";
emit v.utf8ByteSize; emit ", ";
emit v.length; emitLn ");"

def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit :=
match v with
Expand Down
13 changes: 7 additions & 6 deletions src/Lean/Compiler/IR/EmitLLVM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ def callLeanUnsignedToNatFn (builder : LLVM.Builder llvmctx)
let nv ← constIntUnsigned n
LLVM.buildCall2 builder fnty f #[nv] name

def callLeanMkStringFromBytesFn (builder : LLVM.Builder llvmctx)
(strPtr nBytes : LLVM.Value llvmctx) (name : String) : M llvmctx (LLVM.Value llvmctx) := do
let fnName := "lean_mk_string_from_bytes"
def callLeanMkStringUncheckedFn (builder : LLVM.Builder llvmctx)
(strPtr nBytes nChars : LLVM.Value llvmctx) (name : String) : M llvmctx (LLVM.Value llvmctx) := do
let fnName := "lean_mk_string_unchecked"
let retty ← LLVM.voidPtrType llvmctx
let argtys := #[← LLVM.voidPtrType llvmctx, ← LLVM.size_tType llvmctx]
let argtys := #[← LLVM.voidPtrType llvmctx, ← LLVM.size_tType llvmctx, ← LLVM.size_tType llvmctx]
let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys
let fnty ← LLVM.functionType retty argtys
LLVM.buildCall2 builder fnty fn #[strPtr, nBytes] name
LLVM.buildCall2 builder fnty fn #[strPtr, nBytes, nChars] name

def callLeanMkString (builder : LLVM.Builder llvmctx)
(strPtr : LLVM.Value llvmctx) (name : String) : M llvmctx (LLVM.Value llvmctx) := do
Expand Down Expand Up @@ -772,7 +772,8 @@ def emitLit (builder : LLVM.Builder llvmctx)
(← LLVM.opaquePointerTypeInContext llvmctx)
str_global #[zero] ""
let nbytes ← constIntSizeT v.utf8ByteSize
callLeanMkStringFromBytesFn builder strPtr nbytes ""
let nchars ← constIntSizeT v.length
callLeanMkStringUncheckedFn builder strPtr nbytes nchars ""
LLVM.buildStore builder zv zslot
return zslot

Expand Down
3 changes: 3 additions & 0 deletions src/include/lean/lean.h
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,10 @@ static inline size_t lean_string_capacity(lean_object * o) { return lean_to_stri
static inline size_t lean_string_byte_size(lean_object * o) { return sizeof(lean_string_object) + lean_string_capacity(o); }
/* instance : inhabited char := ⟨'A'⟩ */
static inline uint32_t lean_char_default_value() { return 'A'; }
LEAN_EXPORT lean_obj_res lean_mk_string_unchecked(char const * s, size_t sz, size_t len);
LEAN_EXPORT lean_obj_res lean_mk_string_from_bytes(char const * s, size_t sz);
LEAN_EXPORT lean_obj_res lean_mk_string_from_bytes_unchecked(char const * s, size_t sz);
LEAN_EXPORT lean_obj_res lean_mk_ascii_string_unchecked(char const * s);
LEAN_EXPORT lean_obj_res lean_mk_string(char const * s);
static inline char const * lean_string_cstr(b_lean_obj_arg o) {
assert(lean_is_string(o));
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ extern "C" LEAN_EXPORT obj_res lean_io_get_random_bytes (size_t nbytes, obj_arg
#if !defined(LEAN_WINDOWS)
int fd_urandom = open("/dev/urandom", O_RDONLY | O_CLOEXEC);
if (fd_urandom < 0) {
return io_result_mk_error(decode_io_error(errno, lean_mk_string("/dev/urandom")));
return io_result_mk_error(decode_io_error(errno, lean_mk_ascii_string_unchecked("/dev/urandom")));
}
#endif

Expand Down Expand Up @@ -1092,7 +1092,7 @@ extern "C" LEAN_EXPORT obj_res lean_io_exit(uint8_t code, obj_arg /* w */) {
}

void initialize_io() {
g_io_error_nullptr_read = lean_mk_io_user_error(mk_string("null reference read"));
g_io_error_nullptr_read = lean_mk_io_user_error(mk_ascii_string_unchecked("null reference read"));
mark_persistent(g_io_error_nullptr_read);
g_io_handle_external_class = lean_register_external_class(io_handle_finalizer, io_handle_foreach);
#if defined(LEAN_WINDOWS)
Expand Down
87 changes: 58 additions & 29 deletions src/runtime/object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,12 @@ extern "C" LEAN_EXPORT lean_object * lean_array_data(lean_obj_arg a) {
}

extern "C" LEAN_EXPORT lean_obj_res lean_array_get_panic(lean_obj_arg def_val) {
return lean_panic_fn(def_val, lean_mk_string("Error: index out of bounds"));
return lean_panic_fn(def_val, lean_mk_ascii_string_unchecked("Error: index out of bounds"));
}

extern "C" LEAN_EXPORT lean_obj_res lean_array_set_panic(lean_obj_arg a, lean_obj_arg v) {
lean_dec(v);
return lean_panic_fn(a, lean_mk_string("Error: index out of bounds"));
return lean_panic_fn(a, lean_mk_ascii_string_unchecked("Error: index out of bounds"));
}

// =======================================
Expand Down Expand Up @@ -1561,9 +1561,9 @@ extern "C" LEAN_EXPORT lean_obj_res lean_float_to_string(double a) {
if (isnan(a))
// override NaN because we don't want NaNs to be distinguishable
// because the sign bit / payload bits can be architecture-dependent
return mk_string("NaN");
return mk_ascii_string_unchecked("NaN");
else
return mk_string(std::to_string(a));
return mk_ascii_string_unchecked(std::to_string(a));
}

extern "C" LEAN_EXPORT double lean_float_scaleb(double a, b_lean_obj_arg b) {
Expand Down Expand Up @@ -1607,28 +1607,59 @@ static object * string_ensure_capacity(object * o, size_t extra) {
}
}

extern "C" LEAN_EXPORT object * lean_mk_string_core(char const * s, size_t sz, size_t len) {
extern "C" LEAN_EXPORT object * lean_mk_string_unchecked(char const * s, size_t sz, size_t len) {
size_t rsz = sz + 1;
object * r = lean_alloc_string(rsz, rsz, len);
memcpy(w_string_cstr(r), s, sz);
w_string_cstr(r)[sz] = 0;
return r;
}

object * lean_mk_string_lossy_recover(char const * s, size_t sz, size_t pos, size_t i) {
std::string str(s, pos);
size_t start = pos;
while (pos < sz) {
if (!validate_utf8_one((const uint8_t *)s, sz, pos)) {
str.append(s + start, pos - start);
str.append("\ufffd"); // U+FFFD REPLACEMENT CHARACTER
do pos++; while (pos < sz && (s[pos] & 0xc0) == 0x80);
start = pos;
}
i++;
}
str.append(s + start, pos - start);
return lean_mk_string_unchecked(str.data(), str.size(), i);
}

extern "C" LEAN_EXPORT object * lean_mk_string_from_bytes(char const * s, size_t sz) {
return lean_mk_string_core(s, sz, utf8_strlen(s, sz));
size_t pos = 0, i = 0;
if (validate_utf8((const uint8_t *)s, sz, pos, i)) {
return lean_mk_string_unchecked(s, pos, i);
} else {
return lean_mk_string_lossy_recover(s, sz, pos, i);
}
}

extern "C" LEAN_EXPORT object * lean_mk_string_from_bytes_unchecked(char const * s, size_t sz) {
return lean_mk_string_unchecked(s, sz, utf8_strlen(s, sz));
}

extern "C" LEAN_EXPORT object * lean_mk_string(char const * s) {
return lean_mk_string_from_bytes(s, strlen(s));
}

extern "C" LEAN_EXPORT obj_res lean_string_from_utf8(b_obj_arg a) {
return lean_mk_string_from_bytes(reinterpret_cast<char *>(lean_sarray_cptr(a)), lean_sarray_size(a));
extern "C" LEAN_EXPORT object * lean_mk_ascii_string_unchecked(char const * s) {
size_t len = strlen(s);
return lean_mk_string_unchecked(s, len, len);
}

extern "C" LEAN_EXPORT obj_res lean_string_from_utf8_unchecked(b_obj_arg a) {
return lean_mk_string_from_bytes_unchecked(reinterpret_cast<char *>(lean_sarray_cptr(a)), lean_sarray_size(a));
}

extern "C" LEAN_EXPORT uint8 lean_string_validate_utf8(b_obj_arg a) {
return validate_utf8(lean_sarray_cptr(a), lean_sarray_size(a));
size_t pos = 0, i = 0;
return validate_utf8(lean_sarray_cptr(a), lean_sarray_size(a), pos, i);
}

extern "C" LEAN_EXPORT obj_res lean_string_to_utf8(b_obj_arg s) {
Expand All @@ -1642,8 +1673,8 @@ object * mk_string(std::string const & s) {
return lean_mk_string_from_bytes(s.data(), s.size());
}

object * mk_ascii_string(std::string const & s) {
return lean_mk_string_core(s.data(), s.size(), s.size());
object * mk_ascii_string_unchecked(std::string const & s) {
return lean_mk_string_unchecked(s.data(), s.size(), s.size());
}

std::string string_to_std(b_obj_arg o) {
Expand Down Expand Up @@ -1713,16 +1744,6 @@ extern "C" LEAN_EXPORT bool lean_string_lt(object * s1, object * s2) {
return r < 0 || (r == 0 && sz1 < sz2);
}

static std::string list_as_string(b_obj_arg lst) {
std::string s;
b_obj_arg o = lst;
while (!lean_is_scalar(o)) {
push_unicode_scalar(s, lean_unbox_uint32(lean_ctor_get(o, 0)));
o = lean_ctor_get(o, 1);
}
return s;
}

static obj_res string_to_list_core(std::string const & s, bool reverse = false) {
std::vector<unsigned> tmp;
utf8_decode(s, tmp);
Expand All @@ -1741,9 +1762,16 @@ static obj_res string_to_list_core(std::string const & s, bool reverse = false)
}

extern "C" LEAN_EXPORT obj_res lean_string_mk(obj_arg cs) {
std::string s = list_as_string(cs);
std::string s;
b_obj_arg o = cs;
size_t len = 0;
while (!lean_is_scalar(o)) {
push_unicode_scalar(s, lean_unbox_uint32(lean_ctor_get(o, 0)));
o = lean_ctor_get(o, 1);
len++;
}
lean_dec(cs);
return mk_string(s);
return lean_mk_string_unchecked(s.data(), s.size(), len);
}

extern "C" LEAN_EXPORT obj_res lean_string_data(obj_arg s) {
Expand Down Expand Up @@ -1876,7 +1904,7 @@ extern "C" LEAN_EXPORT obj_res lean_string_utf8_get_opt(b_obj_arg s, b_obj_arg i
}

static uint32 lean_string_utf8_get_panic() {
lean_panic_fn(lean_box(0), lean_mk_string("Error: invalid `String.Pos` at `String.get!`"));
lean_panic_fn(lean_box(0), lean_mk_ascii_string_unchecked("Error: invalid `String.Pos` at `String.get!`"));
return lean_char_default_value();
}

Expand Down Expand Up @@ -1957,10 +1985,10 @@ extern "C" LEAN_EXPORT obj_res lean_string_utf8_extract(b_obj_arg s, b_obj_arg b
usize e = lean_unbox(e0);
char const * str = lean_string_cstr(s);
usize sz = lean_string_size(s) - 1;
if (b >= e || b >= sz) return lean_mk_string("");
if (b >= e || b >= sz) return lean_mk_string_unchecked("", 0, 0);
/* In the reference implementation if `b` is not pointing to a valid UTF8
character start position, the result is the empty string. */
if (!is_utf8_first_byte(str[b])) return lean_mk_string("");
if (!is_utf8_first_byte(str[b])) return lean_mk_string_unchecked("", 0, 0);
if (e > sz) e = sz;
lean_assert(b < e);
lean_assert(e > 0);
Expand All @@ -1969,7 +1997,7 @@ extern "C" LEAN_EXPORT obj_res lean_string_utf8_extract(b_obj_arg s, b_obj_arg b
if (e < sz && !is_utf8_first_byte(str[e])) e = sz;
usize new_sz = e - b;
lean_assert(new_sz > 0);
return lean_mk_string_from_bytes(lean_string_cstr(s) + b, new_sz);
return lean_mk_string_from_bytes_unchecked(lean_string_cstr(s) + b, new_sz);
}

extern "C" LEAN_EXPORT obj_res lean_string_utf8_prev(b_obj_arg s, b_obj_arg i0) {
Expand Down Expand Up @@ -2018,9 +2046,10 @@ extern "C" LEAN_EXPORT obj_res lean_string_utf8_set(obj_arg s, b_obj_arg i0, uin
std::string tmp;
push_unicode_scalar(tmp, c);
std::string new_s = string_to_std(s);
usize len = lean_string_len(s);
dec(s);
new_s.replace(i, get_utf8_char_size_at(new_s, i), tmp);
return mk_string(new_s);
return lean_mk_string_unchecked(new_s.data(), new_s.size(), len);
}

extern "C" LEAN_EXPORT uint64 lean_string_hash(b_obj_arg s) {
Expand All @@ -2030,7 +2059,7 @@ extern "C" LEAN_EXPORT uint64 lean_string_hash(b_obj_arg s) {
}

extern "C" LEAN_EXPORT obj_res lean_string_of_usize(size_t n) {
return mk_ascii_string(std::to_string(n));
return mk_ascii_string_unchecked(std::to_string(n));
}

// =======================================
Expand Down
1 change: 1 addition & 0 deletions src/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ inline size_t string_capacity(object * o) { return lean_string_capacity(o); }
inline uint32 char_default_value() { return lean_char_default_value(); }
inline obj_res alloc_string(size_t size, size_t capacity, size_t len) { return lean_alloc_string(size, capacity, len); }
inline obj_res mk_string(char const * s) { return lean_mk_string(s); }
LEAN_EXPORT obj_res mk_ascii_string_unchecked(std::string const & s);
LEAN_EXPORT obj_res mk_string(std::string const & s);
LEAN_EXPORT std::string string_to_std(b_obj_arg o);
inline char const * string_cstr(b_obj_arg o) { return lean_string_cstr(o); }
Expand Down
79 changes: 42 additions & 37 deletions src/runtime/utf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,52 +220,57 @@ void utf8_decode(std::string const & str, std::vector<unsigned> & out) {
}
}

bool validate_utf8(uint8_t const * str, size_t size) {
size_t i = 0;
while (i < size) {
unsigned c = str[i];
if ((c & 0x80) == 0) {
/* zero continuation (0 to 0x7F) */
i++;
} else if ((c & 0xe0) == 0xc0) {
/* one continuation (0x80 to 0x7FF) */
if (i + 1 >= size) return false;
bool validate_utf8_one(uint8_t const * str, size_t size, size_t & pos) {
unsigned c = str[pos];
if ((c & 0x80) == 0) {
/* zero continuation (0 to 0x7F) */
pos++;
} else if ((c & 0xe0) == 0xc0) {
/* one continuation (0x80 to 0x7FF) */
if (pos + 1 >= size) return false;

unsigned c1 = str[i+1];
if ((c1 & 0xc0) != 0x80) return false;
unsigned c1 = str[pos+1];
if ((c1 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r < 0x80) return false;
unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r < 0x80) return false;

i += 2;
} else if ((c & 0xf0) == 0xe0) {
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if (i + 2 >= size) return false;
pos += 2;
} else if ((c & 0xf0) == 0xe0) {
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if (pos + 2 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false;
unsigned c1 = str[pos+1];
unsigned c2 = str[pos+2];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r < 0x800 || (r >= 0xD800 && r <= 0xDFFF)) return false;
unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r < 0x800 || (r >= 0xD800 && r <= 0xDFFF)) return false;

i += 3;
} else if ((c & 0xf8) == 0xf0) {
/* three continuations (0x10000 to 0x10FFFF) */
if (i + 3 >= size) return false;
pos += 3;
} else if ((c & 0xf8) == 0xf0) {
/* three continuations (0x10000 to 0x10FFFF) */
if (pos + 3 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
unsigned c3 = str[i+3];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80 || (c3 & 0xc0) != 0x80) return false;
unsigned c1 = str[pos+1];
unsigned c2 = str[pos+2];
unsigned c3 = str[pos+3];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80 || (c3 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r < 0x10000 || r > 0x10FFFF) return false;
unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r < 0x10000 || r > 0x10FFFF) return false;

i += 4;
} else {
return false;
}
pos += 4;
} else {
return false;
}
return true;
}

bool validate_utf8(uint8_t const * str, size_t size, size_t & pos, size_t & i) {
while (pos < size) {
if (!validate_utf8_one(str, size, pos)) return false;
i++;
}
return true;
}
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/utf8.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ LEAN_EXPORT unsigned next_utf8(char const * str, size_t size, size_t & i);
/* Decode a UTF-8 encoded string `str` into unicode scalar values */
LEAN_EXPORT void utf8_decode(std::string const & str, std::vector<unsigned> & out);

/* Returns true if the given character is valid UTF-8 */
LEAN_EXPORT bool validate_utf8_one(uint8_t const * str, size_t size, size_t & pos);

/* Returns true if the provided string is valid UTF-8 */
LEAN_EXPORT bool validate_utf8(uint8_t const * str, size_t size);
LEAN_EXPORT bool validate_utf8(uint8_t const * str, size_t size, size_t & pos, size_t & i);

/* Push a unicode scalar value into a utf-8 encoded string */
LEAN_EXPORT void push_unicode_scalar(std::string & s, unsigned code);
Expand Down
1 change: 1 addition & 0 deletions src/util/shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Author: Leonardo de Moura
#include "runtime/load_dynlib.h"
#include "runtime/array_ref.h"
#include "runtime/object_ref.h"
#include "runtime/utf8.h"
#include "util/timer.h"
#include "util/macros.h"
#include "util/io.h"
Expand Down
3 changes: 1 addition & 2 deletions tests/lean/1690.lean.expected.out
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
1690.lean:1:10-1:11: error: unknown identifier 'ó
'
1690.lean:1:10: error: expected token
Loading
Loading