欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 会展 > 大模型知识蒸馏(Qwen2.5系列模型KL散度蒸馏)

大模型知识蒸馏(Qwen2.5系列模型KL散度蒸馏)

2025/5/12 15:28:48 来源:https://blog.csdn.net/weixin_43679037/article/details/147797469  浏览:    关键词:大模型知识蒸馏(Qwen2.5系列模型KL散度蒸馏)

文章目录

    • 一、蒸馏原理(Qwen2.5系列模型KL散度蒸馏)
    • 二、蒸馏方法(只是对输出对齐,没对模型中间层对齐)
      • 1)黑盒知识蒸馏
      • 2)白盒知识蒸馏
        • (1)前向KL散度
        • (2)反向kl散度
        • (3)偏向前kl散度
        • (4)偏向反kl散度
    • 三、测试
    • 四、代码解析
      • 1)前向KL散度
      • 2)训练代码解析

一、蒸馏原理(Qwen2.5系列模型KL散度蒸馏)

将待压缩的模型作为教师模型,将体积更小的模型作为学生模型,让学生模型在教师模型的监督下进行优化,将学生模型学习到教师模型的概率分布,然后通过KL散度进行控制

二、蒸馏方法(只是对输出对齐,没对模型中间层对齐)

1)黑盒知识蒸馏

  • 原理:
    使用大模型生成数据,通过这些数据去微调更小的模型,来达到蒸馏的目的,缺点是蒸馏效率低,优点是实现简单

2)白盒知识蒸馏

  • 原理:
    获取学生模型和教师模型的输出概率分布(或者中间隐藏层的概率分布),通过KL散度将学生模型的概率分布向教师模型对齐。下面介绍和测试白盒知识蒸馏:白盒知识蒸馏主要在于模型分布的对齐,模型对齐主要依赖KL散度,对于KL散度的使用又有以下几种方式
(1)前向KL散度

也就是我们经常说的kl散度。
在这里插入图片描述
p为教师模型的概率分布,q为学生模型的概率分布,minillm论文中提到前向kl散度可能会使学生模型高估教师模型中概率比较低的位置,结合公式来看,当p增大时,为了使得kl散度小,则q也需要增大,但是当p趋于0时,无论q取任何值,kl散度都比较小,因为此时p(x)log((p(x)/q(x)))的大小主要受p(x)控制,这样起不到优化q分布的效果,可能会使q分布高估p分布中概率低的位置。 下图展示了前向kl散度的拟合情况,前向kl散度是一种均值搜索,更倾向于拟合多峰
在这里插入图片描述

(2)反向kl散度

为了缓解前向kl散度的缺点,提出了反向kl散度。
在这里插入图片描述
p为教师模型的概率分布,q为学生模型的概率分布,当p趋于零时,为了使kl散度小,q也需趋于0。 minillm论文中说对于大模型的知识蒸馏,反向kl散度优于前向kl散度,但是也有其他论文说反向kl散度不一定比前向kl散度更优,实际选择中,可能要基于实验驱动。 反向kl散度是一种模式搜索,更倾向于拟合单个峰
在这里插入图片描述

(3)偏向前kl散度

对学生模型的分布和教师模型的分布进行加权作为学生模型的分布

(4)偏向反kl散度

对学生模型的分布和教师模型的分布进行加权作为教师模型的分布。

三、测试

  • 前提
    qwen2.5-3b作为教师模型,qwen2.5-0.5b作为学生模型

  • 流程

1、将qwen2.5-3b模型在指定数据集上微调(训练数据5000条,测试数据1000条,测试准确度为81.1%)
2、探索如下三种方案下的蒸馏效果(均使用前向kl散度):
2.1 不微调学生模型+kl散度损失(前向KL散度)
蒸馏1个epoch,准确度70.5%
蒸馏2个epoch,准确度73%
2.2 微调学生模型(模型准确度80.3%)+kl散度损失(前向KL散度)
蒸馏2个epoch,准确度61.9%
2.3 不微调学生模型+kl散度损失(前向KL散度)50%和交叉熵损失加权50%
蒸馏2个epoch,70.5%
3、上述实验中只使用kl散度(前向KL散度)的效果最好,如下实验中使用kl散度的变种进行测试,经过测试,效果都不如前向kl散度效果好。
3.1 反向kl散度
准确率只有54%
3.2 偏向前向kl散度
损失下降异常,效果很差,不断重复输出
  • 备注
    由于资源和时间的限制,所有测试均保持相同的超参数,未针对不同损失设置不同超参数

四、代码解析

  • 源码
knowledge_distillation_llm夸克网盘链接如下(需要可自取):
https://pan.quark.cn/s/82f1bbae549d
github链接如下:https://github.com/wyf3/llm_related

1)前向KL散度

  • 公式
    在这里插入图片描述
    p为教师模型的概率分布,q为学生模型的概率分布

  • 代码

