Skip to content

Commit edde570

Browse files
committed
merge main
2 parents f276b28 + e653ef6 commit edde570

File tree

7 files changed

+230
-22
lines changed

7 files changed

+230
-22
lines changed

openml/_api/resources/base/resources.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
if TYPE_CHECKING:
1313
import pandas as pd
1414

15-
from openml import OpenMLEvaluation
15+
from openml.estimation_procedures import OpenMLEstimationProcedure
16+
from openml.evaluations import OpenMLEvaluation
1617
from openml.flows.flow import OpenMLFlow
1718
from openml.setups.setup import OpenMLSetup
1819
from openml.tasks.task import OpenMLTask, TaskType
@@ -87,6 +88,9 @@ class EstimationProcedureAPI(ResourceAPI):
8788

8889
resource_type: ResourceType = ResourceType.ESTIMATION_PROCEDURE
8990

91+
@abstractmethod
92+
def list(self) -> list[OpenMLEstimationProcedure]: ...
93+
9094

9195
class EvaluationAPI(ResourceAPI):
9296
"""Abstract API interface for evaluation resources."""
Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,84 @@
11
from __future__ import annotations
22

3+
import warnings
4+
5+
import xmltodict
6+
7+
from openml.estimation_procedures.estimation_procedure import OpenMLEstimationProcedure
8+
from openml.tasks.task import TaskType
9+
310
from .base import EstimationProcedureAPI, ResourceV1API, ResourceV2API
411

512

613
class EstimationProcedureV1API(ResourceV1API, EstimationProcedureAPI):
7-
"""Version 1 API implementation for estimation procedure resources."""
14+
"""V1 API implementation for estimation procedures.
15+
16+
Fetches estimation procedures from the v1 XML API endpoint.
17+
"""
18+
19+
def list(self) -> list[OpenMLEstimationProcedure]:
20+
"""Return a list of all estimation procedures which are on OpenML.
21+
22+
Returns
23+
-------
24+
procedures : list
25+
A list of all estimation procedures. Every procedure is represented by
26+
a dictionary containing the following information: id, task type id,
27+
name, type, repeats, folds, stratified.
28+
"""
29+
path = "estimationprocedure/list"
30+
response = self._http.get(path)
31+
xml_content = response.text
32+
33+
procs_dict = xmltodict.parse(xml_content)
34+
35+
# Minimalistic check if the XML is useful
36+
if "oml:estimationprocedures" not in procs_dict:
37+
raise ValueError("Error in return XML, does not contain tag oml:estimationprocedures.")
38+
39+
if "@xmlns:oml" not in procs_dict["oml:estimationprocedures"]:
40+
raise ValueError(
41+
"Error in return XML, does not contain tag "
42+
"@xmlns:oml as a child of oml:estimationprocedures.",
43+
)
44+
45+
if procs_dict["oml:estimationprocedures"]["@xmlns:oml"] != "http://openml.org/openml":
46+
raise ValueError(
47+
"Error in return XML, value of "
48+
"oml:estimationprocedures/@xmlns:oml is not "
49+
"http://openml.org/openml, but {}".format(
50+
str(procs_dict["oml:estimationprocedures"]["@xmlns:oml"])
51+
),
52+
)
53+
54+
procs: list[OpenMLEstimationProcedure] = []
55+
for proc_ in procs_dict["oml:estimationprocedures"]["oml:estimationprocedure"]:
56+
task_type_int = int(proc_["oml:ttid"])
57+
try:
58+
task_type_id = TaskType(task_type_int)
59+
procs.append(
60+
OpenMLEstimationProcedure(
61+
id=int(proc_["oml:id"]),
62+
task_type_id=task_type_id,
63+
name=proc_["oml:name"],
64+
type=proc_["oml:type"],
65+
)
66+
)
67+
except ValueError as e:
68+
warnings.warn(
69+
f"Could not create task type id for {task_type_int} due to error {e}",
70+
RuntimeWarning,
71+
stacklevel=2,
72+
)
73+
74+
return procs
875

976

