跳到内容

dspy.Refine

dspy.Refine(module: Module, N: int, reward_fn: Callable[[dict, Prediction], float], threshold: float, fail_count: Optional[int] = None)

基类: Module

通过以不同的温度运行模块最多 N 次来优化它,并返回最佳预测。

此模块以不同的温度设置多次运行提供的模块,并选择第一个超过指定阈值的预测或奖励最高的预测。如果没有预测达到阈值,它会自动生成反馈以改进未来的预测。

参数

名称 类型 描述 默认值
module Module

要优化的模块。

必需
N int

运行模块的次数。必须

必需
reward_fn Callable

奖励函数。

必需
threshold float

奖励函数的阈值。

必需
fail_count 可选[int]

模块在引发错误之前可以失败的次数

None
Example
import dspy

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

# Define a QA module with chain of thought
qa = dspy.ChainOfThought("question -> answer")

# Define a reward function that checks for one-word answers
def one_word_answer(args, pred):
    return 1.0 if len(pred.answer.split()) == 1 else 0.0

# Create a refined module that tries up to 3 times
best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

# Use the refined module
result = best_of_3(question="What is the capital of Belgium?").answer
# Returns: Brussels
源代码位于 dspy/predict/refine.py
def __init__(
    self,
    module: Module,
    N: int,  # noqa: N803
    reward_fn: Callable[[dict, Prediction], float],
    threshold: float,
    fail_count: Optional[int] = None,
):
    """
    Refines a module by running it up to N times with different temperatures and returns the best prediction.

    This module runs the provided module multiple times with varying temperature settings and selects
    either the first prediction that exceeds the specified threshold or the one with the highest reward.
    If no prediction meets the threshold, it automatically generates feedback to improve future predictions.


    Args:
        module (Module): The module to refine.
        N (int): The number of times to run the module. must
        reward_fn (Callable): The reward function.
        threshold (float): The threshold for the reward function.
        fail_count (Optional[int], optional): The number of times the module can fail before raising an error

    Example:
        ```python
        import dspy

        dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini"))

        # Define a QA module with chain of thought
        qa = dspy.ChainOfThought("question -> answer")

        # Define a reward function that checks for one-word answers
        def one_word_answer(args, pred):
            return 1.0 if len(pred.answer.split()) == 1 else 0.0

        # Create a refined module that tries up to 3 times
        best_of_3 = dspy.Refine(module=qa, N=3, reward_fn=one_word_answer, threshold=1.0)

        # Use the refined module
        result = best_of_3(question="What is the capital of Belgium?").answer
        # Returns: Brussels
        ```
    """
    self.module = module
    self.reward_fn = lambda *args: reward_fn(*args)  # to prevent this from becoming a parameter
    self.threshold = threshold
    self.N = N
    self.fail_count = fail_count or N  # default to N if fail_count is not provided
    self.module_code = inspect.getsource(module.__class__)
    try:
        self.reward_fn_code = inspect.getsource(reward_fn)
    except TypeError:
        self.reward_fn_code = inspect.getsource(reward_fn.__class__)

函数

__call__(*args, **kwargs)

