跳到内容

dspy.KNNFewShot

dspy.KNNFewShot(k: int, trainset: list[Example], vectorizer: Embedder, **few_shot_bootstrap_args)

继承自: Teleprompter

KNNFewShot 是一个优化器,它使用内存中的 KNN 检索器在测试时找到训练集中的 k 个最近邻。对于前向调用中的每个输入示例,它会从训练集中识别出 k 个最相似的示例,并将它们作为示例附加到学生模块。

参数

名称 类型 描述 默认值
k int

附加到学生模型的最近邻居数量。

必需
trainset list[Example]

用于少量样本提示的训练集。

必需
vectorizer Embedder

用于向量化的 Embedder

必需
**few_shot_bootstrap_args

BootstrapFewShot 优化器的附加参数。

{}
Example
import dspy
from sentence_transformers import SentenceTransformer

# Define a QA module with chain of thought
qa = dspy.ChainOfThought("question -> answer")

# Create a training dataset with examples
trainset = [
    dspy.Example(question="What is the capital of France?", answer="Paris").with_inputs("question"),
    # ... more examples ...
]

# Initialize KNNFewShot with a sentence transformer model
knn_few_shot = KNNFewShot(
    k=3,
    trainset=trainset,
    vectorizer=dspy.Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode)
)

# Compile the QA module with few-shot learning
compiled_qa = knn_few_shot.compile(qa)

# Use the compiled module
result = compiled_qa("What is the capital of Belgium?")
源代码位于 dspy/teleprompt/knn_fewshot.py
def __init__(self, k: int, trainset: list[Example], vectorizer: Embedder, **few_shot_bootstrap_args):
    """
    KNNFewShot is an optimizer that uses an in-memory KNN retriever to find the k nearest neighbors
    in a trainset at test time. For each input example in a forward call, it identifies the k most
    similar examples from the trainset and attaches them as demonstrations to the student module.

    Args:
        k: The number of nearest neighbors to attach to the student model.
        trainset: The training set to use for few-shot prompting.
        vectorizer: The `Embedder` to use for vectorization
        **few_shot_bootstrap_args: Additional arguments for the `BootstrapFewShot` optimizer.

    Example:
        ```python
        import dspy
        from sentence_transformers import SentenceTransformer

        # Define a QA module with chain of thought
        qa = dspy.ChainOfThought("question -> answer")

        # Create a training dataset with examples
        trainset = [
            dspy.Example(question="What is the capital of France?", answer="Paris").with_inputs("question"),
            # ... more examples ...
        ]

        # Initialize KNNFewShot with a sentence transformer model
        knn_few_shot = KNNFewShot(
            k=3,
            trainset=trainset,
            vectorizer=dspy.Embedder(SentenceTransformer("all-MiniLM-L6-v2").encode)
        )

        # Compile the QA module with few-shot learning
        compiled_qa = knn_few_shot.compile(qa)

        # Use the compiled module
        result = compiled_qa("What is the capital of Belgium?")
        ```
    """
    self.KNN = KNN(k, trainset, vectorizer=vectorizer)
    self.few_shot_bootstrap_args = few_shot_bootstrap_args

函数

compile(student, *, teacher=None)

源代码位于 dspy/teleprompt/knn_fewshot.py
def compile(self, student, *, teacher=None):
    student_copy = student.reset_copy()

    def forward_pass(_, **kwargs):
        knn_trainset = self.KNN(**kwargs)
        few_shot_bootstrap = BootstrapFewShot(**self.few_shot_bootstrap_args)
        compiled_program = few_shot_bootstrap.compile(
            student,
            teacher=teacher,
            trainset=knn_trainset,
        )
        return compiled_program(**kwargs)

    student_copy.forward = types.MethodType(forward_pass, student_copy)
    return student_copy

get_params() -> dict[str, Any]

获取 Teleprompter 的参数。

返回值

类型 描述
dict[str, Any]

Teleprompter 的参数。

源代码位于 dspy/teleprompt/teleprompt.py
def get_params(self) -> dict[str, Any]:
    """
    Get the parameters of the teleprompter.

    Returns:
        The parameters of the teleprompter.
    """
    return self.__dict__