跳到内容

dspy.Predict

dspy.Predict(signature, callbacks=None, **config)

基础类: Module, Parameter

源代码位于 dspy/predict/predict.py
def __init__(self, signature, callbacks=None, **config):
    self.stage = random.randbytes(8).hex()
    self.signature = ensure_signature(signature)
    self.config = config
    self.callbacks = callbacks or []
    self.reset()

函数

__call__(*args, **kwargs)

源代码位于 dspy/predict/predict.py
def __call__(self, *args, **kwargs):
    if args:
        raise ValueError(self._get_positional_args_error_message())

    return super().__call__(**kwargs)

acall(*args, **kwargs) 异步

源代码位于 dspy/predict/predict.py
async def acall(self, *args, **kwargs):
    if args:
        raise ValueError(self._get_positional_args_error_message())

    return await super().acall(**kwargs)

aforward(**kwargs) 异步

源代码位于 dspy/predict/predict.py
async def aforward(self, **kwargs):
    lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs)

    adapter = settings.adapter or ChatAdapter()
    if self._should_stream():
        with settings.context(caller_predict=self):
            completions = await adapter.acall(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)
    else:
        with settings.context(send_stream=None):
            completions = await adapter.acall(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)

    return self._forward_postprocess(completions, signature, **kwargs)

batch(examples, num_threads: Optional[int] = None, max_errors: int = 10, return_failed_examples: bool = False, provide_traceback: Optional[bool] = None, disable_progress_bar: bool = False)

使用 Parallel 模块并行处理 dspy.Example 实例列表。

:param examples: 要处理的 dspy.Example 实例列表。 :param num_threads: 用于并行处理的线程数。 :param max_errors: 在停止执行前允许的最大错误数。 :param return_failed_examples: 是否返回失败的例子和异常。 :param provide_traceback: 是否在错误日志中包含追溯信息。 :return: 结果列表,以及可选的失败例子和异常。

源代码位于 dspy/primitives/program.py
def batch(
    self,
    examples,
    num_threads: Optional[int] = None,
    max_errors: int = 10,
    return_failed_examples: bool = False,
    provide_traceback: Optional[bool] = None,
    disable_progress_bar: bool = False,
):
    """
    Processes a list of dspy.Example instances in parallel using the Parallel module.

    :param examples: List of dspy.Example instances to process.
    :param num_threads: Number of threads to use for parallel processing.
    :param max_errors: Maximum number of errors allowed before stopping execution.
    :param return_failed_examples: Whether to return failed examples and exceptions.
    :param provide_traceback: Whether to include traceback information in error logs.
    :return: List of results, and optionally failed examples and exceptions.
    """
    # Create a list of execution pairs (self, example)
    exec_pairs = [(self, example.inputs()) for example in examples]

    # Create an instance of Parallel
    parallel_executor = Parallel(
        num_threads=num_threads,
        max_errors=max_errors,
        return_failed_examples=return_failed_examples,
        provide_traceback=provide_traceback,
        disable_progress_bar=disable_progress_bar,
    )

    # Execute the forward method of Parallel
    if return_failed_examples:
        results, failed_examples, exceptions = parallel_executor.forward(exec_pairs)
        return results, failed_examples, exceptions
    else:
        results = parallel_executor.forward(exec_pairs)
        return results

deepcopy()

深度复制该模块。

这是对默认 Python 深度复制的调整,它只深度复制 self.parameters(),对于其他属性,我们只进行浅复制。

源代码位于 dspy/primitives/module.py
def deepcopy(self):
    """Deep copy the module.

    This is a tweak to the default python deepcopy that only deep copies `self.parameters()`, and for other
    attributes, we just do the shallow copy.
    """
    try:
        # If the instance itself is copyable, we can just deep copy it.
        # Otherwise we will have to create a new instance and copy over the attributes one by one.
        return copy.deepcopy(self)
    except Exception:
        pass

    # Create an empty instance.
    new_instance = self.__class__.__new__(self.__class__)
    # Set attribuetes of the copied instance.
    for attr, value in self.__dict__.items():
        if isinstance(value, BaseModule):
            setattr(new_instance, attr, value.deepcopy())
        else:
            try:
                # Try to deep copy the attribute
                setattr(new_instance, attr, copy.deepcopy(value))
            except Exception:
                logging.warning(
                    f"Failed to deep copy attribute '{attr}' of {self.__class__.__name__}, "
                    "falling back to shallow copy or reference copy."
                )
                try:
                    # Fallback to shallow copy if deep copy fails
                    setattr(new_instance, attr, copy.copy(value))
                except Exception:
                    # If even the shallow copy fails, we just copy over the reference.
                    setattr(new_instance, attr, value)

    return new_instance

dump_state()

源代码位于 dspy/predict/predict.py
def dump_state(self):
    state_keys = ["traces", "train"]
    state = {k: getattr(self, k) for k in state_keys}

    state["demos"] = []
    for demo in self.demos:
        demo = demo.copy()

        for field in demo:
            # FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
            demo[field] = serialize_object(demo[field])

        state["demos"].append(demo)

    state["signature"] = self.signature.dump_state()
    state["lm"] = self.lm.dump_state() if self.lm else None
    return state

