欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 社会 > unsloth 在medical-o1-reasoning-SFT对话数据集上进行指令微调

unsloth 在medical-o1-reasoning-SFT对话数据集上进行指令微调

2025/6/22 22:56:49 来源:https://blog.csdn.net/u011564831/article/details/148811791  浏览:    关键词:unsloth 在medical-o1-reasoning-SFT对话数据集上进行指令微调

简述

medical-o1-reasoning-SFT 数据集上使用 Unsloth 进行 指令微调(Instruction Tuning)

medical-o1-reasoning-SFT 是一个带有医学问题 + 推理步骤(CoT)+ 最终答复的数据集,涵盖中文、英文,且问题来源权威(医考/临床)。使用指令微调,可以让模型学会:

  1. 遵循人类提问指令(Instruction Following)

  2. 进行逐步医学推理(Reasoning)

  3. 输出可信、专业、结构化的答复

使用 Unsloth 进行指令微调,主要有三大实际意义:

关键点Unsloth 的优势对微调效果的影响
✅ 超长上下文支持支持 32K+ tokens,适合长 CoT完整吸收推理链条,提升 reasoning 能力
✅ 训练速度极快FlashAttention2 + Triton 加速用同样算力训练更大数据,提高收敛质量
✅ 低成本高性能 SFT内置 LoRA,节省参数与显存适配中小型显卡(20GB 也能训练)

关于unsloth和peft

对比与关系

项目UnslothPEFT
类型完整的模型训练加速库 + LoRA 微调封装专注于参数高效微调(LoRA、Prompt-Tuning等)
作者Unsloth.ai(专注高性能SFT训练)Hugging Face 官方
核心功能- 高速 SFT/QLoRA
- Flash-Attn2、Triton 支持
- 最多训练 128K context
- 封装各种 PEFT 方法(LoRA, PrefixTuning等)
模型支持LLaMA、Mistral、Qwen、Gemma 等几乎所有 transformers 支持的模型
性能优势极致速度优化(PyTorch 2.2、flash-attn、fused kernel)易集成性强,灵活通用,但性能略逊一筹

PEFT 更像是微调算法的工具箱,而 Unsloth 是一个极致优化的 “SFT 微调引擎”,集成了 PEFT 的核心能力(比如 LoRA),但对训练过程进行了大量性能级优化

要处理的 FreedomIntelligence/medical-o1-reasoning-SFT 数据集是典型的 SFT + 长文本 CoT + 生成任务,这时候 Unsloth 的以下优势特别明显:

1. 支持超长输入上下文

  • 医疗推理数据中每条样本可能包含较长的 chain-of-thought(CoT),一般在 2048-8192 tokens 左右。

  • Unsloth 专门优化了对 32K、64K、128K上下文长度 的支持,远优于普通 transformers + peft

2. 训练速度极快

  • 基于 Flash Attention 2、Triton、xFormers 等内核技术,训练速度比 HuggingFace 官方快 5-8 倍,尤其适合你这种大数据训练场景。

3. 默认集成 PEFT(LoRA)

  • 你只需要一句代码就能启用 LoRA,而不需要再显式调用 PEFT:

    model = FastLanguageModel.get_peft_model(...)

4. 兼容所有 HuggingFace 权重

  • medical-o1 是针对 HuatuoGPT-o1 微调的,而它本身就是 transformers 模型。Unsloth 100% 兼容这些格式。

5. 训练代码更简洁

Unsloth 封装非常好,几行代码即可运行完整 SFT:

from unsloth import FastLanguageModelmodel, tokenizer = FastLanguageModel.from_pretrained(...)
model = FastLanguageModel.get_peft_model(model, ...)
model.fit(...)  # 开始训练

推荐建议

条件推荐方式
只需微调少量 LoRA 参数,模型小、数据短PEFT(例如 8-bit LLaMA)
数据包含长 CoT 推理,训练集大,目标是高性能 SFT(如 medical-o1)Unsloth + LoRA
有强化学习(如 PPO)可先 SFT 用 Unsloth,然后 RLHF 用 TRL/HF

关于数据集

huggingface上数据集地址

https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT/tree/mainhttps://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT/tree/main

FreedomIntelligence/medical‑o1‑reasoning‑SFT 是一个专为 微调医学大语言模型 而构建的监督学习(SFT)数据集,用于提升其“病例推理 → 生成答案”的能力:

  • Question:真实的医学问题(来自医学考试等可验证来源)

  • Complex_CoT:GPT‑4o 自动生成的“链式思考”步骤,表示模型的推理链条

  • Response:最终的答案或建议

