教程:分类微调¶
让我们快速了解一个在 DSPy 程序中微调 LM 权重的小例子。我们将它应用于一个简单的 77 类分类任务。
我们微调后的程序将使用一个微型的 Llama-3.2-1B
语言模型,该模型托管在您的本地 GPU 上。为了让这个例子更有趣,我们假设 (i) 我们没有任何训练标签,但 (ii) 我们有 500 个未标记的训练示例。
安装依赖项并下载数据¶
通过 pip install -U dspy>=2.6.0
安装最新版本的 DSPy 并跟着操作(如果喜欢,也可以使用 uv pip
)。本教程依赖于 DSPy >= 2.6.0。
本教程目前需要本地 GPU 进行推理,但我们计划也支持使用 Ollama 为微调模型提供服务。
您还需要以下依赖项
- 推理:我们使用 SGLang 运行本地推理服务。您可以按照此处的说明安装最新版本:https://docs.sglang.com.cn/start/install.html 下面分享的是截至 2025 年 4 月 2 日的最新安装命令,但我们建议您通过访问安装链接,按照最新版本的说明进行操作。这可以确保微调包与
sglang
包同步。> pip install --upgrade pip > pip install uv > uv pip install "sglang[all]>=0.4.4.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
- 微调:我们使用以下软件包。请注意,我们指定了 transformers 软件包的版本,作为最近一个问题的临时修复:https://github.com/huggingface/trl/issues/2338
> uv pip install -U torch transformers==4.48.3 accelerate trl peft
我们建议使用 uv
包管理器来加快安装速度。
推荐:设置 MLflow Tracing 以了解内部运行情况。
MLflow DSPy 集成¶
MLflow 是一个 LLMOps 工具,与 DSPy 原生集成,提供可解释性和实验跟踪功能。在本教程中,您可以使用 MLflow 将提示和优化进度可视化为轨迹,以便更好地理解 DSPy 的行为。您可以按照以下四个步骤轻松设置 MLflow。
- 安装 MLflow
%pip install mlflow>=2.20
- 在单独的终端中启动 MLflow UI
mlflow ui --port 5000
- 连接 Notebook 到 MLflow
import mlflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("DSPy")
- 启用追踪。
mlflow.dspy.autolog()
要了解有关集成的更多信息,请访问 MLflow DSPy 文档。
数据集¶
在本教程中,我们将使用 Banking77 数据集。
import dspy
import random
from dspy.datasets import DataLoader
from datasets import load_dataset
# Load the Banking77 dataset.
CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features['label'].names
kwargs = dict(fields=("text", "label"), input_keys=("text",), split="train", trust_remote_code=True)
# Load the first 2000 examples from the dataset, and assign a hint to each *training* example.
raw_data = [
dspy.Example(x, label=CLASSES[x.label]).with_inputs("text")
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:1000]
]
random.Random(0).shuffle(raw_data)
该数据集包含 77 个不同的分类类别。我们来回顾一下其中的一些。
len(CLASSES), CLASSES[:10]
(77, ['activate_my_card', 'age_limit', 'apple_pay_or_google_pay', 'atm_support', 'automatic_top_up', 'balance_not_updated_after_bank_transfer', 'balance_not_updated_after_cheque_or_cash_deposit', 'beneficiary_not_allowed', 'cancel_transfer', 'card_about_to_expire'])
让我们从 Banking77 中抽取 500 个(未标记的)查询。我们将用它们进行自举微调。
unlabeled_trainset = [dspy.Example(text=x.text).with_inputs("text") for x in raw_data[:500]]
unlabeled_trainset[0]
Example({'text': 'What if there is an error on the exchange rate?'}) (input_keys={'text'})
DSPy 程序¶
假设我们想要一个程序,它接受 text
输入,逐步推理,然后从 Banking77 中选择一个类别。
请注意,这主要用于说明,或者用于您想要检查模型推理过程(例如,为了少量的可解释性)的情况。换句话说,这类任务不一定能从显式推理中获得很大的益处。
from typing import Literal
classify = dspy.ChainOfThought(f"text -> label: Literal{CLASSES}")
自举微调¶
有很多方法可以实现这一点,例如让模型自学,或使用推理时计算(例如,集成)来识别没有标签的高置信度案例。
也许最简单的方法是使用一个我们认为能够很好地完成此任务的模型作为推理和分类的教师模型,并将其蒸馏到我们的小模型中。所有这些模式都可以用几行代码来表达。
让我们设置微型的 Llama-3.2-1B-Instruct
作为学生 LM。我们将使用 GPT-4o-mini 作为教师 LM。
from dspy.clients.lm_local import LocalProvider
student_lm_name = "meta-llama/Llama-3.2-1B-Instruct"
student_lm = dspy.LM(model=f"openai/local:{student_lm_name}", provider=LocalProvider(), max_tokens=2000)
teacher_lm = dspy.LM('openai/gpt-4o-mini', max_tokens=3000)
现在,让我们为我们的 LM 分配分类器。
student_classify = classify.deepcopy()
student_classify.set_lm(student_lm)
teacher_classify = classify.deepcopy()
teacher_classify.set_lm(teacher_lm)
现在让我们启动自举微调。这里的“自举”意味着程序本身将在训练输入上被调用,并在所有模块上看到的生成轨迹将被记录下来用于微调。这是 DSPy 中各种 BootstrapFewShot 方法的权重优化变体。
在(未标记的)训练集中的每个问题上,这将调用教师程序,教师程序将产生推理并选择一个类别。这将进行追踪,然后构成学生程序中所有模块(在本例中只有一个 CoT 模块)的训练集。
调用 compile
方法时,BootstrapFinetune
优化器将使用传入的教师程序(或多个程序,您可以传入列表!)来创建训练数据集。然后,它将使用此训练数据集为 student
程序创建 LM 的微调版本,并用训练好的 LM 替换它。请注意,训练好的 LM 将是一个新的 LM 实例(我们在此处实例化的 student_lm
对象将保持不变!)
注意:如果您有标签,可以将 metric
传递给 BootstrapFinetune
的构造函数。如果想在实践中应用此功能,可以将 train_kwargs
传递给构造函数以控制本地 LM 训练设置:device
、use_peft
、num_train_epochs
、per_device_train_batch_size
、gradient_accumulation_steps
、learning_rate
、max_seq_length
、packing
、bf16
和 output_dir
。
# Optional:
# [1] You can set `DSPY_FINETUNEDIR` environment variable to control where the directory that will be used to store the
# checkpoints and fine-tuning data. If this is not set, `DSPY_CACHEDIR` is used by default.
# [2] You can set the `CUDA_VISIBLE_DEVICES` environment variable to control the GPU that will be used for fine-tuning
# and inference. If this is not set and the default GPU that's used by HuggingFace's `transformers` library is
# occupied, an OutOfMemoryError might be raised.
#
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["DSPY_FINETUNEDIR"] = "/path/to/dir"
dspy.settings.experimental = True # fine-tuning is an experimental feature, so we set a flag to enable it
optimizer = dspy.BootstrapFinetune(num_threads=16) # if you *do* have labels, pass metric=your_metric here!
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=unlabeled_trainset)
由于这是本地模型,我们需要显式启动它。
classify_ft.get_lm().launch()
验证微调后的程序¶
现在让我们看看这是否成功。我们可以问系统一个问题并检查其行为。
classify_ft(text="I didn't receive my money earlier and it says the transaction is still in progress. Can you fix it?")
Prediction( reasoning='The user is inquiring about a specific issue, which they did not receive and is still showing as a pending transaction. This situation typically indicates a problem with the cash withdrawal process, as the user is not receiving the money they attempted to withdraw. The appropriate label for this scenario is "pending_cash_withdrawal," as it directly relates to the status of the cash withdrawal transaction.', label='pending_cash_withdrawal' )
我们也可以获取一小部分黄金标签,看看系统是否能泛化到未见过的查询。
devset = raw_data[500:600]
devset[0]
Example({'text': 'Which fiat currencies do you currently support? Will this change in this future?', 'label': 'fiat_currency_support'}) (input_keys={'text'})
让我们在这个小型开发集上定义一个评估器,其中指标忽略推理过程,仅检查标签是否完全正确。
metric = (lambda x, y, trace=None: x.label == y.label)
evaluate = dspy.Evaluate(devset=devset, metric=metric, display_progress=True, display_table=5, num_threads=16)
现在,让我们评估微调后的 1B 分类器。
evaluate(classify_ft)
Average Metric: 51.00 / 99 (51.5%): 100%|██████████| 100/100 [00:35<00:00, 2.79it/s]
text | example_label | reasoning | pred_label | <lambda> | label | |
---|---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the current support for fiat currencie... | fiat_currency_support | ✔️ [True] | NaN |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is inquiring about a specific issue, which they did not r... | pending_cash_withdrawal | ✔️ [True] | NaN |
2 | what currencies do you accept? | fiat_currency_support | The user is inquiring about the currencies that are accepted, whic... | fiat_currency_support | ✔️ [True] | NaN |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] | NaN |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card, which sugges... | card_arrival | ✔️ [True] | NaN |
51.0
在 MLflow Experiment 中跟踪评估结果
为了跟踪和可视化随时间变化的评估结果,您可以将结果记录在 MLflow Experiment 中。
import mlflow
with mlflow.start_run(run_name="classifier_evaluation"):
evaluate_correctness = dspy.Evaluate(
devset=devset,
metric=extraction_correctness_metric,
num_threads=16,
display_progress=True,
# To record the outputs and detailed scores to MLflow
return_all_scores=True,
return_outputs=True,
)
# Evaluate the program as usual
aggregated_score, outputs, all_scores = evaluate_correctness(people_extractor)
# Log the aggregated score
mlflow.log_metric("exact_match", aggregated_score)
# Log the detailed evaluation results as a table
mlflow.log_table(
{
"Text": [example.text for example in devset],
"Expected": [example.example_label for example in devset],
"Predicted": outputs,
"Exact match": all_scores,
},
artifact_file="eval_results.json",
)
要了解有关集成的更多信息,请访问 MLflow DSPy 文档。
不错,考虑到我们开始时没有任何任务标签。即使我们没有标签,您也可以使用各种策略来提高自举训练数据的质量。
接下来尝试这个,让我们通过关闭微调后的 LM 来释放 GPU 内存。
classify_ft.get_lm().kill()
基于指标的自举微调¶
如果您有标签,通常可以大幅提升性能。要做到这一点,您可以将 metric
传递给 BootstrapFinetune,它将使用该指标在构建微调数据之前过滤您的程序轨迹。
optimizer = dspy.BootstrapFinetune(num_threads=16, metric=metric)
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=raw_data[:500])
现在让我们启动并评估它。
classify_ft.get_lm().launch()
evaluate(classify_ft)
Average Metric: 85.00 / 98 (86.7%): 100%|██████████| 100/100 [00:46<00:00, 2.14it/s]
text | example_label | reasoning | pred_label | <lambda> | label | |
---|---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the fiat currencies currently supporte... | fiat_currency_support | ✔️ [True] | NaN |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is inquiring about an unexpected fee on their account, wh... | extra_charge_on_statement | NaN | |
2 | what currencies do you accept? | fiat_currency_support | The user is inquiring about the types of currencies that are accep... | fiat_currency_support | ✔️ [True] | NaN |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] | NaN |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card delivery, whi... | card_arrival | ✔️ [True] | NaN |
85.0
考虑到只有 500 个标签,这已经相当不错了。事实上,它似乎比教师 LM 开箱即用的效果强得多!
evaluate(teacher_classify)
Average Metric: 55.00 / 100 (55.0%): 100%|██████████| 100/100 [00:11<00:00, 8.88it/s]
2025/01/08 12:38:35 INFO dspy.evaluate.evaluate: Average Metric: 55 / 100 (55.0%)
text | example_label | reasoning | pred_label | <lambda> | |
---|---|---|---|---|---|
0 | Which fiat currencies do you currently support? Will this change i... | fiat_currency_support | The user is inquiring about the fiat currencies supported by the s... | fiat_currency_support | ✔️ [True] |
1 | I didn't receive my money earlier and it says the transaction is s... | pending_cash_withdrawal | The user is experiencing an issue with a transaction that is still... | pending_transfer | |
2 | what currencies do you accept? | fiat_currency_support | The question is asking about the types of currencies accepted, whi... | fiat_currency_support | ✔️ [True] |
3 | Where can I find your exchange rates? | exchange_rate | The user is inquiring about where to find exchange rates, which re... | exchange_rate | ✔️ [True] |
4 | why hasnt my card come in yet? | card_arrival | The user is inquiring about the status of their card delivery, whi... | card_delivery_estimate |
55.0
得益于自举,模型学会了应用我们的模块来获得正确的标签,在这种情况下,是显式地进行推理
classify_ft(text="why hasnt my card come in yet?")
dspy.inspect_history()
[2025-01-08T12:39:42.143798] System message: Your input fields are: 1. `text` (str) Your output fields are: 1. `reasoning` (str) 2. `label` (Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]) All interactions will be structured in the following way, with the appropriate values filled in. [[ ## text ## ]] {text} [[ ## reasoning ## ]] {reasoning} [[ ## label ## ]] {label} # note: the value you produce must be one of: activate_my_card; age_limit; apple_pay_or_google_pay; atm_support; automatic_top_up; balance_not_updated_after_bank_transfer; balance_not_updated_after_cheque_or_cash_deposit; beneficiary_not_allowed; cancel_transfer; card_about_to_expire; card_acceptance; card_arrival; card_delivery_estimate; card_linking; card_not_working; card_payment_fee_charged; card_payment_not_recognised; card_payment_wrong_exchange_rate; card_swallowed; cash_withdrawal_charge; cash_withdrawal_not_recognised; change_pin; compromised_card; contactless_not_working; country_support; declined_card_payment; declined_cash_withdrawal; declined_transfer; direct_debit_payment_not_recognised; disposable_card_limits; edit_personal_details; exchange_charge; exchange_rate; exchange_via_app; extra_charge_on_statement; failed_transfer; fiat_currency_support; get_disposable_virtual_card; get_physical_card; getting_spare_card; getting_virtual_card; lost_or_stolen_card; lost_or_stolen_phone; order_physical_card; passcode_forgotten; pending_card_payment; pending_cash_withdrawal; pending_top_up; pending_transfer; pin_blocked; receiving_money; Refund_not_showing_up; request_refund; reverted_card_payment?; supported_cards_and_currencies; terminate_account; top_up_by_bank_transfer_charge; top_up_by_card_charge; top_up_by_cash_or_cheque; top_up_failed; top_up_limits; top_up_reverted; topping_up_by_card; transaction_charged_twice; transfer_fee_charged; transfer_into_account; transfer_not_received_by_recipient; transfer_timing; unable_to_verify_identity; verify_my_identity; verify_source_of_funds; verify_top_up; virtual_card_not_working; visa_or_mastercard; why_verify_identity; wrong_amount_of_cash_received; wrong_exchange_rate_for_cash_withdrawal [[ ## completed ## ]] In adhering to this structure, your objective is: Given the fields `text`, produce the fields `label`. User message: [[ ## text ## ]] why hasnt my card come in yet? Respond with the corresponding output fields, starting with the field `[[ ## reasoning ## ]]`, then `[[ ## label ## ]]` (must be formatted as a valid Python Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]), and then ending with the marker for `[[ ## completed ## ]]`. Response: [[ ## reasoning ## ]] The user is inquiring about the status of their card delivery, which suggests they are concerned about when they will receive their card. This aligns with the topic of card arrival and delivery estimates. [[ ## label ## ]] card_arrival [[ ## completed ## ]]
在 MLflow Experiment 中保存微调后的程序
为了在生产环境中部署微调后的程序或与您的团队共享,您可以将其保存在 MLflow Experiment 中。与简单地保存到本地文件相比,MLflow 提供了以下好处
- 依赖管理:MLflow 自动保存冻结的环境元数据以及程序,以确保可重现性。
- 实验跟踪:使用 MLflow,您可以跟踪程序的性能和成本以及程序本身。
- 协作:通过共享 MLflow Experiment,您可以与团队成员共享程序和结果。
要在 MLflow 中保存程序,请运行以下代码
import mlflow
# Start an MLflow Run and save the program
with mlflow.start_run(run_name="optimized_classifier"):
model_info = mlflow.dspy.log_model(
classify_ft,
artifact_path="model", # Any name to save the program in MLflow
)
# Load the program back from MLflow
loaded = mlflow.dspy.load_model(model_info.model_uri)
要了解有关集成的更多信息,请访问 MLflow DSPy 文档。