跳到内容

dspy.ChainOfThought

dspy.ChainOfThought(signature: Type[Signature], rationale_field: Optional[Union[OutputField, FieldInfo]] = None, rationale_field_type: Type = str, **config)

基类: Module

一个通过逐步推理来预测任务输出的模块。

参数

名称 类型 描述 默认值
signature Type[Signature]

模块的签名。

必需
rationale_field Optional[Union[OutputField, FieldInfo]]

包含推理内容的字段。

None
rationale_field_type 类型

推理字段的类型。

str
**config

模块的配置。

{}
源代码位于 dspy/predict/chain_of_thought.py
def __init__(
    self,
    signature: Type[Signature],
    rationale_field: Optional[Union[OutputField, FieldInfo]] = None,
    rationale_field_type: Type = str,
    **config,
):
    """
    A module that reasons step by step in order to predict the output of a task.

    Args:
        signature (Type[dspy.Signature]): The signature of the module.
        rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning.
        rationale_field_type (Type): The type of the rationale field.
        **config: The configuration for the module.
    """
    super().__init__()
    signature = ensure_signature(signature)
    prefix = "Reasoning: Let's think step by step in order to"
    desc = "${reasoning}"
    rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
    rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
    extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
    self.predict = dspy.Predict(extended_signature, **config)

函数

__call__(*args, **kwargs)