数据版本包含内容适用场景
medical_o1_sft.json纯医学问答(CoT + Answer)专注医学领域、节省算力资源
medical_o1_sft_mix.json医学 + 通用指令训练数据同时提升医学表现和通用对话能力

这个数据集的主要价值在于:

  • 采用 GPT‑4o + 医学验证器 确保推理和答案的准确性 github.com+3huggingface.co+3hyper.ai+3

  • 通过包含完整的 CoT,可显著提升下游模型对复杂医学问题的理解与应对能力 reddit.com

  • 被广泛用于微调 SFT 阶段,后续还可结合 PPO 强化学习进一步优化模型 oxen.ai+10github.com+10blog.csdn.net+10

  • 当前版本包含约 90 k 条中英文问答对(en 约 19 k 条,zh 约20 k 条),单条记录包括 Question + Complex_CoT + Response,供开发者直接用于训练或分析

使用工具包依赖版本

transformers              4.52.4
triton                    3.3.0
trl                       0.18.2
typeguard                 4.4.4
typing_extensions         4.14.0
typing-inspection         0.4.1
tyro                      0.9.24
tzdata                    2025.2
unsloth                   2025.6.3
unsloth_zoo               2025.6.2
urllib3                   2.5.0
wandb                     0.20.1
watchdog                  6.0.0
wheel                     0.45.1
xformers                  0.0.30

训练

