跳到内容

dspy.ReAct

dspy.ReAct(signature, tools: list[Callable], max_iters=5)

基类:Module

tools 可以是函数列表、可调用类或 dspy.Tool 实例。

源代码位于 dspy/predict/react.py
def __init__(self, signature, tools: list[Callable], max_iters=5):
    """
    `tools` is either a list of functions, callable classes, or `dspy.Tool` instances.
    """

    self.signature = signature = ensure_signature(signature)
    self.max_iters = max_iters

    tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
    tools = {tool.name: tool for tool in tools}

    inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
    outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
    instr = [f"{signature.instructions}\n"] if signature.instructions else []

    instr.extend(
        [
            f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.",
            f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n",
            "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.",
            "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n",
            "When writing next_thought, you may reason about the current situation and plan for future steps.",
            "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",
        ]
    )

    tools["finish"] = Tool(
        func=lambda: "Completed.",
        name="finish",
        desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.",
        args={},
    )

    for idx, tool in enumerate(tools.values()):
        instr.append(f"({idx + 1}) {tool}")

    react_signature = (
        dspy.Signature({**signature.input_fields}, "\n".join(instr))
        .append("trajectory", dspy.InputField(), type_=str)
        .append("next_thought", dspy.OutputField(), type_=str)
        .append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
        .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
    )

    fallback_signature = dspy.Signature(
        {**signature.input_fields, **signature.output_fields},
        signature.instructions,
    ).append("trajectory", dspy.InputField(), type_=str)

    self.tools = tools
    self.react = dspy.Predict(react_signature)
    self.extract = dspy.ChainOfThought(fallback_signature)

函数

__call__(*args, **kwargs)

