diff --git a/Makefile b/Makefile index 53eb56a..90fae26 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ pull-model: docker compose exec ollama ollama pull mistral test: - docker compose exec app python3 -m pytest src/test/ + docker compose exec app python3 -m pytest tests/ -v clean: docker compose down -v diff --git a/api/db/models.py b/api/db/models.py index f76c93b..61b4d90 100644 --- a/api/db/models.py +++ b/api/db/models.py @@ -1,6 +1,8 @@ from sqlmodel import SQLModel, Field from sqlalchemy import Column, JSON from datetime import datetime +import uuid + class Template(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) @@ -15,4 +17,51 @@ class FormSubmission(SQLModel, table=True): template_id: int input_text: str output_pdf_path: str + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class FillJob(SQLModel, table=True): + """ + Tracks an asynchronous form-fill job submitted via POST /forms/fill/async. + Clients poll GET /forms/jobs/{id} to check status and retrieve results. + """ + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + status: str = Field(default="pending") # pending | running | complete | failed + template_id: int + input_text: str + output_pdf_path: str | None = None + partial_results: dict | None = Field(default=None, sa_column=Column(JSON)) + field_confidence: dict | None = Field(default=None, sa_column=Column(JSON)) + error_message: str | None = None + created_at: datetime = Field(default_factory=datetime.utcnow) + + +class BatchSubmission(SQLModel, table=True): + """ + Tracks a multi-template batch fill submitted via POST /forms/fill/batch. + + A single BatchSubmission represents one incident transcript filled into + N agency forms simultaneously using a single canonical LLM extraction pass. + The canonical_extraction JSON column stores the full per-field evidence + record (value + verbatim transcript quote + confidence) for audit purposes. + + submission_ids links to the individual FormSubmission records created for + each template so clients can retrieve per-template output PDF paths. + errors stores per-template error messages for partial failure cases. + """ + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + status: str = Field(default="complete") # complete | partial | failed + input_text: str + # Full canonical extraction: category -> {value, evidence, confidence} + canonical_extraction: dict | None = Field(default=None, sa_column=Column(JSON)) + # Evidence report: only categories with non-null extracted values + evidence_report: dict | None = Field(default=None, sa_column=Column(JSON)) + # List of template IDs that were requested + template_ids: list = Field(sa_column=Column(JSON)) + # List of FormSubmission integer IDs created (one per template) + submission_ids: list | None = Field(default=None, sa_column=Column(JSON)) + # Per-template output paths: {template_id: output_pdf_path} + output_paths: dict | None = Field(default=None, sa_column=Column(JSON)) + # Per-template errors: {template_id: error_message} for partial failures + errors: dict | None = Field(default=None, sa_column=Column(JSON)) created_at: datetime = Field(default_factory=datetime.utcnow) \ No newline at end of file diff --git a/api/db/repositories.py b/api/db/repositories.py index 6608718..077d89d 100644 --- a/api/db/repositories.py +++ b/api/db/repositories.py @@ -1,5 +1,5 @@ from sqlmodel import Session, select -from api.db.models import Template, FormSubmission +from api.db.models import Template, FormSubmission, FillJob, BatchSubmission # Templates def create_template(session: Session, template: Template) -> Template: @@ -16,4 +16,35 @@ def create_form(session: Session, form: FormSubmission) -> FormSubmission: session.add(form) session.commit() session.refresh(form) - return form \ No newline at end of file + return form + +# Fill Jobs +def create_job(session: Session, job: FillJob) -> FillJob: + session.add(job) + session.commit() + session.refresh(job) + return job + +def get_job(session: Session, job_id: str) -> FillJob | None: + return session.get(FillJob, job_id) + +def update_job(session: Session, job_id: str, **kwargs) -> FillJob | None: + job = session.get(FillJob, job_id) + if not job: + return None + for key, value in kwargs.items(): + setattr(job, key, value) + session.add(job) + session.commit() + session.refresh(job) + return job + +# Batch Submissions +def create_batch(session: Session, batch: BatchSubmission) -> BatchSubmission: + session.add(batch) + session.commit() + session.refresh(batch) + return batch + +def get_batch(session: Session, batch_id: str) -> BatchSubmission | None: + return session.get(BatchSubmission, batch_id) \ No newline at end of file diff --git a/api/main.py b/api/main.py index d0b8c79..64057b9 100644 --- a/api/main.py +++ b/api/main.py @@ -1,7 +1,15 @@ from fastapi import FastAPI -from api.routes import templates, forms +from api.routes import templates, forms, batch +from api.errors.handlers import register_exception_handlers -app = FastAPI() +app = FastAPI( + title="FireForm", + description="Report once, file everywhere — multi-agency incident form filling.", + version="0.2.0", +) + +register_exception_handlers(app) app.include_router(templates.router) -app.include_router(forms.router) \ No newline at end of file +app.include_router(forms.router) +app.include_router(batch.router) \ No newline at end of file diff --git a/api/routes/batch.py b/api/routes/batch.py new file mode 100644 index 0000000..562ef4a --- /dev/null +++ b/api/routes/batch.py @@ -0,0 +1,242 @@ +""" +Batch fill endpoint — the "report once, file everywhere" API. + +POST /forms/fill/batch + +This endpoint is the architectural completion of FireForm's core promise. +A firefighter records one incident transcript. This endpoint: + + 1. Extracts a canonical incident record from the transcript in a SINGLE + LLM call (all N agency forms share this extraction). + + 2. Maps the canonical record to each agency template's field schema + CONCURRENTLY via asyncio.gather() — N fast mapping calls in parallel + instead of N * F sequential full-transcript extractions. + + 3. Fills all N PDFs concurrently in a ThreadPoolExecutor (pdfrw is + synchronous; offloading prevents event loop blocking). + + 4. Persists a BatchSubmission record with the full canonical extraction + including per-field evidence attribution (verbatim transcript quotes) + alongside individual FormSubmission records per template. + + 5. Returns everything in a single response: per-template PDF paths, + success/failure per template, and the complete evidence report. + +Time complexity improvement: + Sequential per-form extraction: O(T * F) LLM calls + Batch canonical + mapping: O(1 + T) LLM calls + Example (5 agencies, 10 fields): 50 calls → 6 calls + +GET /forms/batches/{batch_id} — lightweight status + output paths +GET /forms/batches/{batch_id}/audit — full evidence trail for legal compliance +""" + +import asyncio + +from fastapi import APIRouter, Depends +from sqlmodel import Session + +from api.deps import get_db +from api.schemas.batch import ( + BatchFill, + BatchFillResponse, + BatchStatusResponse, + AuditResponse, + TemplateResult, + EvidenceField, +) +from api.db.repositories import get_template, create_form, create_batch, get_batch +from api.db.models import FormSubmission, BatchSubmission +from api.errors.base import AppError +from src.extractor import IncidentExtractor + +router = APIRouter(prefix="/forms", tags=["batch"]) + + +@router.post("/fill/batch", response_model=BatchFillResponse) +async def batch_fill(body: BatchFill, db: Session = Depends(get_db)): + """ + Fill multiple agency-specific PDF forms from a single incident transcript. + + Extraction runs once (canonical pass) then maps to each template concurrently. + Partial success is tolerated — if one template fails (bad PDF path, mapping + error), the others still complete and the batch status is reported as "partial". + """ + # ── Validate all templates exist upfront ────────────────────────────────── + templates = {} + for tid in body.template_ids: + tpl = get_template(db, tid) + if not tpl: + raise AppError(f"Template {tid} not found", status_code=404) + templates[tid] = tpl + + # ── Pass 1: single canonical extraction ─────────────────────────────────── + extractor = IncidentExtractor(body.input_text) + canonical = await extractor.async_extract_canonical() + evidence_report = IncidentExtractor.build_evidence_report(canonical) + + # ── Pass 2: concurrent mapping to each template ─────────────────────────── + import httpx + + async with httpx.AsyncClient(timeout=120.0) as client: + mapping_tasks = [ + extractor.async_map_to_template(client, canonical, tpl.fields) + for tpl in templates.values() + ] + mappings = await asyncio.gather(*mapping_tasks, return_exceptions=True) + + # mappings[i] corresponds to templates.values()[i] + template_list = list(templates.values()) + template_id_list = list(templates.keys()) + + # ── Pass 3: concurrent PDF fill in thread pool ──────────────────────────── + loop = asyncio.get_running_loop() + + async def _fill_one(tpl, data: dict) -> str: + from src.filler import Filler + filler = Filler() + return await loop.run_in_executor( + None, + lambda: filler.fill_form_with_data(tpl.pdf_path, data), + ) + + fill_tasks = [] + failed_at_mapping: dict[int, str] = {} + + for i, (tpl, mapping) in enumerate(zip(template_list, mappings)): + if isinstance(mapping, Exception): + failed_at_mapping[template_id_list[i]] = str(mapping) + fill_tasks.append(asyncio.sleep(0)) # placeholder + else: + fill_tasks.append(_fill_one(tpl, mapping)) + + fill_results = await asyncio.gather(*fill_tasks, return_exceptions=True) + + # ── Persist FormSubmission per template + collect results ───────────────── + results: list[TemplateResult] = [] + submission_ids: list[int] = [] + output_paths: dict[str, str | None] = {} + errors: dict[str, str] = {} + + for i, tpl in enumerate(template_list): + tid = template_id_list[i] + + if tid in failed_at_mapping: + err = failed_at_mapping[tid] + results.append(TemplateResult( + template_id=tid, status="failed", + submission_id=None, output_pdf_path=None, error=err, + )) + errors[str(tid)] = err + output_paths[str(tid)] = None + continue + + pdf_result = fill_results[i] + if isinstance(pdf_result, Exception): + err = str(pdf_result) + results.append(TemplateResult( + template_id=tid, status="failed", + submission_id=None, output_pdf_path=None, error=err, + )) + errors[str(tid)] = err + output_paths[str(tid)] = None + continue + + submission = FormSubmission( + template_id=tid, + input_text=body.input_text, + output_pdf_path=pdf_result, + ) + saved = create_form(db, submission) + submission_ids.append(saved.id) + output_paths[str(tid)] = pdf_result + results.append(TemplateResult( + template_id=tid, status="complete", + submission_id=saved.id, output_pdf_path=pdf_result, error=None, + )) + + # ── Determine overall batch status ──────────────────────────────────────── + total_succeeded = sum(1 for r in results if r.status == "complete") + total_failed = len(results) - total_succeeded + + if total_failed == 0: + status = "complete" + elif total_succeeded == 0: + status = "failed" + else: + status = "partial" + + # ── Persist BatchSubmission ─────────────────────────────────────────────── + batch = BatchSubmission( + status=status, + input_text=body.input_text, + canonical_extraction=canonical, + evidence_report=evidence_report, + template_ids=body.template_ids, + submission_ids=submission_ids if submission_ids else None, + output_paths=output_paths, + errors=errors if errors else None, + ) + saved_batch = create_batch(db, batch) + + # ── Build response ──────────────────────────────────────────────────────── + # Convert evidence_report to EvidenceField instances for schema validation + typed_evidence = { + k: EvidenceField( + value=v.get("value"), + evidence=v.get("evidence"), + confidence=v.get("confidence", "low"), + ) + for k, v in evidence_report.items() + } if evidence_report else None + + return BatchFillResponse( + batch_id=saved_batch.id, + status=status, + input_text=body.input_text, + template_ids=body.template_ids, + results=results, + evidence_report=typed_evidence, + total_requested=len(body.template_ids), + total_succeeded=total_succeeded, + total_failed=total_failed, + created_at=saved_batch.created_at, + ) + + +@router.get("/batches/{batch_id}", response_model=BatchStatusResponse) +def get_batch_status(batch_id: str, db: Session = Depends(get_db)): + """ + Lightweight status check for a completed batch submission. + Returns per-template output_paths and any errors without the full + canonical extraction payload. Use /audit for the full evidence trail. + """ + batch = get_batch(db, batch_id) + if not batch: + raise AppError("Batch not found", status_code=404) + return batch + + +@router.get("/batches/{batch_id}/audit", response_model=AuditResponse) +def get_batch_audit(batch_id: str, db: Session = Depends(get_db)): + """ + Returns the full evidence trail for a batch submission. + + For each canonical incident category that was extracted, the response + includes the extracted value, the verbatim transcript quote used as + evidence, and the confidence level. This endpoint exists specifically + for legal compliance and chain-of-custody verification: supervisors and + legal teams can confirm that every value in every filed form is traceable + to a specific statement in the original incident transcript. + """ + batch = get_batch(db, batch_id) + if not batch: + raise AppError("Batch not found", status_code=404) + return AuditResponse( + batch_id=batch.id, + input_text=batch.input_text, + canonical_extraction=batch.canonical_extraction, + evidence_report=batch.evidence_report, + created_at=batch.created_at, + ) diff --git a/api/schemas/batch.py b/api/schemas/batch.py new file mode 100644 index 0000000..c1a3b21 --- /dev/null +++ b/api/schemas/batch.py @@ -0,0 +1,86 @@ +from pydantic import BaseModel, ConfigDict, field_validator +from datetime import datetime + + +class BatchFill(BaseModel): + """Request body for POST /forms/fill/batch.""" + input_text: str + template_ids: list[int] + + @field_validator("template_ids") + @classmethod + def must_have_at_least_one(cls, v: list[int]) -> list[int]: + if not v: + raise ValueError("template_ids must contain at least one template ID") + if len(v) > 20: + raise ValueError("Batch size is limited to 20 templates per request") + if len(v) != len(set(v)): + raise ValueError("template_ids must not contain duplicates") + return v + + +class TemplateResult(BaseModel): + """Per-template result within a BatchFillResponse.""" + template_id: int + status: str # complete | failed + submission_id: int | None # FormSubmission.id if successful + output_pdf_path: str | None # path to filled PDF if successful + error: str | None # error message if failed + + +class EvidenceField(BaseModel): + """Evidence attribution for a single canonical incident field.""" + value: str | list | None + evidence: str | None # verbatim transcript quote + confidence: str # high | medium | low + + +class BatchFillResponse(BaseModel): + """Response from POST /forms/fill/batch.""" + model_config = ConfigDict(from_attributes=True) + + batch_id: str + status: str # complete | partial | failed + input_text: str + template_ids: list[int] + results: list[TemplateResult] + # Evidence report: canonical category -> {value, evidence, confidence} + # Only includes categories that were successfully extracted + evidence_report: dict[str, EvidenceField] | None + total_requested: int + total_succeeded: int + total_failed: int + created_at: datetime + + +class BatchStatusResponse(BaseModel): + """Response from GET /forms/batches/{batch_id} — lightweight status check.""" + model_config = ConfigDict(from_attributes=True) + + id: str + status: str + template_ids: list + submission_ids: list | None + output_paths: dict | None + errors: dict | None + created_at: datetime + + +class AuditResponse(BaseModel): + """ + Response from GET /forms/batches/{batch_id}/audit. + + Returns the full canonical extraction with per-field evidence attribution. + This endpoint is specifically designed for legal compliance and chain-of-custody + verification in emergency services contexts. Each extracted value is paired + with the exact verbatim transcript quote that supports it, allowing supervisors + and legal teams to verify that every value in every filed form is traceable + back to a specific statement in the original incident transcript. + """ + model_config = ConfigDict(from_attributes=True) + + batch_id: str + input_text: str + canonical_extraction: dict | None + evidence_report: dict | None + created_at: datetime diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..4584de7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +pythonpath = . diff --git a/src/extractor.py b/src/extractor.py new file mode 100644 index 0000000..8465052 --- /dev/null +++ b/src/extractor.py @@ -0,0 +1,271 @@ +""" +IncidentExtractor — single-pass canonical extraction with evidence attribution. + +The canonical extraction pipeline is the foundation of the batch fill system. +Instead of asking the LLM to extract specific template fields N times (once per +agency form), the extractor runs ONE LLM call to extract a rich, template-agnostic +incident record from the transcript. Each extracted value carries an "evidence" +field containing the verbatim transcript quote that supports it, which is required +for chain-of-custody and legal compliance in emergency services reporting. + +A second, much faster LLM call then maps the pre-extracted canonical record to +each specific template's field schema. Because the data is already structured, +the mapping call only needs to match field names — it does not re-read or +re-interpret the transcript. This makes the mapping calls fast and parallelizable. + +Time complexity: + Old (per-form extraction): O(T * F) LLM calls — T templates × F fields each + New (canonical + mapping): O(1 + T) LLM calls — 1 extraction + T mappings + For 5 agency forms at 10 fields each: 50 calls → 6 calls. +""" + +import json +import os +import requests + + +# ── Canonical incident categories ───────────────────────────────────────────── +# These represent the full universe of information that may appear in an +# emergency incident transcript. Template-specific field names are mapped +# from these during the per-template mapping pass. + +CANONICAL_CATEGORIES = [ + "reporting_officer", + "badge_number", + "unit_number", + "case_number", + "incident_type", + "incident_date", + "incident_time", + "incident_location", + "city", + "jurisdiction", + "narrative", + "victim_names", + "victim_ages", + "victim_injuries", + "suspect_names", + "suspect_descriptions", + "witness_names", + "assisting_officers", + "assisting_agencies", + "actions_taken", + "property_damage", + "weapons_involved", + "vehicle_descriptions", + "medical_response", + "hospital_transported_to", + "follow_up_required", +] + + +class IncidentExtractor: + """ + Extracts a canonical, template-agnostic incident record from an incident + transcript in a single LLM call, then maps it to any number of agency-specific + form templates without re-reading the original transcript. + + Each canonical field carries three sub-fields: + value — the extracted value (string, list, or null) + evidence — the verbatim transcript quote that supports this extraction + confidence — "high" if clearly stated, "medium" if inferred, "low" if uncertain + + Usage (synchronous / sync batch): + extractor = IncidentExtractor(transcript) + canonical = extractor.extract_canonical() + mapped = extractor.map_to_template(canonical, template.fields) + + Usage (async batch — preferred for POST /forms/fill/batch): + canonical = await extractor.async_extract_canonical() + results = await asyncio.gather(*[ + extractor.async_map_to_template(client, canonical, t.fields) + for t in templates + ]) + """ + + def __init__(self, transcript: str): + self._transcript = transcript + self._ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + self._ollama_url = f"{self._ollama_host}/api/generate" + + # ── Prompt builders ──────────────────────────────────────────────────────── + + def _build_canonical_prompt(self) -> str: + categories = json.dumps(CANONICAL_CATEGORIES, indent=2) + return f"""You are an AI assistant specializing in emergency incident report analysis for law enforcement and fire services. + +Your task is to extract every identifiable piece of information from the incident transcript below. + +For every piece of information you extract, you MUST return three things: + 1. value — the extracted value (use a JSON list if there are multiple values) + 2. evidence — the exact verbatim quote from the transcript that supports this extraction + 3. confidence — "high" if the value is explicitly and clearly stated, "medium" if reasonably inferred from context, "low" if uncertain + +Return ONLY a valid JSON object. No markdown, no explanation, no code fences. +Use null for the entire field object if a category is not mentioned in the transcript at all. + +Categories to extract: +{categories} + +Output format (example): +{{ + "reporting_officer": {{ + "value": "Officer Smith", + "evidence": "Officer Smith reporting from unit 4", + "confidence": "high" + }}, + "victim_names": {{ + "value": ["Jane Doe", "Mark Smith"], + "evidence": "victims Jane Doe and Mark Smith were treated on scene", + "confidence": "high" + }}, + "case_number": null +}} + +Transcript: +{self._transcript} +""" + + def _build_mapping_prompt(self, canonical: dict, template_fields: dict) -> str: + # Only pass the values (not evidence/confidence) to the mapping prompt + # to keep it focused and fast + canonical_values = { + k: (v["value"] if isinstance(v, dict) and "value" in v else v) + for k, v in canonical.items() + if v is not None + } + return f"""You are mapping a pre-extracted incident record to a specific agency form template. +The incident data below was already extracted from a transcript — do NOT re-interpret anything. +Your ONLY task is to match the most semantically relevant value from the incident record to each template field. + +Rules: +- Use only values from the provided incident record. Do not invent or infer new values. +- If a template field has no matching data in the incident record, use null. +- If a template field maps to a list value (e.g. multiple victims), join with "; ". +- Return ONLY a valid JSON object. No markdown, no explanation. + +Pre-extracted incident record: +{json.dumps(canonical_values, indent=2)} + +Template fields to fill (field name -> description/type): +{json.dumps(template_fields, indent=2)} + +Output: +{{ + "template_field_name": "matched value or null", + ... +}} +""" + + # ── JSON parsing helper ──────────────────────────────────────────────────── + + def _parse_json(self, raw: str) -> dict: + raw = raw.strip() + if raw.startswith("```"): + parts = raw.split("```") + raw = parts[1].lstrip("json").strip() + return json.loads(raw) + + # ── Synchronous API ──────────────────────────────────────────────────────── + + def _post_ollama_sync(self, prompt: str) -> str: + resp = requests.post( + self._ollama_url, + json={"model": "mistral", "prompt": prompt, "stream": False}, + timeout=120, + ) + resp.raise_for_status() + return resp.json()["response"].strip() + + def extract_canonical(self) -> dict: + """ + Synchronous canonical extraction. + + Returns a dict of category -> {value, evidence, confidence} or None. + All categories in CANONICAL_CATEGORIES that appear in the transcript + are populated. Missing categories are null. + """ + raw = self._post_ollama_sync(self._build_canonical_prompt()) + try: + return self._parse_json(raw) + except json.JSONDecodeError: + return {} + + def map_to_template(self, canonical: dict, template_fields: dict) -> dict: + """ + Maps a canonical extraction to a specific template. + Returns field -> value dict ready for Filler.fill_form_with_data(). + """ + raw = self._post_ollama_sync(self._build_mapping_prompt(canonical, template_fields)) + try: + return self._parse_json(raw) + except json.JSONDecodeError: + return {f: None for f in template_fields} + + # ── Async API (used by POST /forms/fill/batch) ──────────────────────────── + + async def _post_ollama_async(self, client, prompt: str) -> str: + import httpx + resp = await client.post( + self._ollama_url, + json={"model": "mistral", "prompt": prompt, "stream": False}, + ) + resp.raise_for_status() + return resp.json()["response"].strip() + + async def async_extract_canonical(self) -> dict: + """ + Async canonical extraction via httpx.AsyncClient. + Identical semantics to extract_canonical() but non-blocking. + """ + import httpx + async with httpx.AsyncClient(timeout=180.0) as client: + raw = await self._post_ollama_async(client, self._build_canonical_prompt()) + try: + return self._parse_json(raw) + except json.JSONDecodeError: + return {} + + async def async_map_to_template( + self, client, canonical: dict, template_fields: dict + ) -> dict: + """ + Async template mapping. Designed to be called concurrently with + asyncio.gather() across multiple templates after a single canonical + extraction, so M agency forms are filled in O(1 + M) LLM calls + instead of O(M * F) where F is the number of fields per form. + """ + raw = await self._post_ollama_async( + client, self._build_mapping_prompt(canonical, template_fields) + ) + try: + return self._parse_json(raw) + except json.JSONDecodeError: + return {f: None for f in template_fields} + + # ── Evidence report ──────────────────────────────────────────────────────── + + @staticmethod + def build_evidence_report(canonical: dict) -> dict: + """ + Transforms the raw canonical extraction into a clean evidence report + keyed by canonical category. Only includes fields where a value was + successfully extracted. Used by GET /forms/batches/{id}/audit. + + Returns: + { + "reporting_officer": { + "value": "Officer Smith", + "evidence": "Officer Smith reporting from unit 4", + "confidence": "high" + }, + ... + } + """ + return { + k: v + for k, v in canonical.items() + if v is not None + and isinstance(v, dict) + and v.get("value") is not None + } diff --git a/src/file_manipulator.py b/src/file_manipulator.py index b7815cc..63b1cb0 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -1,7 +1,6 @@ import os from src.filler import Filler from src.llm import LLM -from commonforms import prepare_form class FileManipulator: @@ -12,7 +11,9 @@ def __init__(self): def create_template(self, pdf_path: str): """ By using commonforms, we create an editable .pdf template and we store it. + Lazy import prevents ultralytics/YOLO from loading during test collection. """ + from commonforms import prepare_form # lazy import template_path = pdf_path[:-4] + "_template.pdf" prepare_form(pdf_path, template_path) return template_path diff --git a/src/filler.py b/src/filler.py index e31e535..3c51607 100644 --- a/src/filler.py +++ b/src/filler.py @@ -19,8 +19,11 @@ def fill_form(self, pdf_form: str, llm: LLM): + "_filled.pdf" ) - # Generate dictionary of answers from your original function - t2j = llm.main_loop() + # Generate dictionary of answers from your original function. + # main_loop_batch() extracts all fields in a single LLM call instead of + # one call per field, significantly reducing latency for large forms. + # Falls back to the sequential main_loop() if the LLM returns invalid JSON. + t2j = llm.main_loop_batch() textbox_answers = t2j.get_data() # This is a dictionary answers_list = list(textbox_answers.values()) @@ -50,3 +53,52 @@ def fill_form(self, pdf_form: str, llm: LLM): # Your main.py expects this function to return the path return output_pdf + + def fill_form_with_data(self, pdf_form: str, data: dict) -> str: + """ + Fill a PDF form using a pre-extracted field → value mapping. + + Unlike fill_form(), which calls the LLM itself and fills fields + positionally, this method accepts an already-extracted ``data`` dict + and matches values to form annotations by field name (the ``T`` + annotation key). This makes it suitable for the batch pipeline where + extraction has already happened upstream (canonical pass + mapping + pass) and only the PDF write step remains. + + Field names that appear in ``data`` but not in the PDF are silently + skipped. Fields present in the PDF but absent from ``data`` are left + blank. + + Args: + pdf_form: Path to the blank PDF template. + data: A ``{field_name: value}`` dict as returned by + ``IncidentExtractor.async_map_to_template()``. + + Returns: + Path to the written output PDF. + """ + output_pdf = ( + pdf_form[:-4] + + "_" + + datetime.now().strftime("%Y%m%d_%H%M%S") + + "_filled.pdf" + ) + + pdf = PdfReader(pdf_form) + + for page in pdf.pages: + if page.Annots: + for annot in page.Annots: + if annot.Subtype == "/Widget" and annot.T: + # Strip surrounding parentheses that pdfrw adds to strings + field_name = annot.T.strip("()") + if field_name in data and data[field_name] is not None: + value = data[field_name] + if isinstance(value, list): + value = ", ".join(str(v) for v in value) + annot.V = str(value) + annot.AP = None + + PdfWriter().write(output_pdf, pdf) + return output_pdf + diff --git a/src/llm.py b/src/llm.py index 70937f9..2be2417 100644 --- a/src/llm.py +++ b/src/llm.py @@ -131,5 +131,95 @@ def handle_plural_values(self, plural_value): return values + def build_batch_prompt(self): + """ + Builds a single prompt that asks the LLM to extract ALL target fields + at once and return them as a JSON object. + This replaces N sequential API calls with a single round-trip. + """ + fields_list = json.dumps(list(self._target_fields.keys()), indent=2) + prompt = f""" +SYSTEM PROMPT: +You are an AI assistant that extracts structured data from incident transcriptions. +Extract values for ALL of the following JSON fields from the text below. +Return ONLY a valid JSON object with no extra explanation, commentary, or markdown fences. +If a field is plural and multiple values exist in the text, use a list of strings. +If a value cannot be found in the text, use null. + +FIELDS TO EXTRACT: +{fields_list} + +TEXT: +{self._transcript_text} + +OUTPUT FORMAT: +{{ + "field_name": "extracted value or null", + ... +}} +""" + return prompt + + def main_loop_batch(self): + """ + Single-call extraction — replaces the N sequential calls in main_loop(). + Sends one prompt containing all target fields and parses the JSON response. + Falls back to main_loop() if the LLM does not return valid JSON. + """ + prompt = self.build_batch_prompt() + ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + ollama_url = f"{ollama_host}/api/generate" + + payload = { + "model": "mistral", + "prompt": prompt, + "stream": False, + } + + try: + response = requests.post(ollama_url, json=payload) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Could not connect to Ollama at {ollama_url}. " + "Please ensure Ollama is running and accessible." + ) + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Ollama returned an error: {e}") + + raw = response.json()["response"].strip() + + # Strip markdown code fences if the model wrapped the output + if raw.startswith("```"): + parts = raw.split("```") + # parts[1] is the fenced block; drop a leading "json" language tag if present + raw = parts[1].lstrip("json").strip() + + try: + extracted = json.loads(raw) + except json.JSONDecodeError as e: + print( + f"\t[WARN] main_loop_batch: LLM did not return valid JSON ({e}). " + "Falling back to sequential main_loop()." + ) + return self.main_loop() + + # Populate self._json using the existing add_response_to_json logic + for field in self._target_fields.keys(): + value = extracted.get(field) + if value is None: + self.add_response_to_json(field, "-1") + elif isinstance(value, list): + self.add_response_to_json(field, "; ".join(str(v) for v in value)) + else: + self.add_response_to_json(field, str(value)) + + print("----------------------------------") + print("\t[LOG] Resulting JSON created from the input text (batch mode):") + print(json.dumps(self._json, indent=2)) + print("--------- extracted data ---------") + + return self + def get_data(self): return self._json diff --git a/tests/conftest.py b/tests/conftest.py index 7cb4db3..260a88a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from api.main import app from api.deps import get_db -from api.db.models import Template, FormSubmission +from api.db.models import Template, FormSubmission, FillJob, BatchSubmission # In-memory SQLite database for tests TEST_DATABASE_URL = "sqlite://" diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..86d1a1f --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,494 @@ +""" +Tests for POST /forms/fill/batch, GET /forms/batches/{id}, +and GET /forms/batches/{id}/audit. + +Mocking strategy: + - IncidentExtractor.async_extract_canonical → returns a minimal canonical dict + - IncidentExtractor.async_map_to_template → returns a minimal field-value dict + - Filler.fill_form_with_data → returns a deterministic output path + - IncidentExtractor.build_evidence_report → returns filtered canonical fields + +Templates are created through the real /templates/create endpoint (with +Controller mocked) so test IDs are stable and foreign-key constraints hold. +""" + +from unittest.mock import patch, AsyncMock, MagicMock + + +# ── Shared canonical fixture ─────────────────────────────────────────────────── + +CANONICAL = { + "reporting_officer": { + "value": "Officer Jane Smith", + "evidence": "Officer Jane Smith reporting.", + "confidence": "high", + }, + "incident_location": { + "value": "123 Main St", + "evidence": "incident occurred at 123 Main St", + "confidence": "high", + }, + "victim_names": { + "value": ["Alice", "Bob"], + "evidence": "victims Alice and Bob", + "confidence": "high", + }, + "incident_type": { + "value": None, + "evidence": None, + "confidence": "low", + }, +} + +EVIDENCE_REPORT = {k: v for k, v in CANONICAL.items() if v["value"] is not None} + +MAPPED_FIELDS = { + "reporting_officer": "Officer Jane Smith", + "incident_location": "123 Main St", +} + +TRANSCRIPT = ( + "Officer Jane Smith reporting. Incident occurred at 123 Main St. " + "Two victims Alice and Bob sustained minor injuries." +) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +def create_template(client, name="Agency A", pdf_path="src/inputs/file.pdf"): + """Create a template and return its id. Mocks Controller to avoid FS access.""" + with patch("api.routes.templates.Controller") as MockCtrl: + MockCtrl.return_value.create_template.return_value = pdf_path + payload = { + "name": name, + "pdf_path": pdf_path, + "fields": {"reporting_officer": "string", "incident_location": "string"}, + } + res = client.post("/templates/create", json=payload) + assert res.status_code == 200, res.text + return res.json()["id"] + + +def _mock_extractor(): + """ + Return a context-manager-compatible patch for IncidentExtractor + that simulates a successful canonical extraction + mapping. + """ + mock_instance = MagicMock() + mock_instance.async_extract_canonical = AsyncMock(return_value=CANONICAL) + mock_instance.async_map_to_template = AsyncMock(return_value=MAPPED_FIELDS) + return mock_instance + + +# ── Batch fill — success paths ───────────────────────────────────────────────── + +def test_batch_fill_single_template_complete(client): + tid = create_template(client, "Agency A") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + assert res.status_code == 200 + data = res.json() + assert data["status"] == "complete" + assert data["total_requested"] == 1 + assert data["total_succeeded"] == 1 + assert data["total_failed"] == 0 + assert len(data["results"]) == 1 + assert data["results"][0]["status"] == "complete" + assert data["results"][0]["template_id"] == tid + assert "batch_id" in data + assert data["batch_id"] + + +def test_batch_fill_multiple_templates_complete(client): + tid1 = create_template(client, "Agency B", "src/inputs/b.pdf") + tid2 = create_template(client, "Agency C", "src/inputs/c.pdf") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid1, tid2], + }) + + assert res.status_code == 200 + data = res.json() + assert data["status"] == "complete" + assert data["total_requested"] == 2 + assert data["total_succeeded"] == 2 + assert data["total_failed"] == 0 + result_ids = {r["template_id"] for r in data["results"]} + assert result_ids == {tid1, tid2} + + +def test_batch_fill_returns_evidence_report(client): + tid = create_template(client, "Agency Evidence") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + data = res.json() + er = data.get("evidence_report") + assert er is not None + # reporting_officer has a non-null value → should appear in evidence_report + assert "reporting_officer" in er + assert er["reporting_officer"]["value"] == "Officer Jane Smith" + assert er["reporting_officer"]["evidence"] == "Officer Jane Smith reporting." + # incident_type has value=None → must NOT appear in evidence_report + assert "incident_type" not in er + + +def test_batch_fill_each_result_has_submission_id(client): + tid = create_template(client, "Agency Sub") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + result = res.json()["results"][0] + assert result["submission_id"] is not None + assert isinstance(result["submission_id"], int) + + +# ── Batch fill — validation / error paths ──────────────────────────────────── + +def test_batch_fill_404_on_unknown_template(client): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [999999], + }) + assert res.status_code == 404 + + +def test_batch_fill_422_empty_template_ids(client): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [], + }) + assert res.status_code == 422 + + +def test_batch_fill_422_duplicate_template_ids(client): + tid = create_template(client, "Agency Dup") + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid, tid], + }) + assert res.status_code == 422 + + +def test_batch_fill_422_too_many_template_ids(client): + # 21 IDs > hard limit of 20 + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": list(range(1, 22)), + }) + assert res.status_code == 422 + + +def test_batch_fill_422_missing_input_text(client): + tid = create_template(client, "Agency Miss") + res = client.post("/forms/fill/batch", json={"template_ids": [tid]}) + assert res.status_code == 422 + + +# ── Partial failure ──────────────────────────────────────────────────────────── + +def test_batch_fill_partial_failure_when_one_pdf_fill_fails(client): + """When one PDF fill raises, the other succeeds and status is 'partial'.""" + tid1 = create_template(client, "Agency Partial-Good", "src/inputs/good.pdf") + tid2 = create_template(client, "Agency Partial-Bad", "src/inputs/bad.pdf") + mock_inst = _mock_extractor() + + call_count = {"n": 0} + + def _fill_side_effect(pdf_path, data): + call_count["n"] += 1 + if "bad" in pdf_path: + raise RuntimeError("PDF fill failed") + return "src/outputs/good_out.pdf" + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", side_effect=_fill_side_effect), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid1, tid2], + }) + + assert res.status_code == 200 + data = res.json() + assert data["status"] == "partial" + assert data["total_succeeded"] == 1 + assert data["total_failed"] == 1 + + statuses = {r["template_id"]: r["status"] for r in data["results"]} + # The bad template should have status "failed" + assert statuses[tid2] == "failed" + assert statuses[tid1] == "complete" + + +def test_batch_fill_all_failed_status(client): + """When every PDF fill raises, status must be 'failed' (not 'partial').""" + tid = create_template(client, "Agency AllFail") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", + side_effect=RuntimeError("disk full")), + ): + res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + assert res.status_code == 200 + data = res.json() + assert data["status"] == "failed" + assert data["total_succeeded"] == 0 + assert data["total_failed"] == 1 + + +# ── GET /forms/batches/{id} ──────────────────────────────────────────────────── + +def test_get_batch_status_200(client): + tid = create_template(client, "Agency Status") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + fill_res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + batch_id = fill_res.json()["batch_id"] + status_res = client.get(f"/forms/batches/{batch_id}") + assert status_res.status_code == 200 + data = status_res.json() + assert data["id"] == batch_id + assert data["status"] == "complete" + assert isinstance(data["template_ids"], list) + + +def test_get_batch_status_404_unknown(client): + res = client.get("/forms/batches/nonexistent-batch-id") + assert res.status_code == 404 + + +def test_get_batch_status_has_output_paths(client): + tid = create_template(client, "Agency Paths") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/paths_out.pdf"), + ): + fill_res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + batch_id = fill_res.json()["batch_id"] + data = client.get(f"/forms/batches/{batch_id}").json() + assert data["output_paths"] is not None + assert str(tid) in data["output_paths"] + + +# ── GET /forms/batches/{id}/audit ───────────────────────────────────────────── + +def test_get_batch_audit_200(client): + tid = create_template(client, "Agency Audit") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + fill_res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + batch_id = fill_res.json()["batch_id"] + audit_res = client.get(f"/forms/batches/{batch_id}/audit") + assert audit_res.status_code == 200 + data = audit_res.json() + assert data["batch_id"] == batch_id + assert data["input_text"] == TRANSCRIPT + assert data["canonical_extraction"] is not None + assert data["evidence_report"] is not None + + +def test_get_batch_audit_canonical_has_evidence_fields(client): + """Every key in canonical_extraction must carry value/evidence/confidence.""" + tid = create_template(client, "Agency AuditEvidence") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + fill_res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + batch_id = fill_res.json()["batch_id"] + data = client.get(f"/forms/batches/{batch_id}/audit").json() + canonical = data["canonical_extraction"] + + for field, content in canonical.items(): + assert "value" in content, f"Missing 'value' in canonical field '{field}'" + assert "evidence" in content, f"Missing 'evidence' in canonical field '{field}'" + assert "confidence" in content, f"Missing 'confidence' in canonical field '{field}'" + + +def test_get_batch_audit_evidence_report_excludes_null_fields(client): + """evidence_report must contain only fields where value is not None.""" + tid = create_template(client, "Agency AuditNull") + mock_inst = _mock_extractor() + + with ( + patch("api.routes.batch.IncidentExtractor", return_value=mock_inst), + patch("api.routes.batch.IncidentExtractor.build_evidence_report", + return_value=EVIDENCE_REPORT), + patch("src.filler.Filler.fill_form_with_data", return_value="src/outputs/out.pdf"), + ): + fill_res = client.post("/forms/fill/batch", json={ + "input_text": TRANSCRIPT, + "template_ids": [tid], + }) + + batch_id = fill_res.json()["batch_id"] + data = client.get(f"/forms/batches/{batch_id}/audit").json() + evidence = data["evidence_report"] + + # incident_type has value=None in our CANONICAL fixture → must not be in evidence + assert "incident_type" not in evidence + # reporting_officer has value → must be present + assert "reporting_officer" in evidence + + +def test_get_batch_audit_404_unknown(client): + res = client.get("/forms/batches/no-such-batch/audit") + assert res.status_code == 404 + + +# ── Unit tests for BatchFill validator ──────────────────────────────────────── + +def test_batch_fill_schema_rejects_empty_list(): + from pydantic import ValidationError + from api.schemas.batch import BatchFill + import pytest + + with pytest.raises(ValidationError): + BatchFill(input_text="x", template_ids=[]) + + +def test_batch_fill_schema_rejects_duplicates(): + from pydantic import ValidationError + from api.schemas.batch import BatchFill + import pytest + + with pytest.raises(ValidationError): + BatchFill(input_text="x", template_ids=[1, 1]) + + +def test_batch_fill_schema_rejects_over_limit(): + from pydantic import ValidationError + from api.schemas.batch import BatchFill + import pytest + + with pytest.raises(ValidationError): + BatchFill(input_text="x", template_ids=list(range(1, 22))) + + +def test_batch_fill_schema_accepts_valid(): + from api.schemas.batch import BatchFill + + b = BatchFill(input_text="test", template_ids=[1, 2, 3]) + assert b.template_ids == [1, 2, 3] + + +# ── Unit tests for IncidentExtractor.build_evidence_report ─────────────────── + +def test_build_evidence_report_filters_nulls(): + from src.extractor import IncidentExtractor + + canonical = { + "field_a": {"value": "present", "evidence": "quote", "confidence": "high"}, + "field_b": {"value": None, "evidence": None, "confidence": "low"}, + "field_c": {"value": ["Alice", "Bob"], "evidence": "seen Alice and Bob", "confidence": "medium"}, + } + report = IncidentExtractor.build_evidence_report(canonical) + # field_a has a non-None value → should be included + assert "field_a" in report + # field_b value is None → excluded + assert "field_b" not in report + # field_c value is a non-None list → should be included + assert "field_c" in report + + +def test_build_evidence_report_preserves_structure(): + from src.extractor import IncidentExtractor + + canonical = { + "reporting_officer": { + "value": "Sgt. Davis", + "evidence": "Sgt. Davis at the scene.", + "confidence": "high", + } + } + report = IncidentExtractor.build_evidence_report(canonical) + assert report["reporting_officer"]["value"] == "Sgt. Davis" + assert report["reporting_officer"]["evidence"] == "Sgt. Davis at the scene." + assert report["reporting_officer"]["confidence"] == "high" diff --git a/tests/test_forms.py b/tests/test_forms.py index 8f432bf..756bb1b 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1,25 +1,69 @@ +from unittest.mock import patch + + def test_submit_form(client): - pass - # First create a template - # form_payload = { - # "template_id": 3, - # "input_text": "Hi. The employee's name is John Doe. His job title is managing director. His department supervisor is Jane Doe. His phone number is 123456. His email is jdoe@ucsc.edu. The signature is , and the date is 01/02/2005", - # } - - # template_res = client.post("/templates/", json=template_payload) - # template_id = template_res.json()["id"] - - # # Submit a form - # form_payload = { - # "template_id": template_id, - # "data": {"rating": 5, "comment": "Great service"}, - # } - - # response = client.post("/forms/", json=form_payload) - - # assert response.status_code == 200 - - # data = response.json() - # assert data["id"] is not None - # assert data["template_id"] == template_id - # assert data["data"] == form_payload["data"] + # Step 1: Create a template first + with patch("api.routes.templates.Controller") as MockController: + MockController.return_value.create_template.return_value = "src/inputs/file_template.pdf" + + template_payload = { + "name": "Test Template", + "pdf_path": "src/inputs/file.pdf", + "fields": { + "reporting_officer": "string", + "incident_location": "string", + "amount_of_victims": "string", + "victim_name_s": "string", + "assisting_officer": "string", + }, + } + template_res = client.post("/templates/create", json=template_payload) + assert template_res.status_code == 200 + template_id = template_res.json()["id"] + + # Step 2: Fill form using that template + with patch("api.routes.forms.Controller") as MockController: + MockController.return_value.fill_form.return_value = "src/outputs/filled_test.pdf" + + form_payload = { + "template_id": template_id, + "input_text": ( + "Officer Voldemort here, at an incident reported at 456 Oak Street. " + "Two victims, Mark Smith and Jane Doe. " + "Handed off to Sheriff's Deputy Alvarez. End of transmission." + ), + } + + response = client.post("/forms/fill", json=form_payload) + + assert response.status_code == 200 + data = response.json() + assert data["template_id"] == template_id + assert data["output_pdf_path"] == "src/outputs/filled_test.pdf" + assert data["input_text"] == form_payload["input_text"] + assert "id" in data + + +def test_submit_form_invalid_template(client): + with patch("api.routes.forms.Controller") as MockController: + MockController.return_value.fill_form.return_value = "src/outputs/filled_test.pdf" + + form_payload = { + "template_id": 99999, + "input_text": "Some random incident text here.", + } + + response = client.post("/forms/fill", json=form_payload) + assert response.status_code == 404 + + +def test_submit_form_missing_input_text(client): + with patch("api.routes.forms.Controller") as MockController: + MockController.return_value.fill_form.return_value = "src/outputs/filled_test.pdf" + + form_payload = { + "template_id": 1, + } + + response = client.post("/forms/fill", json=form_payload) + assert response.status_code == 422 diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..bfd1b05 --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,177 @@ +import json +from unittest.mock import patch, MagicMock +from src.llm import LLM + + +SAMPLE_TRANSCRIPT = ( + "Officer Voldemort here, at an incident reported at 456 Oak Street. " + "Two victims, Mark Smith and Jane Doe. " + "Handed off to Sheriff's Deputy Alvarez. End of transmission." +) + +SAMPLE_FIELDS = { + "reporting_officer": "string", + "incident_location": "string", + "victim_name_s": "string", + "assisting_officer": "string", +} + + +def _make_mock_response(payload: dict) -> MagicMock: + """Helper: build a mock requests.Response that returns payload as JSON.""" + mock_resp = MagicMock() + mock_resp.json.return_value = {"response": json.dumps(payload)} + mock_resp.raise_for_status = MagicMock() + return mock_resp + + +# --------------------------------------------------------------------------- +# build_batch_prompt +# --------------------------------------------------------------------------- + +def test_build_batch_prompt_contains_all_fields(): + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + prompt = llm.build_batch_prompt() + + for field in SAMPLE_FIELDS: + assert field in prompt, f"Expected field '{field}' in batch prompt" + + assert SAMPLE_TRANSCRIPT in prompt + + +def test_build_batch_prompt_contains_transcript(): + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + prompt = llm.build_batch_prompt() + assert SAMPLE_TRANSCRIPT in prompt + + +# --------------------------------------------------------------------------- +# main_loop_batch — happy path +# --------------------------------------------------------------------------- + +def test_main_loop_batch_single_api_call(): + """main_loop_batch must call the Ollama API exactly once, regardless of field count.""" + llm_response = { + "reporting_officer": "Officer Voldemort", + "incident_location": "456 Oak Street", + "victim_name_s": ["Mark Smith", "Jane Doe"], + "assisting_officer": "Deputy Alvarez", + } + + with patch("requests.post", return_value=_make_mock_response(llm_response)) as mock_post: + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + llm.main_loop_batch() + + assert mock_post.call_count == 1, ( + f"Expected exactly 1 API call, got {mock_post.call_count}. " + "main_loop_batch should not loop per-field." + ) + + +def test_main_loop_batch_populates_all_fields(): + llm_response = { + "reporting_officer": "Officer Voldemort", + "incident_location": "456 Oak Street", + "victim_name_s": None, # missing value + "assisting_officer": "Deputy Alvarez", + } + + with patch("requests.post", return_value=_make_mock_response(llm_response)): + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + result = llm.main_loop_batch().get_data() + + assert result["reporting_officer"] == "Officer Voldemort" + assert result["incident_location"] == "456 Oak Street" + assert result["victim_name_s"] is None # null maps to None + assert result["assisting_officer"] == "Deputy Alvarez" + + +def test_main_loop_batch_handles_list_values(): + """Plural values returned as a JSON list should be joined into '; ' separated string.""" + llm_response = { + "reporting_officer": "Officer Voldemort", + "incident_location": "456 Oak Street", + "victim_name_s": ["Mark Smith", "Jane Doe"], + "assisting_officer": "Deputy Alvarez", + } + + with patch("requests.post", return_value=_make_mock_response(llm_response)): + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + result = llm.main_loop_batch().get_data() + + assert result["victim_name_s"] == ["Mark Smith", "Jane Doe"] + + +# --------------------------------------------------------------------------- +# main_loop_batch — markdown code-fence stripping +# --------------------------------------------------------------------------- + +def test_main_loop_batch_strips_markdown_fences(): + raw_with_fences = ( + "```json\n" + + json.dumps({ + "reporting_officer": "Officer Voldemort", + "incident_location": "456 Oak Street", + "victim_name_s": None, + "assisting_officer": "Deputy Alvarez", + }) + + "\n```" + ) + + mock_resp = MagicMock() + mock_resp.json.return_value = {"response": raw_with_fences} + mock_resp.raise_for_status = MagicMock() + + with patch("requests.post", return_value=mock_resp): + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + result = llm.main_loop_batch().get_data() + + assert result["reporting_officer"] == "Officer Voldemort" + + +# --------------------------------------------------------------------------- +# main_loop_batch — fallback to sequential main_loop on bad JSON +# --------------------------------------------------------------------------- + +def test_main_loop_batch_falls_back_on_invalid_json(): + """If the LLM returns garbage instead of JSON, fall back to main_loop().""" + bad_resp = MagicMock() + bad_resp.json.return_value = {"response": "Sorry, I cannot help with that."} + bad_resp.raise_for_status = MagicMock() + + with patch("requests.post", return_value=bad_resp): + with patch.object(LLM, "main_loop", return_value=MagicMock()) as mock_fallback: + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + llm.main_loop_batch() + mock_fallback.assert_called_once() + + +# --------------------------------------------------------------------------- +# main_loop_batch vs main_loop — call count comparison +# --------------------------------------------------------------------------- + +def test_main_loop_batch_fewer_calls_than_main_loop(): + """ + Explicitly show that main_loop_batch makes 1 call while main_loop + makes len(fields) calls — the core performance improvement. + """ + n_fields = len(SAMPLE_FIELDS) + llm_response = {k: "value" for k in SAMPLE_FIELDS} + + with patch("requests.post", return_value=_make_mock_response(llm_response)) as mock_post: + llm = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + llm.main_loop_batch() + batch_calls = mock_post.call_count + + single_resp = MagicMock() + single_resp.json.return_value = {"response": "some value"} + single_resp.raise_for_status = MagicMock() + + with patch("requests.post", return_value=single_resp) as mock_post: + llm2 = LLM(transcript_text=SAMPLE_TRANSCRIPT, target_fields=SAMPLE_FIELDS) + llm2.main_loop() + sequential_calls = mock_post.call_count + + assert batch_calls == 1 + assert sequential_calls == n_fields + assert batch_calls < sequential_calls diff --git a/tests/test_templates.py b/tests/test_templates.py index bbced2b..fef0f86 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,18 +1,54 @@ +from unittest.mock import patch + + def test_create_template(client): - payload = { - "name": "Template 1", - "pdf_path": "src/inputs/file.pdf", - "fields": { - "Employee's name": "string", - "Employee's job title": "string", - "Employee's department supervisor": "string", - "Employee's phone number": "string", - "Employee's email": "string", - "Signature": "string", - "Date": "string", - }, - } - - response = client.post("/templates/create", json=payload) - - assert response.status_code == 200 + with patch("api.routes.templates.Controller") as MockController: + MockController.return_value.create_template.return_value = "src/inputs/file_template.pdf" + + payload = { + "name": "Template 1", + "pdf_path": "src/inputs/file.pdf", + "fields": { + "Employee's name": "string", + "Employee's job title": "string", + "Employee's department supervisor": "string", + "Employee's phone number": "string", + "Employee's email": "string", + "Signature": "string", + "Date": "string", + }, + } + + response = client.post("/templates/create", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Template 1" + assert data["pdf_path"] == "src/inputs/file_template.pdf" + assert "id" in data + + +def test_create_template_missing_name(client): + with patch("api.routes.templates.Controller") as MockController: + MockController.return_value.create_template.return_value = "src/inputs/file_template.pdf" + + payload = { + "pdf_path": "src/inputs/file.pdf", + "fields": {"Employee's name": "string"}, + } + + response = client.post("/templates/create", json=payload) + assert response.status_code == 422 + + +def test_create_template_missing_fields(client): + with patch("api.routes.templates.Controller") as MockController: + MockController.return_value.create_template.return_value = "src/inputs/file_template.pdf" + + payload = { + "name": "Bad Template", + "pdf_path": "src/inputs/file.pdf", + } + + response = client.post("/templates/create", json=payload) + assert response.status_code == 422