diff --git a/dataflow/serving/api_google_vertexai_serving.py b/dataflow/serving/api_google_vertexai_serving.py index a5620fd3..eec77d93 100644 --- a/dataflow/serving/api_google_vertexai_serving.py +++ b/dataflow/serving/api_google_vertexai_serving.py @@ -312,7 +312,7 @@ def generate_from_input( self, user_inputs: List[str], system_prompt: str = "", - response_schema: Optional[Union[type[BaseModel], dict]] = None, + json_schema: Optional[Union[type[BaseModel], dict]] = None, use_function_call: Optional[bool] = None, use_batch: Optional[bool] = None, batch_wait: Optional[bool] = None, @@ -326,7 +326,7 @@ def generate_from_input( Args: user_inputs: List of user input strings to process. system_prompt: System prompt for the model. - response_schema: Optional Pydantic BaseModel or dict for structured output. + json_schema: Optional Pydantic BaseModel or dict for structured output. use_batch: If True, use batch processing via BigQuery. If False, use parallel real-time generation. batch_wait: If True (and use_batch=True), wait for batch job to complete and return results. If False, return the batch job name immediately for later retrieval. @@ -357,7 +357,7 @@ def generate_from_input( return self._generate_with_batch( user_inputs=user_inputs, system_prompt=system_prompt, - response_schema=response_schema, + response_schema=json_schema, use_function_call=use_function_call, wait_for_completion=batch_wait, dataset_name=batch_dataset, @@ -368,7 +368,7 @@ def generate_from_input( return self._generate_with_parallel( user_inputs=user_inputs, system_prompt=system_prompt, - response_schema=response_schema, + response_schema=json_schema, use_function_call=use_function_call, ) @@ -947,7 +947,7 @@ class UserDetails(BaseModel): "John Doe is 30 years old and lives in New York.", "My name is Jane Smith, I am 25, and I reside in London." ] - results_json = gemini_server_json.generate_from_input(user_prompts_json, system_prompt_json, response_schema=UserDetails) # Pass the schema here + results_json = gemini_server_json.generate_from_input(user_prompts_json, system_prompt_json, json_schema=UserDetails) # Pass the schema here print("--- Generation Complete ---") for i, (prompt, result) in enumerate(zip(user_prompts_json, results_json)): print(f"\n[Prompt {i+1}]: {prompt}") @@ -975,7 +975,7 @@ class UserDetails(BaseModel): "Alice Johnson is 28 years old and lives in San Francisco.", "Bob Brown, aged 35, resides in Toronto." ] - results_json_schema = gemini_server_json_schema.generate_from_input(user_prompts_json_schema, system_prompt_json_schema, response_schema=json_schema) + results_json_schema = gemini_server_json_schema.generate_from_input(user_prompts_json_schema, system_prompt_json_schema, json_schema=json_schema) print("--- Generation Complete ---") for i, (prompt, result) in enumerate(zip(user_prompts_json_schema, results_json_schema)): print(f"\n[Prompt {i+1}]: {prompt}") @@ -1007,7 +1007,7 @@ class Capital(BaseModel): batch_job_name = gemini_server_batch.generate_from_input( user_inputs=user_prompts_batch, system_prompt=system_prompt_batch, - response_schema=Capital, + json_schema=Capital, use_batch=True, batch_wait=False # Don't wait for completion ) @@ -1027,7 +1027,7 @@ class Capital(BaseModel): results_batch = gemini_server_batch.generate_from_input( user_inputs=user_prompts_batch, system_prompt=system_prompt_batch, - response_schema=Capital, + json_schema=Capital, use_batch=True, batch_wait=True # Wait for completion )