Skip to content

Commit 5f866f0

Browse files
committed
feature: migrate DatashareTaskClient to datashare-python
1 parent 4ccc79e commit 5f866f0

File tree

5 files changed

+325
-26
lines changed

5 files changed

+325
-26
lines changed

datashare_python/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def to_es_client(self, address: str | None = None) -> "ESClient":
5454
)
5555
return client
5656

57-
def to_task_client(self) -> "DSTaskClient":
58-
from datashare_python.utils import DSTaskClient
57+
def to_task_client(self) -> "DatashareTaskClient":
58+
from datashare_python.task_client import DatashareTaskClient
5959

60-
return DSTaskClient(self.ds_url)
60+
return DatashareTaskClient(self.ds_url)

datashare_python/task_client.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import uuid
2+
from typing import Any, Dict, Optional
3+
4+
from icij_common.pydantic_utils import jsonable_encoder
5+
from icij_worker import Task, TaskError, TaskState
6+
from icij_worker.exceptions import UnknownTask
7+
from icij_worker.utils.http import AiohttpClient
8+
9+
# TODO: maxRetries is not supported by java, it's automatically set to 3
10+
_TASK_UNSUPPORTED = {"max_retries"}
11+
12+
13+
class DatashareTaskClient(AiohttpClient):
14+
def __init__(self, datashare_url: str, api_key: str | None = None) -> None:
15+
headers = None
16+
if api_key is not None:
17+
headers = {"Authorization": f"Bearer {api_key}"}
18+
super().__init__(datashare_url, headers=headers)
19+
20+
async def __aenter__(self):
21+
await super().__aenter__()
22+
if "Authorization" not in self._headers:
23+
async with self._get("/settings") as res:
24+
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty
25+
# hack here
26+
session_id = [
27+
item
28+
for item in res.headers["Set-Cookie"].split("; ")
29+
if "session_id" in item
30+
]
31+
if len(session_id) != 1:
32+
raise ValueError("Invalid cookie")
33+
k, v = session_id[0].split("=")
34+
self._session.cookie_jar.update_cookies({k: v})
35+
36+
async def create_task(
37+
self,
38+
name: str,
39+
args: Dict[str, Any],
40+
*,
41+
id_: Optional[str] = None,
42+
group: Optional[str] = None,
43+
) -> str:
44+
if id_ is None:
45+
id_ = _generate_task_id(name)
46+
task = Task.create(task_id=id_, task_name=name, args=args)
47+
task = jsonable_encoder(task, exclude=_TASK_UNSUPPORTED, exclude_unset=True)
48+
task.pop("createdAt")
49+
url = f"/api/task/{id_}"
50+
if group is not None:
51+
if not isinstance(group, str):
52+
raise TypeError(f"expected group to be a string found {group}")
53+
url += f"?group={group}"
54+
async with self._put(url, json=task) as res:
55+
task_res = await res.json()
56+
return task_res["taskId"]
57+
58+
async def get_task(self, id_: str) -> Task:
59+
url = f"/api/task/{id_}"
60+
async with self._get(url) as res:
61+
task = await res.json()
62+
if task is None:
63+
raise UnknownTask(id_)
64+
# TODO: align Java on Python here... it's not a good idea to store results
65+
# inside tasks since result can be quite large and we may want to get the task
66+
# metadata without having to deal with the large task results...
67+
task = _ds_to_icij_worker_task(task)
68+
task = Task(**task)
69+
return task
70+
71+
async def get_tasks(self) -> list[Task]:
72+
url = "/api/task/all"
73+
async with self._get(url) as res:
74+
tasks = await res.json()
75+
# TODO: align Java on Python here... it's not a good idea to store results
76+
# inside tasks since result can be quite large and we may want to get the task
77+
# metadata without having to deal with the large task results...
78+
tasks = (_ds_to_icij_worker_task(t) for t in tasks)
79+
tasks = [Task(**task) for task in tasks]
80+
return tasks
81+
82+
async def get_task_state(self, id_: str) -> TaskState:
83+
return (await self.get_task(id_)).state
84+
85+
async def get_task_result(self, id_: str) -> Any:
86+
url = f"/api/task/{id_}/results"
87+
async with self._get(url) as res:
88+
task_res = await res.json()
89+
return task_res
90+
91+
async def get_task_error(self, id_: str) -> TaskError:
92+
url = f"/api/task/{id_}"
93+
async with self._get(url) as res:
94+
task = await res.json()
95+
if task is None:
96+
raise UnknownTask(id_)
97+
task_state = TaskState[task["state"]]
98+
if task_state != TaskState.ERROR:
99+
msg = f"can't find error for task {id_} in state {task_state}"
100+
raise ValueError(msg)
101+
error = TaskError(**task["error"])
102+
return error
103+
104+
async def delete(self, id_: str):
105+
url = f"/api/task/{id_}"
106+
async with self._delete(url):
107+
pass
108+
109+
async def delete_all_tasks(self):
110+
for t in await self.get_tasks():
111+
await self.delete(t.id)
112+
113+
114+
def _generate_task_id(task_name: str) -> str:
115+
return f"{task_name}-{uuid.uuid4()}"
116+
117+
118+
_JAVA_TASK_ATTRIBUTES = ["result", "error"]
119+
120+
121+
def _ds_to_icij_worker_task(task: dict) -> dict:
122+
for k in _JAVA_TASK_ATTRIBUTES:
123+
task.pop(k, None)
124+
return task

