教程:用于多跳研究的在线强化学习¶
警告:此功能是新功能,并且是高度实验性的。与 DSPy 中的几乎所有其他功能不同,它目前纯粹处于概念验证和开发模式,但我们发布它以鼓励社区参与。
如果您想在新功能合并之前就抢先体验,可以通过 pip install git+https://github.com/stanfordnlp/dspy.git@refs/pull/8171/head
安装 dspy.GRPO
PR,并跟着教程操作。
对于本教程,您还需要 DSPy 的 Arbor RL 服务器。
> pip install arbor-ai
> python -m arbor.cli serve --arbor-config arbor.yaml
在您的目录中创建 arbor.yaml
文件,其中包含类似如下的计划
inference:
gpu_ids: '0'
training:
gpu_ids: '1, 2'
它将 GPU 0 分配给推理,将 GPU 1 和 2 分配给训练。
import dspy
from dspy.clients.lm_local_arbor import ArborProvider
port = 7453
local_lm_name = "Qwen/Qwen2.5-7B-Instruct"
local_lm = dspy.LM(
model=f"openai/arbor:{local_lm_name}",
provider=ArborProvider(),
temperature=0.7,
api_base=f"https://:{port}/v1/",
api_key="arbor",
)
dspy.configure(lm=local_lm)
openai_lm = dspy.LM(model="openai/gpt-4.1-mini")
安装依赖并下载数据¶
为了进行检索,我们将使用很酷的 BM25S 库,因为它相当轻量级。您可以随意替换这个组件。
> pip install -U bm25s PyStemmer "jax[cpu]"
接下来,我们将下载截至 2017 年所有 500 万个维基百科页面的摘要(即第一段)。我们将使用它作为我们的检索语料库。
压缩后为 500MB,因此下载和解压可能需要 2-3 分钟。
from dspy.utils import download
download("https://hugging-face.cn/dspy/cache/resolve/main/wiki.abstracts.2017.tar.gz")
!tar -xzvf wiki.abstracts.2017.tar.gz
然后让我们为 BM25 检索建立索引!这将需要 2-3 分钟。
import ujson
import bm25s
import Stemmer
corpus = []
with open("wiki.abstracts.2017.jsonl") as f:
for line in f:
line = ujson.loads(line)
corpus.append(f"{line['title']} | {' '.join(line['text'])}")
stemmer = Stemmer.Stemmer("english")
corpus_tokens = bm25s.tokenize(corpus, stopwords="en", stemmer=stemmer)
retriever = bm25s.BM25(k1=0.9, b=0.4)
retriever.index(corpus_tokens)
加载 HoVer 数据集。¶
让我们为任务加载一个数据集。我们将加载来自 HoVer 多跳任务的示例,其中输入是一个(非常!)复杂的声明,而我们寻求的输出是用于事实核查该声明所需的维基百科页面集合。
import random
from dspy.datasets import DataLoader
kwargs = dict(fields=("claim", "supporting_facts", "hpqa_id", "num_hops"), input_keys=("claim",))
hover = DataLoader().from_huggingface(dataset_name="hover-nlp/hover", split="train", trust_remote_code=True, **kwargs)
hpqa_ids = set()
hover = [
dspy.Example(claim=x.claim, titles=list(set([y["key"] for y in x.supporting_facts]))).with_inputs("claim")
for x in hover
if x["num_hops"] == 3 and x["hpqa_id"] not in hpqa_ids and not hpqa_ids.add(x["hpqa_id"])
]
random.Random(0).shuffle(hover)
trainset, devset, testset = hover[:600], hover[600:900], hover[900:]
len(trainset), len(devset), len(testset)
现在,让我们定义一个在维基百科中进行搜索的函数。这将使用我们的 BM25 索引。
def search(query: str, k: int) -> list[str]:
tokens = bm25s.tokenize(query, stopwords="en", stemmer=stemmer, show_progress=False)
results, scores = retriever.retrieve(tokens, k=k, n_threads=1, show_progress=False)
run = {corpus[doc]: float(score) for doc, score in zip(results[0], scores[0])}
return list(run.keys())
用于多跳研究的 DSPy 程序¶
现在,让我们在 DSPy 中定义多跳程序。它将非常简单,由 generate_query
和 append_notes
模块组成。我们将仔细定义指令,尽管它们通常不是必需的。
instr1 = """
Given a claim and some key facts, generate a follow-up search query to find the next most essential clue towards verifying or refuting the claim. The goal ultimately is to find all documents implicated by the claim.
""".strip()
instr2 = """
Given a claim, some key facts, and new search results, identify any new learnings from the new search results, which will extend the key facts known so far about the whether the claim is true or false. The goal is to ultimately collect all facts that would help us find all documents implicated by the claim.
"""
class ResearchHop(dspy.Module):
def __init__(self, num_docs, num_hops):
self.num_docs, self.num_hops = num_docs, num_hops
self.generate_query = dspy.ChainOfThought(dspy.Signature("claim, key_facts -> followup_search_query", instr1))
self.append_notes = dspy.ChainOfThought(dspy.Signature("claim, key_facts, new_search_results -> new_key_facts", instr2))
def forward(self, claim: str) -> list[str]:
key_facts = []
retrieved_docs = []
for hop_idx in range(self.num_hops):
query = self.generate_query(claim=claim, key_facts=key_facts).followup_search_query if hop_idx else claim
search_results = search(query, k=self.num_docs)
retrieved_docs.extend(search_results)
if hop_idx == self.num_hops - 1:
break
prediction = self.append_notes(claim=claim, key_facts=key_facts, new_search_results=search_results)
key_facts.append(prediction.new_key_facts)
return dspy.Prediction(key_facts=key_facts, retrieved_docs=retrieved_docs)
定义此任务的成功指标¶
def recall(example, pred, trace=None):
gold_titles = example.titles
retrieved_titles = [doc.split(" | ")[0] for doc in pred.retrieved_docs]
return sum(x in retrieved_titles for x in set(gold_titles)) / len(gold_titles)
evaluate = dspy.Evaluate(devset=devset, metric=recall, num_threads=16, display_progress=True, display_table=5)
使用 dspy.GRPO
优化 ResearchHop
系统¶
from dspy.teleprompt.grpo import GRPO
program = ResearchHop(num_docs=4, num_hops=2)
program.set_lm(local_lm)
# NOTE: Training on 6 GPUs.
train_kwargs = {
"update_interval": 3,
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 4,
"temperature": 0.7,
"beta": 0.04,
"learning_rate": 2e-5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"bf16": True,
"lr_scheduler_type": "constant_with_warmup",
"max_prompt_length": None,
"max_completion_length": None,
"scale_rewards": True,
"max_grad_norm": 0.5,
"lora": True,
}
compiler = GRPO(
metric=recall,
multitask=True,
num_dspy_examples_per_grpo_step=6,
num_samples_per_input=8,
exclude_demos=True,
num_train_steps=500,
num_threads=24,
use_train_as_val=False,
num_steps_for_val=10,
train_kwargs=train_kwargs,
report_train_scores=False,
)
optimized_program = compiler.compile(
student=program,
trainset=trainset,
valset=devset,
)
现在,您可以使用经过 GRPO 优化的程序了。
example = devset[0]
optimized_program(**example.inputs())
在我们的初步实验中,上述训练约 18 小时后,(开发集上的)召回率从 61.8% 提升至 66.2%。这在成本/质量方面通常不如运行 prompt 优化器 dspy.MIPROv2 或 dspy.SIMBA 的效果好,但对于小型语言模型上的任意 LM 程序进行在线强化学习来说,这仍然是一个非常坚实的开端。