跳到内容

dspy.Embedder

dspy.Embedder(model, batch_size=200, caching=True, **kwargs)

DSPy 嵌入类。

用于计算文本输入嵌入的类。此类为以下两者提供了统一接口:

  1. 通过 litellm 集成托管的嵌入模型(例如 OpenAI 的 text-embedding-3-small)
  2. 您提供的自定义嵌入函数

对于托管模型,只需将模型名称作为字符串传递(例如,“openai/text-embedding-3-small”)。此类将使用 litellm 处理 API 调用和缓存。

对于自定义嵌入模型,请传递一个可调用函数,该函数: - 将字符串列表作为输入。 - 返回以下形式的嵌入: - 一个 float32 值的二维 numpy 数组 - 一个 float32 值的二维列表 - 每行应代表一个嵌入向量

参数

名称 类型 描述 默认值
model

要使用的嵌入模型。这可以是一个字符串(表示托管嵌入模型的名称,必须是 litellm 支持的嵌入模型)或一个表示自定义嵌入模型的可调用对象。

必需
batch_size int

分批处理输入的默认批处理大小。默认为 200。

200
caching bool

使用托管模型时是否缓存嵌入响应。默认为 True。

True
**kwargs

要传递给嵌入模型的附加默认关键字参数。

{}

示例

示例 1:使用托管模型。

import dspy

embedder = dspy.Embedder("openai/text-embedding-3-small", batch_size=100)
embeddings = embedder(["hello", "world"])

assert embeddings.shape == (2, 1536)

示例 2:使用任何本地嵌入模型,例如来自 https://hugging-face.cn/models?library=sentence-transformers 的模型。

# pip install sentence_transformers
import dspy
from sentence_transformers import SentenceTransformer

# Load an extremely efficient local model for retrieval
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")

embedder = dspy.Embedder(model.encode)
embeddings = embedder(["hello", "world"], batch_size=1)

assert embeddings.shape == (2, 1024)

示例 3:使用自定义函数。

import dspy
import numpy as np

def my_embedder(texts):
    return np.random.rand(len(texts), 10)

embedder = dspy.Embedder(my_embedder)
embeddings = embedder(["hello", "world"], batch_size=1)

assert embeddings.shape == (2, 10)
源代码位于 dspy/clients/embedding.py
def __init__(self, model, batch_size=200, caching=True, **kwargs):
    self.model = model
    self.batch_size = batch_size
    self.caching = caching
    self.default_kwargs = kwargs

函数

__call__(inputs, batch_size=None, caching=None, **kwargs)

计算给定输入的嵌入。

参数

名称 类型 描述 默认值
inputs

要计算嵌入的输入,可以是单个字符串或字符串列表。

必需
batch_size int

处理输入的批处理大小。如果为 None,则默认为初始化时设置的 batch_size。

None
caching bool

使用托管模型时是否缓存嵌入响应。如果为 None,则默认为初始化时的缓存设置。

None
**kwargs

要传递给嵌入模型的附加关键字参数。这些将覆盖初始化时提供的默认 kwargs。

{}

返回值

类型 描述

numpy.ndarray:如果输入是单个字符串,则返回表示嵌入的一维 numpy 数组。

如果输入是字符串列表,则返回嵌入的二维 numpy 数组,每行一个嵌入。

源代码位于 dspy/clients/embedding.py
def __call__(self, inputs, batch_size=None, caching=None, **kwargs):
    """Compute embeddings for the given inputs.

    Args:
        inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
        batch_size (int, optional): The batch size for processing inputs. If None, defaults to the batch_size set
            during initialization.
        caching (bool, optional): Whether to cache the embedding response when using a hosted model. If None,
            defaults to the caching setting from initialization.
        **kwargs: Additional keyword arguments to pass to the embedding model. These will override the default
            kwargs provided during initialization.

    Returns:
        numpy.ndarray: If the input is a single string, returns a 1D numpy array representing the embedding.
        If the input is a list of strings, returns a 2D numpy array of embeddings, one embedding per row.
    """
    input_batches, caching, kwargs, is_single_input = self._preprocess(inputs, batch_size, caching, **kwargs)

    compute_embeddings = _cached_compute_embeddings if caching else _compute_embeddings

    embeddings_list = []

    for batch in input_batches:
        embeddings_list.extend(compute_embeddings(self.model, batch, caching=caching, **kwargs))
    return self._postprocess(embeddings_list, is_single_input)

acall(inputs, batch_size=None, caching=None, **kwargs) 异步

源代码位于 dspy/clients/embedding.py
async def acall(self, inputs, batch_size=None, caching=None, **kwargs):
    input_batches, caching, kwargs, is_single_input = self._preprocess(inputs, batch_size, caching, **kwargs)

    embeddings_list = []
    acompute_embeddings = _cached_acompute_embeddings if caching else _acompute_embeddings

    for batch in input_batches:
        embeddings_list.extend(await acompute_embeddings(self.model, batch, caching=caching, **kwargs))
    return self._postprocess(embeddings_list, is_single_input)