Skip to content

Commit d90af41

Browse files
authored
Add AWS S3 ResourcesRepository
1 parent 61d9f03 commit d90af41

File tree

9 files changed

+662
-6
lines changed

9 files changed

+662
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "uv_build"
55
[project]
66
name = "draive"
77
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
8-
version = "0.93.4"
8+
version = "0.94.0"
99
readme = "README.md"
1010
maintainers = [
1111
{ name = "Kacper Kaliński", email = "[email protected]" },
@@ -41,6 +41,7 @@ mistral = ["mistralai~=1.9"]
4141
gemini = ["google-genai~=1.50", "google-api-core"]
4242
ollama = ["ollama~=0.6.0"]
4343
bedrock = ["boto3~=1.40"]
44+
aws = ["boto3~=1.40"]
4445
vllm = ["openai~=2.8"]
4546
mcp = ["mcp~=1.21"]
4647
opentelemetry = [

src/draive/aws/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from draive.aws.client import AWS
2+
from draive.aws.types import AWSAccessDenied, AWSError, AWSResourceNotFound
3+
4+
__all__ = (
5+
"AWS",
6+
"AWSAccessDenied",
7+
"AWSError",
8+
"AWSResourceNotFound",
9+
)

src/draive/aws/api.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any
2+
3+
from boto3 import Session # pyright: ignore[reportMissingTypeStubs]
4+
from haiway import asynchronous
5+
6+
__all__ = [
7+
"AWSAPI",
8+
]
9+
10+
11+
class AWSAPI:
12+
"""Low-level AWS session and client management.
13+
14+
Provides an asynchronous S3 client initializer that other mixins
15+
can rely on without duplicating boto3 session wiring.
16+
"""
17+
18+
__slots__ = (
19+
"_s3_client",
20+
"_session",
21+
)
22+
23+
def __init__(
24+
self,
25+
region_name: str | None = None,
26+
access_key_id: str | None = None,
27+
secret_access_key: str | None = None,
28+
) -> None:
29+
"""Create an AWS session.
30+
31+
Parameters
32+
----------
33+
region_name
34+
Preferred AWS region for the session. Defaults to the
35+
region configured in the environment or AWS profiles.
36+
access_key_id
37+
Optional access key identifier used for credential override.
38+
secret_access_key
39+
Optional secret key paired with ``access_key_id`` for override.
40+
"""
41+
# using dict as kwargs since existence of some
42+
# arguments when initializing session changes
43+
# the session configuration, even if using None
44+
kwargs: dict[str, object] = {}
45+
46+
if key_id := access_key_id:
47+
kwargs["aws_access_key_id"] = key_id
48+
49+
if key := secret_access_key:
50+
kwargs["aws_secret_access_key"] = key
51+
52+
if region := region_name:
53+
kwargs["region_name"] = region
54+
55+
self._session: Session = Session(**kwargs)
56+
self._s3_client: Any
57+
58+
@asynchronous
59+
def _prepare_client(self) -> None:
60+
self._s3_client = self._session.client( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
61+
service_name="s3",
62+
)
63+
64+
@property
65+
def region(self) -> str | None:
66+
"""Currently configured AWS region for the session."""
67+
68+
return self._session.region_name # pyright: ignore[reportUnknownMemberType, reportReturnType, reportUnknownVariableType]

src/draive/aws/client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from collections.abc import Collection, Iterable
2+
from types import TracebackType
3+
from typing import final
4+
5+
from haiway import State
6+
7+
from draive.aws.api import AWSAPI
8+
from draive.aws.s3 import AWSS3Mixin
9+
from draive.resources import ResourcesRepository
10+
11+
__all__ = ("AWS",)
12+
13+
14+
@final
15+
class AWS(
16+
AWSS3Mixin,
17+
AWSAPI,
18+
):
19+
"""AWS service facade bundling S3 and repository integrations.
20+
21+
Parameters
22+
----------
23+
region_name
24+
Preferred AWS region. Falls back to profile or environment
25+
configuration when omitted.
26+
access_key_id
27+
Optional access key identifier used to override ambient
28+
credentials.
29+
secret_access_key
30+
Secret key paired with ``access_key_id`` when overriding
31+
credentials.
32+
features
33+
Collection of repository feature classes to activate while the
34+
client is bound in a context manager.
35+
"""
36+
__slots__ = ("_features",)
37+
38+
def __init__(
39+
self,
40+
region_name: str | None = None,
41+
access_key_id: str | None = None,
42+
secret_access_key: str | None = None,
43+
features: Collection[type[ResourcesRepository]] | None = None,
44+
) -> None:
45+
super().__init__(
46+
region_name=region_name,
47+
access_key_id=access_key_id,
48+
secret_access_key=secret_access_key,
49+
)
50+
51+
self._features: Collection[type[ResourcesRepository]]
52+
if features is not None:
53+
self._features = features
54+
55+
else:
56+
self._features = (ResourcesRepository,)
57+
58+
async def __aenter__(self) -> Iterable[State]:
59+
"""Prepare the AWS client and bind selected features to context."""
60+
await self._prepare_client()
61+
62+
if ResourcesRepository in self._features:
63+
return (
64+
ResourcesRepository(
65+
fetching=self.fetch,
66+
uploading=self.upload,
67+
),
68+
)
69+
70+
return ()
71+
72+
async def __aexit__(
73+
self,
74+
exc_type: type[BaseException] | None,
75+
exc_val: BaseException | None,
76+
exc_tb: TracebackType | None,
77+
) -> None:
78+
"""No-op cleanup to satisfy the async context manager protocol."""

0 commit comments

Comments
 (0)