Skip to content

Latest commit

 

History

History
17 lines (15 loc) · 508 Bytes

README.md

File metadata and controls

17 lines (15 loc) · 508 Bytes

JSP-GFN

Installation

To avoid any conflict with your existing Python setup, we suggest to work in a virtual environment:

python -m venv venv
source venv/bin/activate

Follow these instructions to install the version of JAX corresponding to your versions of CUDA and CuDNN.

pip install -r requirements.txt

Example

python train.py --batch_size 256 --lr 1e-4 --params_num_samples 64 --model lingauss_diag --artifact ...