Skip to content

Commit

Permalink
Update heatmap asset generation (#5)
Browse files Browse the repository at this point in the history
* Update heatmap asset generation

* Update CI workflows
  • Loading branch information
adrhill authored Oct 14, 2024
1 parent 13c0cb4 commit 2431f32
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
13 changes: 2 additions & 11 deletions .github/workflows/Assets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,10 @@ jobs:
pages: write
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v2
- name: Install dependencies
run: julia --project=assets --color=yes -e 'using Pkg; Pkg.instantiate()'
- name: Aggregate and deploy
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/MultiDocumenter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ jobs:
pages: write
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: '1.8'
version: '1'
- uses: actions/cache@v4
env:
cache-name: cache-artifacts
Expand Down
11 changes: 10 additions & 1 deletion assets/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
RelevancePropagation = "0be6dd02-ae9e-43eb-b318-c6e81d6890d8"
TextHeatmaps = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
VisionHeatmaps = "27106da1-f8bc-4ca8-8c66-9b8289f1e035"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ExplainableAI = "0.9"
RelevancePropagation = "3"
VisionHeatmaps = "1"
XAIBase = "4"
17 changes: 10 additions & 7 deletions assets/make_assets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# to generate up-to-date heatmaps for the docs and READMEs of the Julia-XAI ecosystem.
using ExplainableAI
using RelevancePropagation
using Metalhead # pre-trained vision models
using HTTP, FileIO, ImageMagick # load image from URL
using DataAugmentation # preprocess image
using VisionHeatmaps # visualization of explanations as heatmaps
using Zygote # load autodiff backend for gradient-based methods
using Flux, Metalhead # pre-trained vision models in Flux
using DataAugmentation # input preprocessing
using HTTP, FileIO, ImageIO # load image from URL

assets_dir = "assets/heatmaps"

Expand All @@ -24,10 +26,11 @@ PYTORCH_MEAN = (0.485, 0.456, 0.406)
PYTORCH_STD = (0.229, 0.224, 0.225)

# Preprocess input
tfm =
CenterResizeCrop((224, 224)) |> ImageToTensor() |> Normalize(PYTORCH_MEAN, PYTORCH_STD)
input_data = apply(tfm, Image(img))
input = view(PermutedDimsArray(input_data.data, (2, 1, 3)), :, :, :, :);
mean = (0.485f0, 0.456f0, 0.406f0)
std = (0.229f0, 0.224f0, 0.225f0)
tfm = CenterResizeCrop((224, 224)) |> ImageToTensor() |> Normalize(mean, std)
input = apply(tfm, Image(img)) # apply DataAugmentation transform
input = reshape(input.data, 224, 224, 3, :) # unpack data and add batch dimension

# Assert model weights are loaded correctly
n_castle = 484
Expand Down

0 comments on commit 2431f32

Please sign in to comment.