跳至内容

dspy.BestOfN

dspy.BestOfN(module: Module, N: int, reward_fn: Callable[[dict, Prediction], float], threshold: float, fail_count: Optional[int] = None)

继承自: Module

使用不同的温度运行一个模块最多 N 次,并返回 N 次尝试中最佳的预测,或者第一个通过 threshold 的预测。

参数

名称 类型 描述 默认值
module Module

要运行的模块。

必需
N int

模块运行的次数。

必需
reward_fn Callable[[dict, Prediction], float]

奖励函数,接收传递给模块的参数和生成的预测,并返回一个标量奖励值。

必需
threshold float

奖励函数的阈值。

必需
fail_count Optional[int]

模块在引发错误前可以失败的次数。如果未提供,默认为 N。

None
Example
import dspy

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

# Define a QA module with chain of thought
qa = dspy.ChainOfThought("question -> answer")

# Define a reward function that checks for one-word answers
def one_word_answer(args, pred):
    return 1.0 if len(pred.answer.split()) == 1 else 0.0

# Create a refined module that tries up to 3 times
best_of_3 = dspy.BestOfN(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

# Use the refined module
result = best_of_3(question="What is the capital of Belgium?").answer
# Returns: Brussels
源代码位于 dspy/predict/best_of_n.py
def __init__(
    self,
    module: Module,
    N: int,  # noqa: N803
    reward_fn: Callable[[dict, Prediction], float],
    threshold: float,
    fail_count: Optional[int] = None,
):
    """
    Runs a module up to `N` times with different temperatures and returns the best prediction
    out of `N` attempts or the first prediction that passes the `threshold`.

    Args:
        module (Module): The module to run.
        N (int): The number of times to run the module.
        reward_fn (Callable[[dict, Prediction], float]): The reward function which takes in the args passed to the module, the resulting prediction, and returns a scalar reward.
        threshold (float): The threshold for the reward function.
        fail_count (Optional[int], optional): The number of times the module can fail before raising an error. Defaults to N if not provided.

    Example:
        ```python
        import dspy

        dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

        # Define a QA module with chain of thought
        qa = dspy.ChainOfThought("question -> answer")

        # Define a reward function that checks for one-word answers
        def one_word_answer(args, pred):
            return 1.0 if len(pred.answer.split()) == 1 else 0.0

        # Create a refined module that tries up to 3 times
        best_of_3 = dspy.BestOfN(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

        # Use the refined module
        result = best_of_3(question="What is the capital of Belgium?").answer
        # Returns: Brussels
        ```
    """
    self.module = module
    self.reward_fn = lambda *args: reward_fn(*args)  # to prevent this from becoming a parameter
    self.threshold = threshold
    self.N = N
    self.fail_count = fail_count or N  # default to N if fail_count is not provided

函数

__call__(*args, **kwargs)

源代码位于 dspy/primitives/program.py
@with_callbacks
def __call__(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = self.forward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return self.forward(*args, **kwargs)

acall(*args, **kwargs) async

源代码位于 dspy/primitives/program.py
@with_callbacks
async def acall(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = await self.aforward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return await self.aforward(*args, **kwargs)

batch(examples, num_threads: Optional[int] = None, max_errors: int = 10, return_failed_examples: bool = False, provide_traceback: Optional[bool] = None, disable_progress_bar: bool = False)

使用 Parallel 模块并行处理 dspy.Example 实例列表。

:param examples: 要处理的 dspy.Example 实例列表。 :param num_threads: 用于并行处理的线程数。 :param max_errors: 停止执行前允许的最大错误数。 :param return_failed_examples: 是否返回失败的示例和异常。 :param provide_traceback: 是否在错误日志中包含堆栈跟踪信息。 :return: 结果列表,以及可选的失败示例和异常。

源代码位于 dspy/primitives/program.py
def batch(
    self,
    examples,
    num_threads: Optional[int] = None,
    max_errors: int = 10,
    return_failed_examples: bool = False,
    provide_traceback: Optional[bool] = None,
    disable_progress_bar: bool = False,
):
    """
    Processes a list of dspy.Example instances in parallel using the Parallel module.

    :param examples: List of dspy.Example instances to process.
    :param num_threads: Number of threads to use for parallel processing.
    :param max_errors: Maximum number of errors allowed before stopping execution.
    :param return_failed_examples: Whether to return failed examples and exceptions.
    :param provide_traceback: Whether to include traceback information in error logs.
    :return: List of results, and optionally failed examples and exceptions.
    """
    # Create a list of execution pairs (self, example)
    exec_pairs = [(self, example.inputs()) for example in examples]

    # Create an instance of Parallel
    parallel_executor = Parallel(
        num_threads=num_threads,
        max_errors=max_errors,
        return_failed_examples=return_failed_examples,
        provide_traceback=provide_traceback,
        disable_progress_bar=disable_progress_bar,
    )

    # Execute the forward method of Parallel
    if return_failed_examples:
        results, failed_examples, exceptions = parallel_executor.forward(exec_pairs)
        return results, failed_examples, exceptions
    else:
        results = parallel_executor.forward(exec_pairs)
        return results

deepcopy()

深度复制模块。

这是对 Python 默认深度复制的调整,它只深度复制 self.parameters(),对于其他属性,我们只进行浅复制。

源代码位于 dspy/primitives/module.py
def deepcopy(self):
    """Deep copy the module.

    This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
    attributes, we just do the shallow copy.
    """
    try:
        # If the instance itself is copyable, we can just deep copy it.
        # Otherwise we will have to create a new instance and copy over the attributes one by one.
        return copy.deepcopy(self)
    except Exception:
        pass

    # Create an empty instance.
    new_instance = self.__class__.__new__(self.__class__)
    # Set attribuetes of the copied instance.
    for attr, value in self.__dict__.items():
        if isinstance(value, BaseModule):
            setattr(new_instance, attr, value.deepcopy())
        else:
            try:
                # Try to deep copy the attribute
                setattr(new_instance, attr, copy.deepcopy(value))
            except Exception:
                logging.warning(
                    f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, "
                    "falling back to shallow copy or reference copy."
                )
                try:
                    # Fallback to shallow copy if deep copy fails
                    setattr(new_instance, attr, copy.copy(value))
                except Exception:
                    # If even the shallow copy fails, we just copy over the reference.
                    setattr(new_instance, attr, value)

    return new_instance

dump_state()

源代码位于 dspy/primitives/module.py
def dump_state(self):
    return {name: param.dump_state() for name, param in self.named_parameters()}

forward(**kwargs)

源代码位于 dspy/predict/best_of_n.py
def forward(self, **kwargs):
    lm = self.module.get_lm() or dspy.settings.lm
    temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
    temps = list(dict.fromkeys(temps))[: self.N]
    best_pred, best_trace, best_reward = None, None, -float("inf")

    for idx, t in enumerate(temps):
        lm_ = lm.copy(temperature=t)
        mod = self.module.deepcopy()
        mod.set_lm(lm_)

        try:
            with dspy.context(trace=[]):
                pred = mod(**kwargs)
                trace = dspy.settings.trace.copy()

                # NOTE: Not including the trace of reward_fn.
                reward = self.reward_fn(kwargs, pred)

            if reward > best_reward:
                best_reward, best_pred, best_trace = reward, pred, trace

            if reward >= self.threshold:
                break

        except Exception as e:
            print(f"BestOfN: Attempt {idx + 1} failed with temperature {t}: {e}")
            if idx > self.fail_count:
                raise e
            self.fail_count -= 1

    if best_trace:
        dspy.settings.trace.extend(best_trace)
    return best_pred

get_lm()

源代码位于 dspy/primitives/program.py
def get_lm(self):
    all_used_lms = [param.lm for _, param in self.named_predictors()]

    if len(set(all_used_lms)) == 1:
        return all_used_lms[0]

    raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.")

load(path)

加载已保存的模块。如果您想加载整个程序,而不仅仅是现有程序的状态,您可能还需要查看 dspy.load。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,应为 .json 或 .pkl 文件

必需
源代码位于 dspy/primitives/module.py
def load(self, path):
    """Load the saved module. You may also want to check out dspy.load, if you want to
    load an entire program, not just the state for an existing program.

    Args:
        path (str): Path to the saved state file, which should be a .json or a .pkl file
    """
    path = Path(path)

    if path.suffix == ".json":
        with open(path) as f:
            state = ujson.loads(f.read())
    elif path.suffix == ".pkl":
        with open(path, "rb") as f:
            state = cloudpickle.load(f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}")

    dependency_versions = get_dependency_versions()
    saved_dependency_versions = state["metadata"]["dependency_versions"]
    for key, saved_version in saved_dependency_versions.items():
        if dependency_versions[key] != saved_version:
            logger.warning(
                f"There is a mismatch of {key} version between saved model and current environment. "
                f"You saved with `{key}=={saved_version}`, but now you have "
                f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade "
                "on the loaded model, please consider loading the model in the same environment as the "
                "saving environment."
            )
    self.load_state(state)

load_state(state)

源代码位于 dspy/primitives/module.py
def load_state(self, state):
    for name, param in self.named_parameters():
        param.load_state(state[name])

map_named_predictors(func)

将函数应用于所有具名的预测器。

源代码位于 dspy/primitives/program.py
def map_named_predictors(self, func):
    """Applies a function to all named predictors."""
    for name, predictor in self.named_predictors():
        set_attribute_by_name(self, name, func(predictor))
    return self

named_parameters()

与 PyTorch 不同,它也处理(非递归的)参数列表。

源代码位于 dspy/primitives/module.py
def named_parameters(self):
    """
    Unlike PyTorch, handles (non-recursive) lists of parameters too.
    """

    import dspy
    from dspy.predict.parameter import Parameter

    visited = set()
    named_parameters = []

    def add_parameter(param_name, param_value):
        if isinstance(param_value, Parameter):
            if id(param_value) not in visited:
                visited.add(id(param_value))
                param_name = postprocess_parameter_name(param_name, param_value)
                named_parameters.append((param_name, param_value))

        elif isinstance(param_value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(param_value, "_compiled", False):
                for sub_name, param in param_value.named_parameters():
                    add_parameter(f"{param_name}.{sub_name}", param)

    if isinstance(self, Parameter):
        add_parameter("self", self)

    for name, value in self.__dict__.items():
        if isinstance(value, Parameter):
            add_parameter(name, value)

        elif isinstance(value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(value, "_compiled", False):
                for sub_name, param in value.named_parameters():
                    add_parameter(f"{name}.{sub_name}", param)

        elif isinstance(value, (list, tuple)):
            for idx, item in enumerate(value):
                add_parameter(f"{name}[{idx}]", item)

        elif isinstance(value, dict):
            for key, item in value.items():
                add_parameter(f"{name}['{key}']", item)

    return named_parameters

named_predictors()

源代码位于 dspy/primitives/program.py
def named_predictors(self):
    from dspy.predict.predict import Predict

    return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)]

named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]

查找模块中的所有子模块及其名称。

例如,如果 self.children[4]['key'].sub_module 是一个子模块,则名称将是 'children[4][key].sub_module'。但如果子模块可通过不同的路径访问,则只返回其中一个路径。

源代码位于 dspy/primitives/module.py
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
    """Find all sub-modules in the module, as well as their names.

    Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
    'children[4][key].sub_module'. But if the sub-module is accessible at different
    paths, only one of the paths will be returned.
    """
    if type_ is None:
        type_ = BaseModule

    queue = deque([("self", self)])
    seen = {id(self)}

    def add_to_queue(name, item):
        name = postprocess_parameter_name(name, item)

        if id(item) not in seen:
            seen.add(id(item))
            queue.append((name, item))

    while queue:
        name, item = queue.popleft()

        if isinstance(item, type_):
            yield name, item

        if isinstance(item, BaseModule):
            if skip_compiled and getattr(item, "_compiled", False):
                continue
            for sub_name, sub_item in item.__dict__.items():
                add_to_queue(f"{name}.{sub_name}", sub_item)

        elif isinstance(item, (list, tuple)):
            for i, sub_item in enumerate(item):
                add_to_queue(f"{name}[{i}]", sub_item)

        elif isinstance(item, dict):
            for key, sub_item in item.items():
                add_to_queue(f"{name}[{key}]", sub_item)

parameters()

源代码位于 dspy/primitives/module.py
def parameters(self):
    return [param for _, param in self.named_parameters()]

predictors()

源代码位于 dspy/primitives/program.py
def predictors(self):
    return [param for _, param in self.named_predictors()]

reset_copy()

深度复制模块并重置所有参数。

源代码位于 dspy/primitives/module.py
def reset_copy(self):
    """Deep copy the module and reset all parameters."""
    new_instance = self.deepcopy()

    for param in new_instance.parameters():
        param.reset()

    return new_instance

save(path, save_program=False)

保存模块。

将模块保存到目录或文件。有两种模式:- save_program=False:仅根据文件扩展名将模块的状态保存到 json 或 pickle 文件。- save_program=True:通过 cloudpickle 将整个模块保存到目录,其中包含模型的状态和架构。

我们还保存了依赖项版本,以便加载的模型可以检查关键依赖项或 DSPy 版本是否存在版本不匹配。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,当 save_program=False 时应为 .json 或 .pkl 文件,当 save_program=True 时应为目录。

必需
save_program bool

如果为 True,则通过 cloudpickle 将整个模块保存到目录,否则仅保存状态。

False
源代码位于 dspy/primitives/module.py
def save(self, path, save_program=False):
    """Save the module.

    Save the module to a directory or a file. There are two modes:
    - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of
        the file extension.
    - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
        architecture of the model.

    We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
    critical dependencies or DSPy version.

    Args:
        path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`,
            and a directory when `save_program=True`.
        save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save
            the state.
    """
    metadata = {}
    metadata["dependency_versions"] = get_dependency_versions()
    path = Path(path)

    if save_program:
        if path.suffix:
            raise ValueError(
                f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}"
            )
        if path.exists() and not path.is_dir():
            raise NotADirectoryError(f"The path '{path}' exists but is not a directory.")

        if not path.exists():
            # Create the directory (and any parent directories)
            path.mkdir(parents=True)

        try:
            with open(path / "program.pkl", "wb") as f:
                cloudpickle.dump(self, f)
        except Exception as e:
            raise RuntimeError(
                f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
                "or consider using state-only saving by setting `save_program=False`."
            )
        with open(path / "metadata.json", "w") as f:
            ujson.dump(metadata, f, indent=2)

        return

    state = self.dump_state()
    state["metadata"] = metadata
    if path.suffix == ".json":
        try:
            with open(path, "w") as f:
                f.write(ujson.dumps(state, indent=2))
        except Exception as e:
            raise RuntimeError(
                f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
                "json-serializable objects, please consider saving the state in .pkl by using `path` ending "
                "with `.pkl`, or saving the whole program by setting `save_program=True`."
            )
    elif path.suffix == ".pkl":
        with open(path, "wb") as f:
            cloudpickle.dump(state, f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

set_lm(lm)

源代码位于 dspy/primitives/program.py
def set_lm(self, lm):
    for _, param in self.named_predictors():
        param.lm = lm