Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 150 additions & 16 deletions autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
),
]

# Endpoints that return bool based on HTTP status: 404 -> False, 2xx -> True, else raise.
# Each entry is (METHOD, openapi_path_template).
BOOLEAN_STATUS_ROUTES: list[tuple[str, str]] = [
("GET", "/virtual_tag_configs/async/{request_id}"),
]


@dataclass
class Parameter:
Expand Down Expand Up @@ -61,6 +67,7 @@ class Endpoint:
is_multipart: bool = False
response_handler: str | None = None # internal client method to call, if not the default
response_handler_return_type: str | None = None
boolean_status: bool = False # 404->False, 2xx->True, else raise VantageAPIError


@dataclass
Expand Down Expand Up @@ -196,10 +203,16 @@ def preprocess_inline_models(schemas: dict[str, Any]) -> None:
existing_names.add(model_name)


def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> str:
def openapi_type_to_python(
schema: dict[str, Any],
schemas: dict[str, Any],
name_map: dict[str, str] | None = None,
) -> str:
"""Convert OpenAPI type to Python type hint."""
if "$ref" in schema:
ref_name = schema["$ref"].split("/")[-1]
if name_map and ref_name in name_map:
return name_map[ref_name]
return to_pascal_case(ref_name)

schema_type = schema.get("type", "any")
Expand All @@ -216,12 +229,12 @@ def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> s
return "bool"
elif schema_type == "array":
items = schema.get("items", {})
item_type = openapi_type_to_python(items, schemas)
item_type = openapi_type_to_python(items, schemas, name_map)
return f"List[{item_type}]"
elif schema_type == "object":
additional = schema.get("additionalProperties")
if additional:
value_type = openapi_type_to_python(additional, schemas)
value_type = openapi_type_to_python(additional, schemas, name_map)
return f"Dict[str, {value_type}]"
# Check if inline properties match an existing named schema
inline_props = schema.get("properties")
Expand All @@ -230,14 +243,18 @@ def openapi_type_to_python(schema: dict[str, Any], schemas: dict[str, Any]) -> s
for schema_name, schema_def in schemas.items():
defined_keys = sorted(schema_def.get("properties", {}).keys())
if defined_keys and inline_keys == defined_keys:
if name_map and schema_name in name_map:
return name_map[schema_name]
return to_pascal_case(schema_name)
return "Dict[str, Any]"
else:
return "Any"


