|
5 | 5 | import base64 |
6 | 6 | import json |
7 | 7 | import logging |
8 | | -from typing import TYPE_CHECKING, MutableSequence, Optional |
| 8 | +from typing import MutableSequence, Optional |
9 | 9 |
|
10 | 10 | from pyrit.chat_message_normalizer import ChatMessageNop, ChatMessageNormalizer |
11 | 11 | from pyrit.models import ( |
|
17 | 17 |
|
18 | 18 | logger = logging.getLogger(__name__) |
19 | 19 |
|
20 | | -if TYPE_CHECKING: |
21 | | - import boto3 |
22 | | - from botocore.exceptions import ClientError |
23 | | - |
24 | 20 |
|
25 | 21 | class AWSBedrockClaudeChatTarget(PromptChatTarget): |
26 | 22 | """ |
@@ -64,13 +60,6 @@ def __init__( |
64 | 60 |
|
65 | 61 | self._valid_image_types = ["jpeg", "png", "webp", "gif"] |
66 | 62 |
|
67 | | - try: |
68 | | - import boto3 # noqa: F401 |
69 | | - from botocore.exceptions import ClientError # noqa: F401 |
70 | | - except ModuleNotFoundError as e: |
71 | | - logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all] or pyrit[aws]'") |
72 | | - raise e |
73 | | - |
74 | 63 | @limit_requests_per_minute |
75 | 64 | async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse: |
76 | 65 |
|
@@ -101,11 +90,16 @@ def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: |
101 | 90 |
|
102 | 91 | async def _complete_chat_async(self, messages: list[ChatMessageListDictContent]) -> str: |
103 | 92 | try: |
104 | | - brt = boto3.client( |
105 | | - service_name="bedrock-runtime", region_name="us-east-1", verify=self._enable_ssl_verification |
106 | | - ) |
107 | | - except Exception as e: |
108 | | - raise RuntimeError(f"An error occurred when initializing boto3 client: {str(e)}") from e |
| 93 | + import boto3 # noqa: F401 |
| 94 | + from botocore.exceptions import ClientError # noqa: F401 |
| 95 | + except ModuleNotFoundError as e: |
| 96 | + logger.error("Could not import boto. You may need to install it via 'pip install pyrit[all] or pyrit[aws]'") |
| 97 | + raise e |
| 98 | + |
| 99 | + brt = boto3.client( |
| 100 | + service_name="bedrock-runtime", region_name="us-east-1", verify=self._enable_ssl_verification |
| 101 | + ) |
| 102 | + |
109 | 103 | native_request = self._construct_request_body(messages) |
110 | 104 |
|
111 | 105 | request = json.dumps(native_request) |
|
0 commit comments