"""
第1步:初始化设置和登录
设置访问令牌并登录到HuggingFace和Weights&Biases平台
(略)
""""""
第2步:加载模型和分词器
使用unsloth优化的FastLanguageModel加载预训练模型
"""
from unsloth import FastLanguageModel# 模型配置参数
max_seq_length = 2048  # 最大序列长度
dtype = None          # 数据类型,None表示自动选择
load_in_4bit = True   # 使用4bit量化加载模型以节省显存# 加载预训练模型和分词器
model, tokenizer = FastLanguageModel.from_pretrained(#model_name = "unsloth/DeepSeek-R1-Distill-Qwen-7B",model_name = "/opt/chenrui/qwq32b/base_model/qwen2-7b",max_seq_length = max_seq_length,dtype = dtype,load_in_4bit = load_in_4bit,#token = hf_token,
)"""
第3步:定义提示模板和进行微调前的推理测试
"""
prompt_style = """以下是描述任务的指令,以及提供更多上下文的输入。
请写出恰当完成该请求的回答。
在回答之前,请仔细思考问题,并创建一个逐步的思维链,以确保回答合乎逻辑且准确。### Instruction:
你是一位在临床推理、诊断和治疗计划方面具有专业知识的医学专家。
请回答以下医学问题。### Question:
{}### Response:
<think>{}"""# 测试用医学问题
question = "一位61岁的女性,长期存在咳嗽或打喷嚏等活动时不自主尿失禁的病史,但夜间无漏尿。她接受了妇科检查和Q-tip测试。基于这些发现,膀胱测压最可能显示她的残余尿量和逼尿肌收缩情况如何?"# 设置模型为推理模式
FastLanguageModel.for_inference(model)
inputs = tokenizer([prompt_style.format(question, "")], return_tensors="pt").to("cuda")# 生成回答
outputs = model.generate(input_ids=inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=1200,use_cache=True,
)
response = tokenizer.batch_decode(outputs)
print("### 微调前模型推理结果:")
print(response[0].split("### Response:")[1])"""
第4步:数据集处理函数
"""
train_prompt_style = """以下是描述任务的指令,以及提供更多上下文的输入。请写出恰当完成该请求的回答。在回答之前,请仔细思考问题,并创建一个逐步的思维链,以确保回答合乎逻辑且准确。### Instruction:你是一位在临床推理、诊断和治疗计划方面具有专业知识的医学专家。请回答以下医学问题。### Question:{}### Response:<think>{}</think>{}"""EOS_TOKEN = tokenizer.eos_token  # 添加结束符标记#格式化提示函数,用于处理数据集中的示例
def formatting_prompts_func(examples):# 从examples中提取问题、思维链和回答inputs = examples["Question"]      # 医学问题列表cots = examples["Complex_CoT"]     # 思维链列表outputs = examples["Response"]     # 回答列表# 存储格式化后的文本texts = []# 遍历每个示例,将问题、思维链和回答组合成指定格式for input, cot, output in zip(inputs, cots, outputs):# 使用train_prompt_style模板格式化文本,并添加结束符text = train_prompt_style.format(input, cot, output) + EOS_TOKENtexts.append(text)# 返回格式化后的文本字典return {"text": texts,}# 加载数据集并应用格式化
from datasets import load_dataset
dataset = load_dataset("json",data_files="/opt/chenrui/chatdoctor/dataset/medical_o1_sft.jsonl",split="train",trust_remote_code=True,
)
dataset = dataset.map(formatting_prompts_func, batched = True,)"""
第5步:配置LoRA微调参数
使用LoRA技术进行参数高效微调
"""
FastLanguageModel.for_training(model)model = FastLanguageModel.get_peft_model(# 原始模型model,# LoRA秩,用于控制低秩矩阵的维度,值越大表示可训练参数越多,模型性能可能更好但训练开销更大# 建议: 8-32之间r=16,# 需要应用LoRA的目标模块列表target_modules=["q_proj", "k_proj", "v_proj", "o_proj",  # attention相关层"gate_proj", "up_proj", "down_proj",     # FFN相关层],# LoRA缩放因子,用于控制LoRA更新的幅度。值越大,LoRA的更新影响越大。lora_alpha=16,# LoRA层的dropout率,用于防止过拟合,这里设为0表示不使用dropout。# 如果数据集较小,建议设置0.1左右。lora_dropout=0,# 是否对bias参数进行微调,none表示不微调bias# none: 不微调偏置参数;# all: 微调所有参数;# lora_only: 只微调LoRA参数。bias="none",# 是否使用梯度检查点技术节省显存,使用unsloth优化版本# 会略微降低训练速度,但可以显著减少显存使用use_gradient_checkpointing="unsloth",# 随机数种子,用于结果复现random_state=3407,# 是否使用rank-stabilized LoRA,这里不使用# 会略微降低训练速度,但可以显著减少显存使用use_rslora=False,# LoFTQ配置,这里不使用该量化技术,用于进一步压缩模型大小loftq_config=None,
)"""
第6步:配置训练参数和初始化训练器
"""
from trl import SFTTrainer  # 用于监督微调的训练器
from transformers import TrainingArguments  # 用于配置训练参数
from unsloth import is_bfloat16_supported  # 检查是否支持bfloat16精度训练# 初始化SFT训练器
trainer = SFTTrainer(model=model,  # 待训练的模型tokenizer=tokenizer,  # 分词器train_dataset=dataset,  # 训练数据集dataset_text_field="text",  # 数据集字段的名称max_seq_length=max_seq_length,  # 最大序列长度dataset_num_proc=2,  # 数据集处理的并行进程数,提高CPU利用率args=TrainingArguments(per_device_train_batch_size=2,  # 每个GPU的训练批次大小gradient_accumulation_steps=4,   # 梯度累积步数,用于模拟更大的batch sizewarmup_steps=5,  # 预热步数,逐步增加学习率learning_rate=2e-4,  # 学习率lr_scheduler_type="linear",  # 线性学习率调度器max_steps=60,    # 最大训练步数(一步 = 处理一个batch的数据)# 根据硬件支持选择训练精度fp16=not is_bfloat16_supported(),  # 如果不支持bf16则使用fp16bf16=is_bfloat16_supported(),      # 如果支持则使用bf16logging_steps=10,  # 每10步记录一次日志optim="adamw_8bit",  # 使用8位AdamW优化器节省显存,几乎不影响训练效果weight_decay=0.01,   # 权重衰减系数,用于正则化,防止过拟合seed=3407,  # 随机数种子output_dir="outputs",  # 保存模型检查点和训练日志),
)"""
第7步 开始训练
"""
trainer.train()

1 加载预训练模型与分词器

  • 使用 FastLanguageModel.from_pretrained() 加载本地 Qwen‑7B 基础模型,同时启用了 4‑bit 量化(load_in_4bit=True),极大节省显存。

  • max_seq_length=2048 支持长上下文,适合处理复杂医学推理任务。

  • 4‑bit 量化虽然可能略微影响精度,但能让 7B 参数模型在单卡或中小 GPU 上可行。

2 定义了带有 “链式思维” 格式的 prompt 模板:

先载入模型到推理模式 (FastLanguageModel.for_inference(model)),用一个现实的临床问题执行生成测试,评估 baseline 的推理能力,作为后续微调对比。

在回答之前,请仔细思考问题,并创建一个逐步的思维链,以确保回答合乎逻辑且准确。

3  训练数据格式化

  • 设置了 template,包含 <think>…</think> 区分“思维链”与最终回答。

  • 使用 map() 将原 SFT 数据集(含 Question / CoT / Response)格式化为完整 prompt,并添加 EOS token。

  • 定制化 prompt 结构为下游提供 清晰、可区分的训练信号