源代码位于 dspy/primitives/program.py
@with_callbacks
def __call__(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = self.forward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return self.forward(*args, **kwargs)

acall(*args, **kwargs) 异步

源代码位于 dspy/primitives/program.py
@with_callbacks
async def acall(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = await self.aforward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return await self.aforward(*args, **kwargs)

batch(examples, num_threads: Optional[int] = None, max_errors: int = 10, return_failed_examples: bool = False, provide_traceback: Optional[bool] = None, disable_progress_bar: bool = False)

使用 Parallel 模块并行处理 dspy.Example 实例列表。

:param examples: 要处理的 dspy.Example 实例列表。 :param num_threads: 用于并行处理的线程数。 :param max_errors: 停止执行前允许的最大错误数。 :param return_failed_examples: 是否返回失败的示例和异常。 :param provide_traceback: 是否在错误日志中包含追溯信息。 :return: 结果列表,以及可选的失败示例和异常。

源代码位于 dspy/primitives/program.py
def batch(
    self,
    examples,
    num_threads: Optional[int] = None,
    max_errors: int = 10,
    return_failed_examples: bool = False,
    provide_traceback: Optional[bool] = None,
    disable_progress_bar: bool = False,
):
    """
    Processes a list of dspy.Example instances in parallel using the Parallel module.

    :param examples: List of dspy.Example instances to process.
    :param num_threads: Number of threads to use for parallel processing.
    :param max_errors: Maximum number of errors allowed before stopping execution.
    :param return_failed_examples: Whether to return failed examples and exceptions.
    :param provide_traceback: Whether to include traceback information in error logs.
    :return: List of results, and optionally failed examples and exceptions.
    """
    # Create a list of execution pairs (self, example)
    exec_pairs = [(self, example.inputs()) for example in examples]

    # Create an instance of Parallel
    parallel_executor = Parallel(
        num_threads=num_threads,
        max_errors=max_errors,
        return_failed_examples=return_failed_examples,
        provide_traceback=provide_traceback,
        disable_progress_bar=disable_progress_bar,
    )

    # Execute the forward method of Parallel
    if return_failed_examples:
        results, failed_examples, exceptions = parallel_executor.forward(exec_pairs)
        return results, failed_examples, exceptions
    else:
        results = parallel_executor.forward(exec_pairs)
        return results

deepcopy()

深拷贝此模块。

这是对默认 Python 深拷贝的一个调整,它只深拷贝 self.parameters(),对于其他属性,我们只进行浅拷贝。

源代码位于 dspy/primitives/module.py
def deepcopy(self):
    """Deep copy the module.

    This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
    attributes, we just do the shallow copy.
    """
    try:
        # If the instance itself is copyable, we can just deep copy it.
        # Otherwise we will have to create a new instance and copy over the attributes one by one.
        return copy.deepcopy(self)
    except Exception:
        pass

    # Create an empty instance.
    new_instance = self.__class__.__new__(self.__class__)
    # Set attribuetes of the copied instance.
    for attr, value in self.__dict__.items():
        if isinstance(value, BaseModule):
            setattr(new_instance, attr, value.deepcopy())
        else:
            try:
                # Try to deep copy the attribute
                setattr(new_instance, attr, copy.deepcopy(value))
            except Exception:
                logging.warning(
                    f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, "
                    "falling back to shallow copy or reference copy."
                )
                try:
                    # Fallback to shallow copy if deep copy fails
                    setattr(new_instance, attr, copy.copy(value))
                except Exception:
                    # If even the shallow copy fails, we just copy over the reference.
                    setattr(new_instance, attr, value)

    return new_instance

dump_state()

源代码位于 dspy/primitives/module.py
def dump_state(self):
    return {name: param.dump_state() for name, param in self.named_parameters()}

forward(**kwargs)

源代码位于 dspy/predict/refine.py
def forward(self, **kwargs):
    lm = self.module.get_lm() or dspy.settings.lm
    temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
    temps = list(dict.fromkeys(temps))[: self.N]
    best_pred, best_trace, best_reward = None, None, -float("inf")
    advice = None
    adapter = dspy.settings.adapter or dspy.ChatAdapter()

    for idx, t in enumerate(temps):
        lm_ = lm.copy(temperature=t)
        mod = self.module.deepcopy()
        mod.set_lm(lm_)

        predictor2name = {predictor: name for name, predictor in mod.named_predictors()}
        signature2name = {predictor.signature: name for name, predictor in mod.named_predictors()}
        module_names = [name for name, _ in mod.named_predictors()]

        try:
            with dspy.context(trace=[]):
                if not advice:
                    outputs = mod(**kwargs)
                else:

                    class WrapperAdapter(adapter.__class__):
                        def __call__(self, lm, lm_kwargs, signature, demos, inputs):
                            inputs["hint_"] = advice.get(signature2name[signature], "N/A")  # noqa: B023
                            signature = signature.append(
                                "hint_", InputField(desc="A hint to the module from an earlier run")
                            )
                            return adapter(lm, lm_kwargs, signature, demos, inputs)

                    with dspy.context(adapter=WrapperAdapter()):
                        outputs = mod(**kwargs)

                trace = dspy.settings.trace.copy()

                # TODO: Remove the hint from the trace, if it's there.

                # NOTE: Not including the trace of reward_fn.
                reward = self.reward_fn(kwargs, outputs)

            if reward > best_reward:
                best_reward, best_pred, best_trace = reward, outputs, trace

            if self.threshold is not None and reward >= self.threshold:
                break

            if idx == self.N - 1:
                break

            modules = {"program_code": self.module_code, "modules_defn": inspect_modules(mod)}
            trajectory = [{"module_name": predictor2name[p], "inputs": i, "outputs": dict(o)} for p, i, o in trace]
            trajectory = {
                "program_inputs": kwargs,
                "program_trajectory": trajectory,
                "program_outputs": dict(outputs),
            }
            reward = {
                "reward_code": self.reward_fn_code,
                "target_threshold": self.threshold,
                "reward_value": reward,
            }

            advise_kwargs = dict(**modules, **trajectory, **reward, module_names=module_names)
            # advise_kwargs = {k: ujson.dumps(recursive_mask(v), indent=2) for k, v in advise_kwargs.items()}
            # only dumps if it's a list or dict
            advise_kwargs = {
                k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
                for k, v in advise_kwargs.items()
            }
            advice = dspy.Predict(OfferFeedback)(**advise_kwargs).advice
            # print(f"Advice for each module: {advice}")

        except Exception as e:
            print(f"Refine: Attempt failed with temperature {t}: {e}")
            if idx > self.fail_count:
                raise e
            self.fail_count -= 1
    if best_trace:
        dspy.settings.trace.extend(best_trace)
    return best_pred

get_lm()

源代码位于 dspy/primitives/program.py
def get_lm(self):
    all_used_lms = [param.lm for _, param in self.named_predictors()]

    if len(set(all_used_lms)) == 1:
        return all_used_lms[0]

    raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.")

load(path)

加载已保存的模块。如果想加载整个程序,而不仅仅是现有程序的状态,您可能还需要查看 dspy.load。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,应为 .json 或 .pkl 文件

必需
源代码位于 dspy/primitives/module.py
def load(self, path):
    """Load the saved module. You may also want to check out dspy.load, if you want to
    load an entire program, not just the state for an existing program.

    Args:
        path (str): Path to the saved state file, which should be a .json or a .pkl file
    """
    path = Path(path)

    if path.suffix == ".json":
        with open(path) as f:
            state = ujson.loads(f.read())
    elif path.suffix == ".pkl":
        with open(path, "rb") as f:
            state = cloudpickle.load(f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}")

    dependency_versions = get_dependency_versions()
    saved_dependency_versions = state["metadata"]["dependency_versions"]
    for key, saved_version in saved_dependency_versions.items():
        if dependency_versions[key] != saved_version:
            logger.warning(
                f"There is a mismatch of {key} version between saved model and current environment. "
                f"You saved with `{key}=={saved_version}`, but now you have "
                f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade "
                "on the loaded model, please consider loading the model in the same environment as the "
                "saving environment."
            )
    self.load_state(state)

load_state(state)

源代码位于 dspy/primitives/module.py
def load_state(self, state):
    for name, param in self.named_parameters():
        param.load_state(state[name])

map_named_predictors(func)

对所有命名预测器应用函数。

源代码位于 dspy/primitives/program.py
def map_named_predictors(self, func):
    """Applies a function to all named predictors."""
    for name, predictor in self.named_predictors():
        set_attribute_by_name(self, name, func(predictor))
    return self

named_parameters()

与 PyTorch 不同,这也处理(非递归)参数列表。

源代码位于 dspy/primitives/module.py
def named_parameters(self):
    """
    Unlike PyTorch, handles (non-recursive) lists of parameters too.
    """

    import dspy
    from dspy.predict.parameter import Parameter

    visited = set()
    named_parameters = []

    def add_parameter(param_name, param_value):
        if isinstance(param_value, Parameter):
            if id(param_value) not in visited:
                visited.add(id(param_value))
                param_name = postprocess_parameter_name(param_name, param_value)
                named_parameters.append((param_name, param_value))

        elif isinstance(param_value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(param_value, "_compiled", False):
                for sub_name, param in param_value.named_parameters():
                    add_parameter(f"{param_name}.{sub_name}", param)

    if isinstance(self, Parameter):
        add_parameter("self", self)

    for name, value in self.__dict__.items():
        if isinstance(value, Parameter):
            add_parameter(name, value)

        elif isinstance(value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(value, "_compiled", False):
                for sub_name, param in value.named_parameters():
                    add_parameter(f"{name}.{sub_name}", param)

        elif isinstance(value, (list, tuple)):
            for idx, item in enumerate(value):
                add_parameter(f"{name}[{idx}]", item)

        elif isinstance(value, dict):
            for key, item in value.items():
                add_parameter(f"{name}['{key}']", item)

    return named_parameters

named_predictors()

源代码位于 dspy/primitives/program.py
def named_predictors(self):
    from dspy.predict.predict import Predict

    return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)]

named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]

查找模块中的所有子模块及其名称。

例如,self.children[4]['key'].sub_module 是一个子模块。则名称将是 'children[4][key].sub_module'。但如果子模块可以通过不同的路径访问,则只返回其中一条路径。

源代码位于 dspy/primitives/module.py
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
    """Find all sub-modules in the module, as well as their names.

    Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
    'children[4][key].sub_module'. But if the sub-module is accessible at different
    paths, only one of the paths will be returned.
    """
    if type_ is None:
        type_ = BaseModule

    queue = deque([("self", self)])
    seen = {id(self)}

    def add_to_queue(name, item):
        name = postprocess_parameter_name(name, item)

        if id(item) not in seen:
            seen.add(id(item))
            queue.append((name, item))

    while queue:
        name, item = queue.popleft()

        if isinstance(item, type_):
            yield name, item

        if isinstance(item, BaseModule):
            if skip_compiled and getattr(item, "_compiled", False):
                continue
            for sub_name, sub_item in item.__dict__.items():
                add_to_queue(f"{name}.{sub_name}", sub_item)

        elif isinstance(item, (list, tuple)):
            for i, sub_item in enumerate(item):
                add_to_queue(f"{name}[{i}]", sub_item)

        elif isinstance(item, dict):
            for key, sub_item in item.items():
                add_to_queue(f"{name}[{key}]", sub_item)

parameters()

源代码位于 dspy/primitives/module.py
def parameters(self):
    return [param for _, param in self.named_parameters()]

predictors()

源代码位于 dspy/primitives/program.py
def predictors(self):
    return [param for _, param in self.named_predictors()]

reset_copy()

深拷贝此模块并重置所有参数。

源代码位于 dspy/primitives/module.py
def reset_copy(self):
    """Deep copy the module and reset all parameters."""
    new_instance = self.deepcopy()

    for param in new_instance.parameters():
        param.reset()

    return new_instance

save(path, save_program=False)

保存模块。

将模块保存到目录或文件。有两种模式: - save_program=False:仅根据文件扩展名将模块状态保存到 json 或 pickle 文件。 - save_program=True:通过 cloudpickle 将整个模块保存到目录,其中包含模型的状态和架构。

我们还保存了依赖版本,以便加载的模型可以检查关键依赖项或 DSPy 版本是否存在版本不匹配。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,当 save_program=False 时应为 .json 或 .pkl 文件,当 save_program=True 时应为目录。

必需
save_program bool

如果为 True,则通过 cloudpickle 将整个模块保存到目录,否则仅保存状态。

False
源代码位于 dspy/primitives/module.py
def save(self, path, save_program=False):
    """Save the module.

    Save the module to a directory or a file. There are two modes:
    - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of
        the file extension.
    - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
        architecture of the model.

    We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
    critical dependencies or DSPy version.

    Args:
        path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`,
            and a directory when `save_program=True`.
        save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save
            the state.
    """
    metadata = {}
    metadata["dependency_versions"] = get_dependency_versions()
    path = Path(path)

    if save_program:
        if path.suffix:
            raise ValueError(
                f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}"
            )
        if path.exists() and not path.is_dir():
            raise NotADirectoryError(f"The path '{path}' exists but is not a directory.")

        if not path.exists():
            # Create the directory (and any parent directories)
            path.mkdir(parents=True)

        try:
            with open(path / "program.pkl", "wb") as f:
                cloudpickle.dump(self, f)
        except Exception as e:
            raise RuntimeError(
                f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
                "or consider using state-only saving by setting `save_program=False`."
            )
        with open(path / "metadata.json", "w") as f:
            ujson.dump(metadata, f, indent=2)

        return

    state = self.dump_state()
    state["metadata"] = metadata
    if path.suffix == ".json":
        try:
            with open(path, "w") as f:
                f.write(ujson.dumps(state, indent=2))
        except Exception as e:
            raise RuntimeError(
                f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
                "json-serializable objects, please consider saving the state in .pkl by using `path` ending "
                "with `.pkl`, or saving the whole program by setting `save_program=True`."
            )
    elif path.suffix == ".pkl":
        with open(path, "wb") as f:
            cloudpickle.dump(state, f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

set_lm(lm)

源代码位于 dspy/primitives/program.py
def set_lm(self, lm):
    for _, param in self.named_predictors():
        param.lm = lm