源代码位于 dspy/primitives/program.py
@with_callbacks
def __call__(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = self.forward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return self.forward(*args, **kwargs)

acall(*args, **kwargs) 异步

源代码位于 dspy/primitives/program.py
@with_callbacks
async def acall(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = await self.aforward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return await self.aforward(*args, **kwargs)

aforward(**input_args) 异步

源代码位于 dspy/predict/react.py
async def aforward(self, **input_args):
    trajectory = {}
    max_iters = input_args.pop("max_iters", self.max_iters)
    for idx in range(max_iters):
        try:
            pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
        except ValueError as err:
            logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
            break

        trajectory[f"thought_{idx}"] = pred.next_thought
        trajectory[f"tool_name_{idx}"] = pred.next_tool_name
        trajectory[f"tool_args_{idx}"] = pred.next_tool_args

        try:
            trajectory[f"observation_{idx}"] = await self.tools[pred.next_tool_name].acall(**pred.next_tool_args)
        except Exception as err:
            trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

        if pred.next_tool_name == "finish":
            break

    extract = await self._async_call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
    return dspy.Prediction(trajectory=trajectory, **extract)

batch(examples, num_threads: Optional[int] = None, max_errors: int = 10, return_failed_examples: bool = False, provide_traceback: Optional[bool] = None, disable_progress_bar: bool = False)

使用 Parallel 模块并行处理 dspy.Example 实例列表。

:param examples: 要处理的 dspy.Example 实例列表。
:param num_threads: 用于并行处理的线程数。
:param max_errors: 在停止执行前允许的最大错误数。
:param return_failed_examples: 是否返回失败的示例和异常。
:param provide_traceback: 是否在错误日志中包含堆栈跟踪信息。
:return: 结果列表,以及可选的失败示例和异常。

源代码位于 dspy/primitives/program.py
def batch(
    self,
    examples,
    num_threads: Optional[int] = None,
    max_errors: int = 10,
    return_failed_examples: bool = False,
    provide_traceback: Optional[bool] = None,
    disable_progress_bar: bool = False,
):
    """
    Processes a list of dspy.Example instances in parallel using the Parallel module.

    :param examples: List of dspy.Example instances to process.
    :param num_threads: Number of threads to use for parallel processing.
    :param max_errors: Maximum number of errors allowed before stopping execution.
    :param return_failed_examples: Whether to return failed examples and exceptions.
    :param provide_traceback: Whether to include traceback information in error logs.
    :return: List of results, and optionally failed examples and exceptions.
    """
    # Create a list of execution pairs (self, example)
    exec_pairs = [(self, example.inputs()) for example in examples]

    # Create an instance of Parallel
    parallel_executor = Parallel(
        num_threads=num_threads,
        max_errors=max_errors,
        return_failed_examples=return_failed_examples,
        provide_traceback=provide_traceback,
        disable_progress_bar=disable_progress_bar,
    )

    # Execute the forward method of Parallel
    if return_failed_examples:
        results, failed_examples, exceptions = parallel_executor.forward(exec_pairs)
        return results, failed_examples, exceptions
    else:
        results = parallel_executor.forward(exec_pairs)
        return results

deepcopy()

深度复制模块。

这是对默认 Python 深度复制的调整,仅深度复制 self.parameters(),对于其他属性,我们只进行浅复制。

源代码位于 dspy/primitives/module.py
def deepcopy(self):
    """Deep copy the module.

    This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
    attributes, we just do the shallow copy.
    """
    try:
        # If the instance itself is copyable, we can just deep copy it.
        # Otherwise we will have to create a new instance and copy over the attributes one by one.
        return copy.deepcopy(self)
    except Exception:
        pass

    # Create an empty instance.
    new_instance = self.__class__.__new__(self.__class__)
    # Set attribuetes of the copied instance.
    for attr, value in self.__dict__.items():
        if isinstance(value, BaseModule):
            setattr(new_instance, attr, value.deepcopy())
        else:
            try:
                # Try to deep copy the attribute
                setattr(new_instance, attr, copy.deepcopy(value))
            except Exception:
                logging.warning(
                    f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, "
                    "falling back to shallow copy or reference copy."
                )
                try:
                    # Fallback to shallow copy if deep copy fails
                    setattr(new_instance, attr, copy.copy(value))
                except Exception:
                    # If even the shallow copy fails, we just copy over the reference.
                    setattr(new_instance, attr, value)

    return new_instance

dump_state()

源代码位于 dspy/primitives/module.py
def dump_state(self):
    return {name: param.dump_state() for name, param in self.named_parameters()}

forward(**input_args)

源代码位于 dspy/predict/react.py
def forward(self, **input_args):
    trajectory = {}
    max_iters = input_args.pop("max_iters", self.max_iters)
    for idx in range(max_iters):
        try:
            pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
        except ValueError as err:
            logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
            break

        trajectory[f"thought_{idx}"] = pred.next_thought
        trajectory[f"tool_name_{idx}"] = pred.next_tool_name
        trajectory[f"tool_args_{idx}"] = pred.next_tool_args

        try:
            trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
        except Exception as err:
            trajectory[f"observation_{idx}"] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}"

        if pred.next_tool_name == "finish":
            break

    extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
    return dspy.Prediction(trajectory=trajectory, **extract)

get_lm()

源代码位于 dspy/primitives/program.py
def get_lm(self):
    all_used_lms = [param.lm for _, param in self.named_predictors()]

    if len(set(all_used_lms)) == 1:
        return all_used_lms[0]

    raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.")

load(path)

加载已保存的模块。如果您想加载整个程序,而不仅仅是现有程序的状态,您可能还需要查看 dspy.load。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,应为 .json 或 .pkl 文件

必需
源代码位于 dspy/primitives/module.py
def load(self, path):
    """Load the saved module. You may also want to check out dspy.load, if you want to
    load an entire program, not just the state for an existing program.

    Args:
        path (str): Path to the saved state file, which should be a .json or a .pkl file
    """
    path = Path(path)

    if path.suffix == ".json":
        with open(path) as f:
            state = ujson.loads(f.read())
    elif path.suffix == ".pkl":
        with open(path, "rb") as f:
            state = cloudpickle.load(f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}")

    dependency_versions = get_dependency_versions()
    saved_dependency_versions = state["metadata"]["dependency_versions"]
    for key, saved_version in saved_dependency_versions.items():
        if dependency_versions[key] != saved_version:
            logger.warning(
                f"There is a mismatch of {key} version between saved model and current environment. "
                f"You saved with `{key}=={saved_version}`, but now you have "
                f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade "
                "on the loaded model, please consider loading the model in the same environment as the "
                "saving environment."
            )
    self.load_state(state)

load_state(state)

源代码位于 dspy/primitives/module.py
def load_state(self, state):
    for name, param in self.named_parameters():
        param.load_state(state[name])

map_named_predictors(func)

对所有命名预测器应用一个函数。

源代码位于 dspy/primitives/program.py
def map_named_predictors(self, func):
    """Applies a function to all named predictors."""
    for name, predictor in self.named_predictors():
        set_attribute_by_name(self, name, func(predictor))
    return self

named_parameters()

与 PyTorch 不同,也处理(非递归的)参数列表。

源代码位于 dspy/primitives/module.py
def named_parameters(self):
    """
    Unlike PyTorch, handles (non-recursive) lists of parameters too.
    """

    import dspy
    from dspy.predict.parameter import Parameter

    visited = set()
    named_parameters = []

    def add_parameter(param_name, param_value):
        if isinstance(param_value, Parameter):
            if id(param_value) not in visited:
                visited.add(id(param_value))
                param_name = postprocess_parameter_name(param_name, param_value)
                named_parameters.append((param_name, param_value))

        elif isinstance(param_value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(param_value, "_compiled", False):
                for sub_name, param in param_value.named_parameters():
                    add_parameter(f"{param_name}.{sub_name}", param)

    if isinstance(self, Parameter):
        add_parameter("self", self)

    for name, value in self.__dict__.items():
        if isinstance(value, Parameter):
            add_parameter(name, value)

        elif isinstance(value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(value, "_compiled", False):
                for sub_name, param in value.named_parameters():
                    add_parameter(f"{name}.{sub_name}", param)

        elif isinstance(value, (list, tuple)):
            for idx, item in enumerate(value):
                add_parameter(f"{name}[{idx}]", item)

        elif isinstance(value, dict):
            for key, item in value.items():
                add_parameter(f"{name}['{key}']", item)

    return named_parameters

named_predictors()

源代码位于 dspy/primitives/program.py
def named_predictors(self):
    from dspy.predict.predict import Predict

    return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)]

named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]