datashare_python/tests/tasks/test_translate_docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from icij_common.es import ESClient
77

88
from datashare_python.objects import Document
9+
from datashare_python.task_client import DatashareTaskClient
910
from datashare_python.tasks import create_translation_tasks
1011
from datashare_python.tests.conftest import TEST_PROJECT
11-
from datashare_python.utils import DSTaskClient
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -21,7 +21,7 @@ async def _progress(p: float):
2121
async def test_create_translation_tasks_integration(
2222
populate_es: List[Document], # pylint: disable=unused-argument
2323
test_es_client: ESClient,
24-
test_task_client: DSTaskClient,
24+
test_task_client: DatashareTaskClient,
2525
):
2626
# Given
2727
es_client = test_es_client
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import uuid
2+
from contextlib import asynccontextmanager
3+
from datetime import datetime
4+
from typing import Any
5+
from unittest.mock import AsyncMock
6+
7+
from aiohttp.typedefs import StrOrURL
8+
from icij_worker import Task, TaskError, TaskState
9+
from icij_worker.objects import StacktraceItem
10+
11+
from datashare_python.task_client import DatashareTaskClient
12+
13+
14+
async def test_task_client_create_task(monkeypatch):
15+
# Given
16+
datashare_url = "http://some-url"
17+
api_key = "some-api-key"
18+
task_name = "hello"
19+
task_id = f"{task_name}-{uuid.uuid4()}"
20+
args = {"greeted": "world"}
21+
group = "PYTHON"
22+
23+
@asynccontextmanager
24+
async def _put_and_assert(_, url: StrOrURL, *, data: Any = None, **kwargs: Any):
25+
assert url == f"/api/task/{task_id}?group={group}"
26+
expected_task = {
27+
"@type": "Task",
28+
"id": task_id,
29+
"state": "CREATED",
30+
"name": "hello",
31+
"args": {"greeted": "world"},
32+
}
33+
expected_data = expected_task
34+
assert data is None
35+
json_data = kwargs.pop("json")
36+
assert not kwargs
37+
assert json_data == expected_data
38+
mocked_res = AsyncMock()
39+
mocked_res.json.return_value = {"taskId": task_id}
40+
yield mocked_res
41+
42+
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._put", _put_and_assert)
43+
44+
task_client = DatashareTaskClient(datashare_url, api_key=api_key)
45+
async with task_client:
46+
# When
47+
t_id = await task_client.create_task(task_name, args, id_=task_id, group=group)
48+
assert t_id == task_id
49+
50+
51+
async def test_task_client_get_task(monkeypatch):
52+
# Given
53+
datashare_url = "http://some-url"
54+
api_key = "some-api-key"
55+
task_name = "hello"
56+
task_id = f"{task_name}-{uuid.uuid4()}"
57+
58+
@asynccontextmanager
59+
async def _get_and_assert(
60+
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
61+
):
62+
assert url == f"/api/task/{task_id}"
63+
task = {
64+
"@type": "Task",
65+
"id": task_id,
66+
"state": "CREATED",
67+
"createdAt": datetime.now(),
68+
"name": "hello",
69+
"args": {"greeted": "world"},
70+
}
71+
assert allow_redirects
72+
assert not kwargs
73+
mocked_res = AsyncMock()
74+
mocked_res.json.return_value = task
75+
yield mocked_res
76+
77+
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)
78+
79+
task_client = DatashareTaskClient(datashare_url, api_key=api_key)
80+
async with task_client:
81+
# When
82+
task = await task_client.get_task(task_id)
83+
assert isinstance(task, Task)
84+
85+
86+
async def test_task_client_get_task_state(monkeypatch):
87+
# Given
88+
datashare_url = "http://some-url"
89+
api_key = "some-api-key"
90+
task_name = "hello"
91+
task_id = f"{task_name}-{uuid.uuid4()}"
92+
93+
@asynccontextmanager
94+
async def _get_and_assert(
95+
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
96+
):
97+
assert url == f"/api/task/{task_id}"
98+
task = {
99+
"@type": "Task",
100+
"id": task_id,
101+
"state": "DONE",
102+
"createdAt": datetime.now(),
103+
"completedAt": datetime.now(),
104+
"name": "hello",
105+
"args": {"greeted": "world"},
106+
"result": "hellow world",
107+
}
108+
assert allow_redirects
109+
assert not kwargs
110+
mocked_res = AsyncMock()
111+
mocked_res.json.return_value = task
112+
yield mocked_res
113+
114+
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)
115+
116+
task_client = DatashareTaskClient(datashare_url, api_key=api_key)
117+
async with task_client:
118+
# When
119+
res = await task_client.get_task_state(task_id)
120+
assert res == TaskState.DONE
121+
122+
123+
async def test_task_client_get_task_result(monkeypatch):
124+
# Given
125+
datashare_url = "http://some-url"
126+
api_key = "some-api-key"
127+
task_name = "hello"
128+
task_id = f"{task_name}-{uuid.uuid4()}"
129+
130+
@asynccontextmanager
131+
async def _get_and_assert(
132+
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
133+
):
134+
assert url == f"/api/task/{task_id}/results"
135+
assert allow_redirects
136+
assert not kwargs
137+
mocked_res = AsyncMock()
138+
mocked_res.json.return_value = "hellow world"
139+
yield mocked_res
140+
141+
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)
142+
143+
task_client = DatashareTaskClient(datashare_url, api_key=api_key)
144+
async with task_client:
145+
# When
146+
res = await task_client.get_task_result(task_id)
147+
assert res == "hellow world"
148+
149+
150+
async def test_task_client_get_task_error(monkeypatch):
151+
# Given
152+
datashare_url = "http://some-url"
153+
api_key = "some-api-key"
154+
task_name = "hello"
155+
task_id = f"{task_name}-{uuid.uuid4()}"
156+
157+
@asynccontextmanager
158+
async def _get_and_assert(
159+
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
160+
):
161+
assert url == f"/api/task/{task_id}"
162+
task = {
163+
"@type": "Task",
164+
"id": task_id,
165+
"state": "ERROR",
166+
"createdAt": datetime.now(),
167+
"completedAt": datetime.now(),
168+
"name": "hello",
169+
"args": {"greeted": "world"},
170+
"error": {
171+
"@type": "TaskError",
172+
"name": "SomeError",
173+
"message": "some error found",
174+
"cause": "i'm the culprit",
175+
"stacktrace": [{"lineno": 666, "file": "some_file.py", "name": "err"}],
176+
},
177+
}
178+
assert allow_redirects
179+
assert not kwargs
180+
mocked_res = AsyncMock()
181+
mocked_res.json.return_value = task
182+
yield mocked_res
183+
184+
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)
185+
186+
task_client = DatashareTaskClient(datashare_url, api_key=api_key)
187+
async with task_client:
188+
# When
189+
error = await task_client.get_task_error(task_id)
190+
expected_error = TaskError(
191+
name="SomeError",
192+
message="some error found",
193+
cause="i'm the culprit",
194+
stacktrace=[StacktraceItem(name="err", file="some_file.py", lineno=666)],
195+
)
196+
assert error == expected_error

datashare_python/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from itertools import islice
44
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, TypeVar
55

6-
from icij_worker.ds_task_client import DatashareTaskClient
7-
86
T = TypeVar("T")
97

108
Predicate = Callable[[T], bool] | Callable[[T], Awaitable[bool]]
@@ -69,22 +67,3 @@ async def remainder_iterator():
6967
yield elm
7068

7169
return true_iterator(), remainder_iterator()
72-
73-
74-
class DSTaskClient(DatashareTaskClient):
75-
76-
async def __aenter__(self):
77-
await super().__aenter__()
78-
79-
async with self._get("/settings") as res:
80-
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty
81-
# hack here
82-
session_id = [
83-
item
84-
for item in res.headers["Set-Cookie"].split("; ")
85-
if "session_id" in item
86-
]
87-
if len(session_id) != 1:
88-
raise ValueError("Invalid cookie")
89-
k, v = session_id[0].split("=")
90-
self._session.cookie_jar.update_cookies({k: v})

0 commit comments

Comments
 (0)