def extract_request_body_type(
request_body: dict[str, Any] | None, schemas: dict[str, Any]
request_body: dict[str, Any] | None,
schemas: dict[str, Any],
name_map: dict[str, str] | None = None,
) -> tuple[str | None, bool]:
"""Extract request body type and whether it's multipart."""
if not request_body:
Expand All @@ -252,13 +269,15 @@ def extract_request_body_type(
# Check for JSON
if "application/json" in content:
schema = content["application/json"].get("schema", {})
return openapi_type_to_python(schema, schemas), False
return openapi_type_to_python(schema, schemas, name_map), False

return None, False


def extract_response_type(
responses: dict[str, Any], schemas: dict[str, Any]
responses: dict[str, Any],
schemas: dict[str, Any],
name_map: dict[str, str] | None = None,
) -> str | None:
"""Extract successful response type."""
for code in ["200", "201", "202", "203"]:
Expand All @@ -268,15 +287,59 @@ def extract_response_type(
content = response.get("content", {})
if "application/json" in content:
schema = content["application/json"].get("schema", {})
return openapi_type_to_python(schema, schemas)
return openapi_type_to_python(schema, schemas, name_map)
return None


def find_request_body_schemas(schema: dict[str, Any]) -> set[str]:
"""Return the set of schema names referenced as request bodies in any endpoint."""
result = set()
paths = schema.get("paths", {})
for path_item in paths.values():
for method, spec in path_item.items():
if method in ("parameters", "servers", "summary", "description"):
continue
request_body = spec.get("requestBody", {})
content = request_body.get("content", {})
for media_type in content.values():
ref_schema = media_type.get("schema", {})
if "$ref" in ref_schema:
name = ref_schema["$ref"].split("/")[-1]
result.add(name)
return result


def build_class_name_map(schemas: dict[str, Any], request_body_schemas: set[str]) -> dict[str, str]:
"""Build a mapping from raw schema names to Python class names, resolving conflicts.
If two schema names map to the same PascalCase name, the one used as a
request body is suffixed with 'Request'.
"""
initial = {name: to_pascal_case(name) for name in schemas}

by_class_name: dict[str, list[str]] = {}
for raw_name, class_name in initial.items():
by_class_name.setdefault(class_name, []).append(raw_name)

result: dict[str, str] = {}
for class_name, raw_names in by_class_name.items():
if len(raw_names) == 1:
result[raw_names[0]] = class_name
else:
for raw_name in raw_names:
if raw_name in request_body_schemas:
result[raw_name] = class_name + "Request"
else:
result[raw_name] = class_name
return result


def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
"""Parse all endpoints from OpenAPI schema."""
endpoints = []
paths = schema.get("paths", {})
schemas = schema.get("components", {}).get("schemas", {})
name_map = build_class_name_map(schemas, find_request_body_schemas(schema))

for path, methods in paths.items():
for method, spec in methods.items():
Expand All @@ -288,7 +351,7 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
parameters = []
for param in spec.get("parameters", []):
param_schema = param.get("schema", {})
param_type = openapi_type_to_python(param_schema, schemas)
param_type = openapi_type_to_python(param_schema, schemas, name_map)
parameters.append(
Parameter(
name=param["name"],
Expand All @@ -301,9 +364,9 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
)

request_body = spec.get("requestBody")
body_type, is_multipart = extract_request_body_type(request_body, schemas)
body_type, is_multipart = extract_request_body_type(request_body, schemas, name_map)

response_type = extract_response_type(spec.get("responses", {}), schemas)
response_type = extract_response_type(spec.get("responses", {}), schemas, name_map)

description = spec.get("description")

Expand All @@ -319,6 +382,10 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
if response_handler:
break

boolean_status = (method.upper(), path) in {
(m.upper(), p) for m, p in BOOLEAN_STATUS_ROUTES
}

endpoints.append(
Endpoint(
path=path,
Expand All @@ -336,6 +403,7 @@ def parse_endpoints(schema: dict[str, Any]) -> list[Endpoint]:
is_multipart=is_multipart,
response_handler=response_handler,
response_handler_return_type=response_handler_return_type,
boolean_status=boolean_status,
)
)

Expand Down Expand Up @@ -469,6 +537,7 @@ def _append_response_mapping(lines: list[str], return_type: str, data_var: str)
def generate_pydantic_models(schema: dict[str, Any]) -> str:
"""Generate Pydantic models from OpenAPI schemas."""
schemas = schema.get("components", {}).get("schemas", {})
name_map = build_class_name_map(schemas, find_request_body_schemas(schema))
lines = [
'"""Auto-generated Pydantic models from OpenAPI schema."""',
"",
Expand All @@ -482,7 +551,7 @@ def generate_pydantic_models(schema: dict[str, Any]) -> str:
]

for name, spec in schemas.items():
class_name = to_pascal_case(name)
class_name = name_map.get(name, to_pascal_case(name))
description = spec.get("description", "")

lines.append(f"class {class_name}(BaseModel):")
Expand All @@ -508,7 +577,7 @@ def generate_pydantic_models(schema: dict[str, Any]) -> str:
python_name = python_name + "_"
needs_alias = True

prop_type = openapi_type_to_python(prop_spec, schemas)
prop_type = openapi_type_to_python(prop_spec, schemas, name_map)

# Handle nullable
if prop_spec.get("x-nullable") or prop_spec.get("nullable"):
Expand Down Expand Up @@ -564,6 +633,21 @@ def _collect_handler_routes(resources: dict[str, Resource]) -> dict[str, list[tu
return handler_routes


def _collect_boolean_status_prefixes(resources: dict[str, Resource]) -> list[tuple[str, str]]:
"""Collect (method, path_prefix) pairs for boolean-status endpoints.
The prefix is derived by taking everything before the first path parameter
so it can be matched with str.startswith() at runtime.
"""
result = []
for resource in resources.values():
for endpoint in resource.endpoints:
if endpoint.boolean_status:
prefix = endpoint.path.split("{")[0]
result.append((endpoint.method, prefix))
return result


def generate_sync_client(resources: dict[str, Resource]) -> str:
"""Generate synchronous client code."""
lines = [
Expand Down Expand Up @@ -652,6 +736,29 @@ def generate_sync_client(resources: dict[str, Resource]) -> str:
" json=body,",
" )",
"",
]
)

# Inject boolean-status path checks (before the generic error check)
boolean_prefixes = _collect_boolean_status_prefixes(resources)
for method, prefix in boolean_prefixes:
lines.extend([
f' if method.upper() == "{method}" and path.startswith("{prefix}"):',
" if response.status_code == 404:",
" return False",
" elif response.is_success:",
" return True",
" else:",
" raise VantageAPIError(",
" status=response.status_code,",
" status_text=response.reason_phrase,",
" body=response.text,",
" )",
"",
])

lines.extend(
[
" if not response.is_success:",
" raise VantageAPIError(",
" status=response.status_code,",
Expand Down Expand Up @@ -757,7 +864,9 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:

# Method signature
param_str = ", ".join(["self"] + params) if params else "self"
if endpoint.response_handler:
if endpoint.boolean_status:
return_type = "bool"
elif endpoint.response_handler:
return_type = endpoint.response_handler_return_type or "Any"
else:
return_type = endpoint.response_type or "None"
Expand Down Expand Up @@ -801,7 +910,7 @@ def generate_sync_method(endpoint: Endpoint, method_name: str) -> list[str]:
lines.append(" body_data = None")

# Make request and coerce response payload into typed models where possible
if endpoint.response_handler:
if endpoint.boolean_status or endpoint.response_handler:
lines.append(
f' return self._client.request("{endpoint.method}", path, params=params, body=body_data)'
)
Expand Down Expand Up @@ -907,6 +1016,29 @@ def generate_async_client(resources: dict[str, Resource]) -> str:
" json=body,",
" )",
"",
]
)

# Inject boolean-status path checks (before the generic error check)
boolean_prefixes = _collect_boolean_status_prefixes(resources)
for method, prefix in boolean_prefixes:
lines.extend([
f' if method.upper() == "{method}" and path.startswith("{prefix}"):',
" if response.status_code == 404:",
" return False",
" elif response.is_success:",
" return True",
" else:",
" raise VantageAPIError(",
" status=response.status_code,",
" status_text=response.reason_phrase,",
" body=response.text,",
" )",
"",
])

lines.extend(
[
" if not response.is_success:",
" raise VantageAPIError(",
" status=response.status_code,",
Expand Down Expand Up @@ -1012,7 +1144,9 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:

# Method signature
param_str = ", ".join(["self"] + params) if params else "self"
if endpoint.response_handler:
if endpoint.boolean_status:
return_type = "bool"
elif endpoint.response_handler:
return_type = endpoint.response_handler_return_type or "Any"
else:
return_type = endpoint.response_type or "None"
Expand Down Expand Up @@ -1056,7 +1190,7 @@ def generate_async_method(endpoint: Endpoint, method_name: str) -> list[str]:
lines.append(" body_data = None")

# Make request and coerce response payload into typed models where possible
if endpoint.response_handler:
if endpoint.boolean_status or endpoint.response_handler:
lines.append(
f' return await self._client.request("{endpoint.method}", path, params=params, body=body_data)'
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "vantage-python"
version = "0.3.2"
version = "0.3.3"
description = "Python SDK for the Vantage API"
readme = "README.md"
license = "MIT"
Expand Down
Loading