查找模块中的所有子模块及其名称。

例如,self.children[4]['key'].sub_module 是一个子模块。那么名称将是 'children[4][key].sub_module'。但如果子模块可以通过不同的路径访问,则只返回其中一条路径。

源代码位于 dspy/primitives/module.py
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
    """Find all sub-modules in the module, as well as their names.

    Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
    'children[4][key].sub_module'. But if the sub-module is accessible at different
    paths, only one of the paths will be returned.
    """
    if type_ is None:
        type_ = BaseModule

    queue = deque([("self", self)])
    seen = {id(self)}

    def add_to_queue(name, item):
        name = postprocess_parameter_name(name, item)

        if id(item) not in seen:
            seen.add(id(item))
            queue.append((name, item))

    while queue:
        name, item = queue.popleft()

        if isinstance(item, type_):
            yield name, item

        if isinstance(item, BaseModule):
            if skip_compiled and getattr(item, "_compiled", False):
                continue
            for sub_name, sub_item in item.__dict__.items():
                add_to_queue(f"{name}.{sub_name}", sub_item)

        elif isinstance(item, (list, tuple)):
            for i, sub_item in enumerate(item):
                add_to_queue(f"{name}[{i}]", sub_item)

        elif isinstance(item, dict):
            for key, sub_item in item.items():
                add_to_queue(f"{name}[{key}]", sub_item)

parameters()

源代码位于 dspy/primitives/module.py
def parameters(self):
    return [param for _, param in self.named_parameters()]

predictors()

源代码位于 dspy/primitives/program.py
def predictors(self):
    return [param for _, param in self.named_predictors()]

reset_copy()

深度复制模块并重置所有参数。

源代码位于 dspy/primitives/module.py
def reset_copy(self):
    """Deep copy the module and reset all parameters."""
    new_instance = self.deepcopy()

    for param in new_instance.parameters():
        param.reset()

    return new_instance

save(path, save_program=False)

保存模块。

将模块保存到目录或文件。有两种模式:
- save_program=False:仅将模块的状态保存到 json 或 pickle 文件,具体取决于文件扩展名。
- save_program=True:通过 cloudpickle 将整个模块保存到目录,其中包含模型的状态和架构。

我们还保存了依赖版本,以便加载的模型可以检查关键依赖项或 DSPy 版本是否存在版本不匹配。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,当 save_program=False 时应为 .json 或 .pkl 文件,当 save_program=True 时应为目录。

必需
save_program bool

如果为 True,通过 cloudpickle 将整个模块保存到目录,否则仅保存状态。

False
源代码位于 dspy/primitives/module.py
def save(self, path, save_program=False):
    """Save the module.

    Save the module to a directory or a file. There are two modes:
    - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of
        the file extension.
    - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
        architecture of the model.

    We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
    critical dependencies or DSPy version.

    Args:
        path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`,
            and a directory when `save_program=True`.
        save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save
            the state.
    """
    metadata = {}
    metadata["dependency_versions"] = get_dependency_versions()
    path = Path(path)

    if save_program:
        if path.suffix:
            raise ValueError(
                f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}"
            )
        if path.exists() and not path.is_dir():
            raise NotADirectoryError(f"The path '{path}' exists but is not a directory.")

        if not path.exists():
            # Create the directory (and any parent directories)
            path.mkdir(parents=True)

        try:
            with open(path / "program.pkl", "wb") as f:
                cloudpickle.dump(self, f)
        except Exception as e:
            raise RuntimeError(
                f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
                "or consider using state-only saving by setting `save_program=False`."
            )
        with open(path / "metadata.json", "w") as f:
            ujson.dump(metadata, f, indent=2)

        return

    state = self.dump_state()
    state["metadata"] = metadata
    if path.suffix == ".json":
        try:
            with open(path, "w") as f:
                f.write(ujson.dumps(state, indent=2))
        except Exception as e:
            raise RuntimeError(
                f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
                "json-serializable objects, please consider saving the state in .pkl by using `path` ending "
                "with `.pkl`, or saving the whole program by setting `save_program=True`."
            )
    elif path.suffix == ".pkl":
        with open(path, "wb") as f:
            cloudpickle.dump(state, f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

set_lm(lm)

源代码位于 dspy/primitives/program.py
def set_lm(self, lm):
    for _, param in self.named_predictors():
        param.lm = lm

truncate_trajectory(trajectory)

截断轨迹,使其适应上下文窗口。

用户可以覆盖此方法以实现自己的截断逻辑。

源代码位于 dspy/predict/react.py
def truncate_trajectory(self, trajectory):
    """Truncates the trajectory so that it fits in the context window.

    Users can override this method to implement their own truncation logic.
    """
    keys = list(trajectory.keys())
    if len(keys) < 4:
        # Every tool call has 4 keys: thought, tool_name, tool_args, and observation.
        raise ValueError(
            "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be "
            "truncated because it only has one tool call."
        )

    for key in keys[:4]:
        trajectory.pop(key)

    return trajectory