Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jan 20, 2023
1 parent b0bcc02 commit 51e8402
Showing 1 changed file with 31 additions and 28 deletions.
59 changes: 31 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,11 @@ Upgrade `jax` to the gpu version
pip install --upgrade "jax[cuda]==0.4.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### Experiments
The validation experiments are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
```
pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r experiments/requirements.txt
```

## Experiments and datasets
The N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.

QM9 is automatically downloaded and processed when running the respective experiment.

The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data)
#### Charged dataset (5 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=charged
```
#### Gravity dataset (100 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=gravity --n-balls=100
```
## Validation
N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.
The implementation is validated on all three of them, getting close results and considerably faster runtimes.

### Results
<table>
<tr>
<td></td>
Expand Down Expand Up @@ -78,19 +61,39 @@ python3 -u generate_dataset.py --simulation=gravity --n-balls=100

** padded

## Usage for validation
Validation experiments are only included in the github repo, so it needs to be cloned first.
### Validation install

The experiments are only included in the github repo, so it needs to be cloned first.
```
git clone https://github.com/gerkone/segnn-jax
```

### Nbody
#### Charged experiment
They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
```
pip3 install torch==1.12.1 --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r experiments/requirements.txt
```

### Datasets
QM9 is automatically downloaded and processed when running the respective experiment.

The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data) (it will take some time, especially n-body `gravity`)
#### Charged dataset (5 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=charged
```
#### Gravity dataset (100 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=gravity --n-balls=100
```

### Usage
#### N-body (charged)
```
python main.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-4 --weight-decay=1e-8
```

#### Gravity experiment
#### N-body (gravity)
```
python main.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=1e-4 --weight-decay=1e-8 --neighbours=5 --n-bodies=100
```
Expand All @@ -102,7 +105,7 @@ python main.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax

(configurations used in validation)


## Acknowledgments
- [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible.
- [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for the developement support.

- [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for supporting developement.

0 comments on commit 51e8402

Please sign in to comment.