源代码位于 dspy/primitives/program.py
@with_callbacks
def __call__(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = self.forward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return self.forward(*args, **kwargs)

acall(*args, **kwargs) 异步

源代码位于 dspy/primitives/program.py
@with_callbacks
async def acall(self, *args, **kwargs):
    if settings.track_usage and settings.usage_tracker is None:
        with track_usage() as usage_tracker:
            output = await self.aforward(*args, **kwargs)
            output.set_lm_usage(usage_tracker.get_total_tokens())
            return output

    return await self.aforward(*args, **kwargs)

aforward(**kwargs) 异步

源代码位于 dspy/predict/chain_of_thought.py
async def aforward(self, **kwargs):
    return await self.predict.acall(**kwargs)

batch(examples, num_threads: Optional[int] = None, max_errors: int = 10, return_failed_examples: bool = False, provide_traceback: Optional[bool] = None, disable_progress_bar: bool = False)

使用 Parallel 模块并行处理 dspy.Example 实例列表。

:param examples: 要处理的 dspy.Example 实例列表。 :param num_threads: 用于并行处理的线程数。 :param max_errors: 在停止执行前允许的最大错误数。 :param return_failed_examples: 是否返回失败的实例和异常。 :param provide_traceback: 是否在错误日志中包含回溯信息。 :return: 结果列表,以及可选的失败实例和异常。

源代码位于 dspy/primitives/program.py
def batch(
    self,
    examples,
    num_threads: Optional[int] = None,
    max_errors: int = 10,
    return_failed_examples: bool = False,
    provide_traceback: Optional[bool] = None,
    disable_progress_bar: bool = False,
):
    """
    Processes a list of dspy.Example instances in parallel using the Parallel module.

    :param examples: List of dspy.Example instances to process.
    :param num_threads: Number of threads to use for parallel processing.
    :param max_errors: Maximum number of errors allowed before stopping execution.
    :param return_failed_examples: Whether to return failed examples and exceptions.
    :param provide_traceback: Whether to include traceback information in error logs.
    :return: List of results, and optionally failed examples and exceptions.
    """
    # Create a list of execution pairs (self, example)
    exec_pairs = [(self, example.inputs()) for example in examples]

    # Create an instance of Parallel
    parallel_executor = Parallel(
        num_threads=num_threads,
        max_errors=max_errors,
        return_failed_examples=return_failed_examples,
        provide_traceback=provide_traceback,
        disable_progress_bar=disable_progress_bar,
    )

    # Execute the forward method of Parallel
    if return_failed_examples:
        results, failed_examples, exceptions = parallel_executor.forward(exec_pairs)
        return results, failed_examples, exceptions
    else:
        results = parallel_executor.forward(exec_pairs)
        return results

deepcopy()

深度复制模块。

这是对默认 Python deepcopy 的一个调整,它只深度复制 self.parameters(),对于其他属性,我们只进行浅复制。

源代码位于 dspy/primitives/module.py
def deepcopy(self):
    """Deep copy the module.

    This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
    attributes, we just do the shallow copy.
    """
    try:
        # If the instance itself is copyable, we can just deep copy it.
        # Otherwise we will have to create a new instance and copy over the attributes one by one.
        return copy.deepcopy(self)
    except Exception:
        pass

    # Create an empty instance.
    new_instance = self.__class__.__new__(self.__class__)
    # Set attribuetes of the copied instance.
    for attr, value in self.__dict__.items():
        if isinstance(value, BaseModule):
            setattr(new_instance, attr, value.deepcopy())
        else:
            try:
                # Try to deep copy the attribute
                setattr(new_instance, attr, copy.deepcopy(value))
            except Exception:
                logging.warning(
                    f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, "
                    "falling back to shallow copy or reference copy."
                )
                try:
                    # Fallback to shallow copy if deep copy fails
                    setattr(new_instance, attr, copy.copy(value))
                except Exception:
                    # If even the shallow copy fails, we just copy over the reference.
                    setattr(new_instance, attr, value)

    return new_instance

dump_state()

源代码位于 dspy/primitives/module.py
def dump_state(self):
    return {name: param.dump_state() for name, param in self.named_parameters()}

forward(**kwargs)

源代码位于 dspy/predict/chain_of_thought.py
def forward(self, **kwargs):
    return self.predict(**kwargs)

get_lm()

源代码位于 dspy/primitives/program.py
def get_lm(self):
    all_used_lms = [param.lm for _, param in self.named_predictors()]

    if len(set(all_used_lms)) == 1:
        return all_used_lms[0]

    raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.")

load(path)

加载已保存的模块。如果想加载整个程序而不仅仅是现有程序的状态,也可以查看 dspy.load。

参数

名称 类型 描述 默认值
path str

已保存状态文件的路径,应为 .json 或 .pkl 文件

必需
源代码位于 dspy/primitives/module.py
def load(self, path):
    """Load the saved module. You may also want to check out dspy.load, if you want to
    load an entire program, not just the state for an existing program.

    Args:
        path (str): Path to the saved state file, which should be a .json or a .pkl file
    """
    path = Path(path)

    if path.suffix == ".json":
        with open(path) as f:
            state = ujson.loads(f.read())
    elif path.suffix == ".pkl":
        with open(path, "rb") as f:
            state = cloudpickle.load(f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}")

    dependency_versions = get_dependency_versions()
    saved_dependency_versions = state["metadata"]["dependency_versions"]
    for key, saved_version in saved_dependency_versions.items():
        if dependency_versions[key] != saved_version:
            logger.warning(
                f"There is a mismatch of {key} version between saved model and current environment. "
                f"You saved with `{key}=={saved_version}`, but now you have "
                f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade "
                "on the loaded model, please consider loading the model in the same environment as the "
                "saving environment."
            )
    self.load_state(state)

load_state(state)

源代码位于 dspy/primitives/module.py
def load_state(self, state):
    for name, param in self.named_parameters():
        param.load_state(state[name])

map_named_predictors(func)

对所有命名的预测器应用一个函数。

源代码位于 dspy/primitives/program.py
def map_named_predictors(self, func):
    """Applies a function to all named predictors."""
    for name, predictor in self.named_predictors():
        set_attribute_by_name(self, name, func(predictor))
    return self

named_parameters()

与 PyTorch 不同,它也处理(非递归的)参数列表。

源代码位于 dspy/primitives/module.py
def named_parameters(self):
    """
    Unlike PyTorch, handles (non-recursive) lists of parameters too.
    """

    import dspy
    from dspy.predict.parameter import Parameter

    visited = set()
    named_parameters = []

    def add_parameter(param_name, param_value):
        if isinstance(param_value, Parameter):
            if id(param_value) not in visited:
                visited.add(id(param_value))
                param_name = postprocess_parameter_name(param_name, param_value)
                named_parameters.append((param_name, param_value))

        elif isinstance(param_value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(param_value, "_compiled", False):
                for sub_name, param in param_value.named_parameters():
                    add_parameter(f"{param_name}.{sub_name}", param)

    if isinstance(self, Parameter):
        add_parameter("self", self)

    for name, value in self.__dict__.items():
        if isinstance(value, Parameter):
            add_parameter(name, value)

        elif isinstance(value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(value, "_compiled", False):
                for sub_name, param in value.named_parameters():
                    add_parameter(f"{name}.{sub_name}", param)

        elif isinstance(value, (list, tuple)):
            for idx, item in enumerate(value):
                add_parameter(f"{name}[{idx}]", item)

        elif isinstance(value, dict):
            for key, item in value.items():
                add_parameter(f"{name}['{key}']", item)

    return named_parameters

named_predictors()

源代码位于 dspy/primitives/program.py
def named_predictors(self):
    from dspy.predict.predict import Predict

    return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)]

named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]

