Skip to content

Commit

Permalink
New version
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Mar 9, 2024
1 parent 3884e00 commit 3c9c1d9
Showing 1 changed file with 79 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ begin
using LinearAlgebra
using MLDatasets: TemporalBrains
using CUDA
using cuDNN
end

# ╔═╡ 69d00ec8-da47-11ee-1bba-13a14e8a6db2
md"In this tutorial, we will learn how to extend the graph classification task to the case of temporal graphs, i.e., graphs whose topology and features are time-varying.
We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals.
We will design and train a simple temporal graph neural network architecture to classify subjects' gender (female or male) using the temporal graphs extracted from their brain fMRI scan signals. Given the large amount of data, we will implement the training so that it can also run on the GPU.
"

# ╔═╡ ef8406e4-117a-4cc6-9fa5-5028695b1a4f
Expand Down Expand Up @@ -69,16 +70,16 @@ function data_loader(brain_dataset)
dataset[i].tgdata.g = Float32.(Flux.onehot(graph.graph_data.g, ["F", "M"]))
end
# Split the dataset into a 80% training set and a 20% test set
train_loader = dataset[1:80]
test_loader = dataset[81:100]
train_loader = dataset[1:200]
test_loader = dataset[201:250]
return train_loader, test_loader
end

# ╔═╡ d4732340-9179-4ada-b82e-a04291d745c2
md"
The first part of the `data_loader` function calls the `mlgraph2gnngraph` function for each snapshot, which takes the graph and converts it to a `GNNGraph`. The vector of `GNNGraph`s is then rewritten to a `TemporalSnapshotsGNNGraph`.
The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph and appends the mean activation of the node of the snapshot. For the graph feature, it adds the one-hot encoding of the gender.
The second part adds the graph and node features to the temporal graphs, in particular it adds the one-hot encoding of the label of the graph (in this case we directly use the identity matrix) and appends the mean activation of the node of the snapshot (which is contained in the vector `dataset[i].snapshots[t].ndata.x`, where `i` is the index indicating the subject and `t` is the snapshot). For the graph feature, it adds the one-hot encoding of gender.
The last part splits the dataset.
"
Expand Down Expand Up @@ -140,7 +141,7 @@ end
md"
## Training
We train the model for 200 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format.
We train the model for 100 epochs, using the Adam optimizer with a learning rate of 0.001. We use the `logitbinarycrossentropy` as the loss function, which is typically used as the loss in two-class classification, where the labels are given in a one-hot format.
The accuracy expresses the number of correct classifications.
"

Expand All @@ -150,15 +151,20 @@ lossfunction(ŷ, y) = Flux.logitbinarycrossentropy(ŷ, y)
# ╔═╡ cc2ebdcf-72de-4a3b-af46-5bddab6689cc
function eval_loss_accuracy(model, data_loader)
error = mean([lossfunction(model(g), g.tgdata.g) for g in data_loader])
acc = mean([round(
100 *
mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g));
digits = 2) for g in data_loader])
acc = mean([round(100 * mean(Flux.onecold(model(g)) .== Flux.onecold(g.tgdata.g)); digits = 2) for g in data_loader])
return (loss = error, acc = acc)
end

# ╔═╡ d64be72e-8c1f-4551-b4f2-28c8b78466c0
function train(graphs; kws...)
function train(dataset; usecuda::Bool, kws...)

if usecuda && CUDA.functional() #check if GPU is available
my_device = gpu
@info "Training on GPU"
else
my_device = cpu
@info "Training on CPU"
end

function report(epoch)
train_loss, train_acc = eval_loss_accuracy(model, train_loader)
Expand All @@ -167,14 +173,16 @@ function train(graphs; kws...)
return (train_loss, train_acc, test_loss, test_acc)
end

model = GenderPredictionModel()
model = GenderPredictionModel() |> my_device

opt = Flux.setup(Adam(1.0f-3), model)

train_loader, test_loader = data_loader(graphs)
train_loader, test_loader = data_loader(dataset)
train_loader = train_loader |> my_device
test_loader = test_loader |> my_device

report(0)
for epoch in 1:200
for epoch in 1:100
for g in train_loader
grads = Flux.gradient(model) do model
= model(g)
Expand All @@ -191,19 +199,18 @@ end


# ╔═╡ 483f17ba-871c-4769-88bd-8ec781d1909d
train(brain_dataset)
train(brain_dataset; usecuda = true)

# ╔═╡ b4a3059a-db7d-47f1-9ae5-b8c3d896c5e5
md"
Training the whole dataset takes a lot of time, especially since we are working on CPU, so we only train on 100 subjects.
To speed up the training, you can see the linked example that inspired this tutorial [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/examples/graph_classification_temporalbrains.jl), where the training can also be done on the GPU with the whole dataset.
We set up the training on the GPU because training the whole dataset takes a lot of time, especially when working on the CPU.
"

# ╔═╡ cb4eed19-2658-411d-886c-e0c9c2b44219
md"
## Conclusions
In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model for 200 epochs on a small subset of the TemporalBrains dataset. The accuracy of the model is better than chance, but can be improved by training on more data and fine-tuning the parameters.
In this tutorial, we implemented a very simple architecture to classify temporal graphs in the context of gender classification using brain data. We then trained the model on the GPU for 100 epochs on the TemporalBrains dataset. The accuracy of the model is approximately 75-80%, but can be improved by fine-tuning the parameters and training on more data.
"

