Skip to content

Commit

Permalink
make changes following @brucala's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
joshday committed Oct 5, 2021
1 parent 420fd7d commit f1162f7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
2 changes: 0 additions & 2 deletions src/OnlineStatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ roughly to:
for x in 1:10
fit!(o, x)
end
fit!(o, 11:20)
"""
fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o)

Expand Down
28 changes: 14 additions & 14 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ mutable struct CountMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T}
nmissing::Int
end
CountMissing(stat::OnlineStat) = CountMissing(stat, 0)
value(o::CountMissing) = (nmissing=o.nmissing, stat=o.stat)
nobs(o::CountMissing) = nobs(o.stat) + o.nmissing
additional_info(o::CountMissing) = (; nmissing=o.nmissing)

_fit!(o::CountMissing, x) = _fit!(o.stat, x)
_fit!(o::CountMissing, ::Missing) = (o.nmissing += 1)
Expand All @@ -45,22 +44,23 @@ your transformation, you may need to specify the type of a single observation (`
o = FilterTransform(String => (x->true) => (x->parse(Int,x)) => Mean())
fit!(o, "1")
"""
struct FilterTransform{S, T, O<:OnlineStat{T},F,F2} <: StatWrapper{S}
mutable struct FilterTransform{S, T, O<:OnlineStat{T},F,F2} <: StatWrapper{S}
stat::O
filter::F
transform::F2
nfiltered::Int
end
FilterTransform(intype::DataType, stat::OnlineStat; kw...) = FilterTransform(stat, intype; kw...)
function FilterTransform(stat::OnlineStat{T}, intype=T; filter=always_true, transform=identity) where {T}
FilterTransform{intype, T, typeof(stat), typeof(filter), typeof(transform)}(stat, filter, transform)
FilterTransform{intype, T, typeof(stat), typeof(filter), typeof(transform)}(stat, filter, transform, 0)
end
FilterTransform(intype::DataType, stat::OnlineStat; kw...) = FilterTransform(stat, intype; kw...)
function FilterTransform(p::Pair{DataType, <:Pair{<:Function, <:Pair{<:Function, <:OnlineStat}}})
FilterTransform(p[1], p[2][2][2]; filter=p[2][1], transform=p[2][2][1])
end

_fit!(o::FilterTransform, y) = o.filter(y) && _fit!(o.stat, o.transform(y))
_fit!(o::FilterTransform, y) = o.filter(y) ? _fit!(o.stat, o.transform(y)) : (o.nfiltered += 1)

additional_info(o::FilterTransform) = (; filter=o.filter, transform=o.transform)
additional_info(o::FilterTransform) = (; filter=o.filter, transform=o.transform, nfiltered=o.nfiltered)

always_true(x) = true

Expand Down Expand Up @@ -112,21 +112,21 @@ function TryCatch(stat::OnlineStat; error_limit=1000, error_message_limit=90)
end

errors(o::TryCatch) = value(o.errors)
nerrors(o::TryCatch) = sum(values(value(o.errors)))

function additional_info(o::TryCatch)
ex = errors(o)
nex = length(ex)
msg = length(ex) o.error_limit ? "$nex (limit reached)" : nex
nex == 0 ? () : (; errors=msg)
end
additional_info(o::TryCatch) = (;nerrors = nerrors(o))

function handle_error!(o::TryCatch, ex)
io = IOBuffer()
Base.showerror(io, ex)
s = String(take!(io))
lim = o.error_message_limit
s = length(s) > lim ? s[1:lim] * "..." : s
length(value(o.errors)) < o.error_limit && _fit!(o.errors, s)
if length(value(o.errors)) < o.error_limit || haskey(value(o.errors), s)
_fit!(o.errors, s)
else
_fit!(o.errors, "Other (error_limit reached)")
end
end

function fit!(o::TryCatch{T}, y::T) where {T}
Expand Down
9 changes: 7 additions & 2 deletions test/test_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ println(" > CountMissing")
data2 = Vector{Union{Missing,Float64}}(copy(y2))
data[x] .= missing
data2[x2] .= missing
a, b = mergevals(CountMissing(Mean()), data, data2)
@test value(a.stat) value(b.stat)
a, b = mergestats(CountMissing(Mean()), data, data2, nobs_equals_length=false)
@test value(a) value(b)
@test a.nmissing == b.nmissing
end
#-----------------------------------------------------------------------# CovMatrix
Expand Down Expand Up @@ -150,6 +150,11 @@ println(" > FilterTransform")
o = FilterTransform(String => (x->true) => (x -> parse(Int,x)) => Mean())
fit!(o, ["1", "3", "5"])
@test value(o) 3

o = FilterTransform(String => (x -> x != "1") => (x -> parse(Int,x)) => Mean())
fit!(o, ["1", "3", "5"])
@test value(o) 4
@test o.nfiltered == 1
end

#-----------------------------------------------------------------------# Group
Expand Down

0 comments on commit f1162f7

Please sign in to comment.