Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from typing import ClassVar
from typing import Any, ClassVar

from google.genai import Client

from semantic_kernel.connectors.ai.google.google_ai.google_ai_settings import GoogleAISettings
from semantic_kernel.const import USER_AGENT
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.telemetry.user_agent import IS_TELEMETRY_ENABLED, SEMANTIC_KERNEL_USER_AGENT


class GoogleAIBase(KernelBaseModel, ABC):
Expand All @@ -17,3 +19,14 @@ class GoogleAIBase(KernelBaseModel, ABC):
service_settings: GoogleAISettings

client: Client | None = None

def _get_http_options(self) -> dict[str, Any] | None:
"""Get the HTTP options for the Google AI client.

Returns:
The HTTP options dictionary, or None if telemetry is disabled.
"""
if not IS_TELEMETRY_ENABLED:
return None

return {"headers": {USER_AGENT: SEMANTIC_KERNEL_USER_AGENT}}
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,14 @@ async def _generate_content(client: Client) -> GenerateContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]

return [self._create_chat_message_content(response, candidate) for candidate in response.candidates] # type: ignore
Expand Down Expand Up @@ -218,14 +222,18 @@ async def _generate_content_stream(client: Client) -> AsyncGenerator[GenerateCon
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
async for chunk in _generate_content_stream(client):
yield [
self._create_streaming_chat_message_content(chunk, candidate, function_invoke_attempt)
for candidate in chunk.candidates # type: ignore
]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
async for chunk in _generate_content_stream(client):
yield [
self._create_streaming_chat_message_content(chunk, candidate, function_invoke_attempt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,14 @@ async def _generate_content(client: Client) -> GenerateContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]

return [self._create_text_content(response, candidate) for candidate in response.candidates] # type: ignore
Expand Down Expand Up @@ -175,11 +179,15 @@ async def _generate_content_stream(client: Client) -> AsyncGenerator[GenerateCon
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
async for chunk in _generate_content_stream(client):
yield [self._create_streaming_text_content(chunk, candidate) for candidate in chunk.candidates] # type: ignore
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
async for chunk in _generate_content_stream(client):
yield [self._create_streaming_text_content(chunk, candidate) for candidate in chunk.candidates] # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,14 @@ async def _embed_content(client: Client) -> EmbedContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: EmbedContentResponse = await _embed_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: EmbedContentResponse = await _embed_content(client) # type: ignore[no-redef]

return [embedding.values for embedding in response.embeddings] # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import AsyncMock, patch

from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import (
GoogleAIChatPromptExecutionSettings,
GoogleAIEmbeddingPromptExecutionSettings,
GoogleAITextPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion import GoogleAIChatCompletion
from semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_completion import GoogleAITextCompletion
from semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_embedding import GoogleAITextEmbedding
from semantic_kernel.const import USER_AGENT
from semantic_kernel.contents.chat_history import ChatHistory


async def test_google_ai_chat_completion_user_agent(google_ai_unit_test_env):
"""Test that GoogleAIChatCompletion sends the User-Agent header."""
chat_history = ChatHistory()
chat_history.add_user_message("hi")

with patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion.Client"
) as mock_client:
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAIChatCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_chat_message_contents(
chat_history=chat_history, settings=GoogleAIChatPromptExecutionSettings()
)

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is not None
assert "headers" in kwargs["http_options"]
assert USER_AGENT in kwargs["http_options"]["headers"]
assert "semantic-kernel-python" in kwargs["http_options"]["headers"][USER_AGENT]


async def test_google_ai_chat_completion_no_telemetry(google_ai_unit_test_env):
"""Test that GoogleAIChatCompletion does not send the User-Agent header when telemetry is disabled."""
chat_history = ChatHistory()
chat_history.add_user_message("hi")

with (
patch("semantic_kernel.connectors.ai.google.google_ai.services.google_ai_base.IS_TELEMETRY_ENABLED", False),
patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion.Client"
) as mock_client,
):
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAIChatCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_chat_message_contents(
chat_history=chat_history, settings=GoogleAIChatPromptExecutionSettings()
)

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is None


async def test_google_ai_text_completion_user_agent(google_ai_unit_test_env):
"""Test that GoogleAITextCompletion sends the User-Agent header."""
with patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_completion.Client"
) as mock_client:
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAITextCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_text_contents(prompt="hi", settings=GoogleAITextPromptExecutionSettings())

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is not None
assert "headers" in kwargs["http_options"]
assert USER_AGENT in kwargs["http_options"]["headers"]
assert "semantic-kernel-python" in kwargs["http_options"]["headers"][USER_AGENT]


async def test_google_ai_text_completion_no_telemetry(google_ai_unit_test_env):
"""Test that GoogleAITextCompletion does not send the User-Agent header when telemetry is disabled."""
with (
patch("semantic_kernel.connectors.ai.google.google_ai.services.google_ai_base.IS_TELEMETRY_ENABLED", False),
patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_completion.Client"
) as mock_client,
):
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAITextCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_text_contents(prompt="hi", settings=GoogleAITextPromptExecutionSettings())

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is None


async def test_google_ai_text_embedding_user_agent(google_ai_unit_test_env):
"""Test that GoogleAITextEmbedding sends the User-Agent header."""
with patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_embedding.Client"
) as mock_client:
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.embed_content = AsyncMock()

service = GoogleAITextEmbedding(embedding_model_id="gemini-embedding-2-preview", api_key="AIza-test-key")

await service.generate_embeddings(texts=["hi"], settings=GoogleAIEmbeddingPromptExecutionSettings())

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is not None
assert "headers" in kwargs["http_options"]
assert USER_AGENT in kwargs["http_options"]["headers"]
assert "semantic-kernel-python" in kwargs["http_options"]["headers"][USER_AGENT]


async def test_google_ai_text_embedding_no_telemetry(google_ai_unit_test_env):
"""Test that GoogleAITextEmbedding does not send the User-Agent header when telemetry is disabled."""
with (
patch("semantic_kernel.connectors.ai.google.google_ai.services.google_ai_base.IS_TELEMETRY_ENABLED", False),
patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_text_embedding.Client"
) as mock_client,
):
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.embed_content = AsyncMock()

service = GoogleAITextEmbedding(embedding_model_id="gemini-embedding-2-preview", api_key="AIza-test-key")

await service.generate_embeddings(texts=["hi"], settings=GoogleAIEmbeddingPromptExecutionSettings())

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is None
Loading