forward(**kwargs)

源代码位于 dspy/predict/predict.py
def forward(self, **kwargs):
    lm, config, signature, demos, kwargs = self._forward_preprocess(**kwargs)

    adapter = settings.adapter or ChatAdapter()

    if self._should_stream():
        with settings.context(caller_predict=self):
            completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)
    else:
        with settings.context(send_stream=None):
            completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)

    return self._forward_postprocess(completions, signature, **kwargs)

get_config()

源代码位于 dspy/predict/predict.py
def get_config(self):
    return self.config

get_lm()

源代码位于 dspy/primitives/program.py
def get_lm(self):
    all_used_lms = [param.lm for _, param in self.named_predictors()]

    if len(set(all_used_lms)) == 1:
        return all_used_lms[0]

    raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.")

load(path)

加载已保存的模块。如果您想加载整个程序,而不仅仅是现有程序的状态,您可能还会查看 dspy.load。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,应为 .json 或 .pkl 文件

必需
源代码位于 dspy/primitives/module.py
def load(self, path):
    """Load the saved module. You may also want to check out dspy.load, if you want to
    load an entire program, not just the state for an existing program.

    Args:
        path (str): Path to the saved state file, which should be a .json or a .pkl file
    """
    path = Path(path)

    if path.suffix == ".json":
        with open(path) as f:
            state = ujson.loads(f.read())
    elif path.suffix == ".pkl":
        with open(path, "rb") as f:
            state = cloudpickle.load(f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}")

    dependency_versions = get_dependency_versions()
    saved_dependency_versions = state["metadata"]["dependency_versions"]
    for key, saved_version in saved_dependency_versions.items():
        if dependency_versions[key] != saved_version:
            logger.warning(
                f"There is a mismatch of {key} version between saved model and current environment. "
                f"You saved with `{key}=={saved_version}`, but now you have "
                f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade "
                "on the loaded model, please consider loading the model in the same environment as the "
                "saving environment."
            )
    self.load_state(state)

load_state(state)

加载 Predict 对象的保存状态。

参数

名称 类型 描述 默认值
state dict

Predict 对象的保存状态。

必需

返回值

名称 类型 描述
self

返回 self 以允许方法链式调用

源代码位于 dspy/predict/predict.py
def load_state(self, state):
    """Load the saved state of a `Predict` object.

    Args:
        state (dict): The saved state of a `Predict` object.

    Returns:
        self: Returns self to allow method chaining
    """
    excluded_keys = ["signature", "extended_signature", "lm"]
    for name, value in state.items():
        # `excluded_keys` are fields that go through special handling.
        if name not in excluded_keys:
            setattr(self, name, value)

    self.signature = self.signature.load_state(state["signature"])
    self.lm = LM(**state["lm"]) if state["lm"] else None

    if "extended_signature" in state:  # legacy, up to and including 2.5, for CoT.
        raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+")

    return self

map_named_predictors(func)

将一个函数应用于所有命名的 Predictor。

源代码位于 dspy/primitives/program.py
def map_named_predictors(self, func):
    """Applies a function to all named predictors."""
    for name, predictor in self.named_predictors():
        set_attribute_by_name(self, name, func(predictor))
    return self

named_parameters()

与 PyTorch 不同,它也处理(非递归的)参数列表。

