diff --git a/openml/_api/resources/base/resources.py b/openml/_api/resources/base/resources.py index 0c60e69de..af5b5d752 100644 --- a/openml/_api/resources/base/resources.py +++ b/openml/_api/resources/base/resources.py @@ -10,6 +10,8 @@ from .base import ResourceAPI if TYPE_CHECKING: + import pandas as pd + from openml.estimation_procedures import OpenMLEstimationProcedure from openml.evaluations import OpenMLEvaluation from openml.flows.flow import OpenMLFlow @@ -80,6 +82,17 @@ class StudyAPI(ResourceAPI): resource_type: ResourceType = ResourceType.STUDY + @abstractmethod + def list( # noqa: PLR0913 + self, + limit: int | None = None, + offset: int | None = None, + status: str | None = None, + main_entity_type: str | None = None, + uploader: list[int] | None = None, + benchmark_suite: int | None = None, + ) -> pd.DataFrame: ... + class RunAPI(ResourceAPI): """Abstract API interface for run resources.""" diff --git a/openml/_api/resources/base/versions.py b/openml/_api/resources/base/versions.py index bba59b869..0447bfd2f 100644 --- a/openml/_api/resources/base/versions.py +++ b/openml/_api/resources/base/versions.py @@ -25,6 +25,7 @@ ResourceType.DATASET, ResourceType.TASK, ResourceType.FLOW, + ResourceType.STUDY, ResourceType.SETUP, ResourceType.RUN, ] diff --git a/openml/_api/resources/study.py b/openml/_api/resources/study.py index fb073555c..fa8714b3e 100644 --- a/openml/_api/resources/study.py +++ b/openml/_api/resources/study.py @@ -1,11 +1,162 @@ from __future__ import annotations -from .base import ResourceV1API, ResourceV2API, StudyAPI +import builtins + +import pandas as pd +import xmltodict + +from openml._api.resources.base import ResourceV1API, ResourceV2API, StudyAPI class StudyV1API(ResourceV1API, StudyAPI): - """Version 1 API implementation for study resources.""" + def list( # noqa: PLR0913 + self, + limit: int | None = None, + offset: int | None = None, + status: str | None = None, + main_entity_type: str | None = None, + uploader: builtins.list[int] | None = None, + benchmark_suite: int | None = None, + ) -> pd.DataFrame: + """List studies using V1 API. + + Parameters + ---------- + limit : int, optional + Maximum number of studies to return. + offset : int, optional + Number of studies to skip. + status : str, optional + Filter by status (active, in_preparation, deactivated, all). + main_entity_type : str, optional + Filter by main entity type (run, task). + uploader : list[int], optional + Filter by uploader IDs. + benchmark_suite : int, optional + Filter by benchmark suite ID. + + Returns + ------- + pd.DataFrame + DataFrame containing study information. + """ + api_call = self._build_url( + limit=limit, + offset=offset, + status=status, + main_entity_type=main_entity_type, + uploader=uploader, + benchmark_suite=benchmark_suite, + ) + response = self._http.get(api_call) + xml_string = response.content.decode("utf-8") + return self._parse_list_xml(xml_string) + + @staticmethod + def _build_url( # noqa: PLR0913 + limit: int | None = None, + offset: int | None = None, + status: str | None = None, + main_entity_type: str | None = None, + uploader: builtins.list[int] | None = None, + benchmark_suite: int | None = None, + ) -> str: + """Build the V1 API URL for listing studies. + + Parameters + ---------- + limit : int, optional + Maximum number of studies to return. + offset : int, optional + Number of studies to skip. + status : str, optional + Filter by status (active, in_preparation, deactivated, all). + main_entity_type : str, optional + Filter by main entity type (run, task). + uploader : list[int], optional + Filter by uploader IDs. + benchmark_suite : int, optional + Filter by benchmark suite ID. + + Returns + ------- + str + The API call string with all filters applied. + """ + api_call = "study/list" + + if limit is not None: + api_call += f"/limit/{limit}" + if offset is not None: + api_call += f"/offset/{offset}" + if status is not None: + api_call += f"/status/{status}" + if main_entity_type is not None: + api_call += f"/main_entity_type/{main_entity_type}" + if uploader is not None: + api_call += f"/uploader/{','.join(str(u) for u in uploader)}" + if benchmark_suite is not None: + api_call += f"/benchmark_suite/{benchmark_suite}" + + return api_call + + @staticmethod + def _parse_list_xml(xml_string: str) -> pd.DataFrame: + """Parse the XML response from study list API. + + Parameters + ---------- + xml_string : str + The XML response from the API. + + Returns + ------- + pd.DataFrame + DataFrame containing study information. + """ + study_dict = xmltodict.parse(xml_string, force_list=("oml:study",)) + + # Minimalistic check if the XML is useful + assert isinstance(study_dict["oml:study_list"]["oml:study"], list), type( + study_dict["oml:study_list"], + ) + assert study_dict["oml:study_list"]["@xmlns:oml"] == "http://openml.org/openml", study_dict[ + "oml:study_list" + ]["@xmlns:oml"] + + studies = {} + for study_ in study_dict["oml:study_list"]["oml:study"]: + # maps from xml name to a tuple of (dict name, casting fn) + expected_fields = { + "oml:id": ("id", int), + "oml:alias": ("alias", str), + "oml:main_entity_type": ("main_entity_type", str), + "oml:benchmark_suite": ("benchmark_suite", int), + "oml:name": ("name", str), + "oml:status": ("status", str), + "oml:creation_date": ("creation_date", str), + "oml:creator": ("creator", int), + } + study_id = int(study_["oml:id"]) + current_study = {} + for oml_field_name, (real_field_name, cast_fn) in expected_fields.items(): + if oml_field_name in study_: + current_study[real_field_name] = cast_fn(study_[oml_field_name]) + current_study["id"] = int(current_study["id"]) + studies[study_id] = current_study + + return pd.DataFrame.from_dict(studies, orient="index") class StudyV2API(ResourceV2API, StudyAPI): - """Version 2 API implementation for study resources.""" + def list( # noqa: PLR0913 + self, + limit: int | None = None, # noqa: ARG002 + offset: int | None = None, # noqa: ARG002 + status: str | None = None, # noqa: ARG002 + main_entity_type: str | None = None, # noqa: ARG002 + uploader: builtins.list[int] | None = None, # noqa: ARG002 + benchmark_suite: int | None = None, # noqa: ARG002 + ) -> pd.DataFrame: + """V2 API for listing studies is not yet available.""" + self._not_supported(method="list") diff --git a/openml/study/functions.py b/openml/study/functions.py index 7268ea97c..02140efb6 100644 --- a/openml/study/functions.py +++ b/openml/study/functions.py @@ -3,7 +3,7 @@ import warnings from functools import partial -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pandas as pd import xmltodict @@ -336,7 +336,8 @@ def delete_study(study_id: int) -> bool: bool True iff the deletion was successful. False otherwise """ - return openml.utils._delete_entity("study", study_id) + result: bool = openml._backend.study.delete(study_id) + return result def attach_to_suite(suite_id: int, task_ids: list[int]) -> int: @@ -465,7 +466,7 @@ def list_suites( - creation_date """ listing_call = partial( - _list_studies, + openml._backend.study.list, main_entity_type="task", status=status, uploader=uploader, @@ -481,7 +482,7 @@ def list_studies( offset: int | None = None, size: int | None = None, status: str | None = None, - uploader: list[str] | None = None, + uploader: list[int] | None = None, benchmark_suite: int | None = None, ) -> pd.DataFrame: """ @@ -516,7 +517,7 @@ def list_studies( these are also returned. """ listing_call = partial( - _list_studies, + openml._backend.study.list, main_entity_type="run", status=status, uploader=uploader, @@ -527,81 +528,3 @@ def list_studies( return pd.DataFrame() return pd.concat(batches) - - -def _list_studies(limit: int, offset: int, **kwargs: Any) -> pd.DataFrame: - """Perform api call to return a list of studies. - - Parameters - ---------- - limit: int - The maximum number of studies to return. - offset: int - The number of studies to skip, starting from the first. - kwargs : dict, optional - Legal filter operators (keys in the dict): - status, main_entity_type, uploader, benchmark_suite - - Returns - ------- - studies : dataframe - """ - api_call = "study/list" - if limit is not None: - api_call += f"/limit/{limit}" - if offset is not None: - api_call += f"/offset/{offset}" - if kwargs is not None: - for operator, value in kwargs.items(): - if value is not None: - api_call += f"/{operator}/{value}" - return __list_studies(api_call=api_call) - - -def __list_studies(api_call: str) -> pd.DataFrame: - """Retrieves the list of OpenML studies and - returns it in a dictionary or a Pandas DataFrame. - - Parameters - ---------- - api_call : str - The API call for retrieving the list of OpenML studies. - - Returns - ------- - pd.DataFrame - A Pandas DataFrame of OpenML studies - """ - xml_string = openml._api_calls._perform_api_call(api_call, "get") - study_dict = xmltodict.parse(xml_string, force_list=("oml:study",)) - - # Minimalistic check if the XML is useful - assert isinstance(study_dict["oml:study_list"]["oml:study"], list), type( - study_dict["oml:study_list"], - ) - assert study_dict["oml:study_list"]["@xmlns:oml"] == "http://openml.org/openml", study_dict[ - "oml:study_list" - ]["@xmlns:oml"] - - studies = {} - for study_ in study_dict["oml:study_list"]["oml:study"]: - # maps from xml name to a tuple of (dict name, casting fn) - expected_fields = { - "oml:id": ("id", int), - "oml:alias": ("alias", str), - "oml:main_entity_type": ("main_entity_type", str), - "oml:benchmark_suite": ("benchmark_suite", int), - "oml:name": ("name", str), - "oml:status": ("status", str), - "oml:creation_date": ("creation_date", str), - "oml:creator": ("creator", int), - } - study_id = int(study_["oml:id"]) - current_study = {} - for oml_field_name, (real_field_name, cast_fn) in expected_fields.items(): - if oml_field_name in study_: - current_study[real_field_name] = cast_fn(study_[oml_field_name]) - current_study["id"] = int(current_study["id"]) - studies[study_id] = current_study - - return pd.DataFrame.from_dict(studies, orient="index") diff --git a/tests/test_api/test_study.py b/tests/test_api/test_study.py new file mode 100644 index 000000000..e704a6742 --- /dev/null +++ b/tests/test_api/test_study.py @@ -0,0 +1,309 @@ +# License: BSD 3-Clause +from __future__ import annotations + +import pytest +from requests import Session, Response +from unittest.mock import patch +import pandas as pd + +from openml._api.resources import StudyV1API, StudyV2API +from openml.exceptions import OpenMLNotSupportedError +import openml + + +@pytest.fixture +def study_v1(http_client_v1, minio_client) -> StudyV1API: + """Fixture for V1 Study API instance.""" + return StudyV1API(http=http_client_v1, minio=minio_client) + + +@pytest.fixture +def study_v2(http_client_v2, minio_client) -> StudyV2API: + """Fixture for V2 Study API instance.""" + return StudyV2API(http=http_client_v2, minio=minio_client) + +def test_v1_list_basic(study_v1, test_server_v1, test_apikey_v1): + """Test V1 list basic functionality with limit and offset.""" + # Mock response with study list + mock_response = """ + + + 1 + test-study-1 + task + Test Study 1 + active + + + 2 + test-study-2 + run + Test Study 2 + active + + + """ + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = mock_response.encode("utf-8") + + studies_df = study_v1.list(limit=5, offset=0) + + mock_request.assert_called_once() + assert studies_df is not None + assert isinstance(studies_df, pd.DataFrame) + assert len(studies_df) == 2 + + expected_columns = {"id", "alias", "main_entity_type", "name", "status"} + assert expected_columns.issubset(set(studies_df.columns)) + + +def test_v1_list_with_status_filter(study_v1, test_server_v1, test_apikey_v1): + """Test V1 list with status filter.""" + mock_response = """ + + + 1 + active-study + task + Active Study + active + + + """ + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = mock_response.encode("utf-8") + + studies_df = study_v1.list(limit=10, offset=0, status="active") + + assert studies_df is not None + assert all(studies_df["status"] == "active") + + mock_request.assert_called_once() + call_args = mock_request.call_args + assert "/status/active" in call_args.kwargs.get("url", "") + + +def test_v1_list_pagination(study_v1, test_server_v1, test_apikey_v1): + """Test V1 list pagination with offset and limit.""" + page1_response = """ + + + 1 + study-1 + task + Study 1 + active + + + 2 + study-2 + task + Study 2 + active + + + 3 + study-3 + task + Study 3 + active + + + """ + + page2_response = """ + + + 4 + study-4 + run + Study 4 + active + + + 5 + study-5 + run + Study 5 + active + + + """ + + with patch.object(Session, "request") as mock_request: + page1_response_obj = Response() + page1_response_obj.status_code = 200 + page1_response_obj._content = page1_response.encode("utf-8") + + page2_response_obj = Response() + page2_response_obj.status_code = 200 + page2_response_obj._content = page2_response.encode("utf-8") + + mock_request.side_effect = [page1_response_obj, page2_response_obj] + + page1 = study_v1.list(limit=3, offset=0) + page2 = study_v1.list(limit=3, offset=3) + + assert len(page1) == 3 + assert len(page2) == 2 + + assert mock_request.call_count == 2 + call_args_list = mock_request.call_args_list + assert "/limit/3/offset/0" in call_args_list[0].kwargs.get("url", "") + assert "/limit/3/offset/3" in call_args_list[1].kwargs.get("url", "") + + +def test_v1_publish(study_v1, test_server_v1, test_apikey_v1): + """Test V1 publish a new study.""" + study_id = 999 + study_files = {"description": "Test Study Description"} + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'\n' + f"\t{study_id}\n" + f"\n" + ).encode("utf-8") + + published_id = study_v1.publish("study", files=study_files) + + assert published_id == study_id + + mock_request.assert_called_once_with( + method="POST", + url=test_server_v1 + "study", + params={}, + data={"api_key": test_apikey_v1}, + headers=openml.config._HEADERS, + files=study_files, + ) + + +def test_v1_tag(study_v1, test_server_v1, test_apikey_v1): + """Test V1 tag a study.""" + study_id = 100 + tag_name = "important-tag" + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'' + f"{study_id}" + f"{tag_name}" + f"" + ).encode("utf-8") + + tags = study_v1.tag(study_id, tag_name) + + assert tag_name in tags + + mock_request.assert_called_once_with( + method="POST", + url=test_server_v1 + "study/tag", + params={}, + data={ + "api_key": test_apikey_v1, + "study_id": study_id, + "tag": tag_name, + }, + headers=openml.config._HEADERS, + files=None, + ) + + +def test_v1_untag(study_v1, test_server_v1, test_apikey_v1): + """Test V1 untag a study.""" + study_id = 100 + tag_name = "important-tag" + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'' + f"{study_id}" + f"" + ).encode("utf-8") + + tags = study_v1.untag(study_id, tag_name) + + assert tag_name not in tags + + mock_request.assert_called_once_with( + method="POST", + url=test_server_v1 + "study/untag", + params={}, + data={ + "api_key": test_apikey_v1, + "study_id": study_id, + "tag": tag_name, + }, + headers=openml.config._HEADERS, + files=None, + ) + + +def test_v1_delete(study_v1, test_server_v1, test_apikey_v1): + """Test V1 delete a study.""" + study_id = 100 + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'\n' + f" {study_id}\n" + f"\n" + ).encode("utf-8") + + result = study_v1.delete(study_id) + + assert result + + mock_request.assert_called_once_with( + method="DELETE", + url=test_server_v1 + "study/" + str(study_id), + params={"api_key": test_apikey_v1}, + data={}, + headers=openml.config._HEADERS, + files=None, + ) + +def test_v2_list_not_supported(study_v2): + """Test that V2 list raises OpenMLNotSupportedError.""" + with pytest.raises(OpenMLNotSupportedError): + study_v2.list(limit=5, offset=0) + + +def test_v2_publish_not_supported(study_v2): + """Test that V2 publish raises OpenMLNotSupportedError.""" + with pytest.raises(OpenMLNotSupportedError): + study_v2.publish(path="study", files=None) + + +def test_v2_delete_not_supported(study_v2): + """Test that V2 delete raises OpenMLNotSupportedError.""" + with pytest.raises(OpenMLNotSupportedError): + study_v2.delete(resource_id=100) + + +def test_v2_tag_not_supported(study_v2): + """Test that V2 tag raises OpenMLNotSupportedError.""" + with pytest.raises(OpenMLNotSupportedError): + study_v2.tag(resource_id=100, tag="test-tag") + + +def test_v2_untag_not_supported(study_v2): + """Test that V2 untag raises OpenMLNotSupportedError.""" + with pytest.raises(OpenMLNotSupportedError): + study_v2.untag(resource_id=100, tag="test-tag") +