Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def parse_chat_completion(

choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
type_=ParsedChoice[ResponseFormatT],
value={
**choice.to_dict(),
"message": {
Expand All @@ -153,15 +153,12 @@ def parse_chat_completion(
)
)

return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
return construct_type_unchecked(
type_=ParsedChatCompletion[ResponseFormatT],
value={
**chat_completion.to_dict(),
"choices": choices,
},
)


Expand Down
24 changes: 10 additions & 14 deletions src/openai/lib/_parsing/_responses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, List, Iterable, cast
from typing import TYPE_CHECKING, List, Iterable, cast
from typing_extensions import TypeVar, assert_never

import pydantic
Expand All @@ -12,7 +12,7 @@
from ..._compat import PYDANTIC_V1, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from ._completions import solve_response_format_t, type_to_response_format_param
from ._completions import type_to_response_format_param
from ...types.responses import (
Response,
ToolParam,
Expand Down Expand Up @@ -56,7 +56,6 @@ def parse_response(
input_tools: Iterable[ToolParam] | Omit | None,
response: Response | ParsedResponse[object],
) -> ParsedResponse[TextFormatT]:
solved_t = solve_response_format_t(text_format)
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []

for output in response.output:
Expand All @@ -69,7 +68,7 @@ def parse_response(

content_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputText)[solved_t],
type_=ParsedResponseOutputText[TextFormatT],
value={
**item.to_dict(),
"parsed": parse_text(item.text, text_format=text_format),
Expand All @@ -79,7 +78,7 @@ def parse_response(

output_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
type_=ParsedResponseOutputMessage[TextFormatT],
value={
**output.to_dict(),
"content": content_list,
Expand Down Expand Up @@ -118,15 +117,12 @@ def parse_response(
else:
output_list.append(output)

return cast(
ParsedResponse[TextFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedResponse)[solved_t],
value={
**response.to_dict(),
"output": output_list,
},
),
return construct_type_unchecked(
type_=ParsedResponse[TextFormatT],
value={
**response.to_dict(),
"output": output_list,
},
)


Expand Down
11 changes: 2 additions & 9 deletions src/openai/lib/streaming/chat/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing import TYPE_CHECKING, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing_extensions import Self, Iterator, assert_never

from jiter import from_json
Expand Down Expand Up @@ -33,7 +33,6 @@
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
solve_response_format_t,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
Expand Down Expand Up @@ -658,13 +657,7 @@ def _content_done_events(

events_to_fire.append(
build(
# we do this dance so that when the `ContentDoneEvent` instance
# is printed at runtime the class name will include the solved
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
),
ContentDoneEvent[ResponseFormatT],
type="content.done",
content=choice_snapshot.message.content,
parsed=parsed,
Expand Down
72 changes: 36 additions & 36 deletions tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pyte

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[NoneType](
ParsedChatCompletion(
choices=[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
Expand Down Expand Up @@ -120,13 +120,13 @@ class Location(BaseModel):

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down Expand Up @@ -191,13 +191,13 @@ class Location(BaseModel):

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down Expand Up @@ -266,11 +266,11 @@ class ColorDetection(BaseModel):

assert print_obj(completion.choices[0], monkeypatch) == snapshot(
"""\
ParsedChoice[ColorDetection](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[ColorDetection](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"color":"red","hex_color_code":"#FF0000"}',
Expand Down Expand Up @@ -317,11 +317,11 @@ class Location(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":64,"units":"f"}',
Expand All @@ -332,11 +332,11 @@ class Location(BaseModel):
tool_calls=None
)
),
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=1,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand All @@ -347,11 +347,11 @@ class Location(BaseModel):
tool_calls=None
)
),
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=2,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":63.0,"units":"f"}',
Expand Down Expand Up @@ -397,13 +397,13 @@ class CalendarEvent:

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[CalendarEvent](
ParsedChatCompletion(
choices=[
ParsedChoice[CalendarEvent](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[CalendarEvent](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
Expand Down Expand Up @@ -462,11 +462,11 @@ def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, m

assert print_obj(completion.choices[0], monkeypatch) == snapshot(
"""\
ParsedChoice[Query](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Query](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -576,11 +576,11 @@ class Location(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -627,11 +627,11 @@ class GetWeatherArgs(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -701,11 +701,11 @@ class GetStockPrice(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -784,11 +784,11 @@ def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch:
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -866,13 +866,13 @@ class Location(BaseModel):
assert isinstance(message.parsed.city, str)
assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":58,"units":"f"}',
Expand Down Expand Up @@ -943,13 +943,13 @@ class Location(BaseModel):
assert isinstance(message.parsed.city, str)
assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down
Loading