# 计算前向kl散度
def compute_fkl(logits, teacher_logits, target, padding_id,reduction="sum",temp = 1.0, ):logits = logits / tempteacher_logits = teacher_logits / templog_probs = torch.log_softmax(logits, -1, dtype=torch.float32)teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)#因为log p(x)/q(x)  = log p(x) - log q(x)kl = (teacher_probs * (teacher_log_probs - log_probs)) kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()
  • 参数含义
    ①temp:temp 代表温度系数(Temperature)。它主要用于调整模型输出的概率分布平滑度。
    ( T > 1 ) 时:温度升高,概率分布会更平滑,模型输出的各个类别概率差异会变小。这能让模型更关注到那些概率较小的类别,有助于学生模型学习到教师模型更丰富的知识。
    ( T < 1 ) 时:温度降低,概率分布会更尖锐,模型输出的高概率类别会更突出,低概率类别则更趋近于 0。
    ( T = 1 ) 时:就是普通的 softmax 函数。
    ②log_probs :计算的就是log q(x)
    ③teacher_probs:计算的就是p(x)
    ④teacher_log_probs :计算的就是log p(x)
    ⑤ kl = kl.sum(-1):对 kl 张量沿着最后一个维度进行求和操作。下面详细解释其含义与用途。
    代码解释
import torch# 创建一个示例张量
kl = torch.tensor([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],[[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]
], dtype=torch.float32)print("原始张量形状:", kl.shape)  # 输出: torch.Size([2, 2, 3])# 沿着最后一个维度求和
kl_sum = kl.sum(-1)
print("求和后张量形状:", kl_sum.shape)  # 输出: torch.Size([2, 2])
print("求和后张量值:\n", kl_sum)

⑥if reduction == “sum”:这段代码是为了把输入或填充的向量那部分的分布mask掉,只需要输出那部分的分布

2)训练代码解析

  • 代码解析
    ①KGTrainer类是继承了transformer的训练类
    ②只需要重写计算损失的函数compute_loss
    ③计算损失的时候,教师模型不随输入参数更新所以用的with torch.no_grad()
    ④loss为损失,logits是概率分布
    ⑤相对于0.5B,7B,16B,32B可能输出分布差别较大,所以需要做一些填充处理:对学生模型进行padding或对教师模型进行截断
1)填充gap = teacher_logits.shape[-1] - logits.shape[-1]if gap > 0:pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)logits = torch.cat([logits, pad_logits], dim=-1)
2)截断if logits.shape[-1] != teacher_logits.shape[-1]:teacher_logits = teacher_logits[:, :, :logits.shape[-1]]

⑥compute_fkl的padding_id参数为-100是因为transformer对于-100是不计算损失的
⑦数据集:标准sft的数据,promt输入的损失是不计算的所以乘以-100

        input_ids = prompt_input_ids + answer_input_idslabels = [-100] * len(prompt_input_ids) + answer_input_ids

⑧输入长度调整:多的要截断,少的要填充,注意填充部分是不计算损失的

        if text_len > self.max_seq_len:input_ids = input_ids[:self.max_seq_len]labels = labels[:self.max_seq_len]attention_mask = attention_mask[:self.max_seq_len]else:input_ids = input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - text_len)labels = labels + [-100] * (self.max_seq_len - text_len)attention_mask = attention_mask + [0] * (self.max_seq_len - text_len)
  • 代码
class KGTrainer(Trainer):def __init__(self,model = None,            #学生模型teacher_model = None,    #教师模型if_use_entropy = False,  #是否计算交叉熵损失args = None,data_collator = None, train_dataset = None,eval_dataset = None,tokenizer = None,model_init = None, compute_metrics = None, callbacks = None,optimizers = (None, None), preprocess_logits_for_metrics = None,):super().__init__(model,args,data_collator,train_dataset,eval_dataset,tokenizer,model_init,compute_metrics,callbacks,optimizers,preprocess_logits_for_metrics,)self.teacher_model = teacher_modelself.if_use_entropy = if_use_entropydef compute_loss(self, model, inputs, return_outputs=False):outputs = model(**inputs)with torch.no_grad():teacher_outputs = self.teacher_model(**inputs)loss = outputs.losslogits = outputs.logitsteacher_logits = teacher_outputs.logits# 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断if logits.shape[-1] != teacher_logits.shape[-1]:# gap = teacher_logits.shape[-1] - logits.shape[-1]# if gap > 0:#     pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)#     logits = torch.cat([logits, pad_logits], dim=-1)teacher_logits = teacher_logits[:, :, :logits.shape[-1]]labels = inputs['labels']kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0)if self.if_use_entropy:loss_total = 0.5 * kl + 0.5 * losselse:loss_total = klreturn (loss_total, outputs) if return_outputs else loss_total

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词