4  LoRA 参数高效微调,只微调少量参数(≈0.1–1%),通过 LoRA 实现 SFT 微调但成本大幅下降,适合资源受限环境。

  • 锁定 attention 与 FFN 模块的关键子模块 (q_proj, k_proj, v_proj, gate_proj, 等) 做低秩微调。

  • 使用 r=16, lora_alpha=16, dropout=0 初步设置合理 trade‑off。

  • bias="none" 限制偏置被微调,减少训练复杂度和过拟合风险。

  • 启用 use_gradient_checkpointing="unsloth" 进一步节省显存。

5  Trainer 配置细节

  • 使用 SFTTrainer 搭配 transformers.TrainingArguments

    • 微批次 per GPU 设为 2,配合 gradient_accumulation_steps=4 实际等效批次大小为 8。

    • 步数 max_steps=60,小批次数据快速迭代 SFT。

    • 混合精度:优先使用 bf16,如果不支持则 fallback fp16。

    • 使用 8‑bit AdamW 优化器,内存更友好。

    • 学习率 2e‑4、线性 warmup + decay,利于稳定训练。

  • 调用 trainer.train(),LoRA 参数更新。

  • 训练周期短(60 步),适用于快速迭代与验证 prompt / 模型组合效果。

  • 每 10 步 log 一次,有助于及时观察 loss、生成样本及指标变化。

开启训练

生成lora部分adapter模型权重和基座权重进行合并

new_model_local = "./Medical-COT-Qwen-7B"
model.save_pretrained(new_model_local)

模型问答测试

使用streamlit 启动web页面测试问答

