diff --git a/src/utilities.jl b/src/utilities.jl index 7288c30e..969fce4c 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -135,20 +135,6 @@ function recursive_setproperty!(obj, ex::Expr, value) return recursive_setproperty!(last_obj, field, value) end -""" - check_dimensions(X, Y) - -Internal function to check two arrays have the same shape. - -""" -@inline function check_dimensions(X, Y) - size(X) == size(Y) || - throw(DimensionMismatch( - "Encountered two objects with sizes $(size(X)) and "* - "$(size(Y)) which needed to match but don't. ")) - return nothing -end - """ check_same_nrows(X, Y) diff --git a/test/composition/learning_networks/replace.jl b/test/composition/learning_networks/replace.jl index fab3f16c..6186bf9c 100644 --- a/test/composition/learning_networks/replace.jl +++ b/test/composition/learning_networks/replace.jl @@ -33,8 +33,6 @@ zhat = inverse_transform(standM, uhat) yhat = exp(zhat) enode = @node mae(ys, yhat) -_header(accel) = - @testset "replace() method; $(typeof(accel))" for accel in (CPU1(), CPUThreads()) fit!(yhat, verbosity=0, acceleration=accel) @@ -50,15 +48,12 @@ _header(accel) = knn2 = deepcopy(knn) # duplicate the network with `yhat` as glb: - yhat_clone = @test_logs( - (:warn, r"No replacement"), - replace( - yhat, - hot=>hot2, - knn=>knn2, - ys=>source(42); - copy_models_deeply=false, - ), + yhat_clone = replace( + yhat, + hot=>hot2, + knn=>knn2, + ys=>source(42); + copy_unspecified_deeply=false, ) # test models and sources duplicated correctly: @@ -79,16 +74,13 @@ _header(accel) = @test all(isempty, sources(yhat_ser)) # duplicate a signature: - signature = (predict=yhat, report=(mae=enode,)) |> MLJBase.signature - signature_clone = @test_logs( - (:warn, r"No replacement"), - replace( - signature, - hot=>hot2, - knn=>knn2, - ys=>source(42); - copy_models_deeply=false, - ) + signature = (predict=yhat, report=(mae=enode,)) |> MLJBase.Signature + signature_clone = replace( + signature, + hot=>hot2, + knn=>knn2, + ys=>source(2*y); + copy_unspecified_deeply=false, ) glb_node = glb(signature_clone) models_clone = MLJBase.models(glb_node) @@ -97,28 +89,20 @@ _header(accel) = @test models_clone[3] === hot2 sources_clone = sources(glb_node) @test sources_clone[1]() == X - @test sources_clone[2]() === 42 + @test sources_clone[2]() == 2*y + + # warning thrown + @test_logs( + (:warn, r"No replacement"), + replace( + signature, + hot=>hot2, + knn=>knn2, + ys=>source(2*y); + ), + ) - # duplicate a learning network machine: - mach = machine(Deterministic(), Xs, ys; - predict=yhat, - report=(mae=enode,)) - mach2 = replace(mach, hot=>hot2, knn=>knn2, - ys=>source(ys.data); - empty_unspecified_sources=true) - ss = sources(glb(mach2)) - @test isempty(ss[1]) - mach2 = @test_logs((:warn, r"No replacement"), - replace(mach, hot=>hot2, knn=>knn2, - ys=>source(ys.data))) - yhat2 = mach2.fitresult.predict - fit!(mach, verbosity=0) - fit!(mach2, verbosity=0) - @test predict(mach, X) ≈ predict(mach2, X) - @test report(mach).mae ≈ report(mach2).mae - - @test mach2.args[1]() == Xs() - @test mach2.args[2]() == ys() + yhat2 = MLJBase.operation_nodes(signature_clone).predict ## EXTRA TESTS FOR TRAINING SEQUENCE @@ -141,9 +125,7 @@ _header(accel) = @test length(MLJBase.machines(yhat)) == length(MLJBase.machines(yhat2)) @test MLJBase.models(yhat) == MLJBase.models(yhat2) - @test sources(yhat) == sources(yhat2) - @test MLJBase.tree(yhat) == MLJBase.tree(yhat2) - @test yhat() ≈ yhat2() + @test 2yhat() ≈ yhat2() # this change should trigger retraining of all machines except the # univariate standardizer: @@ -159,7 +141,6 @@ _header(accel) = (:train, oakM2), (:train, knnM2)]) end - end # module true diff --git a/test/utilities.jl b/test/utilities.jl index 03be2877..5356ce66 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -205,6 +205,16 @@ MLJBase.target_scitype(::Type{<:DRegressor2}) = @test MLJBase.guess_model_target_observation_scitype(DRegressor2()) == Continuous +@testset "pretty" begin + X = (x=fill(1, 3), y=fill(2, 3)) + io = IOBuffer() + pretty(X) + pretty(io, X) + str = take!(io) |> String + @test contains(str, "x") + @test contains(str, "y") + @test contains(str, "│") +end end # module true