Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions src/openai/lib/streaming/responses/_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEven
events: List[ResponseStreamEvent[TextFormatT]] = []

if event.type == "response.output_text.delta":
if event.output_index >= len(snapshot.output):
return events
output = snapshot.output[event.output_index]
assert output.type == "message"

if output.type != "message":
return events
if event.content_index >= len(output.content):
return events
content = output.content[event.content_index]
assert content.type == "output_text"
if content.type != "output_text":
return events

events.append(
build(
Expand All @@ -270,11 +275,16 @@ def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEven
)
)
elif event.type == "response.output_text.done":
if event.output_index >= len(snapshot.output):
return events
output = snapshot.output[event.output_index]
assert output.type == "message"

if output.type != "message":
return events
if event.content_index >= len(output.content):
return events
content = output.content[event.content_index]
assert content.type == "output_text"
if content.type != "output_text":
return events

events.append(
build(
Expand All @@ -290,8 +300,11 @@ def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEven
)
)
elif event.type == "response.function_call_arguments.delta":
if event.output_index >= len(snapshot.output):
return events
output = snapshot.output[event.output_index]
assert output.type == "function_call"
if output.type != "function_call":
return events

events.append(
build(
Expand Down Expand Up @@ -341,18 +354,26 @@ def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnaps
else:
snapshot.output.append(event.item)
elif event.type == "response.content_part.added":
if event.output_index >= len(snapshot.output):
return snapshot
output = snapshot.output[event.output_index]
if output.type == "message":
output.content.append(
construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
)
elif event.type == "response.output_text.delta":
if event.output_index >= len(snapshot.output):
return snapshot
output = snapshot.output[event.output_index]
if output.type == "message":
if event.content_index >= len(output.content):
return snapshot
content = output.content[event.content_index]
assert content.type == "output_text"
content.text += event.delta
elif event.type == "response.function_call_arguments.delta":
if event.output_index >= len(snapshot.output):
return snapshot
output = snapshot.output[event.output_index]
if output.type == "function_call":
output.arguments += event.delta
Expand Down
Empty file.
Empty file.
47 changes: 47 additions & 0 deletions tests/lib/streaming/responses/test_responses_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from openai._types import omit
from openai.lib.streaming.responses._responses import ResponseStreamState
from openai.types.responses import Response
from openai.types.responses.response_created_event import ResponseCreatedEvent
from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent


def _minimal_response_created_event() -> ResponseCreatedEvent:
response = Response.model_construct(
id="resp-test",
created_at=0.0,
model="gpt-4o",
object="response",
output=[],
parallel_tool_calls=False,
tool_choice="auto",
tools=[],
)
return ResponseCreatedEvent(
response=response,
sequence_number=0,
type="response.created",
)


def _delta_event_before_output_item_added() -> ResponseTextDeltaEvent:
return ResponseTextDeltaEvent(
content_index=0,
delta="x",
item_id="item-1",
logprobs=[],
output_index=0,
sequence_number=1,
type="response.output_text.delta",
)


def test_responses_stream_accumulate_handles_out_of_range_output_index() -> None:
state = ResponseStreamState(input_tools=omit, text_format=omit)

state.handle_event(_minimal_response_created_event())

events = state.handle_event(_delta_event_before_output_item_added())

assert isinstance(events, list)