跳到内容

dspy.streaming.StreamListener

dspy.streaming.StreamListener(signature_field_name: str, predict: Any = None, predict_name: Optional[str] = None)

监听流以捕获预测器特定输出字段流式传输的类。

参数

名称 类型 描述 默认值
signature_field_name str

要监听的字段名称。

必需
predict Any

要监听的预测器。如果为 None,在调用 streamify() 时,它将自动查找其签名中包含 signature_field_name 的预测器。

None
predict_name Optional[str]

要监听的预测器的名称。如果为 None,在调用 streamify() 时,它将自动查找其签名中包含 signature_field_name 的预测器。

None
源代码位于 dspy/streaming/streaming_listener.py
def __init__(self, signature_field_name: str, predict: Any = None, predict_name: Optional[str] = None):
    """
    Args:
        signature_field_name: The name of the field to listen to.
        predict: The predictor to listen to. If None, when calling `streamify()` it will automatically look for
            the predictor that has the `signature_field_name` in its signature.
        predict_name: The name of the predictor to listen to. If None, when calling `streamify()` it will
            automatically look for the predictor that has the `signature_field_name` in its signature.
    """
    self.signature_field_name = signature_field_name
    self.predict = predict
    self.predict_name = predict_name

    self.field_start_queue = []
    self.field_end_queue = Queue()
    self.stream_start = False
    self.stream_end = False
    self.cache_hit = False

    self.json_adapter_start_identifier = f'"{self.signature_field_name}":'
    self.json_adapter_end_identifier = re.compile(r"\w*\"(,|\s*})")

    self.chat_adapter_start_identifier = f"[[ ## {self.signature_field_name} ## ]]"
    self.chat_adapter_end_identifier = re.compile(r"\[\[ ## (\w+) ## \]\]")

函数

flush() -> str

刷新字段结束队列中的所有 tokens。

当流结束时,调用此方法来 flush 最后几个 tokens。这些 tokens 保留在缓冲区中,因为我们不直接 yielding 流监听器接收到的 tokens,目的是不 yield end_identifier tokens,例如 ChatAdapter 的 "[[ ## ... ## ]]"。

源代码位于 dspy/streaming/streaming_listener.py
def flush(self) -> str:
    """Flush all tokens in the field end queue.

    This method is called to flush out the last a few tokens when the stream is ended. These tokens
    are in the buffer because we don't directly yield the tokens received by the stream listener
    with the purpose to not yield the end_identifier tokens, e.g., "[[ ## ... ## ]]" for ChatAdapter.
    """
    last_tokens = "".join(self.field_end_queue.queue)
    self.field_end_queue = Queue()
    if isinstance(settings.adapter, JSONAdapter):
        match = re.search(r'",|"\s*}', last_tokens)
        if match:
            boundary_index = match.start()
        else:
            boundary_index = len(last_tokens)
        return last_tokens[:boundary_index]
    elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
        boundary_index = last_tokens.find("[[")
        return last_tokens[:boundary_index]
    else:
        raise ValueError(
            f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or "
            "JSONAdapter for streaming purposes."
        )

receive(chunk: ModelResponseStream)

源代码位于 dspy/streaming/streaming_listener.py
def receive(self, chunk: ModelResponseStream):
    if isinstance(settings.adapter, JSONAdapter):
        start_identifier = self.json_adapter_start_identifier
        end_identifier = self.json_adapter_end_identifier

        start_indicator = "{"
    elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
        start_identifier = self.chat_adapter_start_identifier
        end_identifier = self.chat_adapter_end_identifier

        start_indicator = "["
    else:
        raise ValueError(
            f"Unsupported adapter for streaming: {settings.adapter}, please use either ChatAdapter or "
            "JSONAdapter for streaming purposes."
        )

    if self.stream_end:
        return

    try:
        chunk_message = chunk.choices[0].delta.content
        if chunk_message is None:
            return
    except Exception:
        return

    if chunk_message and start_identifier in chunk_message:
        # If the cache is hit, the chunk_message could be the full response. When it happens we can
        # directly end the stream listening. In some models like gemini, each stream chunk can be multiple
        # tokens, so it's posible that response only has one chunk, we also fall back to this logic.
        message_after_start_identifier = chunk_message[
            chunk_message.find(start_identifier) + len(start_identifier) :
        ]
        if re.search(end_identifier, message_after_start_identifier):
            self.cache_hit = True
            self.stream_start = True
            self.stream_end = True
            return

    if len(self.field_start_queue) == 0 and not self.stream_start and start_indicator in chunk_message:
        # We look for the pattern of start_identifier, i.e., "[[ ## {self.signature_field_name} ## ]]" for
        # ChatAdapter to identify the start of the stream of our target field. Once the start_indicator, i.e., "[["
        # for ChatAdapter, is found, we start checking the next tokens
        self.field_start_queue.append(chunk_message)
        return

    if len(self.field_start_queue) > 0 and not self.stream_start:
        # We keep appending the tokens to the queue until we have a full identifier or the concanated
        # tokens no longer match our expected identifier.
        self.field_start_queue.append(chunk_message)
        concat_message = "".join(self.field_start_queue).strip()

        if start_identifier in concat_message:
            # We have a full identifier, we can start the stream.
            self.stream_start = True
            self.field_start_queue = []
            # Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
            value_start_index = concat_message.find(start_identifier) + len(start_identifier)
            chunk_message = concat_message[value_start_index:].lstrip()
            if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
                # For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
                # because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
                chunk_message = chunk_message[1:]

        elif self._buffered_message_end_with_start_identifier(concat_message, start_identifier):
            # If the buffered message ends with part of the start_identifier, we can start the stream.
            return
        else:
            # Doesn't match the expected identifier, reset the queue.
            self.field_start_queue = []
            return

    if self.stream_start:
        # The stream is started, we keep returning the token until we see the start of the next field.
        token = None
        self.field_end_queue.put(chunk_message)
        if self.field_end_queue.qsize() > 10:
            # We keep the last 10 tokens in the buffer to check if they form a valid identifier for end_identifier,
            # i.e., "[[ ## {next_field_name} ## ]]" for ChatAdapter to identify the end of the current field.
            # In most cases 10 tokens are enough to cover the end_identifier for all adapters.
            token = self.field_end_queue.get()
        concat_message = "".join(self.field_end_queue.queue).strip()
        if re.search(end_identifier, concat_message):
            # The next field is identified, we can end the stream and flush out all tokens in the buffer.
            self.stream_end = True
            last_token = self.flush()
            token = token + last_token if token else last_token
            token = token.rstrip()  # Remove the trailing \n\n

        if token:
            return StreamResponse(self.predict_name, self.signature_field_name, token)