Skip to content

Commit

Permalink
fix #29885, make grisu digit buffer task-local
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson committed Nov 3, 2018
1 parent e2aba7a commit 0a427a3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 32 deletions.
23 changes: 17 additions & 6 deletions base/grisu/grisu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,26 @@ include("grisu/bignum.jl")
const DIGITS = Vector{UInt8}(undef, 309+17)
const BIGNUMS = [Bignums.Bignum(),Bignums.Bignum(),Bignums.Bignum(),Bignums.Bignum()]

# thread-safe code should use a per-thread DIGITS buffer DIGITSs[Threads.threadid()]
# NOTE: DIGITS[s] is deprecated; you should use getbuf() instead.
const DIGITSs = [DIGITS]
const BIGNUMSs = [BIGNUMS]
function __init__()
Threads.resize_nthreads!(DIGITSs)
Threads.resize_nthreads!(BIGNUMSs)
end

function getbuf()
tls = task_local_storage()
d = get(tls, :DIGITS, nothing)
if d === nothing
d = Vector{UInt8}(undef, 309+17)
tls[:DIGITS] = d
end
return d::Vector{UInt8}
end

"""
(len, point, neg) = Grisu.grisu(v::AbstractFloat, mode, requested_digits,
buffer=DIGITSs[Threads.threadid()], bignums=BIGNUMSs[Threads.threadid()])
(len, point, neg) = Grisu.grisu(v::AbstractFloat, mode, requested_digits, [buffer], [bignums])
Convert the number `v` to decimal using the Grisu algorithm.
Expand All @@ -38,7 +47,7 @@ Convert the number `v` to decimal using the Grisu algorithm.
- `Grisu.FIXED`: round to `requested_digits` digits.
- `Grisu.PRECISION`: round to `requested_digits` significant digits.
The characters are written as bytes to `buffer`, with a terminating NUL byte, and `bignums` are used internally as part of the correction step.
The characters are written as bytes to `buffer`, with a terminating NUL byte, and `bignums` are used internally as part of the correction step. You can call `Grisu.getbuf()` to obtain a suitable task-local buffer.
The returned tuple contains:
Expand Down Expand Up @@ -94,7 +103,8 @@ function _show(io::IO, x::AbstractFloat, mode, n::Int, typed, compact)
return
end
typed && isa(x,Float16) && print(io, "Float16(")
(len,pt,neg),buffer = grisu(x,mode,n),DIGITSs[Threads.threadid()]
buffer = getbuf()
len, pt, neg = grisu(x,mode,n,buffer)
pdigits = pointer(buffer)
if mode == PRECISION
while len > 1 && buffer[len] == 0x30
Expand Down Expand Up @@ -182,7 +192,8 @@ function _print_shortest(io::IO, x::AbstractFloat, dot::Bool, mode, n::Int)
isnan(x) && return print(io, "NaN")
x < 0 && print(io,'-')
isinf(x) && return print(io, "Inf")
(len,pt,neg),buffer = grisu(x,mode,n),DIGITSs[Threads.threadid()]
buffer = getbuf()
len, pt, neg = grisu(x,mode,n,buffer)
pdigits = pointer(buffer)
e = pt-len
k = -9<=e<=9 ? 1 : 2
Expand Down
52 changes: 26 additions & 26 deletions base/printf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end

# note: if print_fixed is changed, print_fixed_width should be changed accordingly
function print_fixed(out, precision, pt, ndigits, trailingzeros=true)
pdigits = pointer(DIGITSs[Threads.threadid()])
pdigits = pointer(Grisu.getbuf())
if pt <= 0
# 0.0dddd0
print(out, '0')
Expand Down Expand Up @@ -346,7 +346,7 @@ function gen_d(flags::String, width::Int, precision::Int, c::Char)
push!(blk.args, pad(width-1, zeros, '0'))
end
# print integer
push!(blk.args, :(unsafe_write(out, pointer(DIGITSs[Threads.threadid()]), pt)))
push!(blk.args, :(unsafe_write(out, pointer($Grisu.getbuf()), pt)))
# print padding
if padding !== nothing && '-' in flags
push!(blk.args, pad(width-precision, padding, ' '))
Expand Down Expand Up @@ -405,7 +405,7 @@ function gen_f(flags::String, width::Int, precision::Int, c::Char)
if precision > 0
push!(blk.args, :(print_fixed(out,$precision,pt,len)))
else
push!(blk.args, :(unsafe_write(out, pointer(DIGITSs[Threads.threadid()]), len)))
push!(blk.args, :(unsafe_write(out, pointer($Grisu.getbuf()), len)))
push!(blk.args, :(while pt >= (len+=1) print(out,'0') end))
'#' in flags && push!(blk.args, :(print(out, '.')))
end
Expand Down Expand Up @@ -438,9 +438,9 @@ function gen_e(flags::String, width::Int, precision::Int, c::Char, inside_g::Boo
end
# interpret the number
if precision < 0; precision = 6; end
ndigits = min(precision+1,length(DIGITSs[Threads.threadid()])-1)
ndigits = min(precision+1,length(Grisu.getbuf())-1)
push!(blk.args, :((do_out, args) = ini_dec(out,$x,$ndigits, $flags, $width, $precision, $c)))
push!(blk.args, :(digits = DIGITSs[Threads.threadid()]))
push!(blk.args, :(digits = $Grisu.getbuf()))
ifblk = Expr(:if, :do_out, Expr(:block))
push!(blk.args, ifblk)
blk = ifblk.args[2]
Expand Down Expand Up @@ -554,10 +554,10 @@ function gen_a(flags::String, width::Int, precision::Int, c::Char)
if precision < 0
push!(blk.args, :((do_out, args) = $fn(out,$x, $flags, $width, $precision, $c)))
else
ndigits = min(precision+1,length(DIGITSs[Threads.threadid()])-1)
ndigits = min(precision+1,length(Grisu.getbuf())-1)
push!(blk.args, :((do_out, args) = $fn(out,$x,$ndigits, $flags, $width, $precision, $c)))
end
push!(blk.args, :(digits = DIGITSs[Threads.threadid()]))
push!(blk.args, :(digits = $Grisu.getbuf()))
ifblk = Expr(:if, :do_out, Expr(:block))
push!(blk.args, ifblk)
blk = ifblk.args[2]
Expand Down Expand Up @@ -750,7 +750,7 @@ function gen_g(flags::String, width::Int, precision::Int, c::Char)
#
x, ex, blk = special_handler(flags,width)
if precision < 0; precision = 6; end
ndigits = min(precision+1,length(DIGITSs[Threads.threadid()])-1)
ndigits = min(precision+1,length(Grisu.getbuf())-1)
# See if anyone else wants to handle it
push!(blk.args, :((do_out, args) = ini_dec(out,$x,$ndigits, $flags, $width, $precision, $c)))
ifblk = Expr(:if, :do_out, Expr(:block))
Expand Down Expand Up @@ -859,7 +859,7 @@ end

function decode_oct(d::Integer)
neg, x = handlenegative(d)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
@handle_zero x digits
pt = i = div((sizeof(x)<<3)-leading_zeros(x)+2,3)
while i > 0
Expand All @@ -874,7 +874,7 @@ function decode_0ct(d::Integer)
neg, x = handlenegative(d)
# doesn't need special handling for zero
pt = i = div((sizeof(x)<<3)-leading_zeros(x)+5,3)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
while i > 0
digits[i] = 48+(x&0x7)
x >>= 3
Expand All @@ -885,7 +885,7 @@ end

function decode_dec(d::Integer)
neg, x = handlenegative(d)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
@handle_zero x digits
pt = i = Base.ndigits0z(x)
while i > 0
Expand All @@ -898,7 +898,7 @@ end

function decode_hex(d::Integer, symbols::AbstractArray{UInt8,1})
neg, x = handlenegative(d)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
@handle_zero x digits
pt = i = (sizeof(x)<<1)-(leading_zeros(x)>>2)
while i > 0
Expand All @@ -918,7 +918,7 @@ decode_HEX(x::Integer) = decode_hex(x,HEX_symbols)
function decode(b::Int, x::BigInt)
neg = x.size < 0
pt = Base.ndigits(x, base=abs(b))
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
length(digits) < pt+1 && resize!(digits, pt+1)
neg && (x.size = -x.size)
GMP.MPZ.get_str!(digits, b, x)
Expand All @@ -932,7 +932,7 @@ decode_HEX(x::BigInt) = decode(-16, x)

function decode_0ct(x::BigInt)
neg = x.size < 0
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
digits[1] = '0'
if x.size == 0
return Int32(1), Int32(1), neg
Expand Down Expand Up @@ -961,12 +961,12 @@ end
#

function decode_dec(x::SmallFloatingPoint)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
if x == 0.0
digits[1] = '0'
return (Int32(1), Int32(1), false)
end
len,pt,neg = grisu(x,Grisu.FIXED,0)
len,pt,neg = grisu(x,Grisu.FIXED,0,digits)
if len == 0
digits[1] = '0'
return (Int32(1), Int32(1), false)
Expand All @@ -991,9 +991,9 @@ fix_dec(x::Real, n::Int) = fix_dec(float(x),n)
fix_dec(x::Integer, n::Int) = decode_dec(x)

function fix_dec(x::SmallFloatingPoint, n::Int)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
if n > length(digits)-1; n = length(digits)-1; end
len,pt,neg = grisu(x,Grisu.FIXED,n)
len,pt,neg = grisu(x,Grisu.FIXED,n,digits)
if len == 0
digits[1] = '0'
return (Int32(1), Int32(1), neg)
Expand All @@ -1013,7 +1013,7 @@ ini_dec(x::Real, n::Int) = ini_dec(float(x),n)
function ini_dec(d::Integer, n::Int)
neg, x = handlenegative(d)
k = ndigits(x)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
if k <= n
pt = k
for i = k:-1:1
Expand Down Expand Up @@ -1045,24 +1045,24 @@ end

function ini_dec(x::SmallFloatingPoint, n::Int)
if x == 0.0
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), DIGITSs[Threads.threadid()], '0', n)
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), Grisu.getbuf(), '0', n)
return Int32(1), Int32(1), signbit(x)
else
len,pt,neg = grisu(x,Grisu.PRECISION,n)
len,pt,neg = grisu(x,Grisu.PRECISION,n,Grisu.getbuf())
end
return Int32(len), Int32(pt), neg
end

function ini_dec(x::BigInt, n::Int)
if x.size == 0
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), DIGITSs[Threads.threadid()], '0', n)
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), Grisu.getbuf(), '0', n)
return Int32(1), Int32(1), false
end
d = Base.ndigits0z(x)
if d <= n
info = decode_dec(x)
d == n && return info
p = convert(Ptr{Cvoid}, DIGITSs[Threads.threadid()]) + info[2]
p = convert(Ptr{Cvoid}, Grisu.getbuf()) + info[2]
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), p, '0', n - info[2])
return info
end
Expand All @@ -1081,7 +1081,7 @@ ini_hex(x::Real, symbols::AbstractArray{UInt8,1}) = ini_hex(float(x), symbols)

function ini_hex(x::SmallFloatingPoint, n::Int, symbols::AbstractArray{UInt8,1})
x = Float64(x)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
if x == 0.0
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), digits, '0', n)
return Int32(1), Int32(0), signbit(x)
Expand All @@ -1107,7 +1107,7 @@ end

function ini_hex(x::SmallFloatingPoint, symbols::AbstractArray{UInt8,1})
x = Float64(x)
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
if x == 0.0
ccall(:memset, Ptr{Cvoid}, (Ptr{Cvoid}, Cint, Csize_t), digits, '0', 1)
return Int32(1), Int32(0), signbit(x)
Expand Down Expand Up @@ -1174,7 +1174,7 @@ function bigfloat_printf(out, d::BigFloat, flags::String, width::Int, precision:
write(fmt, UInt8(0))
printf_fmt = take!(fmt)
@assert length(printf_fmt) == fmt_len
digits = DIGITSs[Threads.threadid()]
digits = Grisu.getbuf()
bufsiz = length(digits)
lng = ccall((:mpfr_snprintf,:libmpfr), Int32,
(Ptr{UInt8}, Culong, Ptr{UInt8}, Ref{BigFloat}...),
Expand Down
11 changes: 11 additions & 0 deletions test/grisu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1751,3 +1751,14 @@ len,point,neg = Grisu.grisu(1.0, Grisu.FIXED, 0, buffer)
@test 1 >= len-1
@test "1" == unsafe_string(pointer(buffer))
@test !neg

# issue #29885
@sync let p = Pipe(), q = Pipe()
Base.link_pipe!(p, reader_supports_async=true, writer_supports_async=true)
Base.link_pipe!(q, reader_supports_async=true, writer_supports_async=true)
@async write(p, zeros(UInt8, 2^18))
@async (print(p, 12.345); close(p.in))
@async print(q, 9.8)
read(p, 2^18)
@test read(p, String) == "12.345"
end

0 comments on commit 0a427a3

Please sign in to comment.