diff --git a/openml/_api/resources/base/resources.py b/openml/_api/resources/base/resources.py
index 0c60e69de..af5b5d752 100644
--- a/openml/_api/resources/base/resources.py
+++ b/openml/_api/resources/base/resources.py
@@ -10,6 +10,8 @@
from .base import ResourceAPI
if TYPE_CHECKING:
+ import pandas as pd
+
from openml.estimation_procedures import OpenMLEstimationProcedure
from openml.evaluations import OpenMLEvaluation
from openml.flows.flow import OpenMLFlow
@@ -80,6 +82,17 @@ class StudyAPI(ResourceAPI):
resource_type: ResourceType = ResourceType.STUDY
+ @abstractmethod
+ def list( # noqa: PLR0913
+ self,
+ limit: int | None = None,
+ offset: int | None = None,
+ status: str | None = None,
+ main_entity_type: str | None = None,
+ uploader: list[int] | None = None,
+ benchmark_suite: int | None = None,
+ ) -> pd.DataFrame: ...
+
class RunAPI(ResourceAPI):
"""Abstract API interface for run resources."""
diff --git a/openml/_api/resources/base/versions.py b/openml/_api/resources/base/versions.py
index bba59b869..0447bfd2f 100644
--- a/openml/_api/resources/base/versions.py
+++ b/openml/_api/resources/base/versions.py
@@ -25,6 +25,7 @@
ResourceType.DATASET,
ResourceType.TASK,
ResourceType.FLOW,
+ ResourceType.STUDY,
ResourceType.SETUP,
ResourceType.RUN,
]
diff --git a/openml/_api/resources/study.py b/openml/_api/resources/study.py
index fb073555c..fa8714b3e 100644
--- a/openml/_api/resources/study.py
+++ b/openml/_api/resources/study.py
@@ -1,11 +1,162 @@
from __future__ import annotations
-from .base import ResourceV1API, ResourceV2API, StudyAPI
+import builtins
+
+import pandas as pd
+import xmltodict
+
+from openml._api.resources.base import ResourceV1API, ResourceV2API, StudyAPI
class StudyV1API(ResourceV1API, StudyAPI):
- """Version 1 API implementation for study resources."""
+ def list( # noqa: PLR0913
+ self,
+ limit: int | None = None,
+ offset: int | None = None,
+ status: str | None = None,
+ main_entity_type: str | None = None,
+ uploader: builtins.list[int] | None = None,
+ benchmark_suite: int | None = None,
+ ) -> pd.DataFrame:
+ """List studies using V1 API.
+
+ Parameters
+ ----------
+ limit : int, optional
+ Maximum number of studies to return.
+ offset : int, optional
+ Number of studies to skip.
+ status : str, optional
+ Filter by status (active, in_preparation, deactivated, all).
+ main_entity_type : str, optional
+ Filter by main entity type (run, task).
+ uploader : list[int], optional
+ Filter by uploader IDs.
+ benchmark_suite : int, optional
+ Filter by benchmark suite ID.
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame containing study information.
+ """
+ api_call = self._build_url(
+ limit=limit,
+ offset=offset,
+ status=status,
+ main_entity_type=main_entity_type,
+ uploader=uploader,
+ benchmark_suite=benchmark_suite,
+ )
+ response = self._http.get(api_call)
+ xml_string = response.content.decode("utf-8")
+ return self._parse_list_xml(xml_string)
+
+ @staticmethod
+ def _build_url( # noqa: PLR0913
+ limit: int | None = None,
+ offset: int | None = None,
+ status: str | None = None,
+ main_entity_type: str | None = None,
+ uploader: builtins.list[int] | None = None,
+ benchmark_suite: int | None = None,
+ ) -> str:
+ """Build the V1 API URL for listing studies.
+
+ Parameters
+ ----------
+ limit : int, optional
+ Maximum number of studies to return.
+ offset : int, optional
+ Number of studies to skip.
+ status : str, optional
+ Filter by status (active, in_preparation, deactivated, all).
+ main_entity_type : str, optional
+ Filter by main entity type (run, task).
+ uploader : list[int], optional
+ Filter by uploader IDs.
+ benchmark_suite : int, optional
+ Filter by benchmark suite ID.
+
+ Returns
+ -------
+ str
+ The API call string with all filters applied.
+ """
+ api_call = "study/list"
+
+ if limit is not None:
+ api_call += f"/limit/{limit}"
+ if offset is not None:
+ api_call += f"/offset/{offset}"
+ if status is not None:
+ api_call += f"/status/{status}"
+ if main_entity_type is not None:
+ api_call += f"/main_entity_type/{main_entity_type}"
+ if uploader is not None:
+ api_call += f"/uploader/{','.join(str(u) for u in uploader)}"
+ if benchmark_suite is not None:
+ api_call += f"/benchmark_suite/{benchmark_suite}"
+
+ return api_call
+
+ @staticmethod
+ def _parse_list_xml(xml_string: str) -> pd.DataFrame:
+ """Parse the XML response from study list API.
+
+ Parameters
+ ----------
+ xml_string : str
+ The XML response from the API.
+
+ Returns
+ -------
+ pd.DataFrame
+ DataFrame containing study information.
+ """
+ study_dict = xmltodict.parse(xml_string, force_list=("oml:study",))
+
+ # Minimalistic check if the XML is useful
+ assert isinstance(study_dict["oml:study_list"]["oml:study"], list), type(
+ study_dict["oml:study_list"],
+ )
+ assert study_dict["oml:study_list"]["@xmlns:oml"] == "http://openml.org/openml", study_dict[
+ "oml:study_list"
+ ]["@xmlns:oml"]
+
+ studies = {}
+ for study_ in study_dict["oml:study_list"]["oml:study"]:
+ # maps from xml name to a tuple of (dict name, casting fn)
+ expected_fields = {
+ "oml:id": ("id", int),
+ "oml:alias": ("alias", str),
+ "oml:main_entity_type": ("main_entity_type", str),
+ "oml:benchmark_suite": ("benchmark_suite", int),
+ "oml:name": ("name", str),
+ "oml:status": ("status", str),
+ "oml:creation_date": ("creation_date", str),
+ "oml:creator": ("creator", int),
+ }
+ study_id = int(study_["oml:id"])
+ current_study = {}
+ for oml_field_name, (real_field_name, cast_fn) in expected_fields.items():
+ if oml_field_name in study_:
+ current_study[real_field_name] = cast_fn(study_[oml_field_name])
+ current_study["id"] = int(current_study["id"])
+ studies[study_id] = current_study
+
+ return pd.DataFrame.from_dict(studies, orient="index")
class StudyV2API(ResourceV2API, StudyAPI):
- """Version 2 API implementation for study resources."""
+ def list( # noqa: PLR0913
+ self,
+ limit: int | None = None, # noqa: ARG002
+ offset: int | None = None, # noqa: ARG002
+ status: str | None = None, # noqa: ARG002
+ main_entity_type: str | None = None, # noqa: ARG002
+ uploader: builtins.list[int] | None = None, # noqa: ARG002
+ benchmark_suite: int | None = None, # noqa: ARG002
+ ) -> pd.DataFrame:
+ """V2 API for listing studies is not yet available."""
+ self._not_supported(method="list")
diff --git a/openml/study/functions.py b/openml/study/functions.py
index 7268ea97c..02140efb6 100644
--- a/openml/study/functions.py
+++ b/openml/study/functions.py
@@ -3,7 +3,7 @@
import warnings
from functools import partial
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
import pandas as pd
import xmltodict
@@ -336,7 +336,8 @@ def delete_study(study_id: int) -> bool:
bool
True iff the deletion was successful. False otherwise
"""
- return openml.utils._delete_entity("study", study_id)
+ result: bool = openml._backend.study.delete(study_id)
+ return result
def attach_to_suite(suite_id: int, task_ids: list[int]) -> int:
@@ -465,7 +466,7 @@ def list_suites(
- creation_date
"""
listing_call = partial(
- _list_studies,
+ openml._backend.study.list,
main_entity_type="task",
status=status,
uploader=uploader,
@@ -481,7 +482,7 @@ def list_studies(
offset: int | None = None,
size: int | None = None,
status: str | None = None,
- uploader: list[str] | None = None,
+ uploader: list[int] | None = None,
benchmark_suite: int | None = None,
) -> pd.DataFrame:
"""
@@ -516,7 +517,7 @@ def list_studies(
these are also returned.
"""
listing_call = partial(
- _list_studies,
+ openml._backend.study.list,
main_entity_type="run",
status=status,
uploader=uploader,
@@ -527,81 +528,3 @@ def list_studies(
return pd.DataFrame()
return pd.concat(batches)
-
-
-def _list_studies(limit: int, offset: int, **kwargs: Any) -> pd.DataFrame:
- """Perform api call to return a list of studies.
-
- Parameters
- ----------
- limit: int
- The maximum number of studies to return.
- offset: int
- The number of studies to skip, starting from the first.
- kwargs : dict, optional
- Legal filter operators (keys in the dict):
- status, main_entity_type, uploader, benchmark_suite
-
- Returns
- -------
- studies : dataframe
- """
- api_call = "study/list"
- if limit is not None:
- api_call += f"/limit/{limit}"
- if offset is not None:
- api_call += f"/offset/{offset}"
- if kwargs is not None:
- for operator, value in kwargs.items():
- if value is not None:
- api_call += f"/{operator}/{value}"
- return __list_studies(api_call=api_call)
-
-
-def __list_studies(api_call: str) -> pd.DataFrame:
- """Retrieves the list of OpenML studies and
- returns it in a dictionary or a Pandas DataFrame.
-
- Parameters
- ----------
- api_call : str
- The API call for retrieving the list of OpenML studies.
-
- Returns
- -------
- pd.DataFrame
- A Pandas DataFrame of OpenML studies
- """
- xml_string = openml._api_calls._perform_api_call(api_call, "get")
- study_dict = xmltodict.parse(xml_string, force_list=("oml:study",))
-
- # Minimalistic check if the XML is useful
- assert isinstance(study_dict["oml:study_list"]["oml:study"], list), type(
- study_dict["oml:study_list"],
- )
- assert study_dict["oml:study_list"]["@xmlns:oml"] == "http://openml.org/openml", study_dict[
- "oml:study_list"
- ]["@xmlns:oml"]
-
- studies = {}
- for study_ in study_dict["oml:study_list"]["oml:study"]:
- # maps from xml name to a tuple of (dict name, casting fn)
- expected_fields = {
- "oml:id": ("id", int),
- "oml:alias": ("alias", str),
- "oml:main_entity_type": ("main_entity_type", str),
- "oml:benchmark_suite": ("benchmark_suite", int),
- "oml:name": ("name", str),
- "oml:status": ("status", str),
- "oml:creation_date": ("creation_date", str),
- "oml:creator": ("creator", int),
- }
- study_id = int(study_["oml:id"])
- current_study = {}
- for oml_field_name, (real_field_name, cast_fn) in expected_fields.items():
- if oml_field_name in study_:
- current_study[real_field_name] = cast_fn(study_[oml_field_name])
- current_study["id"] = int(current_study["id"])
- studies[study_id] = current_study
-
- return pd.DataFrame.from_dict(studies, orient="index")
diff --git a/tests/test_api/test_study.py b/tests/test_api/test_study.py
new file mode 100644
index 000000000..e704a6742
--- /dev/null
+++ b/tests/test_api/test_study.py
@@ -0,0 +1,309 @@
+# License: BSD 3-Clause
+from __future__ import annotations
+
+import pytest
+from requests import Session, Response
+from unittest.mock import patch
+import pandas as pd
+
+from openml._api.resources import StudyV1API, StudyV2API
+from openml.exceptions import OpenMLNotSupportedError
+import openml
+
+
+@pytest.fixture
+def study_v1(http_client_v1, minio_client) -> StudyV1API:
+ """Fixture for V1 Study API instance."""
+ return StudyV1API(http=http_client_v1, minio=minio_client)
+
+
+@pytest.fixture
+def study_v2(http_client_v2, minio_client) -> StudyV2API:
+ """Fixture for V2 Study API instance."""
+ return StudyV2API(http=http_client_v2, minio=minio_client)
+
+def test_v1_list_basic(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 list basic functionality with limit and offset."""
+ # Mock response with study list
+ mock_response = """
+
+
+ 1
+ test-study-1
+ task
+ Test Study 1
+ active
+
+
+ 2
+ test-study-2
+ run
+ Test Study 2
+ active
+
+
+ """
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = mock_response.encode("utf-8")
+
+ studies_df = study_v1.list(limit=5, offset=0)
+
+ mock_request.assert_called_once()
+ assert studies_df is not None
+ assert isinstance(studies_df, pd.DataFrame)
+ assert len(studies_df) == 2
+
+ expected_columns = {"id", "alias", "main_entity_type", "name", "status"}
+ assert expected_columns.issubset(set(studies_df.columns))
+
+
+def test_v1_list_with_status_filter(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 list with status filter."""
+ mock_response = """
+
+
+ 1
+ active-study
+ task
+ Active Study
+ active
+
+
+ """
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = mock_response.encode("utf-8")
+
+ studies_df = study_v1.list(limit=10, offset=0, status="active")
+
+ assert studies_df is not None
+ assert all(studies_df["status"] == "active")
+
+ mock_request.assert_called_once()
+ call_args = mock_request.call_args
+ assert "/status/active" in call_args.kwargs.get("url", "")
+
+
+def test_v1_list_pagination(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 list pagination with offset and limit."""
+ page1_response = """
+
+
+ 1
+ study-1
+ task
+ Study 1
+ active
+
+
+ 2
+ study-2
+ task
+ Study 2
+ active
+
+
+ 3
+ study-3
+ task
+ Study 3
+ active
+
+
+ """
+
+ page2_response = """
+
+
+ 4
+ study-4
+ run
+ Study 4
+ active
+
+
+ 5
+ study-5
+ run
+ Study 5
+ active
+
+
+ """
+
+ with patch.object(Session, "request") as mock_request:
+ page1_response_obj = Response()
+ page1_response_obj.status_code = 200
+ page1_response_obj._content = page1_response.encode("utf-8")
+
+ page2_response_obj = Response()
+ page2_response_obj.status_code = 200
+ page2_response_obj._content = page2_response.encode("utf-8")
+
+ mock_request.side_effect = [page1_response_obj, page2_response_obj]
+
+ page1 = study_v1.list(limit=3, offset=0)
+ page2 = study_v1.list(limit=3, offset=3)
+
+ assert len(page1) == 3
+ assert len(page2) == 2
+
+ assert mock_request.call_count == 2
+ call_args_list = mock_request.call_args_list
+ assert "/limit/3/offset/0" in call_args_list[0].kwargs.get("url", "")
+ assert "/limit/3/offset/3" in call_args_list[1].kwargs.get("url", "")
+
+
+def test_v1_publish(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 publish a new study."""
+ study_id = 999
+ study_files = {"description": "Test Study Description"}
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = (
+ f'\n'
+ f"\t{study_id}\n"
+ f"\n"
+ ).encode("utf-8")
+
+ published_id = study_v1.publish("study", files=study_files)
+
+ assert published_id == study_id
+
+ mock_request.assert_called_once_with(
+ method="POST",
+ url=test_server_v1 + "study",
+ params={},
+ data={"api_key": test_apikey_v1},
+ headers=openml.config._HEADERS,
+ files=study_files,
+ )
+
+
+def test_v1_tag(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 tag a study."""
+ study_id = 100
+ tag_name = "important-tag"
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = (
+ f''
+ f"{study_id}"
+ f"{tag_name}"
+ f""
+ ).encode("utf-8")
+
+ tags = study_v1.tag(study_id, tag_name)
+
+ assert tag_name in tags
+
+ mock_request.assert_called_once_with(
+ method="POST",
+ url=test_server_v1 + "study/tag",
+ params={},
+ data={
+ "api_key": test_apikey_v1,
+ "study_id": study_id,
+ "tag": tag_name,
+ },
+ headers=openml.config._HEADERS,
+ files=None,
+ )
+
+
+def test_v1_untag(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 untag a study."""
+ study_id = 100
+ tag_name = "important-tag"
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = (
+ f''
+ f"{study_id}"
+ f""
+ ).encode("utf-8")
+
+ tags = study_v1.untag(study_id, tag_name)
+
+ assert tag_name not in tags
+
+ mock_request.assert_called_once_with(
+ method="POST",
+ url=test_server_v1 + "study/untag",
+ params={},
+ data={
+ "api_key": test_apikey_v1,
+ "study_id": study_id,
+ "tag": tag_name,
+ },
+ headers=openml.config._HEADERS,
+ files=None,
+ )
+
+
+def test_v1_delete(study_v1, test_server_v1, test_apikey_v1):
+ """Test V1 delete a study."""
+ study_id = 100
+
+ with patch.object(Session, "request") as mock_request:
+ mock_request.return_value = Response()
+ mock_request.return_value.status_code = 200
+ mock_request.return_value._content = (
+ f'\n'
+ f" {study_id}\n"
+ f"\n"
+ ).encode("utf-8")
+
+ result = study_v1.delete(study_id)
+
+ assert result
+
+ mock_request.assert_called_once_with(
+ method="DELETE",
+ url=test_server_v1 + "study/" + str(study_id),
+ params={"api_key": test_apikey_v1},
+ data={},
+ headers=openml.config._HEADERS,
+ files=None,
+ )
+
+def test_v2_list_not_supported(study_v2):
+ """Test that V2 list raises OpenMLNotSupportedError."""
+ with pytest.raises(OpenMLNotSupportedError):
+ study_v2.list(limit=5, offset=0)
+
+
+def test_v2_publish_not_supported(study_v2):
+ """Test that V2 publish raises OpenMLNotSupportedError."""
+ with pytest.raises(OpenMLNotSupportedError):
+ study_v2.publish(path="study", files=None)
+
+
+def test_v2_delete_not_supported(study_v2):
+ """Test that V2 delete raises OpenMLNotSupportedError."""
+ with pytest.raises(OpenMLNotSupportedError):
+ study_v2.delete(resource_id=100)
+
+
+def test_v2_tag_not_supported(study_v2):
+ """Test that V2 tag raises OpenMLNotSupportedError."""
+ with pytest.raises(OpenMLNotSupportedError):
+ study_v2.tag(resource_id=100, tag="test-tag")
+
+
+def test_v2_untag_not_supported(study_v2):
+ """Test that V2 untag raises OpenMLNotSupportedError."""
+ with pytest.raises(OpenMLNotSupportedError):
+ study_v2.untag(resource_id=100, tag="test-tag")
+