文章目录
- 一、蒸馏原理(Qwen2.5系列模型KL散度蒸馏)
- 二、蒸馏方法(只是对输出对齐,没对模型中间层对齐)
- 1)黑盒知识蒸馏
- 2)白盒知识蒸馏
- (1)前向KL散度
- (2)反向kl散度
- (3)偏向前kl散度
- (4)偏向反kl散度
- 三、测试
- 四、代码解析
- 1)前向KL散度
- 2)训练代码解析
一、蒸馏原理(Qwen2.5系列模型KL散度蒸馏)
将待压缩的模型作为教师模型,将体积更小的模型作为学生模型,让学生模型在教师模型的监督下进行优化,将学生模型学习到教师模型的概率分布,然后通过KL散度
进行控制
二、蒸馏方法(只是对输出对齐,没对模型中间层对齐)
1)黑盒知识蒸馏
- 原理:
使用大模型生成数据,通过这些数据去微调更小的模型,来达到蒸馏的目的,缺点是蒸馏效率低,优点是实现简单
2)白盒知识蒸馏
- 原理:
获取学生模型和教师模型的输出概率分布(或者中间隐藏层的概率分布),通过KL散度将学生模型的概率分布向教师模型对齐。下面介绍和测试白盒知识蒸馏:白盒知识蒸馏主要在于模型分布的对齐,模型对齐主要依赖KL散度,对于KL散度的使用又有以下几种方式
(1)前向KL散度
也就是我们经常说的kl散度。
p为教师模型的概率分布,q为学生模型的概率分布,minillm论文中提到前向kl散度可能会使学生模型高估教师模型中概率比较低的位置
,结合公式来看,当p增大时,为了使得kl散度小,则q也需要增大,但是当p趋于0时,无论q取任何值,kl散度都比较小,因为此时p(x)log((p(x)/q(x)))的大小主要受p(x)控制,这样起不到优化q分布的效果,可能会使q分布高估p分布中概率低的位置。 下图展示了前向kl散度的拟合情况,前向kl散度是一种均值搜索,更倾向于拟合多峰
(2)反向kl散度
为了缓解前向kl散度的缺点,提出了反向kl散度。
p为教师模型的概率分布,q为学生模型的概率分布,当p趋于零时,为了使kl散度小,q也需趋于0。 minillm论文中说对于大模型的知识蒸馏,反向kl散度优于前向kl散度,但是也有其他论文说反向kl散度不一定比前向kl散度更优,实际选择中,可能要基于实验驱动。 反向kl散度是一种模式搜索,更倾向于拟合单个峰
(3)偏向前kl散度
对学生模型的分布和教师模型的分布进行加权作为学生模型的分布
(4)偏向反kl散度
对学生模型的分布和教师模型的分布进行加权作为教师模型的分布。
三、测试
-
前提
qwen2.5-3b作为教师模型,qwen2.5-0.5b作为学生模型 -
流程
1、将qwen2.5-3b模型在指定数据集上微调(训练数据5000条,测试数据1000条,测试准确度为81.1%)
2、探索如下三种方案下的蒸馏效果(均使用前向kl散度):
2.1 不微调学生模型+kl散度损失(前向KL散度)
蒸馏1个epoch,准确度70.5%
蒸馏2个epoch,准确度73%
2.2 微调学生模型(模型准确度80.3%)+kl散度损失(前向KL散度)
蒸馏2个epoch,准确度61.9%
2.3 不微调学生模型+kl散度损失(前向KL散度)50%和交叉熵损失加权50%
蒸馏2个epoch,70.5%
3、上述实验中只使用kl散度(前向KL散度)的效果最好,如下实验中使用kl散度的变种进行测试,经过测试,效果都不如前向kl散度效果好。
3.1 反向kl散度
准确率只有54%
3.2 偏向前向kl散度
损失下降异常,效果很差,不断重复输出
- 备注
由于资源和时间的限制,所有测试均保持相同的超参数,未针对不同损失设置不同超参数
四、代码解析
- 源码
knowledge_distillation_llm夸克网盘链接如下(需要可自取):
https://pan.quark.cn/s/82f1bbae549d
github链接如下:https://github.com/wyf3/llm_related
1)前向KL散度
-
公式
p为教师模型的概率分布,q为学生模型的概率分布 -
代码
# 计算前向kl散度
def compute_fkl(logits, teacher_logits, target, padding_id,reduction="sum",temp = 1.0, ):logits = logits / tempteacher_logits = teacher_logits / templog_probs = torch.log_softmax(logits, -1, dtype=torch.float32)teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)#因为log p(x)/q(x) = log p(x) - log q(x)kl = (teacher_probs * (teacher_log_probs - log_probs)) kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()
- 参数含义
①temp:temp 代表温度系数(Temperature)。它主要用于调整模型输出的概率分布平滑度。
( T > 1 ) 时:温度升高,概率分布会更平滑,模型输出的各个类别概率差异会变小。这能让模型更关注到那些概率较小的类别,有助于学生模型学习到教师模型更丰富的知识。
( T < 1 ) 时:温度降低,概率分布会更尖锐,模型输出的高概率类别会更突出,低概率类别则更趋近于 0。
( T = 1 ) 时:就是普通的 softmax 函数。
②log_probs :计算的就是log q(x)
③teacher_probs:计算的就是p(x)
④teacher_log_probs :计算的就是log p(x)
⑤ kl = kl.sum(-1):对 kl 张量沿着最后一个维度进行求和操作。下面详细解释其含义与用途。
代码解释
import torch# 创建一个示例张量
kl = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],[[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]
], dtype=torch.float32)print("原始张量形状:", kl.shape) # 输出: torch.Size([2, 2, 3])# 沿着最后一个维度求和
kl_sum = kl.sum(-1)
print("求和后张量形状:", kl_sum.shape) # 输出: torch.Size([2, 2])
print("求和后张量值:\n", kl_sum)
⑥if reduction == “sum”:这段代码是为了把输入或填充的向量那部分的分布mask掉,只需要输出那部分的分布
2)训练代码解析
- 代码解析
①KGTrainer类是继承了transformer的训练类
②只需要重写计算损失的函数compute_loss
③计算损失的时候,教师模型不随输入参数更新所以用的with torch.no_grad()
④loss为损失,logits是概率分布
⑤相对于0.5B,7B,16B,32B可能输出分布差别较大,所以需要做一些填充处理:对学生模型进行padding或对教师模型进行截断
1)填充gap = teacher_logits.shape[-1] - logits.shape[-1]if gap > 0:pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)logits = torch.cat([logits, pad_logits], dim=-1)
2)截断if logits.shape[-1] != teacher_logits.shape[-1]:teacher_logits = teacher_logits[:, :, :logits.shape[-1]]
⑥compute_fkl的padding_id参数为-100是因为transformer对于-100是不计算损失的
⑦数据集:标准sft的数据,promt输入的损失是不计算的所以乘以-100
input_ids = prompt_input_ids + answer_input_idslabels = [-100] * len(prompt_input_ids) + answer_input_ids
⑧输入长度调整:多的要截断,少的要填充,注意填充部分是不计算损失的
if text_len > self.max_seq_len:input_ids = input_ids[:self.max_seq_len]labels = labels[:self.max_seq_len]attention_mask = attention_mask[:self.max_seq_len]else:input_ids = input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - text_len)labels = labels + [-100] * (self.max_seq_len - text_len)attention_mask = attention_mask + [0] * (self.max_seq_len - text_len)
- 代码
class KGTrainer(Trainer):def __init__(self,model = None, #学生模型teacher_model = None, #教师模型if_use_entropy = False, #是否计算交叉熵损失args = None,data_collator = None, train_dataset = None,eval_dataset = None,tokenizer = None,model_init = None, compute_metrics = None, callbacks = None,optimizers = (None, None), preprocess_logits_for_metrics = None,):super().__init__(model,args,data_collator,train_dataset,eval_dataset,tokenizer,model_init,compute_metrics,callbacks,optimizers,preprocess_logits_for_metrics,)self.teacher_model = teacher_modelself.if_use_entropy = if_use_entropydef compute_loss(self, model, inputs, return_outputs=False):outputs = model(**inputs)with torch.no_grad():teacher_outputs = self.teacher_model(**inputs)loss = outputs.losslogits = outputs.logitsteacher_logits = teacher_outputs.logits# 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断if logits.shape[-1] != teacher_logits.shape[-1]:# gap = teacher_logits.shape[-1] - logits.shape[-1]# if gap > 0:# pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)# logits = torch.cat([logits, pad_logits], dim=-1)teacher_logits = teacher_logits[:, :, :logits.shape[-1]]labels = inputs['labels']kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0)if self.if_use_entropy:loss_total = 0.5 * kl + 0.5 * losselse:loss_total = klreturn (loss_total, outputs) if return_outputs else loss_total