查找模块中的所有子模块及其名称。

例如 self.children[4]['key'].sub_module 是一个子模块。那么名称将是 'children[4][key].sub_module'。但如果子模块可以通过不同的路径访问,则只返回其中一条路径。

源代码位于 dspy/primitives/module.py
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
    """Find all sub-modules in the module, as well as their names.

    Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
    'children[4][key].sub_module'. But if the sub-module is accessible at different
    paths, only one of the paths will be returned.
    """
    if type_ is None:
        type_ = BaseModule

    queue = deque([("self", self)])
    seen = {id(self)}

    def add_to_queue(name, item):
        name = postprocess_parameter_name(name, item)

        if id(item) not in seen:
            seen.add(id(item))
            queue.append((name, item))

    while queue:
        name, item = queue.popleft()

        if isinstance(item, type_):
            yield name, item

        if isinstance(item, BaseModule):
            if skip_compiled and getattr(item, "_compiled", False):
                continue
            for sub_name, sub_item in item.__dict__.items():
                add_to_queue(f"{name}.{sub_name}", sub_item)

        elif isinstance(item, (list, tuple)):
            for i, sub_item in enumerate(item):
                add_to_queue(f"{name}[{i}]", sub_item)

        elif isinstance(item, dict):
            for key, sub_item in item.items():
                add_to_queue(f"{name}[{key}]", sub_item)

parameters()

源代码位于 dspy/primitives/module.py
def parameters(self):
    return [param for _, param in self.named_parameters()]

predictors()

源代码位于 dspy/primitives/program.py
def predictors(self):
    return [param for _, param in self.named_predictors()]

reset_copy()

深度复制模块并重置所有参数。

源代码位于 dspy/primitives/module.py
def reset_copy(self):
    """Deep copy the module and reset all parameters."""
    new_instance = self.deepcopy()

    for param in new_instance.parameters():
        param.reset()

    return new_instance

save(path, save_program=False)

保存模块。

将模块保存到目录或文件。有两种模式: - save_program=False: 只将模块的状态保存到 json 或 pickle 文件,具体取决于文件扩展名。 - save_program=True: 通过 cloudpickle 将整个模块保存到目录,其中包含模型的状态和架构。

我们还保存了依赖版本,以便加载的模型可以检查是否存在关键依赖项或 DSPy 版本的版本不匹配。

参数

名称 类型 描述 默认值
path str

已保存状态文件的路径,当 save_program=False 时应为 .json 或 .pkl 文件,当 save_program=True 时应为目录。

必需
save_program bool

如果为 True,则通过 cloudpickle 将整个模块保存到目录,否则只保存状态。

False
源代码位于 dspy/primitives/module.py
def save(self, path, save_program=False):
    """Save the module.

    Save the module to a directory or a file. There are two modes:
    - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of
        the file extension.
    - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
        architecture of the model.

    We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
    critical dependencies or DSPy version.

    Args:
        path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`,
            and a directory when `save_program=True`.
        save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save
            the state.
    """
    metadata = {}
    metadata["dependency_versions"] = get_dependency_versions()
    path = Path(path)

    if save_program:
        if path.suffix:
            raise ValueError(
                f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}"
            )
        if path.exists() and not path.is_dir():
            raise NotADirectoryError(f"The path '{path}' exists but is not a directory.")

        if not path.exists():
            # Create the directory (and any parent directories)
            path.mkdir(parents=True)

        try:
            with open(path / "program.pkl", "wb") as f:
                cloudpickle.dump(self, f)
        except Exception as e:
            raise RuntimeError(
                f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
                "or consider using state-only saving by setting `save_program=False`."
            )
        with open(path / "metadata.json", "w") as f:
            ujson.dump(metadata, f, indent=2)

        return

    state = self.dump_state()
    state["metadata"] = metadata
    if path.suffix == ".json":
        try:
            with open(path, "w") as f:
                f.write(ujson.dumps(state, indent=2))
        except Exception as e:
            raise RuntimeError(
                f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
                "json-serializable objects, please consider saving the state in .pkl by using `path` ending "
                "with `.pkl`, or saving the whole program by setting `save_program=True`."
            )
    elif path.suffix == ".pkl":
        with open(path, "wb") as f:
            cloudpickle.dump(state, f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

set_lm(lm)

源代码位于 dspy/primitives/program.py
def set_lm(self, lm):
    for _, param in self.named_predictors():
        param.lm = lm