diff --git a/.gitignore b/.gitignore index 7fa2022..e81a1db 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .idea venv .venv -*.db \ No newline at end of file +*.db +src/inputs/*.pdf \ No newline at end of file diff --git a/api/db/repositories.py b/api/db/repositories.py index 6608718..7686510 100644 --- a/api/db/repositories.py +++ b/api/db/repositories.py @@ -16,4 +16,10 @@ def create_form(session: Session, form: FormSubmission) -> FormSubmission: session.add(form) session.commit() session.refresh(form) - return form \ No newline at end of file + return form + +def get_all_templates(session: Session, limit: int = 100, offset: int = 0) -> list[Template]: + return session.exec(select(Template).offset(offset).limit(limit)).all() + +def get_form(session: Session, submission_id: int) -> FormSubmission | None: + return session.get(FormSubmission, submission_id) \ No newline at end of file diff --git a/api/main.py b/api/main.py index d0b8c79..0a7d8e7 100644 --- a/api/main.py +++ b/api/main.py @@ -1,7 +1,25 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from api.routes import templates, forms +from api.errors.base import AppError +from typing import Union app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +@app.exception_handler(AppError) +def app_error_handler(request: Request, exc: AppError): + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.message} + ) + app.include_router(templates.router) app.include_router(forms.router) \ No newline at end of file diff --git a/api/routes/forms.py b/api/routes/forms.py index f3430ed..3491d4e 100644 --- a/api/routes/forms.py +++ b/api/routes/forms.py @@ -1,25 +1,82 @@ +import os from fastapi import APIRouter, Depends +from fastapi.responses import FileResponse from sqlmodel import Session from api.deps import get_db from api.schemas.forms import FormFill, FormFillResponse -from api.db.repositories import create_form, get_template +from api.db.repositories import create_form, get_template, get_form from api.db.models import FormSubmission from api.errors.base import AppError from src.controller import Controller router = APIRouter(prefix="/forms", tags=["forms"]) + @router.post("/fill", response_model=FormFillResponse) def fill_form(form: FormFill, db: Session = Depends(get_db)): - if not get_template(db, form.template_id): + # Single DB query (fixes issue #149 - redundant query) + template = get_template(db, form.template_id) + if not template: raise AppError("Template not found", status_code=404) - fetched_template = get_template(db, form.template_id) + try: + controller = Controller() + # FileManipulator.fill_form expects fields as a list of key strings + fields_list = list(template.fields.keys()) if isinstance(template.fields, dict) else template.fields + path = controller.fill_form( + user_input=form.input_text, + fields=fields_list, + pdf_form_path=template.pdf_path + ) + except ConnectionError: + raise AppError( + "Could not connect to Ollama. Make sure ollama serve is running.", + status_code=503 + ) + except Exception as e: + raise AppError(f"PDF filling failed: {str(e)}", status_code=500) + + # Guard: controller returned None instead of a file path + if not path: + raise AppError( + "PDF generation failed — no output file was produced. " + "Check that the PDF template is a valid fillable form and Ollama is running.", + status_code=500 + ) - controller = Controller() - path = controller.fill_form(user_input=form.input_text, fields=fetched_template.fields, pdf_form_path=fetched_template.pdf_path) + if not os.path.exists(path): + raise AppError( + f"PDF was generated but file not found at: {path}", + status_code=500 + ) - submission = FormSubmission(**form.model_dump(), output_pdf_path=path) + submission = FormSubmission( + **form.model_dump(), + output_pdf_path=path + ) return create_form(db, submission) +@router.get("/{submission_id}", response_model=FormFillResponse) +def get_submission(submission_id: int, db: Session = Depends(get_db)): + submission = get_form(db, submission_id) + if not submission: + raise AppError("Submission not found", status_code=404) + return submission + + +@router.get("/download/{submission_id}") +def download_filled_pdf(submission_id: int, db: Session = Depends(get_db)): + submission = get_form(db, submission_id) + if not submission: + raise AppError("Submission not found", status_code=404) + + file_path = submission.output_pdf_path + if not os.path.exists(file_path): + raise AppError("PDF file not found on server", status_code=404) + + return FileResponse( + path=file_path, + media_type="application/pdf", + filename=os.path.basename(file_path) + ) \ No newline at end of file diff --git a/api/routes/templates.py b/api/routes/templates.py index 5c2281b..9419ae6 100644 --- a/api/routes/templates.py +++ b/api/routes/templates.py @@ -1,16 +1,89 @@ -from fastapi import APIRouter, Depends +import os +import shutil +import uuid +from fastapi import APIRouter, Depends, UploadFile, File, Form from sqlmodel import Session from api.deps import get_db -from api.schemas.templates import TemplateCreate, TemplateResponse -from api.db.repositories import create_template +from api.schemas.templates import TemplateResponse +from api.db.repositories import create_template, get_all_templates from api.db.models import Template -from src.controller import Controller +from api.errors.base import AppError router = APIRouter(prefix="/templates", tags=["templates"]) +# Save directly into src/inputs/ — stable location, won't get wiped +TEMPLATES_DIR = os.path.join("src", "inputs") +os.makedirs(TEMPLATES_DIR, exist_ok=True) + + @router.post("/create", response_model=TemplateResponse) -def create(template: TemplateCreate, db: Session = Depends(get_db)): - controller = Controller() - template_path = controller.create_template(template.pdf_path) - tpl = Template(**template.model_dump(exclude={"pdf_path"}), pdf_path=template_path) - return create_template(db, tpl) \ No newline at end of file +async def create( + name: str = Form(...), + file: UploadFile = File(...), + db: Session = Depends(get_db) +): + # Validate PDF + if not file.filename.endswith(".pdf"): + raise AppError("Only PDF files are allowed", status_code=400) + + # Save uploaded file with unique name into src/inputs/ + unique_name = f"{uuid.uuid4().hex}_{file.filename}" + save_path = os.path.join(TEMPLATES_DIR, unique_name) + + with open(save_path, "wb") as f: + shutil.copyfileobj(file.file, f) + + # Extract fields using commonforms + pypdf + # Store as simple list of field name strings — what Filler expects + try: + from commonforms import prepare_form + from pypdf import PdfReader + + # Read real field names directly from original PDF + # Use /T (internal name) as both key and label + # Real names like "JobTitle", "Phone Number" are already human-readable + reader = PdfReader(save_path) + raw_fields = reader.get_fields() or {} + + fields = {} + for internal_name, field_data in raw_fields.items(): + # Use /TU tooltip if available, otherwise prettify /T name + label = None + if isinstance(field_data, dict): + label = field_data.get("/TU") + if not label: + # Prettify: "JobTitle" → "Job Title", "DATE7_af_date" → "Date" + import re + label = re.sub(r'([a-z])([A-Z])', r'\1 \2', internal_name) + label = re.sub(r'_af_.*$', '', label) # strip "_af_date" suffix + label = label.replace('_', ' ').strip().title() + fields[internal_name] = label + + except Exception as e: + print(f"Field extraction failed: {e}") + fields = [] + + # Save to DB + tpl = Template(name=name, pdf_path=save_path, fields=fields) + return create_template(db, tpl) + + +@router.get("", response_model=list[TemplateResponse]) +def list_templates( + limit: int = 100, + offset: int = 0, + db: Session = Depends(get_db) +): + return get_all_templates(db, limit=limit, offset=offset) + + +@router.get("/{template_id}", response_model=TemplateResponse) +def get_template_by_id( + template_id: int, + db: Session = Depends(get_db) +): + from api.db.repositories import get_template + tpl = get_template(db, template_id) + if not tpl: + raise AppError("Template not found", status_code=404) + return tpl \ No newline at end of file diff --git a/docs/SETUP.md b/docs/SETUP.md new file mode 100644 index 0000000..cf47642 --- /dev/null +++ b/docs/SETUP.md @@ -0,0 +1,228 @@ +# 🔥 FireForm — Setup & Usage Guide + +This guide covers how to install, run, and use FireForm locally on Windows, Linux, and macOS. + +--- + +## 📋 Prerequisites + +| Tool | Version | Purpose | +|------|---------|---------| +| Python | 3.11+ | Backend runtime | +| Ollama | 0.17.7+ | Local LLM server | +| Mistral 7B | latest | AI extraction model | +| Git | any | Clone the repository | + +--- + +## 🪟 Windows + +### 1. Clone the repository +```cmd +git clone https://github.com/fireform-core/FireForm.git +cd FireForm +``` + +### 2. Create and activate virtual environment +```cmd +python -m venv venv +venv\Scripts\activate +``` + +### 3. Install dependencies +```cmd +pip install -r requirements.txt +``` + +### 4. Install and start Ollama +Download Ollama from https://ollama.com/download/windows + +Then pull the Mistral model: +```cmd +ollama pull mistral +ollama serve +``` + +> Ollama runs on `http://localhost:11434` by default. Keep this terminal open. + +### 5. Initialize the database +```cmd +python -m api.db.init_db +``` + +### 6. Start the API server +```cmd +uvicorn api.main:app --reload +``` + +API is now running at `http://127.0.0.1:8000` + +### 7. Start the frontend +Open a new terminal: +```cmd +cd frontend +python -m http.server 3000 +``` + +Open `http://localhost:3000` in your browser. + +--- + + +## 🍎 macOS + +### 1. Clone and enter the repository +```bash +git clone https://github.com/fireform-core/FireForm.git +cd FireForm +``` + +### 2. Create and activate virtual environment +```bash +python3 -m venv venv +source venv/bin/activate +``` + +### 3. Install dependencies +```bash +pip install -r requirements.txt +``` + +### 4. Install and start Ollama +Download from https://ollama.com/download/mac or: +```bash +brew install ollama +ollama pull mistral +ollama serve & +``` + +### 5. Initialize the database +```bash +python -m api.db.init_db +``` + +### 6. Start the API server +```bash +uvicorn api.main:app --reload +``` + +### 7. Start the frontend +```bash +cd frontend +python3 -m http.server 3000 +``` + +--- + +## 🖥️ Using the Frontend + +Once everything is running, open `http://localhost:3000` in your browser. + +### Step 1 — Upload a PDF template +- Click **"Choose File"** and select any fillable PDF form +- Enter a name for the template +- Click **"Upload Template"** + +FireForm will automatically extract all form field names and their human-readable labels. + +### Step 2 — Fill the form +- Select your uploaded template from the dropdown +- In the text box, describe the incident or enter the information in natural language: + +``` +Employee name is John Smith. Employee ID is EMP-2024-789. +Job title is Firefighter Paramedic. Location is Station 12 Sacramento. +Department is Emergency Medical Services. Supervisor is Captain Rodriguez. +Phone number is 916-555-0147. +``` + +- Click **"Fill Form"** + +FireForm sends one request to Ollama (Mistral) which extracts all fields at once and returns structured JSON. + +### Step 3 — Download the filled PDF +- Click **"Download PDF"** to save the completed form + +--- + +## 🤖 How AI Extraction Works + +FireForm uses a **batch extraction** approach: + +``` +Traditional approach (slow): FireForm approach (fast): + Field 1 → Ollama call All fields → 1 Ollama call + Field 2 → Ollama call Mistral returns JSON with all values + Field 3 → Ollama call Parse → fill PDF + ...N calls total 1 call total (O(1)) +``` + +Field names are automatically read from the PDF's annotations and converted to human-readable labels before being sent to Mistral — so the model understands what each field means regardless of internal PDF naming conventions like `textbox_0_0`. + +**Example extraction:** +```json +{ + "NAME/SID": "John Smith", + "JobTitle": "Firefighter Paramedic", + "Department": "Emergency Medical Services", + "Phone Number": "916-555-0147", + "email": null +} +``` + +--- + +## 🧪 Running Tests + +```bash +python -m pytest tests/ -v +``` + +Expected output: **52 passed** + +See [TESTING.md](TESTING.md) for full test coverage details. + +--- + +## 🔧 Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `OLLAMA_HOST` | `http://localhost:11434` | Ollama server URL | + +To use a remote Ollama instance: +```bash +export OLLAMA_HOST=http://your-server:11434 # Linux/Mac +set OLLAMA_HOST=http://your-server:11434 # Windows +``` + +--- + +## 🐳 Docker (Coming Soon) + +Docker support is in progress. See [docker.md](docker.md) for current status. + +--- + +## ❓ Troubleshooting + +**`Form data requires python-multipart`** +```bash +pip install python-multipart +``` + +**`ModuleNotFoundError: No module named 'pypdf'`** +```bash +pip install pypdf +``` + +**`Could not connect to Ollama`** +- Make sure `ollama serve` is running +- Check Ollama is on port 11434: `curl http://localhost:11434` + +**`NameError: name 'Union' is not defined`** +- Pull latest changes: `git pull origin main` +- This bug is fixed in the current version + +**Tests fail with `ModuleNotFoundError: No module named 'api'`** +- Use `python -m pytest` instead of `pytest` \ No newline at end of file diff --git a/docs/TESTING.md b/docs/TESTING.md new file mode 100644 index 0000000..386763d --- /dev/null +++ b/docs/TESTING.md @@ -0,0 +1,64 @@ +# 🧪 Testing + +This document describes how to run the FireForm test suite locally. + +## Prerequisites + +Make sure you have installed all dependencies: + +```bash +pip install -r requirements.txt +``` + +## Running Tests + +From the project root directory: + +```bash +python -m pytest tests/ -v +``` + +> **Note:** Use `python -m pytest` instead of `pytest` directly to ensure the project root is on the Python path. + +## Test Coverage + +| File | Tests | What it covers | +|------|-------|----------------| +| `tests/test_llm.py` | 40 | LLM class — batch prompt, field extraction, plural handling, schema validation | +| `tests/test_templates.py` | 10 | `POST /templates/create`, `GET /templates`, `GET /templates/{id}` | +| `tests/test_forms.py` | 7 | `POST /forms/fill`, `GET /forms/{id}`, `GET /forms/download/{id}` | + +**Total: 57 tests** + +## Test Design + +- All tests use an **in-memory SQLite database** — your local `fireform.db` is never touched +- Each test gets a **fresh empty database** — no data leaks between tests +- Ollama is **never called** during tests — all LLM calls are mocked + +## Key Test Cases + +**LLM extraction (`test_llm.py`)** +- Batch prompt contains all field keys and human-readable labels +- `main_loop()` makes exactly **1 Ollama call** regardless of field count (O(1) assertion) +- Graceful fallback when Mistral returns invalid JSON +- `-1` responses stored as `None`, not as the string `"-1"` + +**Template endpoints (`test_templates.py`)** +- Valid PDF upload returns 200 with field data +- Non-PDF upload returns 400 +- Missing file returns 422 +- Non-existent template returns 404 + +**Form endpoints (`test_forms.py`)** +- Non-existent template returns 404 +- Ollama connection failure returns 503 +- Missing filled PDF on disk returns 404 +- Non-existent submission returns 404 + +**Schema validation (`test_llm.py::TestSchemaValidation`)** +- Valid extraction returns no warnings +- Invalid email (missing @) is flagged +- Same value in 3+ fields flagged as hallucination +- None values are skipped (no false positives) +- Warnings accessible via `get_validation_warnings()` \ No newline at end of file diff --git a/docs/demo/filled_form_output.pdf b/docs/demo/filled_form_output.pdf new file mode 100644 index 0000000..6587e43 Binary files /dev/null and b/docs/demo/filled_form_output.pdf differ diff --git a/docs/demo/frontend_ui.png b/docs/demo/frontend_ui.png new file mode 100644 index 0000000..856c696 Binary files /dev/null and b/docs/demo/frontend_ui.png differ diff --git a/docs/demo/frontend_ui02.png b/docs/demo/frontend_ui02.png new file mode 100644 index 0000000..ca84a72 Binary files /dev/null and b/docs/demo/frontend_ui02.png differ diff --git a/docs/frontend.md b/docs/frontend.md new file mode 100644 index 0000000..22d2b55 --- /dev/null +++ b/docs/frontend.md @@ -0,0 +1,218 @@ +# Frontend UI Guide + +This guide explains how to set up and use the FireForm browser-based frontend interface. + +## Overview + +The FireForm frontend is a single-page web application (`frontend/index.html`) that provides a user-friendly interface for non-technical first responders to: + +- Upload and save fillable PDF form templates +- Describe incidents in plain language +- Auto-fill forms using local AI (Mistral via Ollama) +- Download completed PDF forms instantly + +> [!IMPORTANT] +> The frontend communicates with the FastAPI backend at `http://127.0.0.1:8000`. Make sure both Ollama and the API server are running before opening the frontend. + +--- + +## Prerequisites + +Before running the frontend, ensure the following are set up: + +> [!IMPORTANT] +> Complete the database setup described in [db.md](db.md) first. + +1. **Ollama** installed and running — [https://ollama.com/download](https://ollama.com/download) +2. **Mistral model** pulled: + ```bash + ollama pull mistral + ``` +3. **Dependencies** installed: + ```bash + pip install -r requirements.txt + ``` + +--- + +## Running the Frontend + +### Step 1 — Start Ollama + +In a terminal, run: + +```bash +ollama serve +``` + +> [!TIP] +> Leave this terminal open. Ollama must stay running for AI extraction to work. + +### Step 2 — Initialize the Database + +```bash +python -m api.db.init_db +``` + +### Step 3 — Start the API Server + +In a new terminal, from the project root: + +```bash +uvicorn api.main:app --reload +``` + +If successful, you will see: +`INFO: Uvicorn running on http://127.0.0.1:8000` + +### Step 4 — Open the Frontend + +Open `frontend/index.html` directly in your browser by double-clicking it, or navigate to it in your file explorer. + +> [!NOTE] +> No additional server is required for the frontend. It is a static HTML file that communicates directly with the FastAPI backend. + +--- + +## Using the Frontend + +The interface guides you through 4 steps: + +### Step 1 — Upload a Template + +1. Click **"Click to upload"** or drag and drop a fillable PDF form +2. Enter a name for the template (e.g. `Cal Fire Incident Report`) +3. Click **"SAVE TEMPLATE →"** + +The template is saved to the database and will appear in the **Saved Templates** list. + +> [!TIP] +> Any fillable PDF form works. The system automatically detects all form fields. + +### Step 2 — Select a Template + +Click any saved template from the **Saved Templates** list in the sidebar. The selected template will be highlighted in red. + +### Step 3 — Describe the Incident + +Type or paste a plain-language description of the incident in the text area. For best results, include all relevant details that match your form's fields. + +**Example for an employee form:** +``` +The employee's name is John Smith. His employee ID is EMP-2024-789. +His job title is Firefighter Paramedic. His location is Station 12, +Sacramento. His department is Emergency Medical Services. His supervisor +is Captain Jane Rodriguez. His phone number is 916-555-0147. +His email is jsmith@calfire.ca.gov. +``` + +**Example for an incident report form:** +``` +Officer Hernandez responding to a structure fire at 742 Evergreen Terrace. +Two occupants evacuated safely. Minor smoke inhalation treated on scene +by EMS. Unit 7 on scene at 14:32, cleared at 16:45. +Handed off to Deputy Martinez. +``` + +### Step 4 — Fill and Download + +Click **"⚡ FILL FORM"**. The system will: + +1. Send the description to Mistral (running locally via Ollama) +2. Extract all relevant field values +3. Fill the PDF template automatically +4. Provide a **"⬇ Download PDF"** button + +> [!NOTE] +> Processing time depends on your hardware. Typically 10–30 seconds with Mistral on a standard machine. + +--- + +## API Endpoints + +The frontend uses the following API endpoints: + +| Method | Endpoint | Description | +|--------|----------|-------------| +| `POST` | `/templates/create` | Upload a new PDF template | +| `GET` | `/templates` | List all saved templates | +| `GET` | `/templates/{id}` | Get a specific template | +| `POST` | `/forms/fill` | Fill a form with incident text | +| `GET` | `/forms/{id}` | Get a submission record | +| `GET` | `/forms/download/{id}` | Download a filled PDF | + +For full API documentation, visit [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) while the server is running. + +--- + +## API Status Indicator + +The top-right corner of the frontend shows the API connection status: + +- 🟢 **api online** — Backend is reachable, ready to use +- 🔴 **api offline** — Backend is not running, check uvicorn + +--- + +## Troubleshooting + +### "api offline" shown in the top bar + +The FastAPI server is not running. Start it with: +```bash +uvicorn api.main:app --reload +``` + +### Form fills with null or incorrect values + +This happens when the incident description does not contain information matching the PDF form fields. Ensure your description includes the specific data your form requires (names, dates, locations, etc.). + +See [Issue #113](https://github.com/fireform-core/FireForm/issues/113) for context on matching input to templates. + +### "Could not connect to Ollama" error + +Ollama is not running. Start it with: +```bash +ollama serve +``` + +Then verify Mistral is available: +```bash +ollama list +``` + +If Mistral is not listed, pull it: +```bash +ollama pull mistral +``` + +### Port conflict on 11434 + +Something else is using Ollama's port. On Linux/Mac: +```bash +sudo lsof -i :11434 +``` +On Windows: +```cmd +netstat -ano | findstr :11434 +``` + +--- + +## Privacy + +> [!IMPORTANT] +> FireForm is designed to be fully private. All AI processing happens locally via Ollama. No incident data, form content, or personal information is ever sent to external servers. + +--- + +## Docker Usage + +To run the full stack including the frontend API via Docker: + +```bash +chmod +x container-init.sh +./container-init.sh +``` + +See [docker.md](docker.md) for full Docker setup instructions. diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..a3b0083 --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,467 @@ + + + + + +FireForm — Report Once, File Everywhere + + + + +
+
+ + + +
+ + +
+
+
UN Digital Public Good · GSoC 2026
+

REPORT
ONCE.

+

Describe any incident in plain language. FireForm uses a locally-running AI to extract every relevant detail and auto-fill all required agency forms — instantly and privately.

+
+ +
+
+
1
+
Upload Template
Any fillable PDF form
+
+
+
2
+
Select Template
Choose from saved forms
+
+
+
3
+
Describe Incident
Plain language report
+
+
+
4
+
Download PDF
All fields auto-filled
+
+
+ +
+
← Select a template from the sidebar
+
+ Incident Description * + 0 chars +
+ +
+ +
Runs via Ollama locally.
No data leaves your machine.
+
+
+
+
+
Mistral is extracting data and filling your form...
+
+
+
+
✓ FORM FILLED SUCCESSFULLY
+ ⬇ Download PDF +
+
+
+
+
+
+ +
+
+
Session History
+
0 submissions
+
+
+
No submissions yet this session.
+
+
+
+
+ + + + \ No newline at end of file diff --git a/src/llm.py b/src/llm.py index 70937f9..950b1b6 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,15 +1,37 @@ import json import os +import re import requests +# ── Field-type patterns for schema validation ───────────────────────────────── +FIELD_PATTERNS = { + "phone": re.compile(r"[\d\s\-\+\(\)\.]{7,20}"), + "email": re.compile(r"[^@\s]+@[^@\s]+\.[^@\s]+"), + "date": re.compile(r"\d{1,2}[\/\-\.]\d{1,2}[\/\-\.]\d{2,4}|\d{4}[\/\-]\d{2}[\/\-]\d{2}"), + "id": re.compile(r"[A-Z0-9\-]{3,}"), +} + +FIELD_TYPE_HINTS = { + "phone": ["phone", "tel", "contact", "number"], + "email": ["email", "mail"], + "date": ["date", "time", "when", "dob"], + "id": ["id", "badge", "sid", "identifier", "emp"], +} + + class LLM: def __init__(self, transcript_text=None, target_fields=None, json=None): + """ + target_fields: dict or list containing the template field names to extract + (dict format: {"field_name": "human_label"}, list format: ["field_name1", "field_name2"]) + """ if json is None: json = {} self._transcript_text = transcript_text # str - self._target_fields = target_fields # List, contains the template field. + self._target_fields = target_fields # dict or list self._json = json # dictionary + self._validation_warnings = [] # list of validation issues found def type_check_all(self): if type(self._transcript_text) is not str: @@ -17,64 +39,270 @@ def type_check_all(self): f"ERROR in LLM() attributes ->\ Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}" ) - elif type(self._target_fields) is not list: + if not isinstance(self._target_fields, (list, dict)): raise TypeError( f"ERROR in LLM() attributes ->\ - Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" + Target fields must be a list or dict. Input:\n\ttarget_fields: {self._target_fields}" + ) + + def validate_extracted_fields(self) -> list: + """ + Schema validation — checks extracted values match expected field types. + + Validates: + - Phone numbers contain digits in expected format + - Emails contain @ and a domain + - Dates match common date patterns + - No field value exceeds 500 chars (hallucination indicator) + - No field is suspiciously repeated across multiple fields + + Returns a list of warning strings. Empty list = all valid. + Never raises — validation issues are warnings, not hard failures. + + Closes Issue #114. + """ + warnings = [] + values_seen = {} # track repeated values across fields + + for field, value in self._json.items(): + if value is None: + continue + + str_value = str(value).strip() + field_lower = field.lower() + + # ── 1. Length check — long values suggest hallucination ────────── + if len(str_value) > 500: + warnings.append( + f"[SCHEMA] '{field}': value suspiciously long " + f"({len(str_value)} chars) — possible hallucination" + ) + + # ── 2. Repeated value check — same value in 3+ fields = hallucination ── + if str_value not in values_seen: + values_seen[str_value] = [] + values_seen[str_value].append(field) + + # ── 3. Field-type pattern validation ───────────────────────────── + detected_type = None + for ftype, hints in FIELD_TYPE_HINTS.items(): + if any(hint in field_lower for hint in hints): + detected_type = ftype + break + + if detected_type and detected_type in FIELD_PATTERNS: + pattern = FIELD_PATTERNS[detected_type] + if not pattern.search(str_value): + warnings.append( + f"[SCHEMA] '{field}': expected {detected_type} format, " + f"got '{str_value}' — may be incorrectly extracted" + ) + + # ── 4. Email-specific check ─────────────────────────────────────── + if "email" in field_lower and value is not None: + if "@" not in str_value: + warnings.append( + f"[SCHEMA] '{field}': value '{str_value}' does not " + f"look like a valid email address" + ) + + # ── 5. Global repeated-value check ─────────────────────────────────── + for val, fields in values_seen.items(): + if len(fields) >= 3: + warnings.append( + f"[SCHEMA] Possible hallucination — value '{val}' " + f"appears in {len(fields)} fields: {fields}" + ) + + self._validation_warnings = warnings + + if warnings: + print("\t[SCHEMA VALIDATION] Issues found:") + for w in warnings: + print(f"\t {w}") + else: + print("\t[SCHEMA VALIDATION] All fields passed validation ✓") + + return warnings + + def get_validation_warnings(self) -> list: + """Return validation warnings from last validate_extracted_fields() call.""" + return self._validation_warnings + + def build_batch_prompt(self) -> str: + """ + Build a single prompt that extracts ALL fields at once. + Sends human-readable labels as context so Mistral understands + what each internal field name means. + Fixes Issue #196 — reduces N Ollama calls to 1. + """ + if isinstance(self._target_fields, dict): + fields_lines = "\n".join( + f' "{k}": null // {v if v and v != k else k}' + for k, v in self._target_fields.items() ) + else: + fields_lines = "\n".join( + f' "{f}": null' + for f in self._target_fields + ) + + prompt = f"""You are filling out an official form. Extract values from the transcript below. + +FORM FIELDS (each line: "internal_key": null // visible label on form): +{{ +{fields_lines} +}} + +RULES: +1. Return ONLY a valid JSON object — no explanation, no markdown, no extra text +2. Use the visible label (after //) to understand what each field means +3. Fill each key with the matching value from the transcript +4. If a value is not found in the transcript, use null +5. Never invent or guess values not present in the transcript +6. For multiple values (e.g. multiple victims), use a semicolon-separated string: "Name1; Name2" +7. Distinguish roles carefully: Officer/Employee is NOT the same as Victim or Suspect - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ +TRANSCRIPT: +{self._transcript_text} + +JSON:""" + + return prompt + + def build_prompt(self, current_field: str) -> str: + """ + Legacy single-field prompt — kept for backward compatibility. + Used as fallback if batch parsing fails. + """ + field_lower = current_field.lower() + is_plural = current_field.endswith('s') and not current_field.lower().endswith('ss') + + if any(w in field_lower for w in ['officer', 'employee', 'dispatcher', 'caller', 'reporting', 'supervisor']): + role_guidance = """ +ROLE: Extract the PRIMARY OFFICER/EMPLOYEE/DISPATCHER +- This is typically the person speaking or reporting the incident +- DO NOT extract victims, witnesses, or members of the public +- Example: "Officer Smith reporting... victims are John and Jane" → extract "Smith" +""" + elif any(w in field_lower for w in ['victim', 'injured', 'affected', 'casualty', 'patient']): + role_guidance = f""" +ROLE: Extract VICTIM/AFFECTED PERSON(S) +- Focus on people who experienced harm +- Ignore officers, dispatchers, and witnesses +{'- Return ALL names separated by ";"' if is_plural else '- Return the FIRST/PRIMARY victim'} +""" + elif any(w in field_lower for w in ['location', 'address', 'street', 'place', 'where']): + role_guidance = """ +ROLE: Extract LOCATION/ADDRESS +- Extract WHERE the incident occurred +- Return only the incident location, not other addresses mentioned +""" + elif any(w in field_lower for w in ['date', 'time', 'when', 'occurred', 'reported']): + role_guidance = """ +ROLE: Extract DATE/TIME +- Extract WHEN the incident occurred +- Return in the format it appears in the text +""" + elif any(w in field_lower for w in ['phone', 'number', 'contact', 'tel']): + role_guidance = "ROLE: Extract PHONE NUMBER — return exactly as it appears in text" + elif any(w in field_lower for w in ['email', 'mail']): + role_guidance = "ROLE: Extract EMAIL ADDRESS" + elif any(w in field_lower for w in ['department', 'unit', 'division']): + role_guidance = "ROLE: Extract DEPARTMENT/UNIT name" + elif any(w in field_lower for w in ['title', 'job', 'role', 'rank', 'position']): + role_guidance = "ROLE: Extract JOB TITLE or RANK" + elif any(w in field_lower for w in ['id', 'badge', 'identifier']): + role_guidance = "ROLE: Extract ID or BADGE NUMBER" + elif any(w in field_lower for w in ['description', 'incident', 'detail', 'nature', 'summary']): + role_guidance = "ROLE: Extract a brief INCIDENT DESCRIPTION" + else: + role_guidance = f""" +ROLE: Generic extraction for field "{current_field}" +{'- Return MULTIPLE values separated by ";" if applicable' if is_plural else '- Return the PRIMARY matching value'} +""" + + prompt = f""" +SYSTEM: You are extracting specific information from an incident report transcript. + +FIELD TO EXTRACT: {current_field} +{'[SINGULAR - Extract ONE value]' if not is_plural else '[PLURAL - Extract MULTIPLE values separated by semicolon]'} + +EXTRACTION RULES: +{role_guidance} + +CRITICAL RULES: +1. Read the ENTIRE text before answering +2. Extract ONLY what belongs to this specific field +3. Return values exactly as they appear in the text +4. If not found, return: -1 + +TRANSCRIPT: +{self._transcript_text} + +ANSWER: Return ONLY the extracted value(s), nothing else.""" return prompt def main_loop(self): - # self.type_check_all() - for field in self._target_fields.keys(): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } + """ + Single batch Ollama call — extracts ALL fields in one request. + Falls back to per-field extraction if JSON parsing fails. + Runs schema validation after extraction. + Fixes Issue #196 (O(N) → O(1) LLM calls). + """ + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + ollama_url = f"{ollama_host}/api/generate" - try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") + # Get field keys for result mapping + if isinstance(self._target_fields, dict): + field_keys = list(self._target_fields.keys()) + else: + field_keys = list(self._target_fields) + + # ── Single batch call ───────────────────────────────────── + prompt = self.build_batch_prompt() + payload = {"model": "mistral", "prompt": prompt, "stream": False} + + try: + response = requests.post(ollama_url, json=payload) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Could not connect to Ollama at {ollama_url}. " + "Please ensure Ollama is running and accessible." + ) + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Ollama returned an error: {e}") + + raw = response.json()["response"].strip() - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) + # Strip markdown code fences if Mistral wraps in ```json ... ``` + raw = raw.replace("```json", "").replace("```", "").strip() + + print("----------------------------------") + print("\t[LOG] Raw Mistral batch response:") + print(raw) + + # ── Parse JSON response ─────────────────────────────────── + try: + extracted = json.loads(raw) + for key in field_keys: + val = extracted.get(key) + if val and str(val).lower() not in ("null", "none", ""): + self._json[key] = val + else: + self._json[key] = None + + print("\t[LOG] Batch extraction successful.") + + except json.JSONDecodeError: + print("\t[WARN] Batch JSON parse failed — falling back to per-field extraction") + self._json = {} + self._fallback_per_field(ollama_url, field_keys) + + # ── Schema validation ───────────────────────────────────── + self.validate_extracted_fields() print("----------------------------------") print("\t[LOG] Resulting JSON created from the input text:") @@ -83,10 +311,36 @@ def main_loop(self): return self + def _fallback_per_field(self, ollama_url: str, field_keys: list): + """ + Legacy per-field extraction — used only when batch JSON parse fails. + """ + print("\t[LOG] Running fallback per-field extraction...") + + for field in field_keys: + if isinstance(self._target_fields, dict): + label = self._target_fields.get(field, field) + if not label or label == field: + label = field + else: + label = field + + prompt = self.build_prompt(label) + payload = {"model": "mistral", "prompt": prompt, "stream": False} + + try: + response = requests.post(ollama_url, json=payload) + response.raise_for_status() + parsed_response = response.json()["response"] + self.add_response_to_json(field, parsed_response) + except Exception as e: + print(f"\t[WARN] Failed to extract field '{field}': {e}") + self._json[field] = None + def add_response_to_json(self, field, value): """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict + Add extracted value under field name. + Handles plural (semicolon-separated) values. """ value = value.strip().replace('"', "") parsed_value = None @@ -94,42 +348,35 @@ def add_response_to_json(self, field, value): if value != "-1": parsed_value = value - if ";" in value: - parsed_value = self.handle_plural_values(value) + if parsed_value and ";" in parsed_value: + parsed_value = self.handle_plural_values(parsed_value) - if field in self._json.keys(): - self._json[field].append(parsed_value) + if field in self._json: + existing = self._json[field] + if isinstance(existing, list): + if isinstance(parsed_value, list): + existing.extend(parsed_value) + else: + existing.append(parsed_value) + else: + self._json[field] = [existing, parsed_value] else: self._json[field] = parsed_value - return - def handle_plural_values(self, plural_value): """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] + Split semicolon-separated values into a list. + "Mark Smith; Jane Doe" → ["Mark Smith", "Jane Doe"] """ if ";" not in plural_value: raise ValueError( f"Value is not plural, doesn't have ; separator, Value: {plural_value}" ) - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) - values = plural_value.split(";") - - # Remove trailing leading whitespace - for i in range(len(values)): - current = i + 1 - if current < len(values): - clean_value = values[current].lstrip() - values[current] = clean_value - + print(f"\t[LOG]: Formatting plural values for JSON, [For input {plural_value}]...") + values = [v.strip() for v in plural_value.split(";") if v.strip()] print(f"\t[LOG]: Resulting formatted list of values: {values}") - return values def get_data(self): - return self._json + return self._json \ No newline at end of file diff --git a/src/main.py b/src/main.py index 5bb632b..e07578b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,6 @@ import os # from backend import Fill +from typing import Union from commonforms import prepare_form from pypdf import PdfReader from controller import Controller diff --git a/tests/conftest.py b/tests/conftest.py index 7cb4db3..ff92c19 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,10 @@ from sqlalchemy.pool import StaticPool import pytest - from api.main import app from api.deps import get_db from api.db.models import Template, FormSubmission -# In-memory SQLite database for tests TEST_DATABASE_URL = "sqlite://" engine = create_engine( @@ -23,12 +21,12 @@ def override_get_db(): yield session -# Apply dependency override app.dependency_overrides[get_db] = override_get_db -@pytest.fixture(scope="session", autouse=True) -def create_test_db(): +@pytest.fixture(autouse=True) +def reset_db(): + SQLModel.metadata.drop_all(engine) SQLModel.metadata.create_all(engine) yield SQLModel.metadata.drop_all(engine) @@ -37,3 +35,10 @@ def create_test_db(): @pytest.fixture def client(): return TestClient(app) + + +@pytest.fixture +def db_session(): + """Direct DB session for test setup.""" + with Session(engine) as session: + yield session diff --git a/tests/test_forms.py b/tests/test_forms.py index 8f432bf..5e32755 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1,25 +1,107 @@ -def test_submit_form(client): - pass - # First create a template - # form_payload = { - # "template_id": 3, - # "input_text": "Hi. The employee's name is John Doe. His job title is managing director. His department supervisor is Jane Doe. His phone number is 123456. His email is jdoe@ucsc.edu. The signature is , and the date is 01/02/2005", - # } - - # template_res = client.post("/templates/", json=template_payload) - # template_id = template_res.json()["id"] - - # # Submit a form - # form_payload = { - # "template_id": template_id, - # "data": {"rating": 5, "comment": "Great service"}, - # } - - # response = client.post("/forms/", json=form_payload) - - # assert response.status_code == 200 - - # data = response.json() - # assert data["id"] is not None - # assert data["template_id"] == template_id - # assert data["data"] == form_payload["data"] +""" +Tests for /forms endpoints. +Closes #165, #205, #163 +""" + +import pytest +from unittest.mock import patch +from api.db.models import Template, FormSubmission +from datetime import datetime + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def make_template(db_session): + t = Template( + name="Test Form", + fields={"JobTitle": "Job Title"}, + pdf_path="/tmp/test.pdf", + created_at=datetime.utcnow(), + ) + db_session.add(t) + db_session.commit() + db_session.refresh(t) + return t.id + + +def make_submission(db_session, template_id, output_path="/tmp/filled.pdf"): + s = FormSubmission( + template_id=template_id, + input_text="John Smith is a firefighter.", + output_pdf_path=output_path, + created_at=datetime.utcnow(), + ) + db_session.add(s) + db_session.commit() + db_session.refresh(s) + return s.id + + +# ── POST /forms/fill ────────────────────────────────────────────────────────── + +class TestFillForm: + + def test_fill_form_template_not_found(self, client): + """Returns 404 when template_id does not exist.""" + response = client.post("/forms/fill", json={ + "template_id": 999999, + "input_text": "John Smith is a firefighter.", + }) + assert response.status_code == 404 + + def test_fill_form_missing_fields_returns_422(self, client): + """Returns 422 when required fields are missing.""" + response = client.post("/forms/fill", json={ + "template_id": 1, + }) + assert response.status_code == 422 + + def test_fill_form_ollama_down_returns_503(self, client, db_session): + """Returns 503 when Ollama is not reachable.""" + template_id = make_template(db_session) + + with patch("src.controller.Controller.fill_form", + side_effect=ConnectionError("Ollama not running")): + response = client.post("/forms/fill", json={ + "template_id": template_id, + "input_text": "John Smith is a firefighter.", + }) + + assert response.status_code == 503 + + +# ── GET /forms/{submission_id} ──────────────────────────────────────────────── + +class TestGetSubmission: + + def test_get_submission_not_found(self, client): + """Returns 404 for non-existent submission ID.""" + response = client.get("/forms/999999") + assert response.status_code == 404 + + def test_get_submission_invalid_id(self, client): + """Returns 422 for non-integer submission ID.""" + response = client.get("/forms/not-an-id") + assert response.status_code == 422 + + +# ── GET /forms/download/{submission_id} ─────────────────────────────────────── + +class TestDownloadSubmission: + + def test_download_not_found_submission(self, client): + """Returns 404 when submission does not exist.""" + response = client.get("/forms/download/999999") + assert response.status_code == 404 + + def test_download_file_missing_on_disk(self, client, db_session): + """Returns 404 when submission exists but PDF missing on disk.""" + template_id = make_template(db_session) + submission_id = make_submission( + db_session, template_id, "/nonexistent/filled.pdf" + ) + + with patch("os.path.exists", return_value=False): + response = client.get(f"/forms/download/{submission_id}") + + assert response.status_code == 404 diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..cfe483b --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,278 @@ +""" +Unit tests for src/llm.py — LLM class. + +Closes: #186 (Unit tests for LLM class methods) +Covers: batch prompt, per-field prompt, add_response_to_json, + handle_plural_values, type_check_all, main_loop (mocked) +""" + +import json +import pytest +from unittest.mock import patch, MagicMock +from src.llm import LLM + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + +@pytest.fixture +def dict_fields(): + """Realistic dict fields: {internal_name: human_label}""" + return { + "NAME/SID": "Employee Or Student Name", + "JobTitle": "Job Title", + "Department": "Department", + "Phone Number": "Phone Number", + "email": "Email", + } + +@pytest.fixture +def list_fields(): + """Legacy list fields: [internal_name, ...]""" + return ["officer_name", "location", "incident_date"] + +@pytest.fixture +def transcript(): + return ( + "Employee name is John Smith. Employee ID is EMP-2024-789. " + "Job title is Firefighter Paramedic. Department is Emergency Medical Services. " + "Phone number is 916-555-0147." + ) + +@pytest.fixture +def llm_dict(dict_fields, transcript): + return LLM(transcript_text=transcript, target_fields=dict_fields) + +@pytest.fixture +def llm_list(list_fields, transcript): + return LLM(transcript_text=transcript, target_fields=list_fields) + + +# ── type_check_all ──────────────────────────────────────────────────────────── + +class TestTypeCheckAll: + + def test_raises_on_non_string_transcript(self, dict_fields): + llm = LLM(transcript_text=12345, target_fields=dict_fields) + with pytest.raises(TypeError, match="Transcript must be text"): + llm.type_check_all() + + def test_raises_on_none_transcript(self, dict_fields): + llm = LLM(transcript_text=None, target_fields=dict_fields) + with pytest.raises(TypeError): + llm.type_check_all() + + def test_raises_on_invalid_fields_type(self, transcript): + llm = LLM(transcript_text=transcript, target_fields="not_a_list_or_dict") + with pytest.raises(TypeError, match="list or dict"): + llm.type_check_all() + + def test_passes_with_dict_fields(self, llm_dict): + # Should not raise + llm_dict.type_check_all() + + def test_passes_with_list_fields(self, llm_list): + # Should not raise + llm_list.type_check_all() + + +# ── build_batch_prompt ──────────────────────────────────────────────────────── + +class TestBuildBatchPrompt: + + def test_contains_all_field_keys(self, llm_dict, dict_fields): + prompt = llm_dict.build_batch_prompt() + for key in dict_fields.keys(): + assert key in prompt, f"Field key '{key}' missing from batch prompt" + + def test_contains_human_labels(self, llm_dict, dict_fields): + prompt = llm_dict.build_batch_prompt() + for label in dict_fields.values(): + assert label in prompt, f"Label '{label}' missing from batch prompt" + + def test_contains_transcript(self, llm_dict, transcript): + prompt = llm_dict.build_batch_prompt() + assert transcript in prompt + + def test_contains_json_instruction(self, llm_dict): + prompt = llm_dict.build_batch_prompt() + assert "JSON" in prompt + + def test_list_fields_batch_prompt(self, llm_list, list_fields): + prompt = llm_list.build_batch_prompt() + for field in list_fields: + assert field in prompt + + def test_labels_used_as_comments(self, llm_dict): + """Human labels should appear after // in the prompt""" + prompt = llm_dict.build_batch_prompt() + assert "//" in prompt + + +# ── build_prompt (legacy per-field) ────────────────────────────────────────── + +class TestBuildPrompt: + + def test_officer_field_gets_officer_guidance(self, llm_dict): + prompt = llm_dict.build_prompt("officer_name") + assert "OFFICER" in prompt.upper() or "EMPLOYEE" in prompt.upper() + + def test_location_field_gets_location_guidance(self, llm_dict): + prompt = llm_dict.build_prompt("incident_location") + assert "LOCATION" in prompt.upper() or "ADDRESS" in prompt.upper() + + def test_victim_field_gets_victim_guidance(self, llm_dict): + prompt = llm_dict.build_prompt("victim_name") + assert "VICTIM" in prompt.upper() + + def test_phone_field_gets_phone_guidance(self, llm_dict): + prompt = llm_dict.build_prompt("Phone Number") + assert "PHONE" in prompt.upper() + + def test_prompt_contains_transcript(self, llm_dict, transcript): + prompt = llm_dict.build_prompt("some_field") + assert transcript in prompt + + def test_generic_field_still_builds_prompt(self, llm_dict): + prompt = llm_dict.build_prompt("textbox_0_0") + assert len(prompt) > 50 + + +# ── handle_plural_values ────────────────────────────────────────────────────── + +class TestHandlePluralValues: + + def test_splits_on_semicolon(self, llm_dict): + result = llm_dict.handle_plural_values("Mark Smith;Jane Doe") + assert "Mark Smith" in result + assert "Jane Doe" in result + + def test_strips_whitespace(self, llm_dict): + result = llm_dict.handle_plural_values("Mark Smith; Jane Doe; Bob") + assert all(v == v.strip() for v in result) + + def test_returns_list(self, llm_dict): + result = llm_dict.handle_plural_values("A;B;C") + assert isinstance(result, list) + + def test_raises_without_semicolon(self, llm_dict): + with pytest.raises(ValueError, match="separator"): + llm_dict.handle_plural_values("no semicolon here") + + def test_three_values(self, llm_dict): + result = llm_dict.handle_plural_values("Alice;Bob;Charlie") + assert len(result) == 3 + + +# ── add_response_to_json ────────────────────────────────────────────────────── + +class TestAddResponseToJson: + + def test_stores_value_under_field(self, llm_dict): + llm_dict.add_response_to_json("NAME/SID", "John Smith") + assert llm_dict._json["NAME/SID"] == "John Smith" + + def test_ignores_minus_one(self, llm_dict): + llm_dict.add_response_to_json("email", "-1") + assert llm_dict._json["email"] is None + + def test_strips_quotes(self, llm_dict): + llm_dict.add_response_to_json("JobTitle", '"Firefighter"') + assert llm_dict._json["JobTitle"] == "Firefighter" + + def test_strips_whitespace(self, llm_dict): + llm_dict.add_response_to_json("Department", " EMS ") + assert llm_dict._json["Department"] == "EMS" + + def test_plural_value_becomes_list(self, llm_dict): + llm_dict.add_response_to_json("victims", "Mark Smith;Jane Doe") + assert isinstance(llm_dict._json["victims"], list) + + def test_existing_field_becomes_list(self, llm_dict): + """Adding to existing field should not overwrite silently.""" + llm_dict._json["NAME/SID"] = "John" + llm_dict.add_response_to_json("NAME/SID", "Jane") + assert isinstance(llm_dict._json["NAME/SID"], list) + + +# ── get_data ────────────────────────────────────────────────────────────────── + +class TestGetData: + + def test_returns_dict(self, llm_dict): + assert isinstance(llm_dict.get_data(), dict) + + def test_returns_same_reference_as_internal_json(self, llm_dict): + llm_dict._json["test_key"] = "test_value" + assert llm_dict.get_data()["test_key"] == "test_value" + + +# ── main_loop (mocked Ollama) ───────────────────────────────────────────────── + +class TestMainLoop: + + def _mock_response(self, json_body: dict): + """Build a mock requests.Response returning a valid Mistral JSON reply.""" + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = { + "response": json.dumps(json_body) + } + return mock_resp + + def test_batch_success_fills_all_fields(self, llm_dict, dict_fields): + expected = { + "NAME/SID": "John Smith", + "JobTitle": "Firefighter Paramedic", + "Department": "Emergency Medical Services", + "Phone Number": "916-555-0147", + "email": None, + } + with patch("requests.post", return_value=self._mock_response(expected)): + llm_dict.main_loop() + + result = llm_dict.get_data() + assert result["NAME/SID"] == "John Smith" + assert result["JobTitle"] == "Firefighter Paramedic" + assert result["Department"] == "Emergency Medical Services" + assert result["Phone Number"] == "916-555-0147" + + def test_batch_makes_exactly_one_ollama_call(self, llm_dict, dict_fields): + """Core performance requirement — O(1) not O(N).""" + expected = {k: "value" for k in dict_fields.keys()} + with patch("requests.post", return_value=self._mock_response(expected)) as mock_post: + llm_dict.main_loop() + + assert mock_post.call_count == 1, ( + f"Expected 1 Ollama call, got {mock_post.call_count}. " + "main_loop() must use batch extraction, not per-field." + ) + + def test_fallback_on_invalid_json(self, llm_dict, dict_fields): + """If Mistral returns non-JSON, fallback per-field runs without crash.""" + bad_response = MagicMock() + bad_response.raise_for_status = MagicMock() + bad_response.json.return_value = {"response": "This is not JSON at all."} + + good_response = MagicMock() + good_response.raise_for_status = MagicMock() + good_response.json.return_value = {"response": "John Smith"} + + # First call returns bad JSON, rest return single values + with patch("requests.post", side_effect=[bad_response] + [good_response] * len(dict_fields)): + llm_dict.main_loop() # should not raise + + def test_connection_error_raises_connection_error(self, llm_dict): + import requests as req + with patch("requests.post", side_effect=req.exceptions.ConnectionError): + with pytest.raises(ConnectionError, match="Ollama"): + llm_dict.main_loop() + + def test_null_values_stored_as_none(self, llm_dict, dict_fields): + """Mistral returning null should be stored as None, not the string 'null'.""" + response_with_nulls = {k: None for k in dict_fields.keys()} + with patch("requests.post", return_value=self._mock_response(response_with_nulls)): + llm_dict.main_loop() + + result = llm_dict.get_data() + for key in dict_fields.keys(): + assert result[key] is None, f"Expected None for '{key}', got {result[key]!r}" diff --git a/tests/test_templates.py b/tests/test_templates.py index bbced2b..9b7cf8e 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,18 +1,126 @@ -def test_create_template(client): - payload = { - "name": "Template 1", - "pdf_path": "src/inputs/file.pdf", - "fields": { - "Employee's name": "string", - "Employee's job title": "string", - "Employee's department supervisor": "string", - "Employee's phone number": "string", - "Employee's email": "string", - "Signature": "string", - "Date": "string", - }, - } - - response = client.post("/templates/create", json=payload) - - assert response.status_code == 200 +""" +Tests for /templates endpoints. +Closes #162, #160, #163 +""" + +import io +import pytest +from unittest.mock import patch, MagicMock +from api.db.models import Template +from datetime import datetime + + +# ── POST /templates/create ──────────────────────────────────────────────────── + +class TestCreateTemplate: + + def test_create_template_success(self, client): + """Uploading a valid PDF returns 200 with template data.""" + pdf_bytes = ( + b"%PDF-1.4\n1 0 obj<>endobj\n" + b"2 0 obj<>endobj\n" + b"3 0 obj<>endobj\n" + b"xref\n0 4\n0000000000 65535 f\n" + b"trailer<>\nstartxref\n0\n%%EOF" + ) + + mock_fields = { + "JobTitle": {"/T": "JobTitle", "/FT": "/Tx"}, + "Department": {"/T": "Department", "/FT": "/Tx"}, + } + + with patch("commonforms.prepare_form"), \ + patch("pypdf.PdfReader") as mock_reader, \ + patch("shutil.copyfileobj"), \ + patch("builtins.open", MagicMock()), \ + patch("os.path.exists", return_value=True), \ + patch("os.remove"): + + mock_reader.return_value.get_fields.return_value = mock_fields + + response = client.post( + "/templates/create", + files={"file": ("form.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + data={"name": "Vaccine Form"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Vaccine Form" + assert "id" in data + assert "fields" in data + + def test_create_template_without_file_returns_422(self, client): + """Missing file field returns 422 Unprocessable Entity.""" + response = client.post( + "/templates/create", + data={"name": "No File"}, + ) + assert response.status_code == 422 + + def test_create_template_non_pdf_returns_400(self, client): + """Uploading a non-PDF returns 400.""" + with patch("shutil.copyfileobj"), \ + patch("builtins.open", MagicMock()): + response = client.post( + "/templates/create", + files={"file": ("data.csv", io.BytesIO(b"a,b,c"), "text/csv")}, + data={"name": "CSV attempt"}, + ) + assert response.status_code == 400 + + +# ── GET /templates ──────────────────────────────────────────────────────────── + +class TestListTemplates: + + def test_list_templates_returns_200(self, client): + """GET /templates returns 200.""" + response = client.get("/templates") + assert response.status_code == 200 + + def test_list_templates_returns_list(self, client): + """Response is always a list.""" + response = client.get("/templates") + assert isinstance(response.json(), list) + + def test_list_templates_empty_on_fresh_db(self, client): + """Fresh DB returns empty list.""" + response = client.get("/templates") + assert response.json() == [] + + def test_list_templates_pagination_accepted(self, client): + """Pagination params accepted without error.""" + response = client.get("/templates?limit=5&offset=0") + assert response.status_code == 200 + + +# ── GET /templates/{template_id} ────────────────────────────────────────────── + +class TestGetTemplate: + + def test_get_template_not_found(self, client): + """Returns 404 for non-existent ID.""" + response = client.get("/templates/999999") + assert response.status_code == 404 + + def test_get_template_invalid_id_type(self, client): + """Returns 422 for non-integer ID.""" + response = client.get("/templates/not-an-id") + assert response.status_code == 422 + + def test_get_template_by_id(self, client, db_session): + """Returns correct template for valid ID.""" + t = Template( + name="Cal Fire Form", + fields={"officer_name": "Officer Name"}, + pdf_path="/tmp/cal_fire.pdf", + created_at=datetime.utcnow(), + ) + db_session.add(t) + db_session.commit() + db_session.refresh(t) + + response = client.get(f"/templates/{t.id}") + assert response.status_code == 200 + assert response.json()["name"] == "Cal Fire Form"