1077
class EstimationProcedureV2API(ResourceV2API, EstimationProcedureAPI):
11-
"""Version 2 API implementation for estimation procedure resources."""
78+
"""V2 API implementation for estimation procedures.
79+
80+
Fetches estimation procedures from the v2 JSON API endpoint.
81+
"""
82+
83+
def list(self) -> list[OpenMLEstimationProcedure]:
84+
self._not_supported(method="list")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# License: BSD 3-Clause
2+
3+
from .estimation_procedure import OpenMLEstimationProcedure
4+
5+
__all__ = ["OpenMLEstimationProcedure"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# License: BSD 3-Clause
2+
from __future__ import annotations
3+
4+
from dataclasses import asdict, dataclass
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from openml.tasks import TaskType
9+
10+
11+
@dataclass
12+
class OpenMLEstimationProcedure:
13+
"""
14+
Contains all meta-information about a run / evaluation combination,
15+
according to the evaluation/list function
16+
17+
Parameters
18+
----------
19+
id : int
20+
ID of estimation procedure
21+
task_type_id : TaskType
22+
Assosiated task type
23+
name : str
24+
Name of estimation procedure
25+
type : str
26+
Type of estimation procedure
27+
"""
28+
29+
id: int
30+
task_type_id: TaskType
31+
name: str
32+
type: str
33+
34+
def _to_dict(self) -> dict:
35+
return asdict(self)
36+
37+
def __repr__(self) -> str:
38+
header = "OpenML Estimation Procedure"
39+
header = f"{header}\n{'=' * len(header)}\n"
40+
41+
fields = {
42+
"ID": self.id,
43+
"Name": self.name,
44+
"Type": self.type,
45+
"Task Type": self.task_type_id,
46+
}
47+
longest_field_name_length = max(len(name) for name in fields)
48+
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
49+
body = "\n".join(field_line_format.format(name, value) for name, value in fields.items())
50+
return header + body

openml/evaluations/functions.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
import xmltodict
1312

1413
import openml
1514
import openml._api_calls
@@ -167,24 +166,8 @@ def list_estimation_procedures() -> list[str]:
167166
-------
168167
list
169168
"""
170-
api_call = "estimationprocedure/list"
171-
xml_string = openml._api_calls._perform_api_call(api_call, "get")
172-
api_results = xmltodict.parse(xml_string)
173-
174-
# Minimalistic check if the XML is useful
175-
if "oml:estimationprocedures" not in api_results:
176-
raise ValueError('Error in return XML, does not contain "oml:estimationprocedures"')
177-
178-
if "oml:estimationprocedure" not in api_results["oml:estimationprocedures"]:
179-
raise ValueError('Error in return XML, does not contain "oml:estimationprocedure"')
180-
181-
if not isinstance(api_results["oml:estimationprocedures"]["oml:estimationprocedure"], list):
182-
raise TypeError('Error in return XML, does not contain "oml:estimationprocedure" as a list')
183-
184-
return [
185-
prod["oml:name"]
186-
for prod in api_results["oml:estimationprocedures"]["oml:estimationprocedure"]
187-
]
169+
result = openml._backend.estimation_procedure.list()
170+
return [i.name for i in result]
188171

189172

190173
def list_evaluations_setups(

openml/tasks/functions.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4+
import os
5+
import re
46
import warnings
57
from functools import partial
68
from typing import TYPE_CHECKING, Any
79

810
import pandas as pd
911

1012
import openml.utils
13+
from openml._api.resources.task import _create_task_from_xml
1114
from openml.datasets import get_dataset
15+
from openml.exceptions import OpenMLCacheException
1216

1317
from .task import (
1418
OpenMLClassificationTask,
@@ -23,6 +27,63 @@
2327
from .task import (
2428
OpenMLTask,
2529
)
30+
TASKS_CACHE_DIR_NAME = "tasks"
31+
32+
33+
def _get_cached_tasks() -> dict[int, OpenMLTask]:
34+
"""Return a dict of all the tasks which are cached locally.
35+
36+
Returns
37+
-------
38+
tasks : OrderedDict
39+
A dict of all the cached tasks. Each task is an instance of
40+
OpenMLTask.
41+
"""
42+
task_cache_dir = openml.utils._create_cache_directory(TASKS_CACHE_DIR_NAME)
43+
directory_content = os.listdir(task_cache_dir) # noqa: PTH208
44+
directory_content.sort()
45+
46+
# Find all dataset ids for which we have downloaded the dataset
47+
# description
48+
tids = (int(did) for did in directory_content if re.match(r"[0-9]*", did))
49+
return {tid: _get_cached_task(tid) for tid in tids}
50+
51+
52+
def _get_cached_task(tid: int) -> OpenMLTask:
53+
"""Return a cached task based on the given id.
54+
55+
Parameters
56+
----------
57+
tid : int
58+
Id of the task.
59+
60+
Returns
61+
-------
62+
OpenMLTask
63+
"""
64+
tid_cache_dir = openml.utils._create_cache_directory_for_id(TASKS_CACHE_DIR_NAME, tid)
65+
66+
task_xml_path = tid_cache_dir / "task.xml"
67+
try:
68+
with task_xml_path.open(encoding="utf8") as fh:
69+
return _create_task_from_xml(fh.read())
70+
except OSError as e:
71+
openml.utils._remove_cache_dir_for_id(TASKS_CACHE_DIR_NAME, tid_cache_dir)
72+
raise OpenMLCacheException(f"Task file for tid {tid} not cached") from e
73+
74+
75+
def _get_estimation_procedure_list() -> list[dict[str, Any]]:
76+
"""Return a list of all estimation procedures which are on OpenML.
77+
78+
Returns
79+
-------
80+
procedures : list
81+
A list of all estimation procedures. Every procedure is represented by
82+
a dictionary containing the following information: id, task type id,
83+
name, type, repeats, folds, stratified.
84+
"""
85+
result = openml._backend.estimation_procedure.list()
86+
return [i._to_dict() for i in result]
2687

2788

2889
def list_tasks( # noqa: PLR0913
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# License: BSD 3-Clause
2+
from __future__ import annotations
3+
4+
import pytest
5+
from openml._api import EstimationProcedureV1API, EstimationProcedureV2API
6+
from openml.exceptions import OpenMLNotSupportedError
7+
from openml.estimation_procedures import OpenMLEstimationProcedure
8+
9+
10+
@pytest.fixture
11+
def estimation_procedure_v1(http_client_v1, minio_client) -> EstimationProcedureV1API:
12+
return EstimationProcedureV1API(http=http_client_v1, minio=minio_client)
13+
14+
15+
@pytest.fixture
16+
def estimation_procedure_v2(http_client_v2, minio_client) -> EstimationProcedureV2API:
17+
return EstimationProcedureV2API(http=http_client_v2, minio=minio_client)
18+
19+
20+
@pytest.mark.test_server()
21+
def test_v1_list(estimation_procedure_v1):
22+
details = estimation_procedure_v1.list()
23+
24+
assert isinstance(details, list)
25+
assert len(details) > 0
26+
assert all(isinstance(d, OpenMLEstimationProcedure) for d in details)
27+
28+
29+
@pytest.mark.test_server()
30+
def test_v2_list(estimation_procedure_v2):
31+
with pytest.raises(OpenMLNotSupportedError):
32+
estimation_procedure_v2.list()

0 commit comments

Comments
 (0)