diff --git a/config/prism.php b/config/prism.php index 99f3ea584..cf4b6a0a4 100644 --- a/config/prism.php +++ b/config/prism.php @@ -65,6 +65,12 @@ 'api_key' => env('PERPLEXITY_API_KEY', ''), 'url' => env('PERPLEXITY_URL', 'https://api.perplexity.ai'), ], + 'vertex' => [ + 'project_id' => env('VERTEX_PROJECT_ID', ''), + 'region' => env('VERTEX_REGION', 'us-central1'), + 'access_token' => env('VERTEX_ACCESS_TOKEN', null), + 'credentials_path' => env('VERTEX_CREDENTIALS_PATH', null), + ], 'z' => [ 'url' => env('Z_URL', 'https://api.z.ai/api/paas/v4'), 'api_key' => env('Z_API_KEY', ''), diff --git a/docs/providers/vertex.md b/docs/providers/vertex.md new file mode 100644 index 000000000..2706af6f2 --- /dev/null +++ b/docs/providers/vertex.md @@ -0,0 +1,266 @@ +# Vertex AI + +Google Vertex AI provides enterprise-grade access to Google's Gemini models with enhanced security, compliance, and integration with Google Cloud services. + +## Configuration + +```php +'vertex' => [ + 'project_id' => env('VERTEX_PROJECT_ID', ''), + 'region' => env('VERTEX_REGION', 'us-central1'), + 'access_token' => env('VERTEX_ACCESS_TOKEN', null), + 'credentials_path' => env('VERTEX_CREDENTIALS_PATH', null), +], +``` + +### Authentication + +Vertex AI supports multiple authentication methods: + +#### 1. Access Token (Recommended for development) + +Provide an access token directly: + +```env +VERTEX_ACCESS_TOKEN=your-access-token +``` + +You can obtain an access token using the Google Cloud CLI: + +```bash +gcloud auth print-access-token +``` + +#### 2. Service Account Credentials (Recommended for production) + +Provide the path to your service account JSON key file: + +```env +VERTEX_CREDENTIALS_PATH=/path/to/service-account.json +``` + +#### 3. Application Default Credentials + +If no credentials are provided, Prism will attempt to use Application Default Credentials (ADC). Set up ADC by running: + +```bash +gcloud auth application-default login +``` + +Or by setting the `GOOGLE_APPLICATION_CREDENTIALS` environment variable: + +```bash +export GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json +``` + +## Basic Usage + +### Text Generation + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; + +$response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Explain quantum computing in simple terms.') + ->asText(); + +echo $response->text; +``` + +### With System Prompt + +```php +$response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSystemPrompt('You are a helpful coding assistant.') + ->withPrompt('Write a Python function to calculate fibonacci numbers.') + ->asText(); +``` + +## Structured Output + +Vertex AI supports structured output, allowing you to define schemas that constrain the model's responses to match your exact data structure requirements. + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; +use Prism\Prism\Schema\ObjectSchema; +use Prism\Prism\Schema\StringSchema; +use Prism\Prism\Schema\NumberSchema; + +$schema = new ObjectSchema( + name: 'user_profile', + description: 'A user profile object', + properties: [ + new StringSchema('name', 'The user\'s full name'), + new NumberSchema('age', 'The user\'s age'), + new StringSchema('email', 'The user\'s email address'), + ], + requiredFields: ['name', 'age', 'email'] +); + +$response = Prism::structured() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSchema($schema) + ->withPrompt('Generate a profile for a fictional user named John Doe.') + ->generate(); + +// Access structured data +$profile = $response->structured; +echo $profile['name']; // "John Doe" +echo $profile['age']; // 30 +echo $profile['email']; // "john.doe@example.com" +``` + +## Tool Usage + +Vertex AI supports function calling (tools) to extend the model's capabilities: + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; +use Prism\Prism\Tool; + +$weatherTool = (new Tool) + ->as('get_weather') + ->for('Get the current weather for a location') + ->withStringParameter('location', 'The city and state, e.g. San Francisco, CA') + ->using(function (string $location): string { + // Your weather API call here + return "The weather in {$location} is 72°F and sunny."; + }); + +$response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withTools([$weatherTool]) + ->withMaxSteps(3) + ->withPrompt('What is the weather like in San Francisco?') + ->asText(); + +echo $response->text; +``` + +## Embeddings + +Vertex AI supports text embeddings for semantic search, clustering, and other ML tasks: + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; + +$response = Prism::embeddings() + ->using(Provider::Vertex, 'text-embedding-004') + ->fromInput('The quick brown fox jumps over the lazy dog.') + ->generate(); + +// Access the embedding vector +$embedding = $response->embeddings[0]->embedding; +``` + +## Streaming + +Vertex AI supports streaming responses for real-time output: + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; + +$stream = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Write a short story about a robot.') + ->asStream(); + +foreach ($stream as $event) { + if ($event instanceof \Prism\Prism\Streaming\Events\TextDeltaEvent) { + echo $event->delta; + } +} +``` + +## Image Understanding + +Vertex AI supports multimodal inputs including images: + +```php +use Prism\Prism\Facades\Prism; +use Prism\Prism\Enums\Provider; +use Prism\Prism\ValueObjects\Messages\UserMessage; +use Prism\Prism\ValueObjects\Media\Image; + +$response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withMessages([ + new UserMessage( + 'What do you see in this image?', + additionalContent: [ + Image::fromLocalPath('/path/to/image.png'), + ], + ), + ]) + ->asText(); + +echo $response->text; +``` + +## Thinking Mode + +For models that support it (like Gemini 2.5), you can configure thinking mode: + +```php +$response = Prism::text() + ->using(Provider::Vertex, 'gemini-2.5-flash-preview') + ->withPrompt('Solve this complex math problem...') + ->withProviderOptions([ + 'thinkingBudget' => 2048, + ]) + ->asText(); + +// Access thinking token usage +echo $response->usage->thoughtTokens; +``` + +## Available Models + +Vertex AI provides access to Google's Gemini model family: + +| Model | Description | +|-------|-------------| +| `gemini-1.5-flash` | Fast and efficient for most tasks | +| `gemini-1.5-pro` | Most capable model for complex tasks | +| `gemini-2.0-flash` | Latest generation with improved capabilities | +| `gemini-2.5-flash-preview` | Preview of next generation with thinking support | +| `text-embedding-004` | Text embeddings model | + +## Regions + +Vertex AI is available in multiple regions. Common options include: + +- `us-central1` (default) +- `us-east1` +- `us-west1` +- `europe-west1` +- `europe-west4` +- `asia-northeast1` +- `asia-southeast1` + +Configure your region in the `.env` file: + +```env +VERTEX_REGION=us-central1 +``` + +## Differences from Gemini API + +While Vertex AI uses the same underlying Gemini models, there are key differences: + +| Feature | Gemini API | Vertex AI | +|---------|------------|-----------| +| Authentication | API Key | OAuth 2.0 / Service Account | +| Pricing | Pay-as-you-go | Google Cloud billing | +| Data residency | Global | Regional control | +| Enterprise features | Limited | Full (VPC, audit logs, etc.) | +| SLA | None | Enterprise SLA available | + +Choose Vertex AI when you need enterprise-grade security, compliance, or integration with other Google Cloud services. diff --git a/src/Enums/Provider.php b/src/Enums/Provider.php index 0b0fecb71..343a10928 100644 --- a/src/Enums/Provider.php +++ b/src/Enums/Provider.php @@ -18,5 +18,6 @@ enum Provider: string case VoyageAI = 'voyageai'; case ElevenLabs = 'elevenlabs'; case Perplexity = 'perplexity'; + case Vertex = 'vertex'; case Z = 'z'; } diff --git a/src/PrismManager.php b/src/PrismManager.php index 31c0f4f1d..3e885e9ae 100644 --- a/src/PrismManager.php +++ b/src/PrismManager.php @@ -19,6 +19,7 @@ use Prism\Prism\Providers\OpenRouter\OpenRouter; use Prism\Prism\Providers\Perplexity\Perplexity; use Prism\Prism\Providers\Provider; +use Prism\Prism\Providers\Vertex\Vertex; use Prism\Prism\Providers\VoyageAI\VoyageAI; use Prism\Prism\Providers\XAI\XAI; use Prism\Prism\Providers\Z\Z; @@ -226,6 +227,19 @@ protected function createElevenlabsProvider(array $config): ElevenLabs ); } + /** + * @param array $config + */ + protected function createVertexProvider(array $config): Vertex + { + return new Vertex( + projectId: $config['project_id'] ?? '', + region: $config['region'] ?? 'us-central1', + accessToken: $config['access_token'] ?? null, + credentialsPath: $config['credentials_path'] ?? null, + ); + } + /** * @param array $config */ diff --git a/src/Providers/Vertex/Concerns/ValidatesResponse.php b/src/Providers/Vertex/Concerns/ValidatesResponse.php new file mode 100644 index 000000000..ff015c4db --- /dev/null +++ b/src/Providers/Vertex/Concerns/ValidatesResponse.php @@ -0,0 +1,26 @@ +json(); + + if (! $data || data_get($data, 'error')) { + throw PrismException::providerResponseError(vsprintf( + 'Vertex Error: [%s] %s', + [ + data_get($data, 'error.code', 'unknown'), + data_get($data, 'error.message', 'unknown'), + ] + )); + } + } +} diff --git a/src/Providers/Vertex/Handlers/Embeddings.php b/src/Providers/Vertex/Handlers/Embeddings.php new file mode 100644 index 000000000..2912b72fa --- /dev/null +++ b/src/Providers/Vertex/Handlers/Embeddings.php @@ -0,0 +1,74 @@ +inputs()) > 1) { + throw new PrismException('Vertex Error: Prism currently only supports one input at a time with Vertex AI.'); + } + + $response = $this->sendRequest($request); + + $data = $response->json(); + + if (! isset($data['predictions'][0]['embeddings']['values'])) { + throw PrismException::providerResponseError( + 'Vertex Error: Invalid response format or missing embedding data' + ); + } + + return new EmbeddingsResponse( + embeddings: [Embedding::fromArray(data_get($data, 'predictions.0.embeddings.values', []))], + usage: new EmbeddingsUsage( + data_get($data, 'metadata.billableCharacterCount', 0) + ), + meta: new Meta( + id: '', + model: $this->model, + ), + raw: $data, + ); + } + + protected function sendRequest(Request $request): Response + { + $providerOptions = $request->providerOptions(); + + /** @var Response $response */ + $response = $this->client->post( + "{$this->model}:predict", + Arr::whereNotNull([ + 'instances' => [ + [ + 'content' => $request->inputs()[0], + ], + ], + 'parameters' => Arr::whereNotNull([ + 'outputDimensionality' => $providerOptions['outputDimensionality'] ?? null, + ]) ?: null, + ]) + ); + + return $response; + } +} diff --git a/src/Providers/Vertex/Handlers/Stream.php b/src/Providers/Vertex/Handlers/Stream.php new file mode 100644 index 000000000..6e39721e3 --- /dev/null +++ b/src/Providers/Vertex/Handlers/Stream.php @@ -0,0 +1,517 @@ +state = new StreamState; + } + + /** + * @return Generator + */ + public function handle(Request $request): Generator + { + $this->state->reset(); + $this->currentThoughtSignature = null; + $response = $this->sendRequest($request); + + yield from $this->processStream($response, $request); + } + + /** + * @return Generator + */ + protected function processStream(Response $response, Request $request, int $depth = 0): Generator + { + if ($depth >= $request->maxSteps()) { + throw new PrismException('Maximum tool call chain depth exceeded'); + } + + while (! $response->getBody()->eof()) { + $data = $this->parseNextDataLine($response->getBody()); + + if ($data === null) { + continue; + } + + if ($this->state->shouldEmitStreamStart()) { + $this->state->withMessageId(EventID::generate()); + + yield new StreamStartEvent( + id: EventID::generate(), + timestamp: time(), + model: data_get($data, 'modelVersion', 'unknown'), + provider: 'vertex' + ); + $this->state->markStreamStarted(); + } + + if ($this->state->shouldEmitStepStart()) { + $this->state->markStepStarted(); + + yield new StepStartEvent( + id: EventID::generate(), + timestamp: time() + ); + } + + $this->state->withUsage($this->extractUsage($data)); + + if ($this->hasToolCalls($data)) { + $existingIndices = array_keys($this->state->toolCalls()); + + $toolCalls = $this->extractToolCalls($data, $this->state->toolCalls()); + foreach ($toolCalls as $index => $toolCall) { + $this->state->addToolCall($index, $toolCall); + } + + foreach ($this->state->toolCalls() as $index => $toolCallData) { + if (! in_array($index, $existingIndices, true)) { + yield new ToolCallEvent( + id: EventID::generate(), + timestamp: time(), + toolCall: $this->mapToolCall($toolCallData), + messageId: $this->state->messageId() + ); + } + } + + if ($this->mapFinishReason($data) === FinishReason::ToolCalls) { + yield from $this->handleToolCalls($request, $depth, $data); + + return; + } + + continue; + } + + $parts = data_get($data, 'candidates.0.content.parts', []); + + foreach ($parts as $part) { + if (isset($part['thought']) && $part['thought'] === true) { + $thinkingContent = $part['text'] ?? ''; + + if ($thinkingContent !== '') { + if ($this->state->reasoningId() === '') { + $this->state->withReasoningId(EventID::generate()); + + yield new ThinkingStartEvent( + id: EventID::generate(), + timestamp: time(), + reasoningId: $this->state->reasoningId() + ); + } + + $this->state->appendThinking($thinkingContent); + + yield new ThinkingEvent( + id: EventID::generate(), + timestamp: time(), + delta: $thinkingContent, + reasoningId: $this->state->reasoningId() + ); + } + } elseif (isset($part['text']) && (! isset($part['thought']) || $part['thought'] === false)) { + $content = $part['text']; + + if ($content !== '') { + if ($this->state->shouldEmitTextStart()) { + yield new TextStartEvent( + id: EventID::generate(), + timestamp: time(), + messageId: $this->state->messageId() + ); + $this->state->markTextStarted(); + } + + $this->state->appendText($content); + + yield new TextDeltaEvent( + id: EventID::generate(), + timestamp: time(), + delta: $content, + messageId: $this->state->messageId() + ); + } + } + } + + $finishReason = $this->mapFinishReason($data); + + if ($finishReason !== FinishReason::Unknown) { + if ($this->state->reasoningId() !== '') { + yield new ThinkingCompleteEvent( + id: EventID::generate(), + timestamp: time(), + reasoningId: $this->state->reasoningId() + ); + } + + if ($this->state->hasTextStarted()) { + $this->state->markTextCompleted(); + + yield new TextCompleteEvent( + id: EventID::generate(), + timestamp: time(), + messageId: $this->state->messageId() + ); + } + + $this->state->withFinishReason($finishReason); + $this->state->withMetadata([ + 'grounding_metadata' => $this->extractGroundingMetadata($data), + ]); + } + } + + if ($this->state->hasToolCalls()) { + yield from $this->handleToolCalls($request, $depth); + + return; + } + + $this->state->markStepFinished(); + yield new StepFinishEvent( + id: EventID::generate(), + timestamp: time() + ); + + yield $this->emitStreamEndEvent(); + } + + protected function emitStreamEndEvent(): StreamEndEvent + { + return new StreamEndEvent( + id: EventID::generate(), + timestamp: time(), + finishReason: $this->state->finishReason() ?? FinishReason::Stop, + usage: $this->state->usage(), + additionalContent: Arr::whereNotNull([ + 'grounding_metadata' => $this->state->metadata()['grounding_metadata'] ?? null, + 'thoughtSummaries' => $this->state->thinkingSummaries() === [] ? null : $this->state->thinkingSummaries(), + ]) + ); + } + + /** + * @return array|null + */ + protected function parseNextDataLine(StreamInterface $stream): ?array + { + $line = $this->readLine($stream); + + if (! str_starts_with($line, 'data:')) { + return null; + } + + $line = trim(substr($line, strlen('data: '))); + + if ($line === '' || $line === '[DONE]') { + return null; + } + + try { + return json_decode($line, true, flags: JSON_THROW_ON_ERROR); + } catch (Throwable $e) { + throw new PrismStreamDecodeException('Vertex', $e); + } + } + + /** + * @param array $data + * @param array> $toolCalls + * @return array> + */ + protected function extractToolCalls(array $data, array $toolCalls): array + { + $parts = data_get($data, 'candidates.0.content.parts', []); + $nextIndex = $toolCalls === [] ? 0 : max(array_keys($toolCalls)) + 1; + + foreach ($parts as $part) { + if (isset($part['functionCall'])) { + if (isset($part['thoughtSignature'])) { + $this->currentThoughtSignature = $part['thoughtSignature']; + } + + $toolCalls[$nextIndex] = [ + 'id' => EventID::generate('vx'), + 'name' => data_get($part, 'functionCall.name'), + 'arguments' => data_get($part, 'functionCall.args', []), + 'reasoningId' => $part['thoughtSignature'] ?? $this->currentThoughtSignature, + ]; + $nextIndex++; + } + } + + return $toolCalls; + } + + /** + * @param array $data + * @return Generator + */ + protected function handleToolCalls( + Request $request, + int $depth, + array $data = [] + ): Generator { + $mappedToolCalls = []; + + foreach ($this->state->toolCalls() as $toolCallData) { + $mappedToolCalls[] = $this->mapToolCall($toolCallData); + } + + $toolResults = []; + yield from $this->callToolsAndYieldEvents($request->tools(), $mappedToolCalls, $this->state->messageId(), $toolResults); + + if ($toolResults !== []) { + $this->state->markStepFinished(); + yield new StepFinishEvent( + id: EventID::generate(), + timestamp: time() + ); + + $request->addMessage(new AssistantMessage($this->state->currentText(), $mappedToolCalls)); + $request->addMessage(new ToolResultMessage($toolResults)); + $request->resetToolChoice(); + + $depth++; + if ($depth < $request->maxSteps()) { + $previousUsage = $this->state->usage(); + $this->state->reset(); + $this->currentThoughtSignature = null; + $nextResponse = $this->sendRequest($request); + yield from $this->processStream($nextResponse, $request, $depth); + + if ($previousUsage instanceof Usage && $this->state->usage() instanceof Usage) { + $this->state->withUsage(new Usage( + promptTokens: $previousUsage->promptTokens + $this->state->usage()->promptTokens, + completionTokens: $previousUsage->completionTokens + $this->state->usage()->completionTokens, + cacheWriteInputTokens: ($previousUsage->cacheWriteInputTokens ?? 0) + ($this->state->usage()->cacheWriteInputTokens ?? 0), + cacheReadInputTokens: ($previousUsage->cacheReadInputTokens ?? 0) + ($this->state->usage()->cacheReadInputTokens ?? 0), + thoughtTokens: ($previousUsage->thoughtTokens ?? 0) + ($this->state->usage()->thoughtTokens ?? 0) + )); + } + } else { + yield $this->emitStreamEndEvent(); + } + } + } + + /** + * @param array $toolCallData + */ + protected function mapToolCall(array $toolCallData): ToolCall + { + $arguments = data_get($toolCallData, 'arguments', []); + + if (is_string($arguments) && $arguments !== '') { + $decoded = json_decode($arguments, true); + $arguments = json_last_error() === JSON_ERROR_NONE ? $decoded : ['input' => $arguments]; + } + + return new ToolCall( + id: empty($toolCallData['id']) ? EventID::generate('vx') : $toolCallData['id'], + name: data_get($toolCallData, 'name', 'unknown'), + arguments: $arguments, + reasoningId: data_get($toolCallData, 'reasoningId') + ); + } + + /** + * @param array $data + */ + protected function hasToolCalls(array $data): bool + { + $parts = data_get($data, 'candidates.0.content.parts', []); + + foreach ($parts as $part) { + if (isset($part['functionCall'])) { + return true; + } + } + + return false; + } + + /** + * @param array $data + */ + protected function extractUsage(array $data): Usage + { + return new Usage( + promptTokens: data_get($data, 'usageMetadata.promptTokenCount', 0), + completionTokens: data_get($data, 'usageMetadata.candidatesTokenCount', 0), + cacheReadInputTokens: data_get($data, 'usageMetadata.cachedContentTokenCount'), + thoughtTokens: data_get($data, 'usageMetadata.thoughtsTokenCount'), + ); + } + + /** + * @param array $data + */ + protected function mapFinishReason(array $data): FinishReason + { + $finishReason = data_get($data, 'candidates.0.finishReason'); + + if (! $finishReason) { + return FinishReason::Unknown; + } + + $isToolCall = $this->hasToolCalls($data); + + return FinishReasonMap::map($finishReason, $isToolCall); + } + + protected function sendRequest(Request $request): Response + { + $providerOptions = $request->providerOptions(); + + if ($request->tools() !== [] && $request->providerTools() !== []) { + throw new PrismException('Use of provider tools with custom tools is not currently supported by Vertex.'); + } + + if ($request->tools() !== [] && ($providerOptions['searchGrounding'] ?? false)) { + throw new PrismException('Use of search grounding with custom tools is not currently supported by Prism.'); + } + + $tools = []; + + if ($request->providerTools() !== []) { + $tools = array_map( + fn ($providerTool): array => [ + $providerTool->type => $providerTool->options !== [] ? $providerTool->options : (object) [], + ], + $request->providerTools() + ); + } elseif ($providerOptions['searchGrounding'] ?? false) { + $tools = [ + [ + 'google_search' => (object) [], + ], + ]; + } elseif ($request->tools() !== []) { + $tools = ['function_declarations' => ToolMap::map($request->tools())]; + } + + $thinkingConfig = $providerOptions['thinkingConfig'] ?? null; + + if (isset($providerOptions['thinkingBudget'])) { + $thinkingConfig = [ + 'thinkingBudget' => $providerOptions['thinkingBudget'], + 'includeThoughts' => true, + ]; + } + + if (isset($providerOptions['thinkingLevel'])) { + $thinkingConfig = [ + 'thinkingLevel' => $providerOptions['thinkingLevel'], + 'includeThoughts' => true, + ]; + } + + /** @var Response $response */ + $response = $this->client + ->withOptions(['stream' => true]) + ->post( + "{$this->model}:streamGenerateContent?alt=sse", + Arr::whereNotNull([ + ...(new MessageMap($request->messages(), $request->systemPrompts()))(), + 'generationConfig' => Arr::whereNotNull([ + 'temperature' => $request->temperature(), + 'topP' => $request->topP(), + 'maxOutputTokens' => $request->maxTokens(), + 'thinkingConfig' => $thinkingConfig, + ]) ?: null, + 'tools' => $tools !== [] ? $tools : null, + 'tool_config' => $request->toolChoice() ? ToolChoiceMap::map($request->toolChoice()) : null, + 'safetySettings' => $providerOptions['safetySettings'] ?? null, + ]) + ); + + return $response; + } + + protected function readLine(StreamInterface $stream): string + { + $buffer = ''; + + while (! $stream->eof()) { + $byte = $stream->read(1); + + if ($byte === '') { + return $buffer; + } + + $buffer .= $byte; + + if ($byte === "\n") { + break; + } + } + + return $buffer; + } + + /** + * @param array $data + * @return array|null + */ + protected function extractGroundingMetadata(array $data): ?array + { + $groundingMetadata = data_get($data, 'candidates.0.groundingMetadata'); + + if (! $groundingMetadata) { + return null; + } + + return $groundingMetadata; + } +} diff --git a/src/Providers/Vertex/Handlers/Structured.php b/src/Providers/Vertex/Handlers/Structured.php new file mode 100644 index 000000000..685a61035 --- /dev/null +++ b/src/Providers/Vertex/Handlers/Structured.php @@ -0,0 +1,312 @@ +responseBuilder = new ResponseBuilder; + } + + public function handle(Request $request): StructuredResponse + { + $data = $this->sendRequest($request); + + $this->validateResponse($data); + + $isToolCall = $this->hasToolCalls($data); + + $responseMessage = new AssistantMessage( + $this->extractTextContent($data), + $isToolCall ? ToolCallMap::map(data_get($data, 'candidates.0.content.parts', [])) : [], + ); + + $request->addMessage($responseMessage); + + $finishReason = FinishReasonMap::map( + data_get($data, 'candidates.0.finishReason'), + $isToolCall + ); + + return match ($finishReason) { + FinishReason::ToolCalls => $this->handleToolCalls($data, $request), + FinishReason::Stop, FinishReason::Length => $this->handleStop($data, $request, $finishReason), + default => throw new PrismException('Vertex: unhandled finish reason'), + }; + } + + /** + * @return array + */ + public function sendRequest(Request $request): array + { + $providerOptions = $request->providerOptions(); + + if ($request->tools() !== [] && $request->providerTools() !== []) { + throw new PrismException('Use of provider tools with custom tools is not currently supported by Vertex.'); + } + + $tools = []; + + if ($request->providerTools() !== []) { + $tools = [ + Arr::mapWithKeys( + $request->providerTools(), + fn (ProviderTool $providerTool): array => [ + $providerTool->type => $providerTool->options !== [] ? $providerTool->options : (object) [], + ] + ), + ]; + } + + if ($request->tools() !== []) { + $tools = [ + [ + 'function_declarations' => ToolMap::map($request->tools()), + ], + ]; + } + + $thinkingConfig = $providerOptions['thinkingConfig'] ?? null; + + if (isset($providerOptions['thinkingBudget'])) { + $thinkingConfig = Arr::whereNotNull([ + 'thinkingBudget' => $providerOptions['thinkingBudget'], + 'includeThoughts' => $providerOptions['includeThoughts'] ?? null, + ]); + } + + if (isset($providerOptions['thinkingLevel'])) { + $thinkingConfig = Arr::whereNotNull([ + 'thinkingLevel' => $providerOptions['thinkingLevel'], + 'includeThoughts' => $providerOptions['includeThoughts'] ?? null, + ]); + } + + /** @var Response $response */ + $response = $this->client->post( + "{$this->model}:generateContent", + Arr::whereNotNull([ + ...(new MessageMap($request->messages(), $request->systemPrompts()))(), + 'generationConfig' => Arr::whereNotNull([ + 'response_mime_type' => 'application/json', + 'response_schema' => (new SchemaMap($request->schema()))->toArray(), + 'temperature' => $request->temperature(), + 'topP' => $request->topP(), + 'maxOutputTokens' => $request->maxTokens(), + 'thinkingConfig' => $thinkingConfig, + ]), + 'tools' => $tools !== [] ? $tools : null, + 'tool_config' => $request->toolChoice() ? ToolChoiceMap::map($request->toolChoice()) : null, + 'safetySettings' => $providerOptions['safetySettings'] ?? null, + ]) + ); + + return $response->json(); + } + + /** + * @param array $data + */ + protected function validateResponse(array $data): void + { + if (! $data || data_get($data, 'error')) { + throw PrismException::providerResponseError(vsprintf( + 'Vertex Error: [%s] %s', + [ + data_get($data, 'error.code', 'unknown'), + data_get($data, 'error.message', 'unknown'), + ] + )); + } + + $finishReason = data_get($data, 'candidates.0.finishReason'); + $content = $this->extractTextContent($data); + $thoughtTokens = data_get($data, 'usageMetadata.thoughtsTokenCount', 0); + + if ($finishReason === 'MAX_TOKENS') { + $promptTokens = data_get($data, 'usageMetadata.promptTokenCount', 0); + $candidatesTokens = data_get($data, 'usageMetadata.candidatesTokenCount', 0); + $totalTokens = data_get($data, 'usageMetadata.totalTokenCount', 0); + $outputTokens = $candidatesTokens - $thoughtTokens; + + $isEmpty = in_array(trim($content), ['', '0'], true); + $isInvalidJson = $content !== '' && $content !== '0' && json_decode($content) === null; + $contentLength = strlen($content); + + if (($isEmpty || $isInvalidJson) && $thoughtTokens > 0) { + $errorDetail = $isEmpty + ? 'no tokens remained for structured output' + : "output was truncated at {$contentLength} characters resulting in invalid JSON"; + + throw PrismException::providerResponseError( + 'Vertex hit token limit with high thinking token usage. '. + "Token usage: {$promptTokens} prompt + {$thoughtTokens} thinking + {$outputTokens} output = {$totalTokens} total. ". + "The {$errorDetail}. ". + 'Try increasing maxTokens to at least '.($totalTokens + 1000).' (suggested: '.($totalTokens * 2).' for comfortable margin).' + ); + } + } + } + + /** + * @param array $data + */ + protected function handleStop(array $data, Request $request, FinishReason $finishReason): StructuredResponse + { + $this->addStep($data, $request, $finishReason); + + return $this->responseBuilder->toResponse(); + } + + /** + * @param array $data + */ + protected function handleToolCalls(array $data, Request $request): StructuredResponse + { + $toolResults = $this->callTools( + $request->tools(), + ToolCallMap::map(data_get($data, 'candidates.0.content.parts', [])) + ); + + $request->addMessage(new ToolResultMessage($toolResults)); + $request->resetToolChoice(); + + $this->addStep($data, $request, FinishReason::ToolCalls, $toolResults); + + if ($this->shouldContinue($request)) { + return $this->handle($request); + } + + return $this->responseBuilder->toResponse(); + } + + /** + * @param array $data + * @param ToolResult[] $toolResults + */ + protected function addStep(array $data, Request $request, FinishReason $finishReason, array $toolResults = []): void + { + $isStructuredStep = $finishReason !== FinishReason::ToolCalls; + $thoughtSummaries = $this->extractThoughtSummaries($data); + $textContent = $this->extractTextContent($data); + + $this->responseBuilder->addStep( + new Step( + text: $textContent, + finishReason: $finishReason, + usage: new Usage( + promptTokens: data_get($data, 'usageMetadata.promptTokenCount', 0), + completionTokens: data_get($data, 'usageMetadata.candidatesTokenCount', 0), + cacheReadInputTokens: data_get($data, 'usageMetadata.cachedContentTokenCount'), + thoughtTokens: data_get($data, 'usageMetadata.thoughtsTokenCount'), + ), + meta: new Meta( + id: data_get($data, 'id', ''), + model: data_get($data, 'modelVersion', ''), + ), + messages: $request->messages(), + systemPrompts: $request->systemPrompts(), + additionalContent: Arr::whereNotNull([ + 'thoughtSummaries' => $thoughtSummaries !== [] ? $thoughtSummaries : null, + 'citations' => CitationMapper::mapFromGemini(data_get($data, 'candidates.0', [])) ?: null, + 'searchEntryPoint' => data_get($data, 'candidates.0.groundingMetadata.searchEntryPoint'), + 'searchQueries' => data_get($data, 'candidates.0.groundingMetadata.webSearchQueries'), + 'urlMetadata' => data_get($data, 'candidates.0.urlContextMetadata.urlMetadata'), + ]), + structured: $isStructuredStep ? $this->extractStructuredData(data_get($data, 'candidates.0.content.parts.0.text') ?? '') : [], + toolCalls: $finishReason === FinishReason::ToolCalls ? ToolCallMap::map(data_get($data, 'candidates.0.content.parts', [])) : [], + toolResults: $toolResults, + raw: $data, + ) + ); + } + + /** + * @param array $data + */ + protected function extractTextContent(array $data): string + { + $parts = data_get($data, 'candidates.0.content.parts', []); + $textParts = []; + + foreach ($parts as $part) { + if (isset($part['text']) && (! isset($part['thought']) || $part['thought'] === false)) { + $textParts[] = $part['text']; + } + } + + return implode('', $textParts); + } + + /** + * @param array $data + * @return array + */ + protected function extractThoughtSummaries(array $data): array + { + $parts = data_get($data, 'candidates.0.content.parts', []); + $thoughtSummaries = []; + + foreach ($parts as $part) { + if (isset($part['thought']) && $part['thought'] === true && isset($part['text'])) { + $thoughtSummaries[] = $part['text']; + } + } + + return $thoughtSummaries; + } + + /** + * @param array $data + */ + protected function hasToolCalls(array $data): bool + { + $parts = data_get($data, 'candidates.0.content.parts', []); + + foreach ($parts as $part) { + if (isset($part['functionCall'])) { + return true; + } + } + + return false; + } +} diff --git a/src/Providers/Vertex/Handlers/Text.php b/src/Providers/Vertex/Handlers/Text.php new file mode 100644 index 000000000..9f2e458a9 --- /dev/null +++ b/src/Providers/Vertex/Handlers/Text.php @@ -0,0 +1,255 @@ +responseBuilder = new ResponseBuilder; + } + + public function handle(Request $request): TextResponse + { + $response = $this->sendRequest($request); + + $this->validateResponse($response); + + $data = $response->json(); + + $isToolCall = $this->hasToolCalls($data); + + $finishReason = FinishReasonMap::map( + data_get($data, 'candidates.0.finishReason'), + $isToolCall + ); + + return match ($finishReason) { + FinishReason::ToolCalls => $this->handleToolCalls($data, $request), + FinishReason::Stop, FinishReason::Length => $this->handleStop($data, $request, $finishReason), + default => throw new PrismException('Vertex: unhandled finish reason'), + }; + } + + protected function sendRequest(Request $request): ClientResponse + { + $providerOptions = $request->providerOptions(); + + $thinkingConfig = $providerOptions['thinkingConfig'] ?? null; + + if (isset($providerOptions['thinkingBudget'])) { + $thinkingConfig = Arr::whereNotNull([ + 'thinkingBudget' => $providerOptions['thinkingBudget'], + 'includeThoughts' => $providerOptions['includeThoughts'] ?? null, + ]); + } + + if (isset($providerOptions['thinkingLevel'])) { + $thinkingConfig = Arr::whereNotNull([ + 'thinkingLevel' => $providerOptions['thinkingLevel'], + 'includeThoughts' => $providerOptions['includeThoughts'] ?? null, + ]); + } + + $generationConfig = Arr::whereNotNull([ + 'temperature' => $request->temperature(), + 'topP' => $request->topP(), + 'maxOutputTokens' => $request->maxTokens(), + 'thinkingConfig' => $thinkingConfig, + ]); + + if ($request->tools() !== [] && $request->providerTools() !== []) { + throw new PrismException('Use of provider tools with custom tools is not currently supported by Vertex.'); + } + + $tools = []; + + if ($request->providerTools() !== []) { + $tools = array_map( + fn (ProviderTool $providerTool): array => [ + $providerTool->type => $providerTool->options !== [] ? $providerTool->options : (object) [], + ], + $request->providerTools() + ); + } + + if ($request->tools() !== []) { + $tools['function_declarations'] = ToolMap::map($request->tools()); + } + + /** @var ClientResponse $response */ + $response = $this->client->post( + "{$this->model}:generateContent", + Arr::whereNotNull([ + ...(new MessageMap($request->messages(), $request->systemPrompts()))(), + 'generationConfig' => $generationConfig !== [] ? $generationConfig : null, + 'tools' => $tools !== [] ? $tools : null, + 'tool_config' => $request->toolChoice() ? ToolChoiceMap::map($request->toolChoice()) : null, + 'safetySettings' => $providerOptions['safetySettings'] ?? null, + ]) + ); + + return $response; + } + + /** + * @param array $data + */ + protected function handleStop(array $data, Request $request, FinishReason $finishReason): TextResponse + { + $this->addStep($data, $request, $finishReason); + + return $this->responseBuilder->toResponse(); + } + + /** + * @param array $data + */ + protected function handleToolCalls(array $data, Request $request): TextResponse + { + $toolCalls = ToolCallMap::map(data_get($data, 'candidates.0.content.parts', [])); + + $toolResults = $this->callTools($request->tools(), $toolCalls); + + $this->addStep($data, $request, FinishReason::ToolCalls, $toolResults); + + $request->addMessage(new AssistantMessage( + $this->extractTextContent($data), + $toolCalls, + )); + $request->addMessage(new ToolResultMessage($toolResults)); + $request->resetToolChoice(); + + if ($this->shouldContinue($request)) { + return $this->handle($request); + } + + return $this->responseBuilder->toResponse(); + } + + protected function shouldContinue(Request $request): bool + { + return $this->responseBuilder->steps->count() < $request->maxSteps(); + } + + /** + * @param array $data + * @param ToolResult[] $toolResults + */ + protected function addStep(array $data, Request $request, FinishReason $finishReason, array $toolResults = []): void + { + $thoughtSummaries = $this->extractThoughtSummaries($data); + + $this->responseBuilder->addStep(new Step( + text: $this->extractTextContent($data), + finishReason: $finishReason, + toolCalls: $finishReason === FinishReason::ToolCalls ? ToolCallMap::map(data_get($data, 'candidates.0.content.parts', [])) : [], + toolResults: $toolResults, + providerToolCalls: [], + usage: new Usage( + promptTokens: data_get($data, 'usageMetadata.promptTokenCount', 0), + completionTokens: data_get($data, 'usageMetadata.candidatesTokenCount', 0), + cacheReadInputTokens: data_get($data, 'usageMetadata.cachedContentTokenCount'), + thoughtTokens: data_get($data, 'usageMetadata.thoughtsTokenCount'), + ), + meta: new Meta( + id: data_get($data, 'id', ''), + model: data_get($data, 'modelVersion', ''), + ), + messages: $request->messages(), + systemPrompts: $request->systemPrompts(), + additionalContent: Arr::whereNotNull([ + 'citations' => CitationMapper::mapFromGemini(data_get($data, 'candidates.0', [])) ?: null, + 'searchEntryPoint' => data_get($data, 'candidates.0.groundingMetadata.searchEntryPoint'), + 'searchQueries' => data_get($data, 'candidates.0.groundingMetadata.webSearchQueries'), + 'urlMetadata' => data_get($data, 'candidates.0.urlContextMetadata.urlMetadata'), + 'thoughtSummaries' => $thoughtSummaries !== [] ? $thoughtSummaries : null, + ]), + raw: $data, + )); + } + + /** + * @param array $data + */ + protected function extractTextContent(array $data): string + { + $parts = data_get($data, 'candidates.0.content.parts', []); + $textParts = []; + + foreach ($parts as $part) { + if (isset($part['text']) && (! isset($part['thought']) || $part['thought'] === false)) { + $textParts[] = $part['text']; + } + } + + return implode('', $textParts); + } + + /** + * @param array $data + * @return array + */ + protected function extractThoughtSummaries(array $data): array + { + $parts = data_get($data, 'candidates.0.content.parts', []); + $thoughtSummaries = []; + + foreach ($parts as $part) { + if (isset($part['thought']) && $part['thought'] === true && isset($part['text'])) { + $thoughtSummaries[] = $part['text']; + } + } + + return $thoughtSummaries; + } + + /** + * @param array $data + */ + protected function hasToolCalls(array $data): bool + { + $parts = data_get($data, 'candidates.0.content.parts', []); + + foreach ($parts as $part) { + if (isset($part['functionCall'])) { + return true; + } + } + + return false; + } +} diff --git a/src/Providers/Vertex/Vertex.php b/src/Providers/Vertex/Vertex.php new file mode 100644 index 000000000..523e1e21b --- /dev/null +++ b/src/Providers/Vertex/Vertex.php @@ -0,0 +1,259 @@ +client($request->clientOptions(), $request->clientRetry()), + $request->model() + ); + + return $handler->handle($request); + } + + #[\Override] + public function structured(StructuredRequest $request): StructuredResponse + { + $handler = new Structured( + $this->client($request->clientOptions(), $request->clientRetry()), + $request->model() + ); + + return $handler->handle($request); + } + + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + $handler = new Embeddings( + $this->client($request->clientOptions(), $request->clientRetry()), + $request->model() + ); + + return $handler->handle($request); + } + + #[\Override] + public function stream(TextRequest $request): Generator + { + $handler = new Stream( + $this->client($request->clientOptions(), $request->clientRetry()), + $request->model() + ); + + return $handler->handle($request); + } + + public function handleRequestException(string $model, RequestException $e): never + { + match ($e->response->getStatusCode()) { + 429 => throw PrismRateLimitedException::make([]), + 503 => throw PrismProviderOverloadedException::make(class_basename($this)), + default => $this->handleResponseErrors($e), + }; + } + + protected function handleResponseErrors(RequestException $e): never + { + $data = $e->response->json() ?? []; + + throw PrismException::providerRequestErrorWithDetails( + provider: 'Vertex', + statusCode: $e->response->getStatusCode(), + errorType: data_get($data, 'error.status'), + errorMessage: data_get($data, 'error.message'), + previous: $e + ); + } + + /** + * @param array $options + * @param array $retry + */ + protected function client(array $options = [], array $retry = []): PendingRequest + { + $accessToken = $this->resolveAccessToken(); + + return $this->baseClient() + ->withToken($accessToken) + ->withOptions($options) + ->when($retry !== [], fn ($client) => $client->retry(...$retry)) + ->baseUrl($this->buildBaseUrl()); + } + + protected function buildBaseUrl(): string + { + if ($this->region === 'global') { + return sprintf( + 'https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models', + $this->projectId + ); + } + + return sprintf( + 'https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models', + $this->region, + $this->projectId, + $this->region + ); + } + + protected function resolveAccessToken(): string + { + if ($this->accessToken !== null && $this->accessToken !== '') { + return $this->accessToken; + } + + if ($this->credentialsPath !== null && $this->credentialsPath !== '') { + return $this->getAccessTokenFromServiceAccount(); + } + + return $this->getAccessTokenFromApplicationDefaultCredentials(); + } + + protected function getAccessTokenFromServiceAccount(): string + { + if (! file_exists($this->credentialsPath)) { + throw new PrismException("Vertex AI credentials file not found: {$this->credentialsPath}"); + } + + $credentials = json_decode(file_get_contents($this->credentialsPath), true); + + if (! isset($credentials['client_email'], $credentials['private_key'])) { + throw new PrismException('Invalid Vertex AI service account credentials file'); + } + + return $this->generateJwtToken($credentials); + } + + /** + * @param array $credentials + */ + protected function generateJwtToken(array $credentials): string + { + $header = [ + 'alg' => 'RS256', + 'typ' => 'JWT', + ]; + + $now = time(); + $payload = [ + 'iss' => $credentials['client_email'], + 'sub' => $credentials['client_email'], + 'aud' => 'https://aiplatform.googleapis.com/', + 'iat' => $now, + 'exp' => $now + 3600, + ]; + + $headerEncoded = $this->base64UrlEncode(json_encode($header)); + $payloadEncoded = $this->base64UrlEncode(json_encode($payload)); + + $signatureInput = $headerEncoded.'.'.$payloadEncoded; + + $privateKey = openssl_pkey_get_private($credentials['private_key']); + if ($privateKey === false) { + throw new PrismException('Failed to parse Vertex AI private key'); + } + + $signature = ''; + if (! openssl_sign($signatureInput, $signature, $privateKey, OPENSSL_ALGO_SHA256)) { + throw new PrismException('Failed to sign Vertex AI JWT token'); + } + + return $headerEncoded.'.'.$payloadEncoded.'.'.$this->base64UrlEncode($signature); + } + + protected function base64UrlEncode(string $data): string + { + return rtrim(strtr(base64_encode($data), '+/', '-_'), '='); + } + + protected function getAccessTokenFromApplicationDefaultCredentials(): string + { + $adcPath = getenv('GOOGLE_APPLICATION_CREDENTIALS'); + + if ($adcPath !== false && $adcPath !== '' && file_exists($adcPath)) { + $this->credentialsPath !== null ?: $adcPath; + + return $this->getAccessTokenFromServiceAccount(); + } + + $defaultPath = $this->getDefaultAdcPath(); + if (file_exists($defaultPath)) { + $credentials = json_decode(file_get_contents($defaultPath), true); + + if (isset($credentials['type']) && $credentials['type'] === 'authorized_user') { + return $this->refreshAccessToken($credentials); + } + } + + throw new PrismException( + 'Vertex AI requires authentication. Provide an access_token, credentials_path, '. + 'or set up Application Default Credentials (run: gcloud auth application-default login)' + ); + } + + protected function getDefaultAdcPath(): string + { + if (PHP_OS_FAMILY === 'Windows') { + return getenv('APPDATA').'/gcloud/application_default_credentials.json'; + } + + return getenv('HOME').'/.config/gcloud/application_default_credentials.json'; + } + + /** + * @param array $credentials + */ + protected function refreshAccessToken(array $credentials): string + { + $response = $this->baseClient() + ->asForm() + ->post('https://oauth2.googleapis.com/token', [ + 'client_id' => $credentials['client_id'], + 'client_secret' => $credentials['client_secret'], + 'refresh_token' => $credentials['refresh_token'], + 'grant_type' => 'refresh_token', + ]); + + if (! $response->successful()) { + throw new PrismException('Failed to refresh Vertex AI access token: '.$response->body()); + } + + return $response->json('access_token'); + } +} diff --git a/tests/Fixtures/vertex/embeddings-1.json b/tests/Fixtures/vertex/embeddings-1.json new file mode 100644 index 000000000..850b6c836 --- /dev/null +++ b/tests/Fixtures/vertex/embeddings-1.json @@ -0,0 +1,27 @@ +{ + "predictions": [ + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + 0.0123456789, + -0.0234567890, + 0.0345678901, + -0.0456789012, + 0.0567890123, + -0.0678901234, + 0.0789012345, + -0.0890123456, + 0.0901234567, + -0.1012345678 + ] + } + } + ], + "metadata": { + "billableCharacterCount": 20 + } +} diff --git a/tests/Fixtures/vertex/generate-text-with-a-prompt-1.json b/tests/Fixtures/vertex/generate-text-with-a-prompt-1.json new file mode 100644 index 000000000..badbb1f62 --- /dev/null +++ b/tests/Fixtures/vertex/generate-text-with-a-prompt-1.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I am a large language model, trained by Google. I am an AI, and I don't have a name, feelings, or personal experiences. My purpose is to process information and respond to a wide range of prompts and questions in a helpful and informative way.\n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.12800796408402293 + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 57, + "totalTokenCount": 61 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Fixtures/vertex/generate-text-with-multiple-tools-1.json b/tests/Fixtures/vertex/generate-text-with-multiple-tools-1.json new file mode 100644 index 000000000..08f8a20a2 --- /dev/null +++ b/tests/Fixtures/vertex/generate-text-with-multiple-tools-1.json @@ -0,0 +1,35 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_games", + "args": { + "city": "Detroit" + } + } + }, + { + "functionCall": { + "name": "get_weather", + "args": { + "city": "Detroit" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.15 + } + ], + "usageMetadata": { + "promptTokenCount": 150, + "candidatesTokenCount": 20, + "totalTokenCount": 170 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Fixtures/vertex/generate-text-with-multiple-tools-2.json b/tests/Fixtures/vertex/generate-text-with-multiple-tools-2.json new file mode 100644 index 000000000..f39e52748 --- /dev/null +++ b/tests/Fixtures/vertex/generate-text-with-multiple-tools-2.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The tigers game is at 3pm today in Detroit. The weather will be 45° and cold, so you should wear a coat." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.12 + } + ], + "usageMetadata": { + "promptTokenCount": 200, + "candidatesTokenCount": 22, + "totalTokenCount": 222 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Fixtures/vertex/generate-text-with-system-prompt-1.json b/tests/Fixtures/vertex/generate-text-with-system-prompt-1.json new file mode 100644 index 000000000..2b5c0966e --- /dev/null +++ b/tests/Fixtures/vertex/generate-text-with-system-prompt-1.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I am Prism, a helpful AI assistant created by echo labs." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.08123456789 + } + ], + "usageMetadata": { + "promptTokenCount": 17, + "candidatesTokenCount": 14, + "totalTokenCount": 31 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Fixtures/vertex/image-detection-1.json b/tests/Fixtures/vertex/image-detection-1.json new file mode 100644 index 000000000..62029bdc2 --- /dev/null +++ b/tests/Fixtures/vertex/image-detection-1.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "That's an illustration of a **diamond**. More specifically, it's a stylized, geometric representation of a diamond, often used as an icon or symbol" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.11 + } + ], + "usageMetadata": { + "promptTokenCount": 263, + "candidatesTokenCount": 35, + "totalTokenCount": 298 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Fixtures/vertex/stream-basic-1.sse b/tests/Fixtures/vertex/stream-basic-1.sse new file mode 100644 index 000000000..2ea3610ca --- /dev/null +++ b/tests/Fixtures/vertex/stream-basic-1.sse @@ -0,0 +1,8 @@ +data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":null}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":1,"totalTokenCount":11},"modelVersion":"gemini-1.5-flash"} + +data: {"candidates":[{"content":{"parts":[{"text":", I am"}],"role":"model"},"finishReason":null}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3,"totalTokenCount":13},"modelVersion":"gemini-1.5-flash"} + +data: {"candidates":[{"content":{"parts":[{"text":" a helpful AI assistant."}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"totalTokenCount":18},"modelVersion":"gemini-1.5-flash"} + +data: [DONE] + diff --git a/tests/Fixtures/vertex/structured-response-1.json b/tests/Fixtures/vertex/structured-response-1.json new file mode 100644 index 000000000..d5d861271 --- /dev/null +++ b/tests/Fixtures/vertex/structured-response-1.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "{\"name\":\"John Doe\",\"age\":30,\"email\":\"john.doe@example.com\"}" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.08 + } + ], + "usageMetadata": { + "promptTokenCount": 50, + "candidatesTokenCount": 25, + "totalTokenCount": 75 + }, + "modelVersion": "gemini-1.5-flash" +} diff --git a/tests/Providers/Vertex/VertexEmbeddingsTest.php b/tests/Providers/Vertex/VertexEmbeddingsTest.php new file mode 100644 index 000000000..496e3aa70 --- /dev/null +++ b/tests/Providers/Vertex/VertexEmbeddingsTest.php @@ -0,0 +1,84 @@ +set('prism.providers.vertex.project_id', 'test-project'); + config()->set('prism.providers.vertex.region', 'us-central1'); + config()->set('prism.providers.vertex.access_token', 'test-access-token'); +}); + +describe('Embeddings for Vertex', function (): void { + it('can generate embeddings', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/embeddings'); + + $response = Prism::embeddings() + ->using(Provider::Vertex, 'text-embedding-004') + ->fromInput('Hello, world!') + ->generate(); + + expect($response->embeddings)->toHaveCount(1) + ->and($response->embeddings[0]->embedding)->toBeArray() + ->and($response->embeddings[0]->embedding)->toHaveCount(10) + ->and($response->usage->tokens)->toBe(20) + ->and($response->meta->model)->toBe('text-embedding-004'); + }); + + it('sends requests to the correct Vertex AI endpoint for embeddings', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/embeddings'); + + Prism::embeddings() + ->using(Provider::Vertex, 'text-embedding-004') + ->fromInput('Hello, world!') + ->generate(); + + Http::assertSent(function (Request $request): bool { + expect($request->url())->toContain('us-central1-aiplatform.googleapis.com') + ->and($request->url())->toContain('projects/test-project') + ->and($request->url())->toContain('locations/us-central1') + ->and($request->url())->toContain('publishers/google/models') + ->and($request->url())->toContain('text-embedding-004:predict') + ->and($request->hasHeader('Authorization'))->toBeTrue() + ->and($request->header('Authorization')[0])->toBe('Bearer test-access-token'); + + return true; + }); + }); + + it('includes content in request body', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/embeddings'); + + Prism::embeddings() + ->using(Provider::Vertex, 'text-embedding-004') + ->fromInput('Hello, world!') + ->generate(); + + Http::assertSent(function (Request $request): bool { + $data = $request->data(); + + expect($data['instances'])->toHaveCount(1) + ->and($data['instances'][0]['content'])->toBe('Hello, world!'); + + return true; + }); + }); + + it('throws exception for multiple inputs', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/embeddings'); + + Prism::embeddings() + ->using(Provider::Vertex, 'text-embedding-004') + ->fromInput('Hello') + ->fromInput('World') + ->generate(); + })->throws(PrismException::class, 'Vertex Error: Prism currently only supports one input at a time with Vertex AI.'); +}); diff --git a/tests/Providers/Vertex/VertexExceptionHandlingTest.php b/tests/Providers/Vertex/VertexExceptionHandlingTest.php new file mode 100644 index 000000000..9bd115d53 --- /dev/null +++ b/tests/Providers/Vertex/VertexExceptionHandlingTest.php @@ -0,0 +1,87 @@ +set('prism.providers.vertex.project_id', 'test-project'); + config()->set('prism.providers.vertex.region', 'us-central1'); + config()->set('prism.providers.vertex.access_token', 'test-access-token'); +}); + +describe('Exception handling for Vertex', function (): void { + it('throws PrismRateLimitedException on 429 response', function (): void { + Http::fake([ + '*' => Http::response([ + 'error' => [ + 'code' => 429, + 'message' => 'Resource has been exhausted', + 'status' => 'RESOURCE_EXHAUSTED', + ], + ], 429), + ]); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + })->throws(PrismRateLimitedException::class); + + it('throws PrismException on error response', function (): void { + Http::fake([ + '*' => Http::response([ + 'error' => [ + 'code' => 400, + 'message' => 'Invalid request', + 'status' => 'INVALID_ARGUMENT', + ], + ], 400), + ]); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + })->throws(PrismException::class); + + it('throws PrismException on authentication error', function (): void { + Http::fake([ + '*' => Http::response([ + 'error' => [ + 'code' => 401, + 'message' => 'Request had invalid authentication credentials', + 'status' => 'UNAUTHENTICATED', + ], + ], 401), + ]); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + })->throws(PrismException::class); + + it('throws PrismException on permission denied', function (): void { + Http::fake([ + '*' => Http::response([ + 'error' => [ + 'code' => 403, + 'message' => 'Permission denied on resource', + 'status' => 'PERMISSION_DENIED', + ], + ], 403), + ]); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + })->throws(PrismException::class); +}); diff --git a/tests/Providers/Vertex/VertexStreamTest.php b/tests/Providers/Vertex/VertexStreamTest.php new file mode 100644 index 000000000..e78c4b5ac --- /dev/null +++ b/tests/Providers/Vertex/VertexStreamTest.php @@ -0,0 +1,106 @@ +set('prism.providers.vertex.project_id', 'test-project'); + config()->set('prism.providers.vertex.region', 'us-central1'); + config()->set('prism.providers.vertex.access_token', 'test-access-token'); +}); + +describe('Streaming for Vertex', function (): void { + it('can stream text responses', function (): void { + FixtureResponse::fakeStreamResponses('*', 'vertex/stream-basic'); + + $events = []; + $fullText = ''; + + $stream = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asStream(); + + foreach ($stream as $event) { + $events[] = $event; + + if ($event instanceof TextDeltaEvent) { + $fullText .= $event->delta; + } + } + + // Should have stream start and end events + expect(collect($events)->first())->toBeInstanceOf(StreamStartEvent::class); + expect(collect($events)->last())->toBeInstanceOf(StreamEndEvent::class); + + // Should have text delta events + $textDeltas = collect($events)->filter(fn ($e): bool => $e instanceof TextDeltaEvent); + expect($textDeltas)->toHaveCount(3); + + // Full text should be concatenated correctly + expect($fullText)->toBe('Hello, I am a helpful AI assistant.'); + + // Stream end should have finish reason + $endEvent = collect($events)->last(); + expect($endEvent->finishReason)->toBe(FinishReason::Stop); + }); + + it('sends streaming requests to the correct Vertex AI endpoint', function (): void { + FixtureResponse::fakeStreamResponses('*', 'vertex/stream-basic'); + + $stream = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asStream(); + + // Consume the stream + foreach ($stream as $event) { + // Just iterate to trigger the request + } + + Http::assertSent(function (Request $request): bool { + expect($request->url())->toContain('us-central1-aiplatform.googleapis.com') + ->and($request->url())->toContain('projects/test-project') + ->and($request->url())->toContain('locations/us-central1') + ->and($request->url())->toContain('publishers/google/models') + ->and($request->url())->toContain('gemini-1.5-flash:streamGenerateContent') + ->and($request->url())->toContain('alt=sse') + ->and($request->hasHeader('Authorization'))->toBeTrue() + ->and($request->header('Authorization')[0])->toBe('Bearer test-access-token'); + + return true; + }); + }); + + it('stream start event contains model and provider info', function (): void { + FixtureResponse::fakeStreamResponses('*', 'vertex/stream-basic'); + + $stream = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asStream(); + + $startEvent = null; + foreach ($stream as $event) { + if ($event instanceof StreamStartEvent) { + $startEvent = $event; + break; + } + } + + expect($startEvent)->not->toBeNull() + ->and($startEvent->provider)->toBe('vertex') + ->and($startEvent->model)->toBe('gemini-1.5-flash'); + }); +}); diff --git a/tests/Providers/Vertex/VertexStructuredTest.php b/tests/Providers/Vertex/VertexStructuredTest.php new file mode 100644 index 000000000..132fd0c30 --- /dev/null +++ b/tests/Providers/Vertex/VertexStructuredTest.php @@ -0,0 +1,109 @@ +set('prism.providers.vertex.project_id', 'test-project'); + config()->set('prism.providers.vertex.region', 'us-central1'); + config()->set('prism.providers.vertex.access_token', 'test-access-token'); +}); + +describe('Structured output for Vertex', function (): void { + it('can generate structured output', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/structured-response'); + + $schema = new ObjectSchema( + name: 'user', + description: 'A user object', + properties: [ + new StringSchema('name', 'The user\'s name'), + new NumberSchema('age', 'The user\'s age'), + new StringSchema('email', 'The user\'s email'), + ], + requiredFields: ['name', 'age', 'email'] + ); + + $response = Prism::structured() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSchema($schema) + ->withPrompt('Generate a user object for John Doe, age 30, with email john.doe@example.com') + ->generate(); + + expect($response->structured)->toBe([ + 'name' => 'John Doe', + 'age' => 30, + 'email' => 'john.doe@example.com', + ]) + ->and($response->finishReason)->toBe(FinishReason::Stop) + ->and($response->usage->promptTokens)->toBe(50) + ->and($response->usage->completionTokens)->toBe(25); + }); + + it('sends requests with response schema to Vertex AI', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/structured-response'); + + $schema = new ObjectSchema( + name: 'user', + description: 'A user object', + properties: [ + new StringSchema('name', 'The user\'s name'), + new NumberSchema('age', 'The user\'s age'), + ], + requiredFields: ['name', 'age'] + ); + + Prism::structured() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSchema($schema) + ->withPrompt('Generate a user') + ->generate(); + + Http::assertSent(function (Request $request): bool { + $data = $request->data(); + + expect($data['generationConfig'])->toHaveKey('response_mime_type', 'application/json') + ->and($data['generationConfig'])->toHaveKey('response_schema'); + + return true; + }); + }); + + it('sends requests to the correct Vertex AI endpoint for structured output', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/structured-response'); + + $schema = new ObjectSchema( + name: 'test', + description: 'Test', + properties: [ + new StringSchema('value', 'Test value'), + ], + requiredFields: ['value'] + ); + + Prism::structured() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSchema($schema) + ->withPrompt('Test') + ->generate(); + + Http::assertSent(function (Request $request): bool { + expect($request->url())->toContain('us-central1-aiplatform.googleapis.com') + ->and($request->url())->toContain('projects/test-project') + ->and($request->url())->toContain('gemini-1.5-flash:generateContent'); + + return true; + }); + }); +}); diff --git a/tests/Providers/Vertex/VertexTextTest.php b/tests/Providers/Vertex/VertexTextTest.php new file mode 100644 index 000000000..ba6841781 --- /dev/null +++ b/tests/Providers/Vertex/VertexTextTest.php @@ -0,0 +1,259 @@ +set('prism.providers.vertex.project_id', 'test-project'); + config()->set('prism.providers.vertex.region', 'us-central1'); + config()->set('prism.providers.vertex.access_token', 'test-access-token'); +}); + +describe('Text generation for Vertex', function (): void { + it('can generate text with a prompt', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-a-prompt'); + + $response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Who are you?') + ->withMaxTokens(10) + ->asText(); + + expect($response->text)->toBe( + "I am a large language model, trained by Google. I am an AI, and I don't have a name, feelings, or personal experiences. My purpose is to process information and respond to a wide range of prompts and questions in a helpful and informative way.\n" + ) + ->and($response->usage->promptTokens)->toBe(4) + ->and($response->usage->completionTokens)->toBe(57) + ->and($response->meta->id)->toBe('') + ->and($response->meta->model)->toBe('gemini-1.5-flash') + ->and($response->finishReason)->toBe(FinishReason::Stop); + }); + + it('can generate text with a system prompt', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-system-prompt'); + + $response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withSystemPrompt('You are a helpful AI assistant named Prism generated by echolabs') + ->withPrompt('Who are you?') + ->asText(); + + expect($response->text)->toBe('I am Prism, a helpful AI assistant created by echo labs.') + ->and($response->usage->promptTokens)->toBe(17) + ->and($response->usage->completionTokens)->toBe(14) + ->and($response->meta->id)->toBe('') + ->and($response->meta->model)->toBe('gemini-1.5-flash') + ->and($response->finishReason)->toBe(FinishReason::Stop); + }); + + it('can generate text using multiple tools and multiple steps', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-multiple-tools'); + + $tools = [ + (new Tool) + ->as('get_weather') + ->for('use this tool when you need to get weather for the city') + ->withStringParameter('city', 'The city that you want the weather for') + ->using(fn (string $city): string => 'The weather will be 45° and cold'), + (new Tool) + ->as('search_games') + ->for('useful for searching current games times in the city') + ->withStringParameter('city', 'The city that you want the game times for') + ->using(fn (string $city): string => 'The tigers game is at 3pm in detroit'), + ]; + + $response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withTools($tools) + ->withMaxSteps(5) + ->withPrompt('What time is the tigers game today in Detroit and should I wear a coat? please check all the details from tools') + ->asText(); + + // Assert tool calls in the first step + $firstStep = $response->steps[0]; + expect($firstStep->toolCalls)->toHaveCount(2); + expect($firstStep->toolCalls[0]->name)->toBe('search_games'); + expect($firstStep->toolCalls[0]->arguments())->toBe([ + 'city' => 'Detroit', + ]); + expect($firstStep->toolCalls[1]->name)->toBe('get_weather'); + expect($firstStep->toolCalls[1]->arguments())->toBe([ + 'city' => 'Detroit', + ]); + + // Verify the assistant message from step 1 is present in step 2's input messages + $secondStep = $response->steps[1]; + expect($secondStep->messages)->toHaveCount(3); + expect($secondStep->messages[0])->toBeInstanceOf(UserMessage::class); + expect($secondStep->messages[1])->toBeInstanceOf(AssistantMessage::class); + expect($secondStep->messages[1]->toolCalls)->toHaveCount(2); + expect($secondStep->messages[1]->toolCalls[0]->name)->toBe('search_games'); + expect($secondStep->messages[1]->toolCalls[1]->name)->toBe('get_weather'); + expect($secondStep->messages[2])->toBeInstanceOf(ToolResultMessage::class); + + // Assert usage (combined from both responses) + expect($response->usage->promptTokens)->toBe(350) + ->and($response->usage->completionTokens)->toBe(42); + + // Assert response + expect($response->meta->id)->toBe('') + ->and($response->meta->model)->toBe('gemini-1.5-flash') + ->and($response->text)->toBe('The tigers game is at 3pm today in Detroit. The weather will be 45° and cold, so you should wear a coat.'); + }); +}); + +describe('Image support with Vertex', function (): void { + it('can send images from path', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/image-detection'); + + $response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withMessages([ + new UserMessage( + 'What is this image', + additionalContent: [ + Image::fromLocalPath('tests/Fixtures/diamond.png'), + ], + ), + ]) + ->asText(); + + // Assert response + expect($response->text)->toBe("That's an illustration of a **diamond**. More specifically, it's a stylized, geometric representation of a diamond, often used as an icon or symbol") + ->and($response->usage->promptTokens)->toBe(263) + ->and($response->usage->completionTokens)->toBe(35) + ->and($response->meta->id)->toBe('') + ->and($response->meta->model)->toBe('gemini-1.5-flash') + ->and($response->finishReason)->toBe(FinishReason::Stop); + + // Assert request format + Http::assertSent(function (Request $request): bool { + $message = $request->data()['contents'][0]['parts']; + + expect($message[0])->toBe([ + 'text' => 'What is this image', + ]); + + expect($message[1]['inline_data'])->toHaveKeys(['mime_type', 'data']); + expect($message[1]['inline_data']['mime_type'])->toBe('image/png'); + expect($message[1]['inline_data']['data'])->toBe( + base64_encode(file_get_contents('tests/Fixtures/diamond.png')) + ); + + return true; + }); + }); + + it('can send images from base64', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/image-detection'); + + $response = Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withMessages([ + new UserMessage( + 'What is this image', + additionalContent: [ + Image::fromBase64( + base64_encode(file_get_contents('tests/Fixtures/diamond.png')), + 'image/png' + ), + ], + ), + ]) + ->asText(); + + Http::assertSent(function (Request $request): bool { + $message = $request->data()['contents'][0]['parts']; + + expect($message[0])->toBe([ + 'text' => 'What is this image', + ]); + + expect($message[1]['inline_data'])->toHaveKeys(['mime_type', 'data']); + expect($message[1]['inline_data']['mime_type'])->toBe('image/png'); + expect($message[1]['inline_data']['data'])->toBe( + base64_encode(file_get_contents('tests/Fixtures/diamond.png')) + ); + + return true; + }); + }); +}); + +describe('Request format for Vertex', function (): void { + it('sends requests to the correct Vertex AI endpoint', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-a-prompt'); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + + Http::assertSent(function (Request $request): bool { + expect($request->url())->toContain('us-central1-aiplatform.googleapis.com') + ->and($request->url())->toContain('projects/test-project') + ->and($request->url())->toContain('locations/us-central1') + ->and($request->url())->toContain('publishers/google/models') + ->and($request->url())->toContain('gemini-1.5-flash:generateContent') + ->and($request->hasHeader('Authorization'))->toBeTrue() + ->and($request->header('Authorization')[0])->toBe('Bearer test-access-token'); + + return true; + }); + }); + + it('uses global hostname when region is global', function (): void { + config()->set('prism.providers.vertex.region', 'global'); + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-a-prompt'); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->asText(); + + Http::assertSent(function (Request $request): bool { + expect($request->url())->toContain('https://aiplatform.googleapis.com') + ->and($request->url())->not->toContain('global-aiplatform.googleapis.com') + ->and($request->url())->toContain('projects/test-project') + ->and($request->url())->toContain('locations/global') + ->and($request->url())->toContain('publishers/google/models'); + + return true; + }); + }); + + it('includes generation config in request', function (): void { + FixtureResponse::fakeResponseSequence('*', 'vertex/generate-text-with-a-prompt'); + + Prism::text() + ->using(Provider::Vertex, 'gemini-1.5-flash') + ->withPrompt('Hello') + ->withMaxTokens(100) + ->usingTemperature(0.7) + ->usingTopP(0.9) + ->asText(); + + Http::assertSent(function (Request $request): bool { + $data = $request->data(); + + expect($data['generationConfig'])->toHaveKey('maxOutputTokens', 100) + ->and($data['generationConfig'])->toHaveKey('temperature', 0.7) + ->and($data['generationConfig'])->toHaveKey('topP', 0.9); + + return true; + }); + }); +});