Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/openai/lib/_parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
validate_input_tools as validate_input_tools,
parse_chat_completion as parse_chat_completion,
get_input_tool_by_name as get_input_tool_by_name,
solve_response_format_t as solve_response_format_t,
parse_function_tool_arguments as parse_function_tool_arguments,
type_to_response_format_param as type_to_response_format_param,
)
31 changes: 7 additions & 24 deletions src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def parse_chat_completion(

choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
type_=ParsedChoice[ResponseFormatT],
value={
**choice.to_dict(),
"message": {
Expand All @@ -153,15 +153,12 @@ def parse_chat_completion(
)
)

return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
return construct_type_unchecked(
type_=ParsedChatCompletion[ResponseFormatT],
value={
**chat_completion.to_dict(),
"choices": choices,
},
)


Expand Down Expand Up @@ -201,20 +198,6 @@ def maybe_parse_content(
return None


def solve_response_format_t(
response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
) -> type[ResponseFormatT]:
"""Return the runtime type for the given response format.

If no response format is given, or if we won't auto-parse the response format
then we default to `None`.
"""
if has_rich_response_format(response_format):
return response_format

return cast("type[ResponseFormatT]", _default_response_format)


def has_parseable_input(
*,
response_format: type | ResponseFormatParam | Omit,
Expand Down
24 changes: 10 additions & 14 deletions src/openai/lib/_parsing/_responses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, List, Iterable, cast
from typing import TYPE_CHECKING, List, Iterable, cast
from typing_extensions import TypeVar, assert_never

import pydantic
Expand All @@ -12,7 +12,7 @@
from ..._compat import PYDANTIC_V1, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from ._completions import solve_response_format_t, type_to_response_format_param
from ._completions import type_to_response_format_param
from ...types.responses import (
Response,
ToolParam,
Expand Down Expand Up @@ -56,7 +56,6 @@ def parse_response(
input_tools: Iterable[ToolParam] | Omit | None,
response: Response | ParsedResponse[object],
) -> ParsedResponse[TextFormatT]:
solved_t = solve_response_format_t(text_format)
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []

for output in response.output:
Expand All @@ -69,7 +68,7 @@ def parse_response(

content_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputText)[solved_t],
type_=ParsedResponseOutputText[TextFormatT],
value={
**item.to_dict(),
"parsed": parse_text(item.text, text_format=text_format),
Expand All @@ -79,7 +78,7 @@ def parse_response(

output_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
type_=ParsedResponseOutputMessage[TextFormatT],
value={
**output.to_dict(),
"content": content_list,
Expand Down Expand Up @@ -123,15 +122,12 @@ def parse_response(
else:
output_list.append(output)

return cast(
ParsedResponse[TextFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedResponse)[solved_t],
value={
**response.to_dict(),
"output": output_list,
},
),
return construct_type_unchecked(
type_=ParsedResponse[TextFormatT],
value={
**response.to_dict(),
"output": output_list,
},
)


Expand Down
3 changes: 1 addition & 2 deletions src/openai/lib/streaming/chat/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
solve_response_format_t,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
Expand Down Expand Up @@ -663,7 +662,7 @@ def _content_done_events(
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
cast(Any, ContentDoneEvent),
),
type="content.done",
content=choice_snapshot.message.content,
Expand Down
72 changes: 36 additions & 36 deletions tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pyte

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[NoneType](
ParsedChatCompletion(
choices=[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
Expand Down Expand Up @@ -120,13 +120,13 @@ class Location(BaseModel):

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down Expand Up @@ -191,13 +191,13 @@ class Location(BaseModel):

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down Expand Up @@ -266,11 +266,11 @@ class ColorDetection(BaseModel):

assert print_obj(completion.choices[0], monkeypatch) == snapshot(
"""\
ParsedChoice[ColorDetection](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[ColorDetection](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"color":"red","hex_color_code":"#FF0000"}',
Expand Down Expand Up @@ -317,11 +317,11 @@ class Location(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":64,"units":"f"}',
Expand All @@ -332,11 +332,11 @@ class Location(BaseModel):
tool_calls=None
)
),
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=1,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand All @@ -347,11 +347,11 @@ class Location(BaseModel):
tool_calls=None
)
),
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=2,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":63.0,"units":"f"}',
Expand Down Expand Up @@ -397,13 +397,13 @@ class CalendarEvent:

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[CalendarEvent](
ParsedChatCompletion(
choices=[
ParsedChoice[CalendarEvent](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[CalendarEvent](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
Expand Down Expand Up @@ -462,11 +462,11 @@ def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, m

assert print_obj(completion.choices[0], monkeypatch) == snapshot(
"""\
ParsedChoice[Query](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Query](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -576,11 +576,11 @@ class Location(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -627,11 +627,11 @@ class GetWeatherArgs(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -701,11 +701,11 @@ class GetStockPrice(BaseModel):
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -784,11 +784,11 @@ def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch:
assert print_obj(completion.choices, monkeypatch) == snapshot(
"""\
[
ParsedChoice[NoneType](
ParsedChoice(
finish_reason='tool_calls',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[NoneType](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content=None,
Expand Down Expand Up @@ -866,13 +866,13 @@ class Location(BaseModel):
assert isinstance(message.parsed.city, str)
assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":58,"units":"f"}',
Expand Down Expand Up @@ -943,13 +943,13 @@ class Location(BaseModel):
assert isinstance(message.parsed.city, str)
assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[Location](
ParsedChatCompletion(
choices=[
ParsedChoice[Location](
ParsedChoice(
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[Location](
message=ParsedChatCompletionMessage(
annotations=None,
audio=None,
content='{"city":"San Francisco","temperature":65,"units":"f"}',
Expand Down
Loading
Loading