-
Notifications
You must be signed in to change notification settings - Fork 0
Add model management to RAGFlow HTTP API #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add model management to RAGFlow HTTP API #6
Conversation
…d support for individual models and added tests for the add_model API.
… required fields across various factories. Added extensive test cases to ensure proper error handling for missing parameters.
| - Factory-level: requires llm_factory and api_key (or special auth fields) | ||
| - Individual model: requires llm_factory, llm_name, model_type, and api_base (for local models) or api_key | ||
| """ | ||
| from common.constants import LLMType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only use module level imports
| raise PydanticCustomError("field_required", "model_type is required when adding an individual model") | ||
|
|
||
| # Validate model_type | ||
| valid_types = [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List is unnecessary. To check if model type is valid use self.model_type in LLMType. To get list of LLMTypes use [t.value for t in LLMType]
| raise PydanticCustomError("field_required", "spark_app_id, spark_api_secret, and spark_api_key are required for XunFei Spark TTS models") | ||
|
|
||
| # Factory-level addition mode | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When adding all models of a factory only check factories for which this is actually supported. For all other's check that the API key is provided
|
|
||
| # For local models, api_base is typically required | ||
| if is_local and not self.api_base and not self.api_key: | ||
| raise PydanticCustomError("field_required", "api_base is required for local/self-hosted models when api_key is not provided") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For local models the base URL should always be required even if an API key is provided
| if is_local and not self.api_base and not self.api_key: | ||
| raise PydanticCustomError("field_required", "api_base is required for local/self-hosted models when api_key is not provided") | ||
|
|
||
| # For individual model mode, validate factory-specific required parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use match-case construct
| # For individual model mode, validate factory-specific required parameters | ||
| if self.llm_factory == "VolcEngine": | ||
| if not self.ark_api_key or not self.endpoint_id: | ||
| raise PydanticCustomError("field_required", "ark_api_key and endpoint_id are required for VolcEngine individual model addition") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- create helper function for these checks to reduce code duplication
- should also check that the fields don't just contain whitespace
| ) | ||
| if not self.api_key and not has_special_auth: | ||
| raise PydanticCustomError("field_required", "api_key or appropriate authentication fields are required for factory-level addition") | ||
| # For local factories, empty api_key is allowed in factory-level mode (validation skipped above) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is misleading. Local models aren't supported for factory-level adding
|
|
||
| @manager.route("/models", methods=["POST"]) # noqa: F821 | ||
| @token_required | ||
| async def add_model(tenant_id: str) -> Response: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is full of duplicate validations that is already handled by the Pydantic request validator. Please rework it to remove the redundant code.
|
|
||
| @pytest.fixture(scope="class") | ||
| def cleanup_added_models(request: FixtureRequest, HttpApiAuth): | ||
| """Fixture to clean up models added during tests (for TestAddModel class)""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is used for all tests not just add_model
| def cleanup_added_models(request: FixtureRequest, HttpApiAuth): | ||
| """Fixture to clean up models added during tests (for TestAddModel class)""" | ||
| # Track factories that might be added during tests | ||
| factories_to_cleanup: List[str] = ["Builtin", "LocalAI", "Ollama", "Xinference", "LM-Studio", "GPUStack", "FastEmbed"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not an accurate list of the factories that are added in the tests
| return session_ids | ||
|
|
||
|
|
||
| # USER MODELS MANAGEMENT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The headers and data parameters of the functions here are used. Remove them.
| @manager.route("/models", methods=["POST"]) # noqa: F821 | ||
| @token_required | ||
| async def add_model(tenant_id: str) -> Response: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The information in this docstring is wrong. Use the same docstring format as the other HTTP API endpoints. Additional details should be in the RAGFlow documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of problems with the tests in this module, so I'll only focus on the parameter validation tests for now:
- when testing a string parameter that needs to be present test 4 case: not providing the parameter at all, passing
None, an empty string, and a string with only whitespace - make sure to construct your requests so that the reason they fail is related to what you are trying to test and not e.g. some unrelated missing parameter
- The tests that test factories with special parameters are disorganized and partially redundant. Organize them into one
TestSpecialParameterValidationclass containing the following parameterized tests:- for each factory that has special parameters and can be added in "factory-mode" test the following when adding in factory mode (no
llm_nameormodel_type):- providing all required parameters passes the validation (but request is still expected to fail when testing if model can be accessed)
- for each of the special parameter test not providing it
- for each special parameter test that providing it which a factory where it is not expected fails
- for all factories that have special parameters add the same tests as described above also for adding individual models
- for each factory without special parameters test that validation passes for a minimal valid config with any special parameter when adding in both factory-mode (for factories that support it) and individual mode (for all factories)
- for each factory that has special parameters and can be added in "factory-mode" test the following when adding in factory mode (no
| """Validation model for adding models (factory-level or individual model).""" | ||
|
|
||
| llm_factory: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] | ||
| api_key: Annotated[str | dict[str, Any] | None, Field(default=None)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can this be a dict?
| class ListDatasetReq(BaseListReq): ... | ||
|
|
||
|
|
||
| class AddModelReq(Base): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strip whitespace for all string fields
| assert "model_type is required when adding an individual model" in res["message"], res | ||
|
|
||
| @pytest.mark.p1 | ||
| def test_individual_model_invalid_model_type(self, HttpApiAuth: RAGFlowHttpApiAuth) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also test the success case for all of valid model types that should work
|
|
||
|
|
||
| @pytest.mark.usefixtures("cleanup_added_models") | ||
| class TestAddModelParameterValidation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also test:
llm_namemaximum lengthapi_keyrequired for adding all modelsbase_urlrequired for model without default base URLmax_tokensrange of valid values
No description provided.