import streamlit as st
import torch
from unsloth import FastLanguageModel
import os
import random
import numpy as np
import re# 禁用 wandb
os.environ["WANDB_DISABLED"] = "true"# 模型配置参数
max_seq_length = 2048
dtype = None
load_in_4bit = True# 提示模板
SYSTEM_PROMPT = """你是一位在临床推理、诊断和治疗计划方面具有专业知识的医学专家。请回答以下医学问题,并提供详细的推理过程。格式要求:
<reasoning>...</reasoning>
<answer>...</answer>"""# Streamlit 页面配置
st.set_page_config(page_title="医疗问答 Web Demo", page_icon="🏥", layout="wide", initial_sidebar_state="collapsed")# 自定义 CSS 样式
st.markdown("""<style>/* 按钮样式 */.stButton button {border-radius: 50% !important;width: 32px !important;height: 32px !important;padding: 0 !important;background-color: transparent !important;border: 1px solid #ddd !important;display: flex !important;align-items: center !important;justify-content: center !important;font-size: 14px !important;color: #666 !important;margin: 5px 10px 5px 0 !important;}.stButton button:hover {border-color: #999 !important;color: #333 !important;background-color: #f5f5f5 !important;}.stMainBlockContainer > div:first-child {margin-top: -50px !important;}.stApp > div:last-child {margin-bottom: -35px !important;}</style>
""", unsafe_allow_html=True)# 处理模型输出(添加可折叠的推理内容)
def process_assistant_content(content):if '<reasoning>' in content and '</reasoning>' in content:content = re.sub(r'(<reasoning>)(.*?)(</reasoning>)',r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\2</details>',content,flags=re.DOTALL)elif '<reasoning>' in content and '</reasoning>' not in content:content = re.sub(r'<reasoning>(.*?)$',r'<details open style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理中...</summary>\1</details>',content,flags=re.DOTALL)elif '<reasoning>' not in content and '</reasoning>' in content:content = re.sub(r'(.*?)</reasoning>',r'<details style="font-style: italic; background: rgba(222, 222, 222, 0.5); padding: 10px; border-radius: 10px;"><summary style="font-weight:bold;">推理内容(展开)</summary>\1</details>',content,flags=re.DOTALL)return content# 加载模型和分词器
@st.cache_resource
def load_model_and_tokenizer():model_path = "/opt/chenrui/chatdoctor/Medical-COT-Qwen-7B"# 检查目录是否存在if not os.path.exists(model_path):st.error(f"模型目录 '{model_path}' 不存在。请检查路径或重新运行训练脚本。")raise FileNotFoundError(f"Model directory '{model_path}' does not exist.")try:# 加载合并模型model, tokenizer = FastLanguageModel.from_pretrained(model_name=model_path,max_seq_length=max_seq_length,dtype=dtype,load_in_4bit=load_in_4bit,local_files_only=True)st.success("成功加载合并模型!")st.write(f"Model class: {type(model)}")st.write(f"Model config model_type: {model.config.model_type}")st.write(f"Tokenizer pad_token_id: {tokenizer.pad_token_id}")st.write(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}")# 确保 pad_token_id 已设置if tokenizer.pad_token_id is None:tokenizer.pad_token = tokenizer.eos_tokentokenizer.pad_token_id = tokenizer.eos_token_idst.write(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")FastLanguageModel.for_inference(model)return model, tokenizerexcept Exception as e:st.error(f"加载合并模型失败: {str(e)}")st.error(f"请检查 '{model_path}' 是否包含有效模型文件,或重新运行训练脚本。")raisemodel, tokenizer = load_model_and_tokenizer()# 设置随机种子
def setup_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 侧边栏参数调整
st.sidebar.title("模型参数调整")
st.session_state.history_chat_num = st.sidebar.slider("历史对话轮数", 0, 6, 0, step=2)
st.session_state.max_new_tokens = st.sidebar.slider("最大生成长度", 256, 8192, 1200, step=1)
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)# 侧边栏信息
st.sidebar.header("关于")
st.sidebar.markdown("""
- **模型**:Qwen2-7B(微调后)
- **数据集**:医疗问答数据集
- **技术栈**:Streamlit, Unsloth, PyTorch
- **功能**:支持生成医疗推理和回答
""")# 页面标题和标语
st.markdown(f'<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 0; padding: 0;">'f'<div style="font-style: italic; font-weight: 900; margin: 0; padding-top: 4px; display: flex; align-items: center; justify-content: center; flex-wrap: wrap; width: 100%;">'f'<span style="font-size: 26px; margin-left: 10px;">Hi, I\'m Medical-CoT-Qwen-7B</span>''</div>''<span style="color: #bbb; font-style: italic; margin-top: 6px; margin-bottom: 10px;">内容完全由AI生成,请务必仔细甄别<br>Content AI-generated, please discern with care</span>''</div>',unsafe_allow_html=True
)
st.image("icon.jpg", width=45)# 初始化对话历史
if "messages" not in st.session_state:st.session_state.messages = []st.session_state.chat_messages = []# 显示历史对话
for i, message in enumerate(st.session_state.messages):if message["role"] == "assistant":with st.chat_message("assistant", avatar="icon.jpg"):st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)if st.button("×", key=f"delete_{i}"):st.session_state.messages.pop(i)st.session_state.messages.pop(i - 1)  # 删除对应的用户消息st.session_state.chat_messages.pop(i)st.session_state.chat_messages.pop(i - 1)st.rerun()else:st.markdown(f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px; background-color: gray; border-radius: 10px; color: white;">{message["content"]}</div></div>',unsafe_allow_html=True)# 用户输入和生成
prompt = st.chat_input(key="input",placeholder="请输入医学问题")if prompt:# 显示用户输入st.markdown(f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px; background-color: gray; border-radius: 10px; color: white;">{prompt}</div></div>',unsafe_allow_html=True)st.session_state.messages.append({"role": "user", "content": prompt})st.session_state.chat_messages.append({"role": "user", "content": prompt})# 格式化输入chat_history = st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]chat_history[-1]["content"] = f"{SYSTEM_PROMPT}\n\n### Question:\n{prompt}\n\n### Response:\n<reasoning></reasoning>\n<answer></answer>"new_prompt = tokenizer.apply_chat_template(chat_history,tokenize=False,add_generation_prompt=True)# 生成回答with st.chat_message("assistant", avatar="icon.jpg"):placeholder = st.empty()with st.spinner("生成回答中..."):random_seed = random.randint(0, 2**32 - 1)setup_seed(random_seed)inputs = tokenizer([new_prompt], return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length).to("cuda")try:outputs = model.generate(input_ids=inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=st.session_state.max_new_tokens,temperature=st.session_state.temperature,top_p=st.session_state.top_p,use_cache=True,do_sample=True)response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]# 提取 ### Response: 后的内容if "### Response:" in response:assistant_answer = response.split("### Response:")[1].strip()else:assistant_answer = response.strip()placeholder.markdown(process_assistant_content(assistant_answer), unsafe_allow_html=True)except Exception as e:st.error(f"生成失败: {str(e)}")assistant_answer = "生成回答时发生错误,请稍后重试。"# 保存完整回答st.session_state.messages.append({"role": "assistant", "content": assistant_answer})st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})

运行时加载模型

输入测试问题

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词