Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DoubleClass committed Mar 22, 2024
1 parent d211516 commit 0c8ee63
Show file tree
Hide file tree
Showing 14 changed files with 194 additions and 687 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Git clone our repository, creating a python environment and activate it via the

```bash
git clone https://github.com/DoubleClass/GMM
cd MiniGPT-4
cd GMM
conda env create -n GMM python=3.9
conda install --yes --file requirements.txt
conda activate GMM
Expand All @@ -22,17 +22,17 @@ You can get the LLM Vicuna in [huggingface](https://huggingface.co/meta-llama/Ll
Then set the downloaded vicuna folder path [here](minigpt4/configs/models/minigpt4_vicuna0.yaml) and the initial checkpoint [here](train_configs/minigpt4_stage2_finetune.yaml#L9)

### EVA_VIT_G
The code will automatically downloading the eva_vit_g.pth, we alse put it [here](https://pan.baidu.com/s/1kyc6gp7f2CXkocljhERKVg?pwd=2mux), you can manually download it and put it in 'root/.cache/torch/hub/checkpoints/eva_vit_g.pth'
The code will automatically downloading the eva_vit_g.pth, we alse put it [here](https://pan.baidu.com/s/1kyc6gp7f2CXkocljhERKVg?pwd=2mux), you can manually download it and put it in the cache dir

### bert-base-uncased
The code will automatically downloading this, but in case you don't have access to huggingface, we also put it [here](https://pan.baidu.com/s/1XzAidcFinjsNxdz58M465w?pwd=b98f), you can manually download it and put it in '~/.cache/huggingface/hub/models--bert-base-uncased'
The code will automatically downloading this, but in case you don't have access to huggingface, we also put it [here](https://pan.baidu.com/s/1XzAidcFinjsNxdz58M465w?pwd=b98f), you can manually download it and alse put it in cache dir
### datasets
#### ImageNet-R
You can download it [here](https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar)

Then set the dataset folder path [here](clip_base/datasets.py#L281)
Then set the dataset folder path [here](clip_base/datasets.py#L134)

Besides, you need to customize the dataset for the GPT fine-tuning process. We prepare a example here you can follow. [baidu](https://pan.baidu.com/s/1xMkqOiSylWyKY74Oef4h4g?pwd=yyea)
Besides, you need to customize the dataset for the GPT fine-tuning process. We prepare a example here you can follow: [download](https://pan.baidu.com/s/1xMkqOiSylWyKY74Oef4h4g?pwd=yyea)

After downloaded the customized dataset, you can set the data root path [here](minigpt4/configs/datasets/cc_sbu/align.yaml#L7) and the indexing file [here](minigpt4/datasets/builders/image_text_pair_builder.py#L121)

Expand All @@ -48,6 +48,8 @@ python train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
## Testing
After training, you will get a model checkpoint of the last continual learning stage. put the path to scipts in eval_all.sh and specify a results directory.

Then set the results path in the [get_score_all.py](https://vscode.dev/github/DoubleClass/GMM/get_score_all.py#L1)

Run the script:

```bash
Expand Down
32 changes: 1 addition & 31 deletions batch_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from clip_base.datasets import build_cl_scenarios
from torch.utils.data import DataLoader
import clip
from tqdm import tqdm

def parse_args():
parser = argparse.ArgumentParser(description="Demo")
Expand All @@ -43,11 +44,9 @@ def parse_args():

def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
# import pdb; pdb.set_trace()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

cudnn.benchmark = False
cudnn.deterministic = True

Expand All @@ -68,9 +67,6 @@ def setup_seeds(config):


model_config = cfg.model_cfg
# import pdb; pdb.set_trace()


model_config.device_8bit = args.gpu_id
model_config.ckpt = args.ckpt_path
model_cls = registry.get_model_class(model_config.arch)
Expand Down Expand Up @@ -99,46 +95,20 @@ def get_ordered_class_name(class_order, class_name):
cfg_o, is_train=False, transforms=transforms
)

# import pdb; pdb.set_trace()
new_class_name = get_ordered_class_name(cfg_o.class_order, classes_names)

# import pdb; pdb.set_trace()



from tqdm import tqdm
# import warnings
# warnings.filterwarnings("ignore")

with open(args.txt_path, 'w') as f:

eval_loader = DataLoader(eval_dataset[:args.task_id+1], batch_size=cfg_o.batch)
names = new_class_name[:cfg_o.initial_increment + args.task_id * cfg_o.increment]
for inputs, targets, task_ids in tqdm(eval_loader):




# import pdb; pdb.set_trace()
chat_state = CONV_VISION.copy()
img_list = []
# gr_img = torch.randn(2,3,224,224)
# gr_img = item[0]
llm_message = chat.upload_img(inputs, chat_state, img_list)
# llm_message = chat.upload_img(gr_img, chat_state, img_list)
# import pdb; pdb.set_trace()
chat.ask('what is this photo of?', chat_state)
llm_message = chat.answer(conv=chat_state,img_list=img_list,num_beams=1,temperature=0.01,max_new_tokens=300,max_length=2000)[0]

# llm_message = chat.answer(conv=chat_state,img_list=img_list,num_beams=1,temperature=0.01, max_new_tokens=300,max_length=2000)[0]
# print('the label is ', new_class_name[item[1]])
# print('the message is ', llm_message)
for i in range(inputs.shape[0]):

str1 = 'the label is ' + new_class_name[targets[i]] + '\n'
str2 = 'msg: ' + llm_message[i] + '\n'
# print(str1)
# print(str2)
# import pdb; pdb.set_trace()
f.write(str1)
f.write(str2)
46 changes: 8 additions & 38 deletions clip_base/datasets.py

Large diffs are not rendered by default.

Loading

0 comments on commit 0c8ee63

Please sign in to comment.