Skip to content

Commit

Permalink
call train! once with an iterator instead of 110 times with a for-loop
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaWhittemore committed Mar 9, 2019
1 parent 414501c commit 0591270
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
36 changes: 18 additions & 18 deletions other/iris/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file is machine-generated - editing it directly is not advised

[[AbstractTrees]]
deps = ["Markdown", "Test"]
git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
Expand Down Expand Up @@ -27,9 +29,9 @@ version = "0.5.3"

[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.1"
version = "0.5.2"

[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
Expand All @@ -51,9 +53,9 @@ version = "0.2.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "49269e311ffe11ac5b334681d212329002a9832a"
git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.5.1"
version = "2.0.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
Expand Down Expand Up @@ -82,7 +84,7 @@ uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "0.0.10"

[[Distributed]]
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[FixedPointNumbers]]
Expand All @@ -92,12 +94,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"

[[Flux]]
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
git-tree-sha1 = "1a683b7d156e0b1bf7909e17a6766df9f32f18e4"
repo-rev = "add-iris-dataset"
repo-url = "https://github.com/joshua-whittemore/Flux.jl"
deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DiffRules", "ForwardDiff", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "SpecialFunctions", "Statistics", "StatsBase", "Test", "ZipFile"]
git-tree-sha1 = "28e6dbf663fed71ea607414bc5f2f099d2831c0c"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.7.3+"
version = "0.7.3"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
Expand All @@ -106,14 +106,14 @@ uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3"

[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
git-tree-sha1 = "dc568a3dbc4d0505d252d104bed03710a9a39441"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.5.4"
version = "0.5.5"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down Expand Up @@ -244,19 +244,19 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
git-tree-sha1 = "8f68351fc2600bab59e68406b980b13b2100c472"
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.28.1"
version = "0.29.0"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"
version = "0.9.1"

[[URIParser]]
deps = ["Test", "Unicode"]
Expand All @@ -265,7 +265,7 @@ uuid = "30578b45-9adc-5946-b283-645ec420af67"
version = "0.4.0"

[[UUIDs]]
deps = ["Random"]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
Expand Down
1 change: 0 additions & 1 deletion other/iris/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6 changes: 4 additions & 2 deletions other/iris/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ using Pkg; Pkg.activate("."); Pkg.instantiate()
Then train and evaluate the model:

```julia

julia> include("iris.jl")
Starting training.

Accuracy: 0.92
Accuracy: 0.94

Confusion Matrix:

3×3 Array{Int64,2}:
16 0 0
0 15 2
0 16 1
0 2 15

julia>
Expand Down
8 changes: 4 additions & 4 deletions other/iris/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ loss(x, y) = crossentropy(model(x), y)
optimiser = Descent(0.5)


# Start Training.
for epoch in 1:100
Flux.train!(loss, params(model), [(X_train, y_train)], optimiser)
end
# Create iterator to train model over 110 epochs.
data_iterator = Iterators.repeated((X_train, y_train), 110)

println("Starting training.")
Flux.train!(loss, params(model), data_iterator, optimiser)

# Evaluate trained model against test set.
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
Expand Down

0 comments on commit 0591270

Please sign in to comment.