-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
32 lines (29 loc) · 1 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from typing import List
from dataclasses import dataclass, field
class Config:
# model 参数 ###########################
# 文本生成模型,下载地址 https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat
gpt_model = "E:\\ai_model\\model\\qwen0.5"
data_path = "E:\\ai_model\\RLHF\\my_code\\RLHF_ORPO/data/train_data.json"
save_lora_path = "E:\\ai_model\\model\\orpo\\save_lora"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 2
epochs = 30
lr = 0.001
# ORPO 参数 ############################
alpha = 3
@dataclass
class LoraArguments:
lora_r: int = 2
lora_alpha: int = 8
lora_dropout: float = 0
lora_target_modules: List[str] = field(
default_factory=lambda: ['k_proj', 'v_proj']
)
# lora_target_modules = None
lora_weight_path: str = ""
q_lora: bool = False
load_in_4bit: bool = False
load_in_8bit: bool = False
is_reload_trained_params = True # 是否接着上次训练模型继续训练