Skip to content

Commit

Permalink
Add version in requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
chakkritte committed May 11, 2022
1 parent 18f305e commit 9f95515
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 44 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ This offical implementation of PKD (Pseudo Knowledge Distillation) from On-devic
**This code is based on the implementation of [EML-NET-Saliency](https://github.com/SenJia/EML-NET-Saliency), [SimpleNet](https://github.com/samyak0210/saliency), [MSI-Net](https://github.com/alexanderkroner/saliency), and [EEEA-Net](https://github.com/chakkritte/EEEA-Net).**

## Prerequisite for server
- Tested on Ubuntu OS version 20.04.x
- Tested on Ubuntu OS version 20.04.4 LTS
- Tested on Python 3.6.13
- Tested on CUDA 11.6
- Tested on PyTorch 1.10.2 and TorchVision 0.11.3
- Tested on NVIDIA V100 32 GB (four cards)

Expand Down Expand Up @@ -55,15 +57,15 @@ PKD
### Creating new environments

```
conda create -n pkd python=3.6
conda create -n pkd python=3.6.13
conda activate pkd
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```

### Install Requirements

```
pip install -r requirements.txt
pip install -r requirements.txt --no-cache-dir
```

## Usage
Expand Down
2 changes: 0 additions & 2 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from scipy import io
import random

# seed_everything(42, workers=True)

def _get_file_list(data_path):
"""This function detects all image files within the specified parent
directory for either training or testing. The path content cannot
Expand Down
30 changes: 3 additions & 27 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from matplotlib.pyplot import imread, imsave
from scipy.io import loadmat
from scipy.ndimage import gaussian_filter
import gdown

def download_salicon(data_path):
"""Downloads the SALICON dataset. Three folders are then created that
Expand Down Expand Up @@ -41,15 +42,7 @@ def download_salicon(data_path):
session = requests.Session()

for count, url in enumerate(urls):
response = session.get(url, params={"id": id}, stream=True)
token = _get_confirm_token(response)

if token:
params = {"id": id, "confirm": token}
response = session.get(url, params=params, stream=True)

_save_response_content(response, data_path + "tmp.zip")

gdown.download(url, data_path + "tmp.zip", quiet=False)
with zipfile.ZipFile(data_path + "tmp.zip", "r") as zip_ref:
for file in zip_ref.namelist():
if "test" not in file:
Expand Down Expand Up @@ -417,21 +410,4 @@ def download_fiwi(data_path):

os.remove(data_path + "tmp.zip")

print("done!", flush=True)


def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith("download_warning"):
return value

return None

def _save_response_content(response, file_path):
chunk_size = 32768

with open(file_path, "wb") as data:
for chunk in response.iter_content(chunk_size):
if chunk:
data.write(chunk)

print("done!", flush=True)
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
parser = ArgumentParser()
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--dataset_dir", type=str, default="/home/chakkritt/proj/datasets/")
parser.add_argument("--dataset_dir", type=str, default="data/")
parser.add_argument('--input_size_h',default=384, type=int)
parser.add_argument('--input_size_w',default=384, type=int)
parser.add_argument('--no_workers',default=16, type=int)
parser.add_argument('--no_workers',default=8, type=int)
parser.add_argument('--no_epochs',default=10, type=int)
parser.add_argument('--log_interval',default=20, type=int)
parser.add_argument('--lr_sched',default=True, type=bool)
Expand Down
4 changes: 2 additions & 2 deletions main_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
parser = ArgumentParser()
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--dataset_dir", type=str, default="/home/mllab/proj/2021/supernet/data/")
parser.add_argument("--dataset_dir", type=str, default="data/")
parser.add_argument('--input_size_h',default=256, type=int)
parser.add_argument('--input_size_w',default=256, type=int)
parser.add_argument('--no_workers',default=16, type=int)
parser.add_argument('--no_workers',default=8, type=int)
parser.add_argument('--no_epochs',default=10, type=int)
parser.add_argument('--log_interval',default=20, type=int)
parser.add_argument('--lr_sched',default=True, type=bool)
Expand Down
18 changes: 10 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
darmo
timm
kornia
opencv-python
ptflops
h5py
scikit-image
requests
darmo==0.1.12
timm==0.4.12
kornia==0.6.4
opencv-python==4.5.5.64
ptflops==0.6.9
h5py==3.1.0
scikit-image==0.17.2
requests==2.27.1
matplotlib==3.3.4
gdown==4.4.0

0 comments on commit 9f95515

Please sign in to comment.