Skip to content

Commit

Permalink
update the evaluation guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoziheng committed Nov 6, 2024
1 parent 57552d6 commit 7a519f6
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 28 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ The input image should be with shape `H,W,D` Our data process code will normaliz

## Train Guidance:
Some preparation before start the training:
1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), a jsonl containing all the training samples is required.
1. you need to build your training data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main), specifically, from step 1 to step 5. A jsonl containing all the training samples is required.
2. you need to fetch the text encoder checkpoint from https://huggingface.co/zzh99/SAT to generate prompts.
Our recommendation for training SAT-Nano is 8 or more A100-80G, for SAT-Pro is 16 or more A100-80G. Please use the slurm script in `sh/` to start the training process. Take SAT-Pro for example:
```
sbatch sh/train_sat_pro.sh
```



## TODO
- [ ] Inference demo on website.
- [x] Release the data preprocess code to build SAT-DS.
- [x] Release the train guidance.
## Evaluation Guidance:
This also requires to build test data following this [repo](https://github.com/zhaoziheng/SAT-DS/tree/main).
You may refer to the slurm script `sh/evaluate_sat_pro.sh` to start the evaluation process:
```
sbatch sh/evaluate_sat_pro.sh
```

## Citation
If you use this code for your research or project, please cite:
Expand Down
9 changes: 5 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def main(args):
evaluated_samples.add(f'{line[0]}_{line[2]}')

# dataset and loader
testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
if args.online_crop:
testset = Evaluate_Dataset_OnlineCrop(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
else:
testset = Evaluate_Dataset(args.datasets_jsonl, args.max_queries, args.batchsize_3d, args.crop_size, evaluated_samples)
sampler = DistributedSampler(testset)
testloader = DataLoader(testset, sampler=sampler, batch_size=1, pin_memory=args.pin_memory, num_workers=args.num_workers, collate_fn=collate_fn, shuffle=False)
sampler.set_epoch(0)
Expand Down Expand Up @@ -106,9 +109,7 @@ def main(args):
save_interval=args.save_interval,
dice_score=args.dice,
nsd_score=args.nsd,
visualization=args.visualization,
region_split_json=args.region_split_json,
label_statistic_json=args.label_statistic_json)
visualization=args.visualization)

if __name__ == '__main__':
# get configs
Expand Down
13 changes: 7 additions & 6 deletions evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def evaluate(model,
csv_path,
resume,
save_interval,
visualization,
region_split_json,
label_statistic_json):
visualization):

# if to store pred、gt、img (as nii.gz
if visualization:
Expand Down Expand Up @@ -376,13 +374,16 @@ def evaluate(model,
name = name[len(name)-31:]
df.to_excel(writer, sheet_name=name, index=True)

avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)
# avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path)

os.remove(csv_path.replace('.csv', '.pkl'))

else:
avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0

return avg_dice_over_merged_labels, avg_nsd_over_merged_labels
pass

# avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0

return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels


18 changes: 7 additions & 11 deletions evaluate/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,6 @@ def parse_args():
type=str2bool,
default=True,
)
parser.add_argument(
"--region_split_json",
type=str,
default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json',
)
parser.add_argument(
"--label_statistic_json",
type=str,
default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(72).json',
)

# Med SAM Dataset

Expand All @@ -84,6 +74,12 @@ def parse_args():

# Sampler and Loader

parser.add_argument(
"--online_crop",
type=str2bool,
default='False',
help='load pre-cropped image patches directly, or crop online',
)
parser.add_argument(
"--crop_size",
type=int,
Expand Down Expand Up @@ -133,7 +129,7 @@ def parse_args():
parser.add_argument(
"--vision_backbone",
type=str,
help='UNETs UMamba or SwinUNETR'
help='UNET UNET-L UMamba or SwinUNETR'
)
parser.add_argument(
"--patch_size",
Expand Down
50 changes: 50 additions & 0 deletions sh/evaluate_sat_pro.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash
#SBATCH --job-name=eval_pro
#SBATCH --quotatype=auto
#SBATCH --partition=medai
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=16
#SBATCH --mem-per-cpu=128G
#SBATCH --chdir=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch
#SBATCH --output=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.out
#SBATCH --error=/mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/log/sbatch/%x-%j.error
###SBATCH -w SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...]
###SBATCH -x SH-IDC1-10-140-0-[...], SH-IDC1-10-140-1-[...]

export NCCL_DEBUG=INFO
export NCCL_IBEXT_DISABLE=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=eth0
echo NODELIST=${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
MASTER_PORT=$((RANDOM % 101 + 20000))
echo "MASTER_ADDR="$MASTER_ADDR

srun torchrun \
--nnodes 1 \
--nproc_per_node 1 \
--rdzv_id 100 \
--rdzv_backend c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT /mnt/petrelfs/zhaoziheng/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/evaluate.py \
--rcd_dir 'your_rcd_dir' \
--rcd_file 'your_rcd_file_name' \
--resume False \
--visualization False \
--deep_supervision False \
--datasets_jsonl 'jsonl generated from SAT-DS Step 4' \
--crop_size 288 288 96 \
--online_crop True \
--vision_backbone 'UNET-L' \
--checkpoint 'your ckpt' \
--partial_load True \
--text_encoder 'ours' \
--text_encoder_checkpoint 'your text encoder ckpt' \
--batchsize_3d 2 \
--max_queries 256 \
--pin_memory False \
--num_workers 4 \
--dice True \
--nsd True

0 comments on commit 7a519f6

Please sign in to comment.