环境准备
见:Qwen3 Embedding 测试
代码与解释
# 导入必要的库
import torch # PyTorch深度学习框架
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM # Hugging Face的transformers库# 加载Qwen3-Reranker-0.6B模型和分词器
# padding_side='left'表示在序列左侧进行填充,这对于因果语言模型很重要
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-0.6B", padding_side='left')
# 加载模型并设置为评估模式
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-0.6B").eval()
这部分代码初始化了模型和分词器,为后续的重排序任务做准备。Qwen3-Reranker是专门为文档重排序设计的模型。
# 定义重排序任务的指令
task = 'Given a web search query, retrieve relevant passages that answer the query'# 定义示例查询
queries = ["What is the capital of China?","Explain gravity",
]# 定义对应的文档
documents = ["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]# 格式化指令、查询和文档为模型输入格式
def format_instruction(instruction, query, doc):if instruction is None:instruction = 'Given a web search query, retrieve relevant passages that answer the query'output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)return output# 为每对查询-文档创建格式化输入
pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)]
pairs
这部分代码准备了输入数据。它定义了查询和对应的文档,然后使用format_instruction
函数将它们格式化为模型可以理解的结构。格式化后的输入包含指令、查询和文档三部分,清晰地告诉模型需要判断文档是否满足查询需求。
# 获取"no"和"yes"在词表中的token ID
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
# 设置最大序列长度
max_length = 8192# 定义系统提示前缀,指示模型进行二分类判断
prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
# 定义后缀,包含助手回答的开始部分
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
# 将前缀和后缀转换为token ID
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)# 计算前缀和后缀的总token数
len(prefix_tokens) + len(suffix_tokens)
这部分代码设置了模型输入的格式和参数。它获取了"yes"和"no"的token ID,这将用于从模型输出中提取相关概率。同时定义了系统提示前缀和后缀,将重排序问题转化为二分类问题:文档是否满足查询需求。
# 处理输入数据的函数
def process_inputs(pairs):# 对输入文本进行分词,但不进行填充inputs = tokenizer(pairs, padding=False, truncation='longest_first',return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens))# 为每个输入添加前缀和后缀tokenfor i, ele in enumerate(inputs['input_ids']):inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens# 对所有序列进行填充,转换为PyTorch张量inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)# 将输入移至模型所在设备for key in inputs:inputs[key] = inputs[key].to(model.device)return inputs# 处理示例输入
inputs = process_inputs(pairs)
# 查看处理后的输入形状
inputs["input_ids"].shape
这部分代码定义了process_inputs
函数,用于处理输入数据。它首先对文本进行分词,然后添加前缀和后缀,最后进行填充并转换为PyTorch张量。输出显示处理后的输入形状为[2, 104],表示有2个样本,每个样本的长度为104个token。
# 使用torch.no_grad()装饰器,在推理时不计算梯度,节省内存
@torch.no_grad()
def compute_logits(inputs, **kwargs):# 将输入传递给模型,获取输出的logitsres = model(**inputs).logits# 选择最后一个位置的输出,即回答位置的词表概率分布batch_scores = res[:, -1, :]# 提取"yes" token的logit值true_vector = batch_scores[:, token_true_id]# 提取"no" token的logit值false_vector = batch_scores[:, token_false_id]# 将"no"和"yes"的logit值堆叠成新的张量[batch_size, 2]batch_scores = torch.stack([false_vector, true_vector], dim=1)# 应用log_softmax将logits转换为对数概率batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)# 提取"yes"对应的概率,转换回原始概率空间,并转为Python列表scores = batch_scores[:, 1].exp().tolist()return scores# 计算示例输入的相关性分数
scores = compute_logits(inputs)
scores
这部分代码定义了compute_logits
函数,用于计算查询-文档对的相关性分数。函数的核心思想是将重排序问题转化为二分类问题:
- 首先获取模型对输入的预测结果
- 从预测结果中提取最后一个位置(回答位置)的词表概率分布
- 只关注"yes"和"no" token的logit值
- 将这两个值堆叠并应用log_softmax转换为对数概率
- 最后提取"yes"对应的概率作为相关性分数
输出结果显示两个示例的相关性分数分别为0.9995和0.9994,表明模型认为这些文档与对应的查询高度相关。
另个例子
文档相关
# 添加更多测试用例
task = 'Given a web search query, retrieve relevant passages that answer the query'# 示例1:添加更多查询-文档对
queries = ["What is the capital of China?","Explain gravity","How does photosynthesis work?", # 新增查询"What are the benefits of exercise?", # 新增查询
]documents = ["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.","Photosynthesis is the process by which green plants and some other organisms use sunlight to synthesize foods with carbon dioxide and water. It converts light energy into chemical energy, releasing oxygen as a byproduct.", # 相关文档"Regular physical activity can improve muscle strength, boost endurance, deliver oxygen and nutrients to tissues, and help your cardiovascular system work more efficiently.", # 相关文档
]# 创建所有查询-文档对
pairs = [format_instruction(task, query, doc) for query, doc in zip(queries, documents)]# 处理输入并计算相关性分数
inputs = process_inputs(pairs)
scores = compute_logits(inputs)# 打印每个查询-文档对的相关性分数
for i, (query, doc, score) in enumerate(zip(queries, documents, scores)):print(f"Query {i+1}: {query}")print(f"Document: {doc[:50]}..." if len(doc) > 50 else f"Document: {doc}")print(f"Relevance Score: {score:.6f}")print("-" * 50)
Query 1: What is the capital of China?
Document: The capital of China is Beijing.
Relevance Score: 0.999498
--------------------------------------------------
Query 2: Explain gravity
Document: Gravity is a force that attracts two bodies toward...
Relevance Score: 0.999362
--------------------------------------------------
Query 3: How does photosynthesis work?
Document: Photosynthesis is the process by which green plant...
Relevance Score: 0.998967
--------------------------------------------------
Query 4: What are the benefits of exercise?
Document: Regular physical activity can improve muscle stren...
Relevance Score: 0.981571
--------------------------------------------------
文档不相关
# 测试不相关的查询-文档对
mismatched_queries = ["What is the capital of China?","How does photosynthesis work?",
]mismatched_documents = ["Paris is the capital of France and one of the most populous cities in Europe.", # 不相关文档"Machine learning is a branch of artificial intelligence that focuses on building systems that learn from data.", # 不相关文档
]# 格式化不匹配的查询-文档对
mismatched_pairs = [format_instruction(task, query, doc) for query, doc in zip(mismatched_queries, mismatched_documents)]# 处理输入并计算相关性分数
mismatched_inputs = process_inputs(mismatched_pairs)
mismatched_scores = compute_logits(mismatched_inputs)# 打印每个不匹配查询-文档对的相关性分数
for i, (query, doc, score) in enumerate(zip(mismatched_queries, mismatched_documents, mismatched_scores)):print(f"Query {i+1}: {query}")print(f"Document: {doc[:50]}..." if len(doc) > 50 else f"Document: {doc}")print(f"Relevance Score: {score:.6f}")print("-" * 50)
Query 1: What is the capital of China?
Document: Paris is the capital of France and one of the most...
Relevance Score: 0.000153
--------------------------------------------------
Query 2: How does photosynthesis work?
Document: Machine learning is a branch of artificial intelli...
Relevance Score: 0.000047
--------------------------------------------------