Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions dataflow/serving/api_google_vertexai_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def generate_from_input(
self,
user_inputs: List[str],
system_prompt: str = "",
response_schema: Optional[Union[type[BaseModel], dict]] = None,
json_schema: Optional[Union[type[BaseModel], dict]] = None,
use_function_call: Optional[bool] = None,
use_batch: Optional[bool] = None,
batch_wait: Optional[bool] = None,
Expand All @@ -326,7 +326,7 @@ def generate_from_input(
Args:
user_inputs: List of user input strings to process.
system_prompt: System prompt for the model.
response_schema: Optional Pydantic BaseModel or dict for structured output.
json_schema: Optional Pydantic BaseModel or dict for structured output.
use_batch: If True, use batch processing via BigQuery. If False, use parallel real-time generation.
batch_wait: If True (and use_batch=True), wait for batch job to complete and return results.
If False, return the batch job name immediately for later retrieval.
Expand Down Expand Up @@ -357,7 +357,7 @@ def generate_from_input(
return self._generate_with_batch(
user_inputs=user_inputs,
system_prompt=system_prompt,
response_schema=response_schema,
response_schema=json_schema,
use_function_call=use_function_call,
wait_for_completion=batch_wait,
dataset_name=batch_dataset,
Expand All @@ -368,7 +368,7 @@ def generate_from_input(
return self._generate_with_parallel(
user_inputs=user_inputs,
system_prompt=system_prompt,
response_schema=response_schema,
response_schema=json_schema,
use_function_call=use_function_call,
)

Expand Down Expand Up @@ -947,7 +947,7 @@ class UserDetails(BaseModel):
"John Doe is 30 years old and lives in New York.",
"My name is Jane Smith, I am 25, and I reside in London."
]
results_json = gemini_server_json.generate_from_input(user_prompts_json, system_prompt_json, response_schema=UserDetails) # Pass the schema here
results_json = gemini_server_json.generate_from_input(user_prompts_json, system_prompt_json, json_schema=UserDetails) # Pass the schema here
print("--- Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_json, results_json)):
print(f"\n[Prompt {i+1}]: {prompt}")
Expand Down Expand Up @@ -975,7 +975,7 @@ class UserDetails(BaseModel):
"Alice Johnson is 28 years old and lives in San Francisco.",
"Bob Brown, aged 35, resides in Toronto."
]
results_json_schema = gemini_server_json_schema.generate_from_input(user_prompts_json_schema, system_prompt_json_schema, response_schema=json_schema)
results_json_schema = gemini_server_json_schema.generate_from_input(user_prompts_json_schema, system_prompt_json_schema, json_schema=json_schema)
print("--- Generation Complete ---")
for i, (prompt, result) in enumerate(zip(user_prompts_json_schema, results_json_schema)):
print(f"\n[Prompt {i+1}]: {prompt}")
Expand Down Expand Up @@ -1007,7 +1007,7 @@ class Capital(BaseModel):
batch_job_name = gemini_server_batch.generate_from_input(
user_inputs=user_prompts_batch,
system_prompt=system_prompt_batch,
response_schema=Capital,
json_schema=Capital,
use_batch=True,
batch_wait=False # Don't wait for completion
)
Expand All @@ -1027,7 +1027,7 @@ class Capital(BaseModel):
results_batch = gemini_server_batch.generate_from_input(
user_inputs=user_prompts_batch,
system_prompt=system_prompt_batch,
response_schema=Capital,
json_schema=Capital,
use_batch=True,
batch_wait=True # Wait for completion
)
Expand Down