diff --git a/doc/changelog.rst b/doc/changelog.rst index a3981e6..099323e 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,9 +4,19 @@ Changelog [0.x.x] - Unreleased -------------------- -Changed +Added ^^^^^ -- Introduce a Path object to handle paths. :issue:`111` +- Resources define their schema URN with a ``__schema__`` classvar instead of a ``schemas`` default value. :issue:`110` +- Validation that the base schema is present in ``schemas`` during SCIM context validation. +- Validation that extension schemas are known during SCIM context validation. + +Changed +^^^^^^^ +- Introduce a :class:`~scim2_models.Path` object to handle paths. :issue:`111` + +Deprecated +^^^^^^^^^^ +- Defining ``schemas`` with a default value is deprecated. Use ``__schema__ = URN("...")`` instead. [0.5.2] - 2026-01-22 -------------------- diff --git a/doc/tutorial.rst b/doc/tutorial.rst index 4c883e7..aa32eb6 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -305,8 +305,8 @@ Use :class:`~scim2_models.ComplexAttribute` as base class for complex attributes .. code-block:: python - >>> from typing import Annotated, Optional, List - >>> from scim2_models import Resource, Returned, Mutability, ComplexAttribute + >>> from typing import Annotated, Optional + >>> from scim2_models import Resource, Returned, Mutability, ComplexAttribute, URN >>> from enum import Enum >>> class PetType(ComplexAttribute): @@ -317,7 +317,7 @@ Use :class:`~scim2_models.ComplexAttribute` as base class for complex attributes ... """The pet color.""" >>> class Pet(Resource): - ... schemas: List[str] = ["example:schemas:Pet"] + ... __schema__ = URN("urn:example:schemas:Pet") ... ... name: Annotated[Optional[str], Mutability.immutable, Returned.always] ... """The name of the pet.""" @@ -351,10 +351,12 @@ This is useful for server implementations, so custom models or models provided b .. code-block:: python + >>> from scim2_models import Resource, URN + >>> class MyCustomResource(Resource): ... """My awesome custom schema.""" ... - ... schemas: List[str] = ["example:schemas:MyCustomResource"] + ... __schema__ = URN("urn:example:schemas:MyCustomResource") ... ... foobar: Optional[str] ... @@ -362,7 +364,7 @@ This is useful for server implementations, so custom models or models provided b >>> dump = schema.model_dump() >>> assert dump == { ... "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], - ... "id": "example:schemas:MyCustomResource", + ... "id": "urn:example:schemas:MyCustomResource", ... "name": "MyCustomResource", ... "description": "My awesome custom schema.", ... "attributes": [ diff --git a/scim2_models/base.py b/scim2_models/base.py index 479ebe7..621fd80 100644 --- a/scim2_models/base.py +++ b/scim2_models/base.py @@ -406,7 +406,9 @@ def _set_complex_attribute_urns(self) -> None: main_schema = self._attribute_urn separator = "." else: - main_schema = self.__class__.model_fields["schemas"].default[0] + main_schema = getattr(self.__class__, "__schema__", None) + if main_schema is None: + return separator = ":" for field_name in self.__class__.model_fields: @@ -540,13 +542,12 @@ def get_attribute_urn(self, field_name: str) -> str: """ from scim2_models.resources.resource import Extension - main_schema = self.__class__.model_fields["schemas"].default[0] + main_schema = getattr(self.__class__, "__schema__", None) field = self.__class__.model_fields[field_name] alias = field.serialization_alias or field_name field_type = self.get_field_root_type(field_name) - full_urn = ( - alias - if isclass(field_type) and issubclass(field_type, Extension) - else f"{main_schema}:{alias}" - ) - return full_urn + if isclass(field_type) and issubclass(field_type, Extension): + return alias + if main_schema is None: + return alias + return f"{main_schema}:{alias}" diff --git a/scim2_models/messages/bulk.py b/scim2_models/messages/bulk.py index cf68a86..7c2f19f 100644 --- a/scim2_models/messages/bulk.py +++ b/scim2_models/messages/bulk.py @@ -5,8 +5,8 @@ from pydantic import Field from pydantic import PlainSerializer -from ..annotations import Required from ..attributes import ComplexAttribute +from ..path import URN from ..utils import _int_to_str from .message import Message @@ -53,9 +53,7 @@ class BulkRequest(Message): The models for Bulk operations are defined, but their behavior is not implemented nor tested yet. """ - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:BulkRequest" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:BulkRequest") fail_on_errors: int | None = None """An integer specifying the number of errors that the service provider @@ -76,9 +74,7 @@ class BulkResponse(Message): The models for Bulk operations are defined, but their behavior is not implemented nor tested yet. """ - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:BulkResponse" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:BulkResponse") operations: list[BulkOperation] | None = Field( None, serialization_alias="Operations" diff --git a/scim2_models/messages/error.py b/scim2_models/messages/error.py index 7ed465e..5ad13dd 100644 --- a/scim2_models/messages/error.py +++ b/scim2_models/messages/error.py @@ -2,7 +2,7 @@ from pydantic import PlainSerializer -from ..annotations import Required +from ..path import URN from ..utils import _int_to_str from .message import Message @@ -10,9 +10,7 @@ class Error(Message): """Representation of SCIM API errors.""" - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:Error" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:Error") status: Annotated[int | None, PlainSerializer(_int_to_str)] = None """The HTTP status code (see Section 6 of [RFC7231]) expressed as a JSON diff --git a/scim2_models/messages/list_response.py b/scim2_models/messages/list_response.py index 402773e..41ca070 100644 --- a/scim2_models/messages/list_response.py +++ b/scim2_models/messages/list_response.py @@ -1,4 +1,3 @@ -from typing import Annotated from typing import Any from typing import Generic @@ -9,17 +8,15 @@ from pydantic_core import PydanticCustomError from typing_extensions import Self -from ..annotations import Required from ..context import Context +from ..path import URN from ..resources.resource import AnyResource from .message import Message from .message import _GenericMessageMetaclass class ListResponse(Message, Generic[AnyResource], metaclass=_GenericMessageMetaclass): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:ListResponse" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:ListResponse") total_results: int | None = None """The total number of results returned by the list or query operation.""" diff --git a/scim2_models/messages/message.py b/scim2_models/messages/message.py index f22f274..fee05ae 100644 --- a/scim2_models/messages/message.py +++ b/scim2_models/messages/message.py @@ -7,11 +7,11 @@ from pydantic import Discriminator from pydantic import Tag -from pydantic._internal._model_construction import ModelMetaclass from scim2_models.resources.resource import Resource from ..base import BaseModel +from ..scim_object import ScimMetaclass from ..scim_object import ScimObject from ..utils import UNION_TYPES @@ -56,7 +56,7 @@ def _get_tag(resource_type: type[BaseModel]) -> Tag: :param resource_type: SCIM resource type :return: Pydantic Tag for discrimination """ - return Tag(resource_type.model_fields["schemas"].default[0]) + return Tag(getattr(resource_type, "__schema__", None) or "") def _create_tagged_resource_union(resource_union: Any) -> Any: @@ -75,7 +75,7 @@ def _create_tagged_resource_union(resource_union: Any) -> Any: # Set up schemas for the discriminator function resource_types_schemas = [ - resource_type.model_fields["schemas"].default[0] + getattr(resource_type, "__schema__", None) or "" for resource_type in resource_types ] @@ -92,7 +92,7 @@ def _create_tagged_resource_union(resource_union: Any) -> Any: return Annotated[union, discriminator] -class _GenericMessageMetaclass(ModelMetaclass): +class _GenericMessageMetaclass(ScimMetaclass): """Metaclass for SCIM generic types with discriminated unions.""" def __new__( diff --git a/scim2_models/messages/patch_op.py b/scim2_models/messages/patch_op.py index 8c7aea0..fe67c0f 100644 --- a/scim2_models/messages/patch_op.py +++ b/scim2_models/messages/patch_op.py @@ -15,6 +15,7 @@ from ..annotations import Required from ..attributes import ComplexAttribute from ..context import Context +from ..path import URN from ..path import InvalidPathError from ..path import Path from ..path import PathNotFoundError @@ -199,9 +200,7 @@ def __class_getitem__( return super().__class_getitem__(typevar_values) - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:PatchOp" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:PatchOp") operations: Annotated[list[PatchOperation[ResourceT]] | None, Required.true] = ( Field(None, serialization_alias="Operations", min_length=1) diff --git a/scim2_models/messages/search_request.py b/scim2_models/messages/search_request.py index 1e5707c..5a717d0 100644 --- a/scim2_models/messages/search_request.py +++ b/scim2_models/messages/search_request.py @@ -1,11 +1,10 @@ from enum import Enum -from typing import Annotated from typing import Any from pydantic import field_validator from pydantic import model_validator -from ..annotations import Required +from ..path import URN from ..path import Path from .message import Message @@ -16,9 +15,7 @@ class SearchRequest(Message): https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.3 """ - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:api:messages:2.0:SearchRequest" - ] + __schema__ = URN("urn:ietf:params:scim:api:messages:2.0:SearchRequest") attributes: list[Path[Any]] | None = None """A multi-valued list of strings indicating the names of resource diff --git a/scim2_models/path.py b/scim2_models/path.py index 800f409..6198ae5 100644 --- a/scim2_models/path.py +++ b/scim2_models/path.py @@ -68,7 +68,13 @@ class _Resolution(NamedTuple): is_explicit_schema_path: bool = False -class URN(UserString): +class URN(str): + """URN string type with validation.""" + + def __new__(cls, urn: str) -> "URN": + cls.check_syntax(urn) + return super().__new__(cls, urn) + @classmethod def __get_pydantic_core_schema__( cls, @@ -83,10 +89,6 @@ def __get_pydantic_core_schema__( ), ) - def __init__(self, urn: str): - self.check_syntax(urn) - self.data = urn - @classmethod def check_syntax(cls, path: str) -> None: """Validate URN-based path format. @@ -330,7 +332,7 @@ def urn(self) -> str | None: schema = self.schema if not schema and issubclass(self.__scim_model__, Resource): - schema = self.__scim_model__.model_fields["schemas"].default[0] + schema = self.__scim_model__.__schema__ if not self.attr: return schema if schema else None @@ -348,13 +350,12 @@ def _resolve_model(self) -> tuple[type[BaseModel], str | None] | None: attr_path = self.attr if ":" in self and isclass(model) and issubclass(model, Resource | Extension): - model_schema = model.model_fields["schemas"].default[0] path_lower = str(self).lower() - if path_lower == model_schema.lower(): + if model.__schema__ and path_lower == model.__schema__.lower(): return model, None - elif path_lower.startswith(model_schema.lower()): - attr_path = str(self)[len(model_schema) :].lstrip(":") + elif model.__schema__ and path_lower.startswith(model.__schema__.lower()): + attr_path = str(self)[len(model.__schema__) :].lstrip(":") elif issubclass(model, Resource): for ( extension_schema, @@ -414,7 +415,7 @@ def _resolve_instance( if ":" not in path_str: return _Resolution(resource, path_str) - model_schema = type(resource).model_fields["schemas"].default[0] + model_schema = getattr(type(resource), "__schema__", "") or "" path_lower = path_str.lower() if isinstance(resource, Resource | Extension) and path_lower.startswith( @@ -708,10 +709,11 @@ def iter_model_paths( field_type = target_model.get_field_root_type(field_name) + urn: str if isclass(field_type) and issubclass(field_type, Extension): if not include_extensions: continue - urn = field_type.model_fields["schemas"].default[0] + urn = field_type.__schema__ or "" elif isclass(target_model) and issubclass(target_model, Extension): urn = target_model().get_attribute_urn(field_name) else: diff --git a/scim2_models/resources/enterprise_user.py b/scim2_models/resources/enterprise_user.py index 9b47afa..acca4fd 100644 --- a/scim2_models/resources/enterprise_user.py +++ b/scim2_models/resources/enterprise_user.py @@ -6,6 +6,7 @@ from ..annotations import Mutability from ..annotations import Required from ..attributes import ComplexAttribute +from ..path import URN from ..reference import Reference from .resource import Extension @@ -25,9 +26,7 @@ class Manager(ComplexAttribute): class EnterpriseUser(Extension): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:extension:enterprise:2.0:User") employee_number: str | None = None """Numeric or alphanumeric identifier assigned to a person, typically based diff --git a/scim2_models/resources/group.py b/scim2_models/resources/group.py index aef5cc9..116f4be 100644 --- a/scim2_models/resources/group.py +++ b/scim2_models/resources/group.py @@ -6,8 +6,8 @@ from pydantic import Field from ..annotations import Mutability -from ..annotations import Required from ..attributes import ComplexAttribute +from ..path import URN from ..reference import Reference from .resource import Resource @@ -32,9 +32,7 @@ class GroupMember(ComplexAttribute): class Group(Resource[Any]): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:core:2.0:Group" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:core:2.0:Group") display_name: str | None = None """A human-readable name for the Group.""" diff --git a/scim2_models/resources/resource.py b/scim2_models/resources/resource.py index 7eb8ccb..e2ec152 100644 --- a/scim2_models/resources/resource.py +++ b/scim2_models/resources/resource.py @@ -11,8 +11,13 @@ from pydantic import Field from pydantic import SerializationInfo from pydantic import SerializerFunctionWrapHandler +from pydantic import ValidationInfo +from pydantic import ValidatorFunctionWrapHandler from pydantic import WrapSerializer from pydantic import field_serializer +from pydantic import model_validator +from pydantic_core import PydanticCustomError +from typing_extensions import Self from ..annotations import CaseExact from ..annotations import Mutability @@ -178,7 +183,7 @@ def __class_getitem__(cls, item: Any) -> type["Resource[Any]"]: class_attrs = {"__scim_extension_metadata__": valid_extensions} for extension in valid_extensions: - schema = extension.model_fields["schemas"].default[0] + schema = extension.__schema__ class_attrs[extension.__name__] = Field( default=None, # type: ignore[arg-type] serialization_alias=schema, @@ -220,7 +225,7 @@ def __getitem__(self, item: Any) -> Any: user["name.familyName"] # Get nested attribute """ if isinstance(item, type) and issubclass(item, Extension): - item = item.model_fields["schemas"].default[0] + item = item.__schema__ bound_path = Path.__class_getitem__(type(self)) path = item if isinstance(item, Path) else bound_path(str(item)) @@ -243,7 +248,7 @@ def __setitem__(self, item: Any, value: Any) -> None: user["name.familyName"] = "Doe" """ if isinstance(item, type) and issubclass(item, Extension): - item = item.model_fields["schemas"].default[0] + item = item.__schema__ bound_path = Path.__class_getitem__(type(self)) path = item if isinstance(item, Path) else bound_path(str(item)) @@ -264,7 +269,7 @@ def __delitem__(self, item: Any) -> None: del user["displayName"] # Remove attribute """ if isinstance(item, type) and issubclass(item, Extension): - item = item.model_fields["schemas"].default[0] + item = item.__schema__ bound_path = Path.__class_getitem__(type(self)) path = item if isinstance(item, Path) else bound_path(str(item)) @@ -277,8 +282,8 @@ def __delitem__(self, item: Any) -> None: def get_extension_models(cls) -> dict[str, type[Extension]]: """Return extension a dict associating extension models with their schemas.""" extension_models = getattr(cls, "__scim_extension_metadata__", []) - by_schema = { - ext.model_fields["schemas"].default[0]: ext for ext in extension_models + by_schema: dict[str, type[Extension]] = { + ext.__schema__: ext for ext in extension_models } return by_schema @@ -300,7 +305,7 @@ def get_by_schema( ) -> type["Resource[Any]"] | type["Extension"] | None: """Given a resource type list and a schema, find the matching resource type.""" by_schema: dict[str, type[Resource[Any]] | type[Extension]] = { - resource_type.model_fields["schemas"].default[0].lower(): resource_type + getattr(resource_type, "__schema__", "").lower(): resource_type for resource_type in (resource_types or []) } if with_extensions: @@ -338,6 +343,35 @@ def set_extension_schemas( ] return schemas + @model_validator(mode="wrap") + @classmethod + def _validate_extension_schemas( + cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo + ) -> Self: + """Validate that extension schemas are known.""" + obj: Self = handler(value) + + scim_ctx = info.context.get("scim") if info.context else None + if scim_ctx is None or scim_ctx == Context.DEFAULT: + return obj + + base_schema = getattr(cls, "__schema__", None) + if not base_schema: + return obj + + allowed_extensions = set(cls.get_extension_models().keys()) + provided_schemas = set(obj.schemas) - {base_schema} + + unknown = provided_schemas - allowed_extensions + if unknown: + raise PydanticCustomError( + "unknown_extension_schema", + "Unknown extension schemas: {schemas}", + {"schemas": ", ".join(sorted(unknown))}, + ) + + return obj + @classmethod def to_schema(cls) -> "Schema": """Build a :class:`~scim2_models.Schema` from the current resource class.""" @@ -458,7 +492,7 @@ def compare_field_infos(fi1: Any, fi2: Any) -> bool: def _model_to_schema(model: type[BaseModel]) -> "Schema": from scim2_models.resources.schema import Schema - schema_urn = model.model_fields["schemas"].default[0] + schema_urn = getattr(model, "__schema__", "") or "" field_infos = _dedicated_attributes(model, [Resource]) attributes = [ _model_attribute_to_scim_attribute(model, attribute_name) diff --git a/scim2_models/resources/resource_type.py b/scim2_models/resources/resource_type.py index 0b6f521..ce3b785 100644 --- a/scim2_models/resources/resource_type.py +++ b/scim2_models/resources/resource_type.py @@ -9,6 +9,7 @@ from ..annotations import Required from ..annotations import Returned from ..attributes import ComplexAttribute +from ..path import URN from ..reference import Reference from ..reference import URIReference from .resource import Resource @@ -35,9 +36,7 @@ class SchemaExtension(ComplexAttribute): class ResourceType(Resource[Any]): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:core:2.0:ResourceType" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:core:2.0:ResourceType") name: Annotated[str | None, Mutability.read_only, Required.true] = None """The resource type name. @@ -80,10 +79,11 @@ class ResourceType(Resource[Any]): @classmethod def from_resource(cls, resource_model: type[Resource[Any]]) -> Self: """Build a naive ResourceType from a resource model.""" - schema = resource_model.model_fields["schemas"].default[0] + schema = resource_model.__schema__ + if schema is None: + raise ValueError(f"{resource_model.__name__} has no __schema__ defined") name = schema.split(":")[-1] - # Get extensions from the metadata system extensions = getattr(resource_model, "__scim_extension_metadata__", []) return cls( @@ -91,10 +91,11 @@ def from_resource(cls, resource_model: type[Resource[Any]]) -> Self: name=name, description=name, endpoint=Reference[URIReference](f"/{name}s"), - schema_=schema, + schema_=Reference[URIReference](schema), schema_extensions=[ SchemaExtension( - schema_=extension.model_fields["schemas"].default[0], required=False + schema_=Reference[URIReference](extension.__schema__), + required=False, ) for extension in extensions ], diff --git a/scim2_models/resources/schema.py b/scim2_models/resources/schema.py index c6f2ba1..2a7522a 100644 --- a/scim2_models/resources/schema.py +++ b/scim2_models/resources/schema.py @@ -27,6 +27,7 @@ from ..attributes import is_complex_attribute from ..base import BaseModel from ..constants import RESERVED_WORDS +from ..path import URN from ..reference import ExternalReference from ..reference import Reference from ..reference import URIReference @@ -65,10 +66,6 @@ def _make_python_model( for attr in (obj.attributes or []) if attr.name } - pydantic_attributes["schemas"] = ( - Annotated[list[str], Required.true], - Field(default=[obj.id]), - ) if not obj.name: raise ValueError("Schema or Attribute 'name' must be defined") @@ -76,8 +73,9 @@ def _make_python_model( model_name = to_pascal(to_snake(obj.name)) model: type[T] = create_model(model_name, __base__=base, **pydantic_attributes) # type: ignore[call-overload] - # Set the ComplexType class as a member of the model - # e.g. make Member an attribute of Group + if isinstance(obj, Schema) and obj.id: + model.__schema__ = URN(obj.id) # type: ignore[attr-defined] + for attr_name in model.model_fields: attr_type = model.get_field_root_type(attr_name) if attr_type and is_complex_attribute(attr_type): @@ -254,9 +252,7 @@ def __getitem__(self, name: str) -> "Attribute": class Schema(Resource[Any]): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:core:2.0:Schema" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:core:2.0:Schema") id: Annotated[str | None, Mutability.read_only, Required.true] = None """The unique URI of the schema.""" diff --git a/scim2_models/resources/service_provider_config.py b/scim2_models/resources/service_provider_config.py index 59f1952..54d7677 100644 --- a/scim2_models/resources/service_provider_config.py +++ b/scim2_models/resources/service_provider_config.py @@ -9,6 +9,7 @@ from ..annotations import Returned from ..annotations import Uniqueness from ..attributes import ComplexAttribute +from ..path import URN from ..reference import ExternalReference from ..reference import Reference from .resource import Resource @@ -92,9 +93,7 @@ class Type(str, Enum): class ServiceProviderConfig(Resource[Any]): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig") id: Annotated[ str | None, Mutability.read_only, Returned.default, Uniqueness.global_ diff --git a/scim2_models/resources/user.py b/scim2_models/resources/user.py index 0926213..2a08a23 100644 --- a/scim2_models/resources/user.py +++ b/scim2_models/resources/user.py @@ -13,6 +13,7 @@ from ..annotations import Returned from ..annotations import Uniqueness from ..attributes import ComplexAttribute +from ..path import URN from ..reference import ExternalReference from ..reference import Reference from .resource import AnyExtension @@ -245,9 +246,7 @@ class X509Certificate(ComplexAttribute): class User(Resource[AnyExtension]): - schemas: Annotated[list[str], Required.true] = [ - "urn:ietf:params:scim:schemas:core:2.0:User" - ] + __schema__ = URN("urn:ietf:params:scim:schemas:core:2.0:User") user_name: Annotated[str | None, Uniqueness.server, Required.true] = None """Unique identifier for the User, typically used by the user to directly diff --git a/scim2_models/scim_object.py b/scim2_models/scim_object.py index 3453e58..abe014f 100644 --- a/scim2_models/scim_object.py +++ b/scim2_models/scim_object.py @@ -1,24 +1,107 @@ """Base SCIM object classes with schema identification.""" +import warnings from typing import TYPE_CHECKING from typing import Annotated from typing import Any +from typing import ClassVar + +from pydantic import ValidationInfo +from pydantic import ValidatorFunctionWrapHandler +from pydantic import model_validator +from pydantic._internal._model_construction import ModelMetaclass +from pydantic_core import PydanticCustomError +from typing_extensions import Self from .annotations import Required from .base import BaseModel from .context import Context +from .path import URN if TYPE_CHECKING: pass -class ScimObject(BaseModel): +class ScimMetaclass(ModelMetaclass): + """Metaclass for SCIM objects that handles __schema__ backward compatibility.""" + + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type: + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + + if name in ("ScimObject", "Resource", "Extension"): + return cls + + if getattr(cls, "__schema__", None) is None: + schemas_field = cls.model_fields.get("schemas") # type: ignore[attr-defined] + if ( + schemas_field + and schemas_field.default + and isinstance(schemas_field.default, list) + and schemas_field.default + ): + schema_value = schemas_field.default[0] + try: + cls.__schema__ = URN(schema_value) # type: ignore[attr-defined] + warnings.warn( + f"{name}: Defining schemas with a default value is deprecated " + f"and will be removed in version 0.7. " + f'Use __schema__ = URN("{schema_value}") instead.', + DeprecationWarning, + stacklevel=2, + ) + except ValueError: + pass + + return cls + + +class ScimObject(BaseModel, metaclass=ScimMetaclass): + __schema__: ClassVar[URN | None] = None + schemas: Annotated[list[str], Required.true] """The "schemas" attribute is a REQUIRED attribute and is an array of Strings containing URIs that are used to indicate the namespaces of the SCIM schemas that define the attributes present in the current JSON structure.""" + @model_validator(mode="before") + @classmethod + def _populate_schemas_default(cls, data: Any) -> Any: + """Auto-generate schemas from __schema__ if not provided.""" + if isinstance(data, dict) and "schemas" not in data: + schema = getattr(cls, "__schema__", None) + if schema: + data = {**data, "schemas": [schema]} + return data + + @model_validator(mode="wrap") + @classmethod + def _validate_schemas_attribute( + cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo + ) -> Self: + """Validate that the base schema is present in schemas attribute.""" + obj: Self = handler(value) + + scim_ctx = info.context.get("scim") if info.context else None + if scim_ctx is None or scim_ctx == Context.DEFAULT: + return obj + + schema = getattr(cls, "__schema__", None) + if schema and schema not in obj.schemas: + raise PydanticCustomError( + "schema_error", + "schemas must contain the base schema '{schema}'", + {"schema": schema}, + ) + + return obj + def _prepare_model_dump( self, scim_ctx: Context | None = Context.DEFAULT, diff --git a/tests/test_dynamic_resources.py b/tests/test_dynamic_resources.py index 2c997b6..c43ec9a 100644 --- a/tests/test_dynamic_resources.py +++ b/tests/test_dynamic_resources.py @@ -23,9 +23,7 @@ def test_make_group_model_from_schema(load_sample): schema = Schema.model_validate(payload) Group = Resource.from_schema(schema) - assert Group.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:core:2.0:Group" - ] + assert Group.__schema__ == "urn:ietf:params:scim:schemas:core:2.0:Group" # displayName assert Group.get_field_root_type("display_name") is str @@ -151,9 +149,7 @@ def test_make_user_model_from_schema(load_sample): schema = Schema.model_validate(payload) User = Resource.from_schema(schema) - assert User.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:core:2.0:User" - ] + assert User.__schema__ == "urn:ietf:params:scim:schemas:core:2.0:User" # user_name assert User.get_field_root_type("user_name") is str @@ -1243,9 +1239,10 @@ def test_make_enterprise_user_model_from_schema(load_sample): schema = Schema.model_validate(payload) EnterpriseUser = Extension.from_schema(schema) - assert EnterpriseUser.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" - ] + assert ( + EnterpriseUser.__schema__ + == "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) # employee_number assert EnterpriseUser.get_field_root_type("employee_number") is str @@ -1432,9 +1429,9 @@ def test_make_resource_type_model_from_schema(load_sample): schema = Schema.model_validate(payload) ResourceType = Resource.from_schema(schema) - assert ResourceType.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:core:2.0:ResourceType" - ] + assert ( + ResourceType.__schema__ == "urn:ietf:params:scim:schemas:core:2.0:ResourceType" + ) # id assert ResourceType.get_field_root_type("id") is str @@ -1625,9 +1622,10 @@ def test_make_service_provider_config_model_from_schema(load_sample): schema = Schema.model_validate(payload) ServiceProviderConfig = Resource.from_schema(schema) - assert ServiceProviderConfig.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig" - ] + assert ( + ServiceProviderConfig.__schema__ + == "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig" + ) # documentation_uri assert ( @@ -2161,9 +2159,7 @@ def test_make_schema_model_from_schema(load_sample): schema = Schema.model_validate(payload) Schema_ = Resource.from_schema(schema) - assert Schema_.model_fields["schemas"].default == [ - "urn:ietf:params:scim:schemas:core:2.0:Schema" - ] + assert Schema_.__schema__ == "urn:ietf:params:scim:schemas:core:2.0:Schema" # id assert Schema_.get_field_root_type("id") is str diff --git a/tests/test_schema_validation.py b/tests/test_schema_validation.py new file mode 100644 index 0000000..0fca0ec --- /dev/null +++ b/tests/test_schema_validation.py @@ -0,0 +1,189 @@ +"""Tests for schema validation using __schema__ classvar.""" + +import pytest +from pydantic import ValidationError + +from scim2_models.base import BaseModel +from scim2_models.context import Context +from scim2_models.path import URN +from scim2_models.resources.enterprise_user import EnterpriseUser +from scim2_models.resources.resource import Extension +from scim2_models.resources.resource import Resource +from scim2_models.resources.user import User + + +def test_validation_missing_base_schema(): + """Validation fails when base schema is missing from schemas.""" + with pytest.raises(ValidationError, match="schemas must contain"): + User.model_validate( + {"schemas": ["wrong:schema"], "userName": "foo"}, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + + +def test_validation_unknown_extension_schema(): + """Validation fails when unknown extension schema is provided.""" + with pytest.raises(ValidationError, match="Unknown extension"): + User.model_validate( + { + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:unknown:extension", + ], + "userName": "foo", + }, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + + +def test_validation_valid_extension_schema(): + """Validation succeeds with valid extension schema.""" + user = User[EnterpriseUser].model_validate( + { + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User", + ], + "userName": "foo", + }, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + assert len(user.schemas) == 2 + + +def test_schemas_auto_populated(): + """Schemas is auto-populated from __schema__ when not provided.""" + user = User(user_name="foo") + assert user.schemas == ["urn:ietf:params:scim:schemas:core:2.0:User"] + + +def test_deprecation_warning_old_style(): + """DeprecationWarning is raised for old-style schema definition.""" + with pytest.warns(DeprecationWarning, match="removed in version 0.7"): + + class OldStyleResource(Resource): + schemas: list[str] = ["urn:test:old"] + + +def test_no_validation_without_context(): + """No schema validation without SCIM context.""" + user = User.model_validate({"schemas": ["wrong:schema"], "userName": "foo"}) + assert user.schemas == ["wrong:schema"] + + +def test_no_validation_with_default_context(): + """No schema validation with DEFAULT context.""" + user = User.model_validate( + {"schemas": ["wrong:schema"], "userName": "foo"}, + context={"scim": Context.DEFAULT}, + ) + assert user.schemas == ["wrong:schema"] + + +def test_schema_classvar_defined(): + """Resources have __schema__ classvar with correct URN.""" + assert User.__schema__ == URN("urn:ietf:params:scim:schemas:core:2.0:User") + assert EnterpriseUser.__schema__ == URN( + "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" + ) + + +def test_dynamic_class_inherits_schema(): + """Dynamically created resource classes inherit __schema__ from parent.""" + + class TestExtension(Extension): + __schema__ = URN("urn:test:extension") + + UserWithExt = User[TestExtension] + assert UserWithExt.__schema__ == User.__schema__ + + +def test_extension_schema_validation_multiple_valid(): + """Validation succeeds with multiple valid extension schemas.""" + + class Ext1(Extension): + __schema__ = URN("urn:test:ext1") + + class Ext2(Extension): + __schema__ = URN("urn:test:ext2") + + UserWithExts = User[Ext1 | Ext2] + user = UserWithExts.model_validate( + { + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:test:ext1", + "urn:test:ext2", + ], + "userName": "foo", + }, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + assert len(user.schemas) == 3 + + +def test_extension_schema_validation_partial(): + """Validation succeeds when only some extension schemas are provided.""" + + class Ext1(Extension): + __schema__ = URN("urn:test:ext1") + + class Ext2(Extension): + __schema__ = URN("urn:test:ext2") + + UserWithExts = User[Ext1 | Ext2] + user = UserWithExts.model_validate( + { + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:test:ext1", + ], + "userName": "foo", + }, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + assert len(user.schemas) == 2 + + +def test_extension_schema_validation_rejects_unknown_with_valid(): + """Validation fails when unknown schema is mixed with valid extensions.""" + + class TestExt(Extension): + __schema__ = URN("urn:test:ext") + + UserWithExt = User[TestExt] + with pytest.raises(ValidationError, match="Unknown extension"): + UserWithExt.model_validate( + { + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User", + "urn:test:ext", + "urn:unknown:bad", + ], + "userName": "foo", + }, + context={"scim": Context.RESOURCE_CREATION_REQUEST}, + ) + + +def test_get_attribute_urn_without_schema(): + """get_attribute_urn returns field name when model has no __schema__.""" + + class ModelWithoutSchema(BaseModel): + foo: str | None = None + + model = ModelWithoutSchema(foo="bar") + assert model.get_attribute_urn("foo") == "foo" + + +def test_from_resource_without_schema(): + """from_resource raises ValueError when resource has no __schema__.""" + from scim2_models.resources.resource_type import ResourceType + + class NoSchemaResource(Resource): + pass + + NoSchemaResource.__schema__ = None # type: ignore[assignment] + + with pytest.raises(ValueError, match="has no __schema__ defined"): + ResourceType.from_resource(NoSchemaResource)