From f1162f755922e1141a3dd92eb6aef279dc2033f2 Mon Sep 17 00:00:00 2001 From: joshday Date: Tue, 5 Oct 2021 10:16:40 -0400 Subject: [PATCH] make changes following @brucala's suggestions --- src/OnlineStatsBase.jl | 2 -- src/wrappers.jl | 28 ++++++++++++++-------------- test/test_stats.jl | 9 +++++++-- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/OnlineStatsBase.jl b/src/OnlineStatsBase.jl index 8692b08..87c6c81 100644 --- a/src/OnlineStatsBase.jl +++ b/src/OnlineStatsBase.jl @@ -113,8 +113,6 @@ roughly to: for x in 1:10 fit!(o, x) end - - fit!(o, 11:20) """ fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o) diff --git a/src/wrappers.jl b/src/wrappers.jl index 7dc7fe1..1d97d70 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -21,8 +21,7 @@ mutable struct CountMissing{T, O<:OnlineStat{T}} <: StatWrapper{Union{Missing,T} nmissing::Int end CountMissing(stat::OnlineStat) = CountMissing(stat, 0) -value(o::CountMissing) = (nmissing=o.nmissing, stat=o.stat) -nobs(o::CountMissing) = nobs(o.stat) + o.nmissing +additional_info(o::CountMissing) = (; nmissing=o.nmissing) _fit!(o::CountMissing, x) = _fit!(o.stat, x) _fit!(o::CountMissing, ::Missing) = (o.nmissing += 1) @@ -45,22 +44,23 @@ your transformation, you may need to specify the type of a single observation (` o = FilterTransform(String => (x->true) => (x->parse(Int,x)) => Mean()) fit!(o, "1") """ -struct FilterTransform{S, T, O<:OnlineStat{T},F,F2} <: StatWrapper{S} +mutable struct FilterTransform{S, T, O<:OnlineStat{T},F,F2} <: StatWrapper{S} stat::O filter::F transform::F2 + nfiltered::Int end -FilterTransform(intype::DataType, stat::OnlineStat; kw...) = FilterTransform(stat, intype; kw...) function FilterTransform(stat::OnlineStat{T}, intype=T; filter=always_true, transform=identity) where {T} - FilterTransform{intype, T, typeof(stat), typeof(filter), typeof(transform)}(stat, filter, transform) + FilterTransform{intype, T, typeof(stat), typeof(filter), typeof(transform)}(stat, filter, transform, 0) end +FilterTransform(intype::DataType, stat::OnlineStat; kw...) = FilterTransform(stat, intype; kw...) function FilterTransform(p::Pair{DataType, <:Pair{<:Function, <:Pair{<:Function, <:OnlineStat}}}) FilterTransform(p[1], p[2][2][2]; filter=p[2][1], transform=p[2][2][1]) end -_fit!(o::FilterTransform, y) = o.filter(y) && _fit!(o.stat, o.transform(y)) +_fit!(o::FilterTransform, y) = o.filter(y) ? _fit!(o.stat, o.transform(y)) : (o.nfiltered += 1) -additional_info(o::FilterTransform) = (; filter=o.filter, transform=o.transform) +additional_info(o::FilterTransform) = (; filter=o.filter, transform=o.transform, nfiltered=o.nfiltered) always_true(x) = true @@ -112,13 +112,9 @@ function TryCatch(stat::OnlineStat; error_limit=1000, error_message_limit=90) end errors(o::TryCatch) = value(o.errors) +nerrors(o::TryCatch) = sum(values(value(o.errors))) -function additional_info(o::TryCatch) - ex = errors(o) - nex = length(ex) - msg = length(ex) ≥ o.error_limit ? "$nex (limit reached)" : nex - nex == 0 ? () : (; errors=msg) -end +additional_info(o::TryCatch) = (;nerrors = nerrors(o)) function handle_error!(o::TryCatch, ex) io = IOBuffer() @@ -126,7 +122,11 @@ function handle_error!(o::TryCatch, ex) s = String(take!(io)) lim = o.error_message_limit s = length(s) > lim ? s[1:lim] * "..." : s - length(value(o.errors)) < o.error_limit && _fit!(o.errors, s) + if length(value(o.errors)) < o.error_limit || haskey(value(o.errors), s) + _fit!(o.errors, s) + else + _fit!(o.errors, "Other (error_limit reached)") + end end function fit!(o::TryCatch{T}, y::T) where {T} diff --git a/test/test_stats.jl b/test/test_stats.jl index b3792af..3112151 100644 --- a/test/test_stats.jl +++ b/test/test_stats.jl @@ -90,8 +90,8 @@ println(" > CountMissing") data2 = Vector{Union{Missing,Float64}}(copy(y2)) data[x] .= missing data2[x2] .= missing - a, b = mergevals(CountMissing(Mean()), data, data2) - @test value(a.stat) ≈ value(b.stat) + a, b = mergestats(CountMissing(Mean()), data, data2, nobs_equals_length=false) + @test value(a) ≈ value(b) @test a.nmissing == b.nmissing end #-----------------------------------------------------------------------# CovMatrix @@ -150,6 +150,11 @@ println(" > FilterTransform") o = FilterTransform(String => (x->true) => (x -> parse(Int,x)) => Mean()) fit!(o, ["1", "3", "5"]) @test value(o) ≈ 3 + + o = FilterTransform(String => (x -> x != "1") => (x -> parse(Int,x)) => Mean()) + fit!(o, ["1", "3", "5"]) + @test value(o) ≈ 4 + @test o.nfiltered == 1 end #-----------------------------------------------------------------------# Group