1st author
- Shishir Patil
- Tianjun Zhang
paper: [2305.15334] Gorilla: Large Language Model Connected with Massive APIs
code: ShishirPatil/gorilla: Gorilla: Training and Evaluating LLMs for Function Calls (Tool Calls)
5. 总结 (结果先行)
Gorilla 这篇论文针对 LLM 在大规模、动态 API 调用方面的挑战,提出了一个以数据为驱动、结合微调与检索的解决方案。
贡献:
- 创建了 APIBench,一个包含大量真实世界 API 调用和指令的数据集,为训练和评估 LLM 的工具使用能力提供了重要基石。
- 提出了 Gorilla 模型,通过在 APIBench 上对 LLaMA 进行检索感知微调,使其在 API 调用准确性上超越了包括 GPT-4 在内的先进模型,并显著降低了幻觉。
- 验证了检索增强对于 LLM 适应 API 文档动态变化的关键作用,使得 LLM 能够使用最新的工具信息,而不是受限于训练时的静态知识。
- 引入 AST 子树匹配作为一种更鲁棒的 API 调用评估方法。
前瞻与展望:
- 更强大的检索器: 当前检索器的性能仍是瓶颈 (Oracle 检索器的结果远好于实际检索器)。
- API 执行与反馈: 目前 Gorilla 仅生成 API 调用,并未实际执行。集成执行模块,并根据执行结果 (成功、失败、错误信息) 进行学习和修正,将是重要一步。
- 多轮 API 调用与规划: 许多复杂任务需要一系列 API 调用。如何让 LLM 进行规划、分解任务并有序调用多个 API 是一个更具挑战性的问题。
- 更广泛的 API 覆盖: 将 APIBench 扩展到更多领域 (如 Web API, 操作系统 API) 和更复杂的 API 结构。
1. 思想
大型语言模型 (LLM) 在数学推理、程序综合等任务上取得了显著进展,但其有效利用外部工具 (tools) 的潜力仍有待挖掘,尤其是在通过应用程序接口 (API) 调用工具时。即使是像 GPT-4 这样的顶尖模型,在这方面也面临挑战,主要原因在于:
大问题: 如何让 LLM 准确、可靠地调用数量庞大且动态变化的 API 来完成复杂任务,而不局限于模型内部固化的知识和计算能力?
小问题:
- 准确性:LLM 难以生成具有正确输入参数的 API 调用。API 的参数通常有严格的类型、格式和依赖关系要求。
- 幻觉: LLM 倾向于虚构出不存在的 API 用法或参数 (hallucination),或者错误地选择功能不匹配的 API。
- 适应性: API 文档会频繁更新,版本会迭代。LLM 如何适应这些测试时 (test-time) 的变化?
关键思路 :
Gorilla 提出了一种通过微调 (fine-tuning) 和检索增强 (retrieval augmentation) 的方法,使 LLM 能够更有效地使用大量 API。
- 准确性和幻觉: 构建一个大规模的、包含真实世界 API 调用及其指令的数据集 (APIBench),并在此数据集上微调一个基础 LLM (LLaMA)。这种特定任务的微调旨在教会模型理解 API 的语法和语义。
- 适应性: 将微调后的 LLM 与一个文档检索器 (document retriever) 相结合。在推理时,当用户给出指令,检索器首先从最新的 API 文档库中找到相关的 API 信息,然后将这些信息连同用户指令一起提供给 LLM,引导其生成正确的 API 调用。
2. 方法
主要包含三个部分:数据集构建 (APIBench),模型训练 (Gorilla),以及评估机制。
2.1 APIBench 数据集构建
目标是创建一个全面、真实的 API 调用指令数据集。
-
数据源: 抓取三大主流机器学习模型中心 (TorchHub, TensorHub, HuggingFace) 的 API 文档和用例。
- TorchHub: 94 个 API 调用 (详尽)。
- TensorHub (v2): 626 个 API 调用 (详尽)。
- HuggingFace: 925 个模型 (每个任务类别下载量前 20 的模型,因为模型数量庞大且部分缺乏标准规范)。
- 总计约 1645 个独特的 API 数据点。
-
信息提取: 将每个 API 的模型卡片信息转换为 JSON 对象,包含字段如:
{domain, framework, functionality, api_name, api_call, api_arguments, environment_requirements, example_code, performance, description}
。
这些字段的选择旨在泛化到机器学习领域之外的 RESTful API。 -
指令生成 (Instruction Generation): 利用 Self-Instruct 范式,并使用 GPT-4 作为生成器,为每个 API 生成 10 条对应的自然语言指令。
- 具体做法是:提供 3 个上下文示例 (in-context examples) 和参考 API 文档,要求 GPT-4 生成真实场景下会调用该 API 的用户指令,并且明确指示 GPT-4 在生成指令时避免使用任何 API 名称或提示。
- 最终产出约 16450 个
{指令, API}
对。
2.2 Gorilla 模型
Gorilla 是一个基于 LLaMA-7B
微调得到的模型,专门用于 API 调用。
-
训练数据格式: 将 APIBench 中的
{指令, API}
对转换为用户-助手 (user-agent) 对话风格的数据。 -
标准指令微调 (Standard Instruction Finetuning): 在 LLaMA-7B 基础上进行微调。
-
检索感知训练 (Retriever-Aware Training): 这是 Gorilla 的一个创新。在训练时,除了用户指令和目标 API 调用外,还在用户提示 (prompt) 中加入了相关的 API 文档片段,格式如下:
User: <user_prompt> Use this API documentation for reference: <retrieved_API_doc_JSON> Assistant: <target_API_call>
这里的
<retrieved_API_doc_JSON>
在训练时使用的是与目标 API 对应的真实文档 (ground truth document)。目的是教会 LLM 在有参考文档的情况下,如何解析文档并结合用户需求生成 API 调用。 -
推理 (Inference):
- 零样本 (Zero-shot): 直接将用户指令输入 Gorilla 模型。
- 带检索器 (With Retriever):
- 用户提供自然语言指令。
- 检索器 (如 BM25 或基于 LLM 的 GPT-Index) 从 API 数据库中检索最相关的 API 文档。
- 将用户指令和检索到的 API 文档拼接后,作为输入提供给 Gorilla 模型。
- Gorilla 输出待调用的 API。
2.3 验证 API-AST 子树匹配
评估 LLM 生成的 API 调用的正确性是一个挑战,因为简单的字符串匹配无法处理功能等价但表达形式不同的情况 (例如参数顺序、默认参数等)。
- 核心思想: 采用抽象语法树 (Abstract Syntax Tree, AST) 子树匹配策略来判断生成的 API 调用是否与数据集中的某个参考 API 功能一致。
- 流程:
- 将 LLM 生成的代码解析成 AST。
- 检查该 AST 中是否存在一个子树,其根节点是我们关心的 API 调用 (例如
torch.hub.load
),并且其关键参数与 APIBench 数据集中某个 API 的 AST 子树匹配。 - 参数匹配: 并非所有参数都需要匹配。例如,Python 允许默认参数。Gorilla 在其数据库中为每个 API 定义了哪些参数是必须匹配的。例如,对于
torch.hub.load
,会检查repo_or_dir
和model
参数。而像pretrained=True
这样的可选参数则不强制检查其值。
- 幻觉定义: 如果一个生成的 API 调用解析后的 AST 无法在 APIBench 数据库中找到任何匹配的子树,则认为这是一次幻觉——即模型调用了一个完全虚构的工具。这与调用了错误的 API (error) 是有区别的。
3. 优势
- 更高的 API 调用准确性: 通过在专门构建的 APIBench 数据集上进行微调,Gorilla 在生成正确 API 调用方面(尤其是在函数名和关键参数上)显著优于 GPT-4 等通用大模型。
- 显著减少幻觉: 针对性的微调和 AST 验证机制使得 Gorilla 更少地生成不存在或完全错误的 API 调用。
- 对 API 文档变化的适应性: 通过检索感知训练和推理时整合检索器,Gorilla 能够适应测试时 API 文档的更新或版本变更。当 API 文档变化时,检索器可以提供最新的信息,引导模型做出正确的调整。
- 处理带约束的 API 调用: Gorilla 能够理解并尝试满足用户在指令中提出的约束条件,例如模型参数量、最低准确率等。
- 开放的 APIBench 数据集和模型: Gorilla 团队开源了其模型、代码和 APIBench 数据集。
4. 实验
4.1 实验设置
- 数据集:
- APIBench: 如前所述,包含 TorchHub, HuggingFace, TensorHub 的 API 调用。训练集和测试集进行了划分(例如 HuggingFace 90% 训练/10% 测试,Torch/Tensor Hub 80% 训练/20% 测试)。
- 基线模型 (Baselines):
- GPT-4 (OpenAI,
gpt-4-0314
) - GPT-3.5-turbo (OpenAI,
gpt-3.5-turbo-0301
) - Claude (Anthropic,
claude-v1
) - LLaMA-7B (Meta, 基础模型)
- GPT-4 (OpenAI,
- 检索器 (Retrievers):
- Zero-shot (0-shot): 无检索器。
- BM25: 基于稀疏向量的经典检索算法。每个 API 视为一个文档。
- GPT-Index: 使用 OpenAI 的
text-davinci-003
作为检索模型。 - Oracle Retriever: 理想情况,假设检索器总能 100% 准确地返回相关的 API 文档。用于评估检索系统性能的上限。
- 评估指标:
- 整体准确率 (Overall Accuracy): 基于 AST 子树匹配,正确调用 API 的比例。
- 幻觉率 (Hallucination Rate, hallu↓): 生成无法在数据库中匹配到的 API 的比例 (越低越好)。
- 错误率 (Error Rate, err↓): 调用了数据库中存在的 API,但非目标 API 或参数错误 (越低越好)。
4.2 实验结果
-
API 调用准确性:
- 在零样本设置下,Gorilla (微调后的 LLaMA-7B) 在 TorchHub, HuggingFace, TensorHub 上的 API 调用准确率均显著高于 GPT-4, GPT-3.5, Claude 和原始 LLaMA。例如,在 HuggingFace 上,Gorilla (0-shot) 准确率为 71.68%,而 GPT-4 (0-shot) 为 19.80%。
- 加入检索器后,所有模型的性能通常都有所提升。Gorilla 配合 GPT-Index 检索器在 TorchHub (61.82%) 和 HuggingFace (47.46%) 上表现最佳,在 TensorHub (64.96%) 上与 GPT-3.5 (GPT-Index) 表现接近。
- 一个有趣的发现是,对于未进行检索感知训练的模型 (如 GPT-4),在测试时加入一个非最优的检索器 (如 BM25) 有时反而会降低性能,说明模型可能被不相关的检索结果误导。而 Gorilla 由于进行了检索感知训练,更能有效利用检索信息。
-
幻觉减少:
- Gorilla 在所有设置下都表现出较低的幻觉率。例如,在 HuggingFace (0-shot) 上,Gorilla 幻觉率为 10.95%,而 GPT-4 为 37.16%,GPT-3.5 为 35.73%。
- 令人惊讶的是,实验发现在某些情况下 GPT-3.5 的幻觉比 GPT-4 少,作者推测这可能与 RLHF (Reinforcement Learning from Human Feedback) 的训练方式有关,使得模型更“谨慎”。
-
对测试时文档变化的适应性:
- 实验展示了当 API 发生变化时 (例如 FCN ResNet-50 升级到 ResNet-101,或模型仓库从
pytorch/vision
迁移到NVIDIA/DeepLearningExamples:torchhub
),Gorilla 凭借其检索器能够正确调用更新后的 API,而没有检索能力的模型则会失败。这凸显了检索感知训练的重要性。
- 实验展示了当 API 发生变化时 (例如 FCN ResNet-50 升级到 ResNet-101,或模型仓库从
-
带约束的 API 调用:
- 实验在 TorchHub 子集上评估了模型在满足特定约束条件 (如模型在 ImageNet 上的 top-1 准确率) 下选择 API 的能力。
- Gorilla 在零样本和使用检索器的情况下,在满足约束的准确性上均表现优异,有时能匹配甚至超越 GPT-3.5 的表现。这表明 Gorilla 不仅能找到功能相关的 API,还能在一定程度上理解和权衡约束。
-
检索感知训练的有效性:
- 对比“Gorilla without Retriever” (训练时不加入参考文档) 和 “Gorilla with Oracle retriever” (训练时加入真实参考文档) 的结果。
- 结果显示,在训练中整合检索信息 (即使是理想的 Oracle 信息) 能够显著提升模型性能。例如,在 HuggingFace 上,使用 Oracle 检索器训练的 Gorilla 准确率达到 91.26%,远高于不使用检索信息训练的 Gorilla (45.58%,当测试时也使用 Oracle 检索)。
- 这证明了让模型在训练阶段就学会如何利用检索到的文档至关重要。