源代码位于 dspy/primitives/module.py
def named_parameters(self):
    """
    Unlike PyTorch, handles (non-recursive) lists of parameters too.
    """

    import dspy
    from dspy.predict.parameter import Parameter

    visited = set()
    named_parameters = []

    def add_parameter(param_name, param_value):
        if isinstance(param_value, Parameter):
            if id(param_value) not in visited:
                visited.add(id(param_value))
                param_name = postprocess_parameter_name(param_name, param_value)
                named_parameters.append((param_name, param_value))

        elif isinstance(param_value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(param_value, "_compiled", False):
                for sub_name, param in param_value.named_parameters():
                    add_parameter(f"{param_name}.{sub_name}", param)

    if isinstance(self, Parameter):
        add_parameter("self", self)

    for name, value in self.__dict__.items():
        if isinstance(value, Parameter):
            add_parameter(name, value)

        elif isinstance(value, dspy.Module):
            # When a sub-module is pre-compiled, keep it frozen.
            if not getattr(value, "_compiled", False):
                for sub_name, param in value.named_parameters():
                    add_parameter(f"{name}.{sub_name}", param)

        elif isinstance(value, (list, tuple)):
            for idx, item in enumerate(value):
                add_parameter(f"{name}[{idx}]", item)

        elif isinstance(value, dict):
            for key, item in value.items():
                add_parameter(f"{name}['{key}']", item)

    return named_parameters

named_predictors()

源代码位于 dspy/primitives/program.py
def named_predictors(self):
    from dspy.predict.predict import Predict

    return [(name, param) for name, param in self.named_parameters() if isinstance(param, Predict)]

named_sub_modules(type_=None, skip_compiled=False) -> Generator[tuple[str, BaseModule], None, None]

查找模块中的所有子模块及其名称。

例如,如果 self.children[4]['key'].sub_module 是一个子模块,则名称将是 'children[4][key].sub_module'。但如果子模块可以通过不同的路径访问,则只会返回其中一个路径。

源代码位于 dspy/primitives/module.py
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
    """Find all sub-modules in the module, as well as their names.

    Say self.children[4]['key'].sub_module is a sub-module. Then the name will be
    'children[4][key].sub_module'. But if the sub-module is accessible at different
    paths, only one of the paths will be returned.
    """
    if type_ is None:
        type_ = BaseModule

    queue = deque([("self", self)])
    seen = {id(self)}

    def add_to_queue(name, item):
        name = postprocess_parameter_name(name, item)

        if id(item) not in seen:
            seen.add(id(item))
            queue.append((name, item))

    while queue:
        name, item = queue.popleft()

        if isinstance(item, type_):
            yield name, item

        if isinstance(item, BaseModule):
            if skip_compiled and getattr(item, "_compiled", False):
                continue
            for sub_name, sub_item in item.__dict__.items():
                add_to_queue(f"{name}.{sub_name}", sub_item)

        elif isinstance(item, (list, tuple)):
            for i, sub_item in enumerate(item):
                add_to_queue(f"{name}[{i}]", sub_item)

        elif isinstance(item, dict):
            for key, sub_item in item.items():
                add_to_queue(f"{name}[{key}]", sub_item)

parameters()

源代码位于 dspy/primitives/module.py
def parameters(self):
    return [param for _, param in self.named_parameters()]

predictors()

源代码位于 dspy/primitives/program.py
def predictors(self):
    return [param for _, param in self.named_predictors()]

reset()

源代码位于 dspy/predict/predict.py
def reset(self):
    self.lm = None
    self.traces = []
    self.train = []
    self.demos = []

reset_copy()

深度复制该模块并重置所有参数。

源代码位于 dspy/primitives/module.py
def reset_copy(self):
    """Deep copy the module and reset all parameters."""
    new_instance = self.deepcopy()

    for param in new_instance.parameters():
        param.reset()

    return new_instance

save(path, save_program=False)

保存模块。

将模块保存到目录或文件。有两种模式:- save_program=False:仅将模块状态保存到 json 或 pickle 文件中,取决于文件扩展名。- save_program=True:通过 cloudpickle 将整个模块保存到目录中,其中包含模型的状态和架构。

我们还保存依赖版本,以便加载的模型可以检查关键依赖或 DSPy 版本是否存在不匹配。

参数

名称 类型 描述 默认值
path str

保存状态文件的路径,当 save_program=False 时应为 .json 或 .pkl 文件,当 save_program=True 时应为目录。

必需
save_program bool

如果为 True,则通过 cloudpickle 将整个模块保存到目录中,否则仅保存状态。

False
源代码位于 dspy/primitives/module.py
def save(self, path, save_program=False):
    """Save the module.

    Save the module to a directory or a file. There are two modes:
    - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of
        the file extension.
    - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and
        architecture of the model.

    We also save the dependency versions, so that the loaded model can check if there is a version mismatch on
    critical dependencies or DSPy version.

    Args:
        path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`,
            and a directory when `save_program=True`.
        save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save
            the state.
    """
    metadata = {}
    metadata["dependency_versions"] = get_dependency_versions()
    path = Path(path)

    if save_program:
        if path.suffix:
            raise ValueError(
                f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}"
            )
        if path.exists() and not path.is_dir():
            raise NotADirectoryError(f"The path '{path}' exists but is not a directory.")

        if not path.exists():
            # Create the directory (and any parent directories)
            path.mkdir(parents=True)

        try:
            with open(path / "program.pkl", "wb") as f:
                cloudpickle.dump(self, f)
        except Exception as e:
            raise RuntimeError(
                f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, "
                "or consider using state-only saving by setting `save_program=False`."
            )
        with open(path / "metadata.json", "w") as f:
            ujson.dump(metadata, f, indent=2)

        return

    state = self.dump_state()
    state["metadata"] = metadata
    if path.suffix == ".json":
        try:
            with open(path, "w") as f:
                f.write(ujson.dumps(state, indent=2))
        except Exception as e:
            raise RuntimeError(
                f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non "
                "json-serializable objects, please consider saving the state in .pkl by using `path` ending "
                "with `.pkl`, or saving the whole program by setting `save_program=True`."
            )
    elif path.suffix == ".pkl":
        with open(path, "wb") as f:
            cloudpickle.dump(state, f)
    else:
        raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

set_lm(lm)

源代码位于 dspy/primitives/program.py
def set_lm(self, lm):
    for _, param in self.named_predictors():
        param.lm = lm

update_config(**kwargs)

源代码位于 dspy/predict/predict.py
def update_config(self, **kwargs):
    self.config = {**self.config, **kwargs}