Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Api/Providers/AbstractClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
use Hyperf\Odin\Api\Transport\SSEClient;
use Hyperf\Odin\Contract\Api\ClientInterface;
use Hyperf\Odin\Contract\Api\ConfigInterface;
use Hyperf\Odin\Event\AfterChatCompletionsEvent;
use Hyperf\Odin\Event\AfterChatCompletionsStreamEvent;
use Hyperf\Odin\Event\AfterEmbeddingsEvent;
use Hyperf\Odin\Exception\LLMException;
use Hyperf\Odin\Exception\LLMException\ErrorHandlerInterface;
use Hyperf\Odin\Exception\LLMException\ErrorMappingManager;
use Hyperf\Odin\Exception\LLMException\LLMErrorHandler;
use Hyperf\Odin\Utils\EventUtil;
use Psr\Log\LoggerInterface;
use Throwable;

Expand Down Expand Up @@ -86,6 +90,8 @@ public function chatCompletions(ChatCompletionRequest $chatRequest): ChatComplet
'content' => $chatCompletionResponse->getContent(),
]);

EventUtil::dispatch(new AfterChatCompletionsEvent($chatRequest, $chatCompletionResponse, $duration));

return $chatCompletionResponse;
} catch (Throwable $e) {
throw $this->convertException($e, [
Expand Down Expand Up @@ -125,6 +131,7 @@ public function chatCompletionsStream(ChatCompletionRequest $chatRequest): ChatC
);

$chatCompletionStreamResponse = new ChatCompletionStreamResponse($response, $this->logger, $sseClient);
$chatCompletionStreamResponse->setAfterChatCompletionsStreamEvent(new AfterChatCompletionsStreamEvent($chatRequest, $firstResponseDuration));

$this->logger?->debug('ChatCompletionsStreamResponse', [
'first_response_ms' => $firstResponseDuration,
Expand Down Expand Up @@ -164,6 +171,8 @@ public function embeddings(EmbeddingRequest $embeddingRequest): EmbeddingRespons
'data' => $embeddingResponse->toArray(),
]);

EventUtil::dispatch(new AfterEmbeddingsEvent($embeddingRequest, $embeddingResponse, $duration));

return $embeddingResponse;
} catch (Throwable $e) {
throw $this->convertException($e, [
Expand Down
29 changes: 29 additions & 0 deletions src/Api/Providers/AwsBedrock/AwsBedrockConverseFormatConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ public function getIterator(): Generator
break;
case 'contentBlockStop':
case 'metadata':
if (isset($event['usage'])) {
yield $this->formatUsageEvent($created, $event['usage']);
}
break;
case 'messageStop':
yield $this->formatMessageStopEvent($created, $event['stopReason'] ?? 'stop');
Expand Down Expand Up @@ -152,6 +155,32 @@ public function getModel(): string
return $this->model;
}

private function formatUsageEvent(int $created, array $usage): string
{
return $this->formatOpenAiEvent([
'id' => $this->messageId ?? ('bedrock-' . uniqid()),
'object' => 'chat.completion.chunk',
'created' => $created,
'model' => $this->model ?: 'aws.bedrock',
'choices' => null,
'usage' => [
'prompt_tokens' => $usage['inputTokens'] ?? 0,
'completion_tokens' => $usage['outputTokens'] ?? 0,
'total_tokens' => $usage['totalTokens'] ?? 0,
'prompt_tokens_details' => [
'cache_write_input_tokens' => $usage['cacheWriteInputTokens'] ?? 0,
'cache_read_input_tokens' => $usage['cacheReadInputTokens'] ?? 0,
// 兼容旧参数
'audio_tokens' => 0,
'cached_tokens' => $usage['cacheWriteInputTokens'] ?? 0,
],
'completion_tokens_details' => [
'reasoning_tokens' => 0,
],
],
]);
}

/**
* 格式化消息开始事件.
*
Expand Down
31 changes: 29 additions & 2 deletions src/Api/Request/ChatCompletionRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ChatCompletionRequest implements RequestInterface

private float $presencePenalty = 0.0;

private bool $includeBusinessParams = false;

private array $businessParams = [];

private bool $toolsCache = false;
Expand All @@ -47,6 +49,8 @@ class ChatCompletionRequest implements RequestInterface
*/
private ?int $totalTokenEstimate = null;

private bool $streamIncludeUsage = false;

public function __construct(
/** @var MessageInterface[] $messages */
protected array $messages,
Expand Down Expand Up @@ -98,9 +102,14 @@ public function createOptions(): array
if ($this->presencePenalty > 0) {
$json['presence_penalty'] = $this->presencePenalty;
}
if (! empty($this->businessParams)) {
if ($this->includeBusinessParams && ! empty($this->businessParams)) {
$json['business_params'] = $this->businessParams;
}
if ($this->stream && $this->streamIncludeUsage) {
$json['stream_options'] = [
'include_usage' => true,
];
}

return [
RequestOptions::JSON => $json,
Expand All @@ -116,7 +125,10 @@ public function createOptions(): array
*/
public function calculateTokenEstimates(): int
{
$estimator = new TokenEstimator($model ?? $this->model);
if ($this->totalTokenEstimate) {
return $this->totalTokenEstimate;
}
$estimator = new TokenEstimator($this->model);
$totalTokens = 0;

// 为每个消息计算token
Expand Down Expand Up @@ -161,6 +173,16 @@ public function setBusinessParams(array $businessParams): void
$this->businessParams = $businessParams;
}

public function getBusinessParams(): array
{
return $this->businessParams;
}

public function setIncludeBusinessParams(bool $includeBusinessParams): void
{
$this->includeBusinessParams = $includeBusinessParams;
}

public function setStream(bool $stream): void
{
$this->stream = $stream;
Expand All @@ -181,6 +203,11 @@ public function setStreamContentEnabled(bool $streamContentEnabled): void
$this->streamContentEnabled = $streamContentEnabled;
}

public function setStreamIncludeUsage(bool $streamIncludeUsage): void
{
$this->streamIncludeUsage = $streamIncludeUsage;
}

/**
* 获取消息列表.
*
Expand Down
9 changes: 8 additions & 1 deletion src/Api/Request/CompletionRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class CompletionRequest implements RequestInterface

private array $businessParams = [];

private bool $includeBusinessParams = false;

public function __construct(
protected string $model,
protected string $prompt,
Expand Down Expand Up @@ -65,7 +67,7 @@ public function createOptions(): array
if ($this->presencePenalty > 0) {
$json['presence_penalty'] = $this->presencePenalty;
}
if (! empty($this->businessParams)) {
if ($this->includeBusinessParams && ! empty($this->businessParams)) {
$json['business_params'] = $this->businessParams;
}

Expand All @@ -89,6 +91,11 @@ public function setBusinessParams(array $businessParams): void
$this->businessParams = $businessParams;
}

public function setIncludeBusinessParams(bool $includeBusinessParams): void
{
$this->includeBusinessParams = $includeBusinessParams;
}

/**
* 获取模型名称.
*
Expand Down
68 changes: 63 additions & 5 deletions src/Api/Request/EmbeddingRequest.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
use GuzzleHttp\RequestOptions;
use Hyperf\Odin\Contract\Api\Request\RequestInterface;
use Hyperf\Odin\Exception\InvalidArgumentException;
use Hyperf\Odin\Utils\TokenEstimator;

class EmbeddingRequest implements RequestInterface
{
private array $businessParams = [];

private bool $includeBusinessParams = false;

private ?int $totalTokenEstimate = null;

/**
* @param string|string[] $input 需要嵌入的文本,可以是字符串或字符串数组
* @param string $model 使用的嵌入模型ID
Expand Down Expand Up @@ -53,12 +60,17 @@ public function createOptions(): array
{
$this->validate();

$body = [
'model' => $this->model,
'input' => $this->input,
'encoding_format' => $this->encoding_format,
];
if ($this->includeBusinessParams && ! empty($this->businessParams)) {
$body['business_params'] = $this->businessParams;
}

$options = [
RequestOptions::JSON => [
'input' => $this->input,
'model' => $this->model,
'encoding_format' => $this->encoding_format,
],
RequestOptions::JSON => $body,
];

if ($this->user !== null) {
Expand Down Expand Up @@ -111,4 +123,50 @@ public function getDimensions(): ?array
{
return $this->dimensions;
}

public function getBusinessParams(): array
{
return $this->businessParams;
}

public function setBusinessParams(array $businessParams): void
{
$this->businessParams = $businessParams;
}

public function isIncludeBusinessParams(): bool
{
return $this->includeBusinessParams;
}

public function setIncludeBusinessParams(bool $includeBusinessParams): void
{
$this->includeBusinessParams = $includeBusinessParams;
}

public function getTotalTokenEstimate(): ?int
{
return $this->totalTokenEstimate;
}

public function calculateTokenEstimates(): int
{
if ($this->totalTokenEstimate) {
return $this->totalTokenEstimate;
}
$estimator = new TokenEstimator($this->model);

$input = $this->input;
if (! is_array($input)) {
$input = [$input];
}
$totalTokens = 0;
foreach ($input as $item) {
// 估算每个输入的token数量
$totalTokens += $estimator->estimateTokens($item);
}
$this->totalTokenEstimate = $totalTokens;

return $this->totalTokenEstimate;
}
}
5 changes: 5 additions & 0 deletions src/Api/Response/AbstractResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,10 @@ public function setUsage(?Usage $usage): self
return $this;
}

public function removeBigObject(): void
{
unset($this->originResponse, $this->logger);
}

abstract protected function parseContent(): self;
}
15 changes: 15 additions & 0 deletions src/Api/Response/ChatCompletionResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
namespace Hyperf\Odin\Api\Response;

use Hyperf\Odin\Exception\LLMException\LLMApiException;
use Hyperf\Odin\Utils\TokenEstimator;
use Stringable;

class ChatCompletionResponse extends AbstractResponse implements Stringable
Expand All @@ -30,6 +31,8 @@ class ChatCompletionResponse extends AbstractResponse implements Stringable
*/
protected ?array $choices = [];

private ?int $totalTokenEstimate = null;

public function __toString(): string
{
return trim($this->getChoices()[0]?->getMessage()?->getContent() ?: '');
Expand Down Expand Up @@ -98,6 +101,18 @@ public function setChoices(?array $choices): self
return $this;
}

public function calculateTokenEstimates(): int
{
if ($this->totalTokenEstimate) {
return $this->totalTokenEstimate;
}
$estimator = new TokenEstimator($this->model);

$this->totalTokenEstimate = $estimator->estimateTokens($this->getFirstChoice()?->getMessage()?->getContent() ?? '');

return $this->totalTokenEstimate;
}

protected function parseContent(): self
{
$this->content = $this->originResponse->getBody()->getContents();
Expand Down
Loading
Loading