Mac M2之LLaMA3-8B微调(llama3-fine-tuning)
演示
咱们这次使用 书生·浦语大模型挑战赛(春季赛)Top12,创意应用奖
的数据集,使用LLaMA3-8B大模型微调
环境
- 点击下载 LLaMA3-8B 微调代码压缩包,并解压
- 在终端进入解压后的文件夹,创建一个新的
Conda 虚拟环境
1
2
3cd llama3-ft
conda create -n llama3-ft python=3.10
conda activate llama3-ft - 安装依赖包
1
pip install -r requirements.txt
数据集
你可以直接使用 dataset/huanhuan.json
数据集(该数据集来源于 https://github.com/KMnO4-zx ),也可以自己准备数据集 ,比如你的客服对话(FAQ)
数据集,这样就可以微调一个更适合你的智能客服的模型,客服回答更准确。
数据集的格式也比较简单,示例如下:instruction
是问题,output
是回答
1 | [ |
微调
- 模型选择
我使用的是 LLM-Research/Meta-Llama-3-8B-Instruct ,你也可以选择一个其他模型,只需要修改train.py
文件里面的model_id
变量即可。
由于国内访问HuggingFace
比较困难,因此使用ModelScope
提供的模型。1
2
3
4
5
6
7
8# 需要微调的基座模型
# https://www.modelscope.cn/studios/LLM-Research/Chat_Llama-3-8B/summary
model_id = 'LLM-Research/Meta-Llama-3-8B-Instruct'
# 比如你也可以使用 Qwen1.5-4B-Chat 模型
# https://www.modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary
# model_id = 'qwen/Qwen1.5-4B-Chat' - 开始微调
只需要在项目根目录下执行以下命令即可。1
python train.py
测试
微调完成后,你可以执行以下命令启动一个 ChatBot 进行对话测试。
1 | streamlit run chat.py |
该命令执行后,会自动打开浏览器对话页面
其他说明
微调的时间会根据你的数据集大小和模型大小而定。
我由于没有 GPU,因此耗时2个小时,如果你有 GPU,大概需要 30 分钟。代码会自动下载模型,然后开始微调
微调完成后,所有的文件会保存在 models 文件夹下面,结构如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22├── models
├── checkpoint #【模型微调的 checkpoint】
│ ├── LLM-Research
│ │ └── Meta-Llama-3-8B-Instruct
│ │ ├── checkpoint-100
│ │ ├── checkpoint-200
│ │ ├── checkpoint-xxx
│ └── qwen
│ └── Qwen1.5-4B-Chat
│ ├── checkpoint-100
│ ├── checkpoint-200
│ ├── checkpoint-xxx
├── lora #【模型微调的 lora 文件】
│ ├── LLM-Research
│ │ └── Meta-Llama-3-8B-Instruct
│ └── qwen
│ └── Qwen1.5-4B-Chat
└── model #【自动下载的基座模型】
├── LLM-Research
│ └── Meta-Llama-3-8B-Instruct
└── qwen
└── Qwen1___5-4B-ChatCannot copy out of meta tensor; no data
报错1
`NotImplementedError:Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.`
解决:强制设置 device = “mps”
1
2
3# 检查CUDA是否可用,然后检查MPS是否可用,最后回退到CPU
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = "mps"