# ╔═╡ 00000000-0000-0000-0000-000000000001
Expand All @@ -216,12 +223,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[compat]
CUDA = "~5.2.0"
Flux = "~0.14.12"
GraphNeuralNetworks = "~0.6.17"
Flux = "~0.14.13"
GraphNeuralNetworks = "~0.6.18"
MLDatasets = "~0.7.14"
cuDNN = "~1.3.0"
"""

# ╔═╡ 00000000-0000-0000-0000-000000000002
Expand All @@ -230,7 +239,7 @@ PLUTO_MANIFEST_TOML_CONTENTS = """
julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "04cba1aff578ba316e7651fcb3c634bbcd843c14"
project_hash = "1fa6881859d6c994f6e511ad5e2c8c160c221413"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand All @@ -245,9 +254,9 @@ weakdeps = ["ChainRulesCore", "Test"]
[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "e2a9873379849ce2ac9f9fa34b0e37bde5d5fe0a"
git-tree-sha1 = "cea4ac3f5b4bc4b3000aa55afb6e5626518948fa"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.0.2"
version = "4.0.3"
weakdeps = ["StaticArrays"]
[deps.Adapt.extensions]
Expand Down Expand Up @@ -367,6 +376,12 @@ git-tree-sha1 = "8e25c009d2bf16c2c31a70a6e9e8939f7325cc84"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.11.1+0"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "75923dce4275ead3799b238e10178a68c07dbd3b"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.9.4+0"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c"
Expand Down Expand Up @@ -635,9 +650,9 @@ version = "0.8.4"
[[deps.Flux]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
git-tree-sha1 = "fd7b23aa8013a7528563d429f6eaf406f60364ed"
git-tree-sha1 = "5a626d6ef24ae0a8590c22dc12096fb65eb66325"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.14.12"
version = "0.14.13"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
Expand Down Expand Up @@ -702,9 +717,9 @@ version = "1.3.1"
[[deps.GraphNeuralNetworks]]
deps = ["Adapt", "ChainRulesCore", "DataStructures", "Flux", "Functors", "Graphs", "KrylovKit", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NearestNeighbors", "Random", "Reexport", "SparseArrays", "Statistics", "StatsBase"]
git-tree-sha1 = "c2c372f348d79c71498df4a490f0996b4c8565d8"
git-tree-sha1 = "b1c10955fdcf293160368869a86aa296bc536208"
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
version = "0.6.17"
version = "0.6.18"
weakdeps = ["CUDA"]
[deps.GraphNeuralNetworks.extensions]
Expand All @@ -729,10 +744,10 @@ version = "0.17.1"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
[[deps.HDF5_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
git-tree-sha1 = "e4591176488495bf44d7456bd73179d87d5e6eab"
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"]
git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739"
uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
version = "1.14.3+1"
version = "1.14.2+1"
[[deps.HTTP]]
deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
Expand Down Expand Up @@ -886,6 +901,12 @@ git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea"
uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586"
version = "1.0.0"
[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713"
uuid = "1d63c593-3942-5779-bab2-d838dc0a180e"
version = "15.0.7+0"
[[deps.LaTeXStrings]]
git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Expand Down Expand Up @@ -1139,10 +1160,10 @@ uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+0"
[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "PMIx_jll", "TOML", "Zlib_jll", "libevent_jll", "prrte_jll"]
git-tree-sha1 = "f46caf663e069027a06942d00dced37f1eb3d8ad"
uuid = "fe0851c0-eecd-5654-98d4-656369965a5c"
version = "4.1.6+0"
version = "5.0.2+0"
[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
Expand Down Expand Up @@ -1173,6 +1194,12 @@ git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3"
[[deps.PMIx_jll]]
deps = ["Artifacts", "Hwloc_jll", "JLLWrappers", "Libdl", "Zlib_jll", "libevent_jll"]
git-tree-sha1 = "8b3b19351fa24791f94d7ae85faf845ca1362541"
uuid = "32165bc3-0280-59bc-8c0b-c33b6203efab"
version = "4.2.7+0"
[[deps.PaddedViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
Expand Down Expand Up @@ -1210,9 +1237,9 @@ version = "1.4.3"
[[deps.PrecompileTools]]
deps = ["Preferences"]
git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f"
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.0"
version = "1.2.1"
[[deps.Preferences]]
deps = ["TOML"]
Expand Down Expand Up @@ -1602,6 +1629,12 @@ git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.5"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"]
git-tree-sha1 = "d433ec29756895512190cac9c96666d879f07b92"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
version = "1.3.0"
[[deps.libaec_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997"
Expand All @@ -1613,6 +1646,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+0"
[[deps.libevent_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "OpenSSL_jll"]
git-tree-sha1 = "f04ec6d9a186115fb38f858f05c0c4e1b7fc9dcb"
uuid = "1080aeaf-3a6a-583e-a51c-c537b09f60ec"
version = "2.1.13+1"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
Expand All @@ -1622,6 +1661,12 @@ version = "1.48.0+0"
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
[[deps.prrte_jll]]
deps = ["Artifacts", "Hwloc_jll", "JLLWrappers", "Libdl", "PMIx_jll", "libevent_jll"]
git-tree-sha1 = "5adb2d7a18a30280feb66cad6f1a1dfdca2dc7b0"
uuid = "eb928a42-fffd-568d-ab9c-3f5d54fc65b9"
version = "3.0.2+0"
"""

# ╔═╡ Cell order:
Expand Down

0 comments on commit 3c9c1d9

Please sign in to comment.