diff --git a/Dockerfile b/Dockerfile index b2949405..0f241e7c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,6 +20,7 @@ COPY collector_manager ./collector_manager COPY core ./core COPY html_tag_collector ./html_tag_collector COPY hugging_face/url_relevance ./hugging_face/url_relevance +COPY hugging_face/url_record_type_labeling ./hugging_face/url_record_type_labeling COPY hugging_face/HuggingFaceInterface.py ./hugging_face/HuggingFaceInterface.py COPY source_collectors ./source_collectors COPY util ./util @@ -28,6 +29,8 @@ COPY apply_migrations.py ./apply_migrations.py COPY security_manager ./security_manager COPY execute.sh ./execute.sh COPY .project-root ./.project-root +COPY tests ./tests +COPY llm_api_logic ./llm_api_logic # Expose the application port EXPOSE 80 @@ -35,4 +38,4 @@ EXPOSE 80 RUN chmod +x execute.sh # Use the below for ease of local development, but remove when pushing to GitHub # Because there is no .env file in the repository (for security reasons) -#COPY .env ./.env +COPY .env ./.env diff --git a/ENV.md b/ENV.md index a8210fb9..943ad293 100644 --- a/ENV.md +++ b/ENV.md @@ -16,3 +16,4 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`POSTGRES_PORT` | The port for the test database | `5432` | |`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token that is used in the Data Sources App for encoding. |`abc123`| |`DEV`| Set to any value to run the application in development mode. |`true`| +|'DEEPSEEK_API_KEY'| The API key required for accessing the DeepSeek API. |`abc123`| diff --git a/api/main.py b/api/main.py index 356467af..0a9c0249 100644 --- a/api/main.py +++ b/api/main.py @@ -6,6 +6,7 @@ from api.routes.batch import batch_router from api.routes.collector import collector_router from api.routes.root import root_router +from api.routes.task import task_router from api.routes.url import url_router from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DatabaseClient import DatabaseClient @@ -71,8 +72,15 @@ async def setup_database(db_client): lifespan=lifespan ) -app.include_router(root_router) -app.include_router(collector_router) -app.include_router(batch_router) -app.include_router(annotate_router) -app.include_router(url_router) \ No newline at end of file +routers = [ + root_router, + collector_router, + batch_router, + annotate_router, + url_router, + task_router +] +for router in routers: + app.include_router(router) + + diff --git a/api/routes/task.py b/api/routes/task.py new file mode 100644 index 00000000..d9cdbeac --- /dev/null +++ b/api/routes/task.py @@ -0,0 +1,49 @@ +from typing import Optional + +from fastapi import APIRouter, Depends, Query, Path + +from api.dependencies import get_async_core +from collector_db.DTOs.TaskInfo import TaskInfo +from collector_db.enums import TaskType +from core.AsyncCore import AsyncCore +from core.enums import BatchStatus +from security_manager.SecurityManager import AccessInfo, get_access_info + +task_router = APIRouter( + prefix="/task", + tags=["Task"], + responses={404: {"description": "Not found"}}, +) + + +@task_router.get("") +async def get_tasks( + page: int = Query( + description="The page number", + default=1 + ), + task_status: Optional[BatchStatus] = Query( + description="Filter by task status", + default=None + ), + task_type: Optional[TaskType] = Query( + description="Filter by task type", + default=None + ), + async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info) +): + return await async_core.get_tasks( + page=page, + task_type=task_type, + task_status=task_status + ) + + +@task_router.get("/{task_id}") +async def get_task_info( + task_id: int = Path(description="The task id"), + async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info) +) -> TaskInfo: + return await async_core.get_task_info(task_id) \ No newline at end of file diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index db94a8d5..07f1cc10 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -1,24 +1,31 @@ from functools import wraps +from typing import Optional -from sqlalchemy import select, exists +from sqlalchemy import select, exists, func from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import selectinload from collector_db.ConfigManager import ConfigManager from collector_db.DTOs.MetadataAnnotationInfo import MetadataAnnotationInfo +from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.DTOs.URLAnnotationInfo import URLAnnotationInfo from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo +from collector_db.DTOs.URLInfo import URLInfo from collector_db.DTOs.URLMetadataInfo import URLMetadataInfo from collector_db.DTOs.URLWithHTML import URLWithHTML -from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource +from collector_db.StatementComposer import StatementComposer +from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource, TaskType from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import URLMetadata, URL, URLErrorInfo, URLHTMLContent, Base, MetadataAnnotation, \ - RootURL + RootURL, Task, TaskError, LinkTaskURL from collector_manager.enums import URLStatus +from core.DTOs.GetTasksResponse import GetTasksResponse, GetTasksResponseTaskInfo from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo, GetURLsResponseMetadataInfo, GetURLsResponseErrorInfo, \ GetURLsResponseInnerInfo from core.DTOs.RelevanceAnnotationInfo import RelevanceAnnotationPostInfo +from core.enums import BatchStatus + def add_standard_limit_and_offset(statement, page, limit=100): offset = (page - 1) * limit @@ -31,6 +38,14 @@ def __init__(self, db_url: str = get_postgres_connection_string(is_async=True)): echo=ConfigManager.get_sqlalchemy_echo(), ) self.session_maker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + self.statement_composer = StatementComposer() + + @staticmethod + def _add_models(session: AsyncSession, model_class, models): + for model in models: + instance = model_class(**model.model_dump()) + session.add(instance) + @staticmethod def session_manager(method): @@ -66,14 +81,11 @@ async def get_url_metadata_by_status( @session_manager async def add_url_metadata(self, session: AsyncSession, url_metadata_info: URLMetadataInfo): - url_metadata = URLMetadata(**url_metadata_info.model_dump()) - session.add(url_metadata) + self._add_models(session, URLMetadata, [url_metadata_info]) @session_manager async def add_url_metadatas(self, session: AsyncSession, url_metadata_infos: list[URLMetadataInfo]): - for url_metadata_info in url_metadata_infos: - url_metadata = URLMetadata(**url_metadata_info.model_dump()) - session.add(url_metadata) + self._add_models(session, URLMetadata, url_metadata_infos) @session_manager async def add_url_error_infos(self, session: AsyncSession, url_error_infos: list[URLErrorPydanticInfo]): @@ -88,59 +100,58 @@ async def add_url_error_infos(self, session: AsyncSession, url_error_infos: list @session_manager async def get_urls_with_errors(self, session: AsyncSession) -> list[URLErrorPydanticInfo]: - statement = (select(URL, URLErrorInfo.error, URLErrorInfo.updated_at) + statement = (select(URL, URLErrorInfo.error, URLErrorInfo.updated_at, URLErrorInfo.task_id) .join(URLErrorInfo) .where(URL.outcome == URLStatus.ERROR.value) .order_by(URL.id)) scalar_result = await session.execute(statement) results = scalar_result.all() final_results = [] - for url, error, updated_at in results: - final_results.append(URLErrorPydanticInfo(url_id=url.id, error=error, updated_at=updated_at)) + for url, error, updated_at, task_id in results: + final_results.append(URLErrorPydanticInfo( + url_id=url.id, + error=error, + updated_at=updated_at, + task_id=task_id + )) return final_results @session_manager async def add_html_content_infos(self, session: AsyncSession, html_content_infos: list[URLHTMLContentInfo]): - for html_content_info in html_content_infos: - # Add HTML Content Info to database - db_html_content_info = URLHTMLContent(**html_content_info.model_dump()) - session.add(db_html_content_info) + self._add_models(session, URLHTMLContent, html_content_infos) + + @session_manager + async def has_pending_urls_without_html_data(self, session: AsyncSession) -> bool: + statement = self.statement_composer.pending_urls_without_html_data() + statement = statement.limit(1) + scalar_result = await session.scalars(statement) + return bool(scalar_result.first()) @session_manager async def get_pending_urls_without_html_data(self, session: AsyncSession): # TODO: Add test that includes some urls WITH html data. Check they're not returned - statement = (select(URL). - outerjoin(URLHTMLContent). - where(URLHTMLContent.id == None). - where(URL.outcome == URLStatus.PENDING.value). - limit(100). - order_by(URL.id)) + statement = self.statement_composer.pending_urls_without_html_data() + statement = statement.limit(100).order_by(URL.id) scalar_result = await session.scalars(statement) return scalar_result.all() @session_manager - async def get_urls_with_html_data_and_no_relevancy_metadata( + async def get_urls_with_html_data_and_without_metadata_type( self, - session: AsyncSession + session: AsyncSession, + without_metadata_type: URLMetadataAttributeType = URLMetadataAttributeType.RELEVANT ) -> list[URLWithHTML]: + # Get URLs with no relevancy metadata statement = (select(URL.id, URL.url, URLHTMLContent). join(URLHTMLContent). - where(URL.outcome == URLStatus.PENDING.value) - # No relevancy metadata - .where( - ~exists( - select(URLMetadata.id). - where( - URLMetadata.url_id == URL.id, - URLMetadata.attribute == URLMetadataAttributeType.RELEVANT.value - ) - ) - ) - .limit(100) - .order_by(URL.id) + where(URL.outcome == URLStatus.PENDING.value)) + statement = self.statement_composer.exclude_urls_with_select_metadata( + statement=statement, + attribute=without_metadata_type ) + statement = statement.limit(100).order_by(URL.id) raw_result = await session.execute(statement) result = raw_result.all() url_ids_to_urls = {url_id: url for url_id, url, _ in result} @@ -163,6 +174,26 @@ async def get_urls_with_html_data_and_no_relevancy_metadata( return final_results + @session_manager + async def has_pending_urls_with_html_data_and_without_metadata_type( + self, + session: AsyncSession, + without_metadata_type: URLMetadataAttributeType = URLMetadataAttributeType.RELEVANT + ) -> bool: + # TODO: Generalize this so that it can exclude based on other attributes + # Get URLs with no relevancy metadata + statement = (select(URL.id, URL.url, URLHTMLContent). + join(URLHTMLContent). + where(URL.outcome == URLStatus.PENDING.value)) + statement = self.statement_composer.exclude_urls_with_select_metadata( + statement=statement, + attribute=without_metadata_type + ) + statement = statement.limit(1) + raw_result = await session.execute(statement) + result = raw_result.all() + return len(result) > 0 + @session_manager async def get_urls_with_metadata( self, @@ -377,5 +408,156 @@ async def get_urls(self, session: AsyncSession, page: int, errors: bool) -> GetU count=len(final_results) ) + @session_manager + async def initiate_task( + self, + session: AsyncSession, + task_type: TaskType + ) -> int: + # Create Task + task = Task( + task_type=task_type, + task_status=BatchStatus.IN_PROCESS.value + ) + session.add(task) + # Return Task ID + await session.flush() + await session.refresh(task) + return task.id + @session_manager + async def update_task_status(self, session: AsyncSession, task_id: int, status: BatchStatus): + task = await session.get(Task, task_id) + task.task_status = status.value + await session.commit() + + @session_manager + async def add_task_error(self, session: AsyncSession, task_id: int, error: str): + task_error = TaskError( + task_id=task_id, + error=error + ) + session.add(task_error) + await session.commit() + @session_manager + async def get_task_info(self, session: AsyncSession, task_id: int) -> TaskInfo: + # Get Task + result = await session.execute( + select(Task) + .where(Task.id == task_id) + .options( + selectinload(Task.urls), + selectinload(Task.error), + selectinload(Task.errored_urls) + ) + ) + task = result.scalars().first() + error = task.error[0].error if len(task.error) > 0 else None + # Get error info if any + # Get URLs + urls = task.urls + url_infos = [] + for url in urls: + url_info = URLInfo( + id=url.id, + batch_id=url.batch_id, + url=url.url, + collector_metadata=url.collector_metadata, + outcome=URLStatus(url.outcome), + updated_at=url.updated_at + ) + url_infos.append(url_info) + + errored_urls = [] + for url in task.errored_urls: + url_error_info = URLErrorPydanticInfo( + task_id=url.task_id, + url_id=url.url_id, + error=url.error, + updated_at=url.updated_at + ) + errored_urls.append(url_error_info) + return TaskInfo( + task_type=TaskType(task.task_type), + task_status=BatchStatus(task.task_status), + error_info=error, + updated_at=task.updated_at, + urls=url_infos, + url_errors=errored_urls + ) + + @session_manager + async def get_html_content_info(self, session: AsyncSession, url_id: int) -> list[URLHTMLContentInfo]: + session_result = await session.execute( + select(URLHTMLContent) + .where(URLHTMLContent.url_id == url_id) + ) + results = session_result.scalars().all() + return [URLHTMLContentInfo(**result.__dict__) for result in results] + + + + @session_manager + async def link_urls_to_task(self, session: AsyncSession, task_id: int, url_ids: list[int]): + for url_id in url_ids: + link = LinkTaskURL( + url_id=url_id, + task_id=task_id + ) + session.add(link) + + @session_manager + async def get_tasks( + self, + session: AsyncSession, + task_type: Optional[TaskType] = None, + task_status: Optional[BatchStatus] = None, + page: int = 1 + ) -> GetTasksResponse: + url_count_subquery = self.statement_composer.simple_count_subquery( + LinkTaskURL, + 'task_id', + 'url_count' + ) + + url_error_count_subquery = self.statement_composer.simple_count_subquery( + URLErrorInfo, + 'task_id', + 'url_error_count' + ) + + statement = select( + Task, + url_count_subquery.c.url_count, + url_error_count_subquery.c.url_error_count + ).outerjoin( + url_count_subquery, + Task.id == url_count_subquery.c.task_id + ).outerjoin( + url_error_count_subquery, + Task.id == url_error_count_subquery.c.task_id + ) + if task_type is not None: + statement = statement.where(Task.task_type == task_type.value) + if task_status is not None: + statement = statement.where(Task.task_status == task_status.value) + add_standard_limit_and_offset(statement, page) + + execute_result = await session.execute(statement) + all_results = execute_result.all() + final_results = [] + for task, url_count, url_error_count in all_results: + final_results.append( + GetTasksResponseTaskInfo( + task_id=task.id, + type=TaskType(task.task_type), + status=BatchStatus(task.task_status), + url_count=url_count if url_count is not None else 0, + url_error_count=url_error_count if url_error_count is not None else 0, + updated_at=task.updated_at + ) + ) + return GetTasksResponse( + tasks=final_results + ) diff --git a/collector_db/DTOs/RelevanceLabelStudioInputCycleInfo.py b/collector_db/DTOs/RelevanceLabelStudioInputCycleInfo.py deleted file mode 100644 index 644e0e27..00000000 --- a/collector_db/DTOs/RelevanceLabelStudioInputCycleInfo.py +++ /dev/null @@ -1,9 +0,0 @@ -from pydantic import BaseModel - -from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo - - -class RelevanceLabelStudioInputCycleInfo(BaseModel): - url: str - metadata_id: int - html_content_info: list[URLHTMLContentInfo] \ No newline at end of file diff --git a/collector_db/DTOs/TaskInfo.py b/collector_db/DTOs/TaskInfo.py new file mode 100644 index 00000000..e8d8090d --- /dev/null +++ b/collector_db/DTOs/TaskInfo.py @@ -0,0 +1,18 @@ +import datetime +from typing import Optional + +from pydantic import BaseModel + +from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo +from collector_db.DTOs.URLInfo import URLInfo +from collector_db.enums import TaskType +from core.enums import BatchStatus + + +class TaskInfo(BaseModel): + task_type: TaskType + task_status: BatchStatus + updated_at: datetime.datetime + error_info: Optional[str] = None + urls: list[URLInfo] + url_errors: list[URLErrorPydanticInfo] \ No newline at end of file diff --git a/collector_db/DTOs/URLErrorInfos.py b/collector_db/DTOs/URLErrorInfos.py index cf73a6dc..46f5b9fa 100644 --- a/collector_db/DTOs/URLErrorInfos.py +++ b/collector_db/DTOs/URLErrorInfos.py @@ -5,6 +5,7 @@ class URLErrorPydanticInfo(BaseModel): + task_id: int url_id: int error: str updated_at: Optional[datetime.datetime] = None \ No newline at end of file diff --git a/collector_db/DTOs/URLHTMLContentInfo.py b/collector_db/DTOs/URLHTMLContentInfo.py index ffd82724..f8b24eb0 100644 --- a/collector_db/DTOs/URLHTMLContentInfo.py +++ b/collector_db/DTOs/URLHTMLContentInfo.py @@ -18,4 +18,4 @@ class HTMLContentType(Enum): class URLHTMLContentInfo(BaseModel): url_id: Optional[int] = None content_type: HTMLContentType - content: str \ No newline at end of file + content: str | list[str] \ No newline at end of file diff --git a/collector_db/DTOs/URLMetadataInfo.py b/collector_db/DTOs/URLMetadataInfo.py index 9cbc7dca..461d16e9 100644 --- a/collector_db/DTOs/URLMetadataInfo.py +++ b/collector_db/DTOs/URLMetadataInfo.py @@ -12,6 +12,7 @@ class URLMetadataInfo(BaseModel): attribute: Optional[URLMetadataAttributeType] = None # TODO: May need to add validation here depending on the type of attribute value: Optional[str] = None + notes: Optional[str] = None validation_status: Optional[ValidationStatus] = None validation_source: Optional[ValidationSource] = None created_at: Optional[datetime] = None diff --git a/collector_db/StatementComposer.py b/collector_db/StatementComposer.py new file mode 100644 index 00000000..dc756fb3 --- /dev/null +++ b/collector_db/StatementComposer.py @@ -0,0 +1,43 @@ + +from sqlalchemy import Select, select, exists, Table, func, Subquery + +from collector_db.enums import URLMetadataAttributeType +from collector_db.models import URL, URLHTMLContent, URLMetadata +from collector_manager.enums import URLStatus + + +class StatementComposer: + """ + Assists in the composition of SQLAlchemy statements + """ + + @staticmethod + def pending_urls_without_html_data() -> Select: + return (select(URL). + outerjoin(URLHTMLContent). + where(URLHTMLContent.id == None). + where(URL.outcome == URLStatus.PENDING.value)) + + @staticmethod + def exclude_urls_with_select_metadata( + statement: Select, + attribute: URLMetadataAttributeType + ) -> Select: + return (statement.where( + ~exists( + select(URLMetadata.id). + where( + URLMetadata.url_id == URL.id, + URLMetadata.attribute == attribute.value + ) + ) + )) + + @staticmethod + def simple_count_subquery(model, attribute: str, label: str) -> Subquery: + attr_value = getattr(model, attribute) + return select( + attr_value, + func.count(attr_value).label(label) + ).group_by(attr_value).subquery() + diff --git a/collector_db/alembic/versions/072b32a45b1c_add_task_tables_and_linking_logic.py b/collector_db/alembic/versions/072b32a45b1c_add_task_tables_and_linking_logic.py new file mode 100644 index 00000000..b2174484 --- /dev/null +++ b/collector_db/alembic/versions/072b32a45b1c_add_task_tables_and_linking_logic.py @@ -0,0 +1,80 @@ +"""Add Task Tables and linking logic + +Revision ID: 072b32a45b1c +Revises: dae00e5aa8dd +Create Date: 2025-01-27 15:48:02.713484 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from collector_db.enums import PGEnum + +# revision identifiers, used by Alembic. +revision: str = '072b32a45b1c' +down_revision: Union[str, None] = 'dae00e5aa8dd' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +task_type = PGEnum( + 'HTML', + 'Relevancy', + 'Record Type', + name='task_type', +) + + +def upgrade() -> None: + op.create_table('tasks', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('task_type', task_type, nullable=False), + sa.Column( + 'task_status', + PGEnum( + 'complete', 'error', 'in-process', 'aborted', + name='batch_status', + create_type=False + ), + nullable=False + ), + sa.Column('updated_at', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.PrimaryKeyConstraint('id'), + ) + op.create_table('task_errors', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('task_id', sa.Integer(), nullable=False), + sa.Column('error', sa.Text(), nullable=False), + sa.Column('updated_at', sa.TIMESTAMP(), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['task_id'], ['tasks.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('link_task_urls', + sa.Column('task_id', sa.Integer(), nullable=False), + sa.Column('url_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['task_id'], ['tasks.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['url_id'], ['urls.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('task_id', 'url_id'), + sa.UniqueConstraint('task_id', 'url_id', name='uq_task_id_url_id') + ) + # Change to URL Error Info requires deleting prior data + op.execute("DELETE FROM url_error_info;") + + op.add_column('url_error_info', sa.Column('task_id', sa.Integer(), nullable=False)) + op.add_column("url_metadata", sa.Column('notes', sa.Text(), nullable=True)) + op.create_unique_constraint('uq_url_id_error', 'url_error_info', ['url_id', 'task_id']) + op.create_foreign_key("url_error_info_task_id_fkey", 'url_error_info', 'tasks', ['task_id'], ['id']) + + +def downgrade() -> None: + + op.drop_constraint("url_error_info_task_id_fkey", 'url_error_info', type_='foreignkey') + op.drop_constraint('uq_url_id_error', 'url_error_info', type_='unique') + op.drop_column('url_error_info', 'task_id') + op.drop_column('url_metadata', 'notes') + op.drop_table('link_task_urls') + op.drop_table('task_errors') + op.drop_table('tasks') + + task_type.drop(op.get_bind(), checkfirst=True) diff --git a/collector_db/enums.py b/collector_db/enums.py index fa66aac4..a6f3c95e 100644 --- a/collector_db/enums.py +++ b/collector_db/enums.py @@ -32,6 +32,10 @@ class URLHTMLContentType(PyEnum): H6 = "H6" DIV = "Div" +class TaskType(PyEnum): + HTML = "HTML" + RELEVANCY = "Relevancy" + RECORD_TYPE = "Record Type" class PGEnum(TypeDecorator): impl = postgresql.ENUM diff --git a/collector_db/models.py b/collector_db/models.py index 273d956f..aa33d41e 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -16,6 +16,7 @@ CURRENT_TIME_SERVER_DEFAULT = func.now() +batch_status_enum = PGEnum('complete', 'error', 'in-process', 'aborted', name='batch_status') class Batch(Base): __tablename__ = 'batches' @@ -29,9 +30,7 @@ class Batch(Base): user_id = Column(Integer, nullable=False) # Gives the status of the batch status = Column( - postgresql.ENUM( - 'complete', 'error', 'in-process', 'aborted', - name='batch_status'), + batch_status_enum, nullable=False ) # The number of URLs in the batch @@ -86,6 +85,11 @@ class URL(Base): url_metadata = relationship("URLMetadata", back_populates="url", cascade="all, delete-orphan") html_content = relationship("URLHTMLContent", back_populates="url", cascade="all, delete-orphan") error_info = relationship("URLErrorInfo", back_populates="url", cascade="all, delete-orphan") + tasks = relationship( + "Task", + secondary="link_task_urls", + back_populates="urls", + ) # URL Metadata table definition @@ -110,6 +114,8 @@ class URLMetadata(Base): PGEnum('Machine Learning', 'Label Studio', 'Manual', name='validation_source'), nullable=False ) + notes = Column(Text, nullable=True) + # Timestamps created_at = Column(TIMESTAMP, nullable=False, server_default=func.now()) @@ -153,14 +159,21 @@ class RootURL(Base): class URLErrorInfo(Base): __tablename__ = 'url_error_info' + __table_args__ = (UniqueConstraint( + "url_id", + "task_id", + name="uq_url_id_error"), + ) id = Column(Integer, primary_key=True) url_id = Column(Integer, ForeignKey('urls.id'), nullable=False) error = Column(Text, nullable=False) updated_at = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) + task_id = Column(Integer, ForeignKey('tasks.id'), nullable=False) # Relationships url = relationship("URL", back_populates="error_info") + task = relationship("Task", back_populates="errored_urls") class URLHTMLContent(Base): __tablename__ = 'url_html_content' @@ -231,3 +244,51 @@ class Missing(Base): # Relationships batch = relationship("Batch", back_populates="missings") + +class Task(Base): + __tablename__ = 'tasks' + + id = Column(Integer, primary_key=True) + task_type = Column( + PGEnum( + 'HTML', 'Relevancy', 'Record Type', name='task_type' + ), nullable=False) + task_status = Column(batch_status_enum, nullable=False) + updated_at = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) + + # Relationships + urls = relationship( + "URL", + secondary="link_task_urls", + back_populates="tasks" + ) + error = relationship("TaskError", back_populates="task") + errored_urls = relationship("URLErrorInfo", back_populates="task") + +class LinkTaskURL(Base): + __tablename__ = 'link_task_urls' + __table_args__ = (UniqueConstraint( + "task_id", + "url_id", + name="uq_task_id_url_id"), + ) + + task_id = Column(Integer, ForeignKey('tasks.id', ondelete="CASCADE"), primary_key=True) + url_id = Column(Integer, ForeignKey('urls.id', ondelete="CASCADE"), primary_key=True) + +class TaskError(Base): + __tablename__ = 'task_errors' + + id = Column(Integer, primary_key=True) + task_id = Column(Integer, ForeignKey('tasks.id', ondelete="CASCADE"), nullable=False) + error = Column(Text, nullable=False) + updated_at = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) + + # Relationships + task = relationship("Task", back_populates="error") + + __table_args__ = (UniqueConstraint( + "task_id", + "error", + name="uq_task_id_error"), + ) \ No newline at end of file diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 67f134b1..afa5c7ab 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -1,17 +1,23 @@ import logging from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.DTOs.URLAnnotationInfo import URLAnnotationInfo +from collector_db.enums import TaskType from core.DTOs.GetNextURLForRelevanceAnnotationResponse import GetNextURLForRelevanceAnnotationResponse +from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo from core.DTOs.RelevanceAnnotationInfo import RelevanceAnnotationPostInfo from core.DTOs.RelevanceAnnotationRequestInfo import RelevanceAnnotationRequestInfo -from core.classes.URLHTMLCycler import URLHTMLCycler -from core.classes.URLRelevanceHuggingfaceCycler import URLRelevanceHuggingfaceCycler +from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator +from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator +from core.classes.URLRelevanceHuggingfaceTaskOperator import URLRelevanceHuggingfaceTaskOperator +from core.enums import BatchStatus from html_tag_collector.DataClassTags import convert_to_response_html_info from html_tag_collector.ResponseParser import HTMLResponseParser from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface +from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier class AsyncCore: @@ -30,26 +36,35 @@ def __init__( self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) - async def run_url_html_cycle(self): - self.logger.info("Running URL HTML Cycle") - cycler = URLHTMLCycler( + async def run_url_html_task(self): + self.logger.info("Running URL HTML Task") + operator = URLHTMLTaskOperator( adb_client=self.adb_client, url_request_interface=self.url_request_interface, html_parser=self.html_parser ) - await cycler.cycle() + await operator.run_task() - async def run_url_relevance_huggingface_cycle(self): - self.logger.info("Running URL Relevance Huggingface Cycle") - cycler = URLRelevanceHuggingfaceCycler( + async def run_url_relevance_huggingface_task(self): + self.logger.info("Running URL Relevance Huggingface Task") + operator = URLRelevanceHuggingfaceTaskOperator( adb_client=self.adb_client, huggingface_interface=self.huggingface_interface ) - await cycler.cycle() + await operator.run_task() - async def run_cycles(self): - await self.run_url_html_cycle() - await self.run_url_relevance_huggingface_cycle() + async def run_url_record_type_task(self): + self.logger.info("Running URL Record Type Task") + operator = URLRecordTypeTaskOperator( + adb_client=self.adb_client, + classifier=OpenAIRecordClassifier() + ) + await operator.run_task() + + async def run_tasks(self): + await self.run_url_html_task() + await self.run_url_relevance_huggingface_task() + await self.run_url_record_type_task() async def convert_to_relevance_annotation_request_info(self, url_info: URLAnnotationInfo) -> RelevanceAnnotationRequestInfo: response_html_info = convert_to_response_html_info( @@ -87,3 +102,9 @@ async def submit_url_relevance_annotation( async def get_urls(self, page: int, errors: bool) -> GetURLsResponseInfo: return await self.adb_client.get_urls(page=page, errors=errors) + + async def get_task_info(self, task_id: int) -> TaskInfo: + return await self.adb_client.get_task_info(task_id=task_id) + + async def get_tasks(self, page: int, task_type: TaskType, task_status: BatchStatus) -> GetTasksResponse: + return await self.adb_client.get_tasks(page=page, task_type=task_type, task_status=task_status) diff --git a/core/DTOs/GetTasksResponse.py b/core/DTOs/GetTasksResponse.py new file mode 100644 index 00000000..42b3d954 --- /dev/null +++ b/core/DTOs/GetTasksResponse.py @@ -0,0 +1,19 @@ +import datetime + +from pydantic import BaseModel + +from collector_db.enums import TaskType +from core.enums import BatchStatus + + +class GetTasksResponseTaskInfo(BaseModel): + task_id: int + type: TaskType + status: BatchStatus + url_count: int + url_error_count: int + updated_at: datetime.datetime + + +class GetTasksResponse(BaseModel): + tasks: list[GetTasksResponseTaskInfo] diff --git a/core/DTOs/LabelStudioExportResponseInfo.py b/core/DTOs/LabelStudioExportResponseInfo.py deleted file mode 100644 index fae94096..00000000 --- a/core/DTOs/LabelStudioExportResponseInfo.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Annotated - -from fastapi.param_functions import Doc -from pydantic import BaseModel - - -class LabelStudioExportResponseInfo(BaseModel): - label_studio_import_id: Annotated[int, Doc("The ID of the Label Studio import")] - num_urls_imported: Annotated[int, Doc("The number of URLs imported")] \ No newline at end of file diff --git a/core/DTOs/LabelStudioTaskInfo.py b/core/DTOs/LabelStudioTaskInfo.py deleted file mode 100644 index 5c277c8a..00000000 --- a/core/DTOs/LabelStudioTaskInfo.py +++ /dev/null @@ -1,11 +0,0 @@ -from pydantic import BaseModel - -from collector_db.enums import URLMetadataAttributeType -from core.enums import LabelStudioTaskStatus - - -class LabelStudioTaskInfo(BaseModel): - metadata_id: int - attribute: URLMetadataAttributeType - task_id: int - task_status: LabelStudioTaskStatus \ No newline at end of file diff --git a/core/DTOs/task_data_objects/README.md b/core/DTOs/task_data_objects/README.md new file mode 100644 index 00000000..3d2fc5ae --- /dev/null +++ b/core/DTOs/task_data_objects/README.md @@ -0,0 +1 @@ +Task Data Objects (or TDOs) are data transfer objects (DTOs) used within a given task operation. Each Task type has one type of TDO. \ No newline at end of file diff --git a/core/DTOs/task_data_objects/URLRecordTypeTDO.py b/core/DTOs/task_data_objects/URLRecordTypeTDO.py new file mode 100644 index 00000000..34bbc233 --- /dev/null +++ b/core/DTOs/task_data_objects/URLRecordTypeTDO.py @@ -0,0 +1,15 @@ +from typing import Optional + +from pydantic import BaseModel + +from collector_db.DTOs.URLWithHTML import URLWithHTML +from core.enums import RecordType + + +class URLRecordTypeTDO(BaseModel): + url_with_html: URLWithHTML + record_type: Optional[RecordType] = None + error: Optional[str] = None + + def is_errored(self): + return self.error is not None \ No newline at end of file diff --git a/core/DTOs/URLRelevanceHuggingfaceCycleInfo.py b/core/DTOs/task_data_objects/URLRelevanceHuggingfaceTDO.py similarity index 78% rename from core/DTOs/URLRelevanceHuggingfaceCycleInfo.py rename to core/DTOs/task_data_objects/URLRelevanceHuggingfaceTDO.py index 19318e6a..33311a9b 100644 --- a/core/DTOs/URLRelevanceHuggingfaceCycleInfo.py +++ b/core/DTOs/task_data_objects/URLRelevanceHuggingfaceTDO.py @@ -5,6 +5,6 @@ from collector_db.DTOs.URLWithHTML import URLWithHTML -class URLRelevanceHuggingfaceCycleInfo(BaseModel): +class URLRelevanceHuggingfaceTDO(BaseModel): url_with_html: URLWithHTML relevant: Optional[bool] = None diff --git a/core/DTOs/URLHTMLCycleInfo.py b/core/DTOs/task_data_objects/UrlHtmlTDO.py similarity index 94% rename from core/DTOs/URLHTMLCycleInfo.py rename to core/DTOs/task_data_objects/UrlHtmlTDO.py index 1d739375..05e9caf2 100644 --- a/core/DTOs/URLHTMLCycleInfo.py +++ b/core/DTOs/task_data_objects/UrlHtmlTDO.py @@ -7,7 +7,7 @@ @dataclass -class URLHTMLCycleInfo: +class UrlHtmlTDO: url_info: URLInfo url_response_info: Optional[URLResponseInfo] = None html_tag_info: Optional[ResponseHTMLInfo] = None diff --git a/tests/test_automated/integration/cycles/__init__.py b/core/DTOs/task_data_objects/__init__.py similarity index 100% rename from tests/test_automated/integration/cycles/__init__.py rename to core/DTOs/task_data_objects/__init__.py diff --git a/core/README.md b/core/README.md index c9095c41..25b1cde3 100644 --- a/core/README.md +++ b/core/README.md @@ -2,4 +2,13 @@ The Source Collector Core is a directory which integrates: 1. The Collector Manager 2. The Source Collector Database 3. The API (to be developed) -4. The PDAP API Client (to be developed) \ No newline at end of file +4. The PDAP API Client (to be developed) + +# Nomenclature + +- **Collector**: A submodule for collecting URLs. Different collectors utilize different sources and different methods for gathering URLs. +- **Batch**: URLs are collected in Collector Batches, with different collectors producing different Batches. +- **Cycle**: Refers to the overall lifecycle for Each URL -- from initial retrieval in a Batch to either disposal or incorporation into the Data Sources App Database +- **Task**: A semi-independent operation performed on a set of URLs. These include: Collection, retrieving HTML data, getting metadata via Machine Learning, and so on. +- **Task Set**: Refers to a group of URLs that are operated on together as part of a single task. These URLs in a set are not necessarily all from the same batch. URLs in a task set should only be operated on in that task once. +- **Task Operator**: A class which performs a single task on a set of URLs. \ No newline at end of file diff --git a/core/ScheduledTaskManager.py b/core/ScheduledTaskManager.py index 590690d1..5b2ff0a7 100644 --- a/core/ScheduledTaskManager.py +++ b/core/ScheduledTaskManager.py @@ -52,11 +52,12 @@ def __init__(self, async_core: AsyncCore): def add_scheduled_tasks(self): self.run_cycles_job = self.scheduler.add_job( - self.async_core.run_cycles, + self.async_core.run_tasks, trigger=IntervalTrigger( hours=1, start_date=datetime.now() + timedelta(minutes=1) - ) + ), + misfire_grace_time=60 ) def shutdown(self): diff --git a/core/classes/TaskOperatorBase.py b/core/classes/TaskOperatorBase.py new file mode 100644 index 00000000..7998713c --- /dev/null +++ b/core/classes/TaskOperatorBase.py @@ -0,0 +1,64 @@ + +from abc import ABC, abstractmethod +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.enums import TaskType +from core.enums import BatchStatus + + +class TaskOperatorBase(ABC): + + def __init__(self, adb_client: AsyncDatabaseClient): + self.adb_client = adb_client + self.task_id = None + self.tasks_linked = False + + @property + @abstractmethod + def task_type(self) -> TaskType: + raise NotImplementedError + + @abstractmethod + async def meets_task_prerequisites(self): + """ + A task should not be initiated unless certain + conditions are met + """ + raise NotImplementedError + + async def link_urls_to_task(self, url_ids: list[int]): + await self.adb_client.link_urls_to_task(task_id=self.task_id, url_ids=url_ids) + self.tasks_linked = True + + async def initiate_task_in_db(self) -> int: + task_id = await self.adb_client.initiate_task( + task_type=self.task_type + ) + return task_id + + async def conclude_task_in_db(self): + if not self.tasks_linked: + raise Exception("Task has not been linked to any URLs") + await self.adb_client.update_task_status(task_id=self.task_id, status=BatchStatus.COMPLETE) + + async def run_task(self): + if not await self.meets_task_prerequisites(): + print(f"Task {self.task_type.value} does not meet prerequisites. Skipping...") + return + self.task_id = await self.initiate_task_in_db() + + try: + await self.inner_task_logic() + await self.conclude_task_in_db() + except Exception as e: + await self.handle_task_error(e) + + @abstractmethod + async def inner_task_logic(self): + raise NotImplementedError + + async def handle_task_error(self, e): + await self.adb_client.update_task_status(task_id=self.task_id, status=BatchStatus.ERROR) + await self.adb_client.add_task_error( + task_id=self.task_id, + error=str(e) + ) diff --git a/core/classes/URLHTMLCycler.py b/core/classes/URLHTMLCycler.py deleted file mode 100644 index 73344a9c..00000000 --- a/core/classes/URLHTMLCycler.py +++ /dev/null @@ -1,95 +0,0 @@ -from collector_db.AsyncDatabaseClient import AsyncDatabaseClient -from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo -from collector_db.DTOs.URLInfo import URLInfo -from core.DTOs.URLHTMLCycleInfo import URLHTMLCycleInfo -from core.classes.HTMLContentInfoGetter import HTMLContentInfoGetter -from html_tag_collector.ResponseParser import HTMLResponseParser -from html_tag_collector.URLRequestInterface import URLRequestInterface - - -class URLHTMLCycler: - - def __init__( - self, - url_request_interface: URLRequestInterface, - adb_client: AsyncDatabaseClient, - html_parser: HTMLResponseParser - ): - self.url_request_interface = url_request_interface - self.adb_client = adb_client - self.html_parser = html_parser - - async def cycle(self): - print("Running URL HTML Cycle...") - cycle_infos = await self.get_pending_urls_without_html_data() - await self.get_raw_html_data_for_urls(cycle_infos) - success_cycles, error_cycles = await self.separate_success_and_error_cycles(cycle_infos) - await self.update_errors_in_database(error_cycles) - await self.process_html_data(success_cycles) - await self.update_html_data_in_database(success_cycles) - - - async def get_just_urls(self, cycle_infos: list[URLHTMLCycleInfo]): - return [cycle_info.url_info.url for cycle_info in cycle_infos] - - async def get_pending_urls_without_html_data(self): - pending_urls: list[URLInfo] = await self.adb_client.get_pending_urls_without_html_data() - cycle_infos = [ - URLHTMLCycleInfo( - url_info=url_info, - ) for url_info in pending_urls - ] - return cycle_infos - - async def get_raw_html_data_for_urls(self, cycle_infos: list[URLHTMLCycleInfo]): - just_urls = await self.get_just_urls(cycle_infos) - url_response_infos = await self.url_request_interface.make_requests(just_urls) - for cycle_info, url_response_info in zip(cycle_infos, url_response_infos): - cycle_info.url_response_info = url_response_info - - async def separate_success_and_error_cycles( - self, - cycle_infos: list[URLHTMLCycleInfo] - ) -> tuple[ - list[URLHTMLCycleInfo], # Successful - list[URLHTMLCycleInfo] # Error - ]: - errored_cycle_infos = [] - successful_cycle_infos = [] - for cycle_info in cycle_infos: - if not cycle_info.url_response_info.success: - errored_cycle_infos.append(cycle_info) - else: - successful_cycle_infos.append(cycle_info) - return successful_cycle_infos, errored_cycle_infos - - async def update_errors_in_database(self, errored_cycle_infos: list[URLHTMLCycleInfo]): - error_infos = [] - for errored_cycle_info in errored_cycle_infos: - error_info = URLErrorPydanticInfo( - url_id=errored_cycle_info.url_info.id, - error=str(errored_cycle_info.url_response_info.exception), - ) - error_infos.append(error_info) - await self.adb_client.add_url_error_infos(error_infos) - - async def process_html_data(self, cycle_infos: list[URLHTMLCycleInfo]): - for cycle_info in cycle_infos: - html_tag_info = await self.html_parser.parse( - url=cycle_info.url_info.url, - html_content=cycle_info.url_response_info.html, - content_type=cycle_info.url_response_info.content_type - ) - cycle_info.html_tag_info = html_tag_info - - async def update_html_data_in_database(self, cycle_infos: list[URLHTMLCycleInfo]): - html_content_infos = [] - for cycle_info in cycle_infos: - hcig = HTMLContentInfoGetter( - response_html_info=cycle_info.html_tag_info, - url_id=cycle_info.url_info.id - ) - results = hcig.get_all_html_content() - html_content_infos.extend(results) - - await self.adb_client.add_html_content_infos(html_content_infos) diff --git a/core/classes/URLHTMLTaskOperator.py b/core/classes/URLHTMLTaskOperator.py new file mode 100644 index 00000000..63321635 --- /dev/null +++ b/core/classes/URLHTMLTaskOperator.py @@ -0,0 +1,107 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo +from collector_db.DTOs.URLInfo import URLInfo +from collector_db.enums import TaskType +from core.DTOs.task_data_objects.UrlHtmlTDO import UrlHtmlTDO +from core.classes.HTMLContentInfoGetter import HTMLContentInfoGetter +from core.classes.TaskOperatorBase import TaskOperatorBase +from html_tag_collector.ResponseParser import HTMLResponseParser +from html_tag_collector.URLRequestInterface import URLRequestInterface + + +class URLHTMLTaskOperator(TaskOperatorBase): + + def __init__( + self, + url_request_interface: URLRequestInterface, + adb_client: AsyncDatabaseClient, + html_parser: HTMLResponseParser + ): + super().__init__(adb_client) + self.url_request_interface = url_request_interface + self.html_parser = html_parser + + @property + def task_type(self): + return TaskType.HTML + + async def meets_task_prerequisites(self): + return await self.adb_client.has_pending_urls_without_html_data() + + async def inner_task_logic(self): + print("Running URL HTML Task...") + tdos = await self.get_pending_urls_without_html_data() + url_ids = [task_info.url_info.id for task_info in tdos] + await self.link_urls_to_task(url_ids=url_ids) + await self.get_raw_html_data_for_urls(tdos) + success_subset, error_subset = await self.separate_success_and_error_subsets(tdos) + await self.update_errors_in_database(error_subset) + await self.process_html_data(success_subset) + await self.update_html_data_in_database(success_subset) + + + async def get_just_urls(self, tdos: list[UrlHtmlTDO]): + return [task_info.url_info.url for task_info in tdos] + + async def get_pending_urls_without_html_data(self): + pending_urls: list[URLInfo] = await self.adb_client.get_pending_urls_without_html_data() + tdos = [ + UrlHtmlTDO( + url_info=url_info, + ) for url_info in pending_urls + ] + return tdos + + async def get_raw_html_data_for_urls(self, tdos: list[UrlHtmlTDO]): + just_urls = await self.get_just_urls(tdos) + url_response_infos = await self.url_request_interface.make_requests(just_urls) + for tdto, url_response_info in zip(tdos, url_response_infos): + tdto.url_response_info = url_response_info + + async def separate_success_and_error_subsets( + self, + tdos: list[UrlHtmlTDO] + ) -> tuple[ + list[UrlHtmlTDO], # Successful + list[UrlHtmlTDO] # Error + ]: + errored_tdos = [] + successful_tdos = [] + for tdto in tdos: + if not tdto.url_response_info.success: + errored_tdos.append(tdto) + else: + successful_tdos.append(tdto) + return successful_tdos, errored_tdos + + async def update_errors_in_database(self, error_tdos: list[UrlHtmlTDO]): + error_infos = [] + for error_tdo in error_tdos: + error_info = URLErrorPydanticInfo( + task_id=self.task_id, + url_id=error_tdo.url_info.id, + error=str(error_tdo.url_response_info.exception), + ) + error_infos.append(error_info) + await self.adb_client.add_url_error_infos(error_infos) + + async def process_html_data(self, tdos: list[UrlHtmlTDO]): + for tdto in tdos: + html_tag_info = await self.html_parser.parse( + url=tdto.url_info.url, + html_content=tdto.url_response_info.html, + content_type=tdto.url_response_info.content_type + ) + tdto.html_tag_info = html_tag_info + + async def update_html_data_in_database(self, tdos: list[UrlHtmlTDO]): + html_content_infos = [] + for tdto in tdos: + hcig = HTMLContentInfoGetter( + response_html_info=tdto.html_tag_info, + url_id=tdto.url_info.id + ) + results = hcig.get_all_html_content() + html_content_infos.extend(results) + + await self.adb_client.add_html_content_infos(html_content_infos) diff --git a/core/classes/URLRecordTypeTaskOperator.py b/core/classes/URLRecordTypeTaskOperator.py new file mode 100644 index 00000000..6287bcae --- /dev/null +++ b/core/classes/URLRecordTypeTaskOperator.py @@ -0,0 +1,85 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo +from collector_db.DTOs.URLMetadataInfo import URLMetadataInfo +from collector_db.enums import URLMetadataAttributeType, TaskType, ValidationStatus, ValidationSource +from core.DTOs.task_data_objects.URLRecordTypeTDO import URLRecordTypeTDO +from core.classes.TaskOperatorBase import TaskOperatorBase +from core.enums import RecordType +from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier + + +class URLRecordTypeTaskOperator(TaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + classifier: OpenAIRecordClassifier + ): + super().__init__(adb_client) + self.classifier = classifier + + @property + def task_type(self): + return TaskType.RECORD_TYPE + + async def meets_task_prerequisites(self): + return await self.adb_client.has_pending_urls_with_html_data_and_without_metadata_type( + without_metadata_type=URLMetadataAttributeType.RECORD_TYPE + ) + + async def get_tdos(self) -> list[URLRecordTypeTDO]: + urls_with_html = await self.adb_client.get_urls_with_html_data_and_without_metadata_type( + without_metadata_type=URLMetadataAttributeType.RECORD_TYPE + ) + tdos = [URLRecordTypeTDO(url_with_html=url_with_html) for url_with_html in urls_with_html] + return tdos + + async def inner_task_logic(self): + # Get pending urls from Source Collector + # with HTML data and without Record Type Metadata + tdos = await self.get_tdos() + url_ids = [tdo.url_with_html.url_id for tdo in tdos] + await self.link_urls_to_task(url_ids=url_ids) + + await self.get_ml_classifications(tdos) + success_subset, error_subset = await self.separate_success_and_error_subsets(tdos) + await self.put_results_into_database(success_subset) + await self.update_errors_in_database(error_subset) + + async def update_errors_in_database(self, tdos: list[URLRecordTypeTDO]): + error_infos = [] + for tdo in tdos: + error_info = URLErrorPydanticInfo( + task_id=self.task_id, + url_id=tdo.url_with_html.url_id, + error=tdo.error + ) + error_infos.append(error_info) + await self.adb_client.add_url_error_infos(error_infos) + + async def put_results_into_database(self, tdos: list[URLRecordTypeTDO]): + url_metadatas = [] + for tdo in tdos: + url_metadata = URLMetadataInfo( + url_id=tdo.url_with_html.url_id, + attribute=URLMetadataAttributeType.RECORD_TYPE, + value=str(tdo.record_type.value), + validation_status=ValidationStatus.PENDING_VALIDATION, + validation_source=ValidationSource.MACHINE_LEARNING, + notes=self.classifier.model_name + ) + url_metadatas.append(url_metadata) + await self.adb_client.add_url_metadatas(url_metadatas) + + async def separate_success_and_error_subsets(self, tdos: list[URLRecordTypeTDO]): + success_subset = [tdo for tdo in tdos if not tdo.is_errored()] + error_subset = [tdo for tdo in tdos if tdo.is_errored()] + return success_subset, error_subset + + async def get_ml_classifications(self, tdos: list[URLRecordTypeTDO]): + for tdo in tdos: + try: + record_type_str = await self.classifier.classify_url(tdo.url_with_html.html_infos) + tdo.record_type = RecordType(record_type_str) + except Exception as e: + tdo.error = str(e) \ No newline at end of file diff --git a/core/classes/URLRelevanceHuggingfaceCycler.py b/core/classes/URLRelevanceHuggingfaceCycler.py deleted file mode 100644 index 8ffdb705..00000000 --- a/core/classes/URLRelevanceHuggingfaceCycler.py +++ /dev/null @@ -1,56 +0,0 @@ -from collector_db.AsyncDatabaseClient import AsyncDatabaseClient -from collector_db.DTOs.URLMetadataInfo import URLMetadataInfo -from collector_db.DTOs.URLWithHTML import URLWithHTML -from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource -from core.DTOs.URLRelevanceHuggingfaceCycleInfo import URLRelevanceHuggingfaceCycleInfo -from hugging_face.HuggingFaceInterface import HuggingFaceInterface - - -class URLRelevanceHuggingfaceCycler: - - def __init__( - self, - adb_client: AsyncDatabaseClient, - huggingface_interface: HuggingFaceInterface - ): - self.adb_client = adb_client - self.huggingface_interface = huggingface_interface - - async def cycle(self): - # Get pending urls from Source Collector - # with HTML data and without Relevancy Metadata - cycle_infos = await self.get_pending_url_info() - # Pipe into Huggingface - await self.add_huggingface_relevancy(cycle_infos) - - # Put results into Database - await self.put_results_into_database(cycle_infos) - - async def put_results_into_database(self, cycle_infos): - url_metadatas = [] - for cycle_info in cycle_infos: - url_metadata = URLMetadataInfo( - url_id=cycle_info.url_with_html.url_id, - attribute=URLMetadataAttributeType.RELEVANT, - value=str(cycle_info.relevant), - validation_status=ValidationStatus.PENDING_VALIDATION, - validation_source=ValidationSource.MACHINE_LEARNING - ) - url_metadatas.append(url_metadata) - await self.adb_client.add_url_metadatas(url_metadatas) - - async def add_huggingface_relevancy(self, cycle_infos: list[URLRelevanceHuggingfaceCycleInfo]): - urls_with_html = [cycle_info.url_with_html for cycle_info in cycle_infos] - results = self.huggingface_interface.get_url_relevancy(urls_with_html) - for cycle_info, result in zip(cycle_infos, results): - cycle_info.relevant = result - - async def get_pending_url_info(self) -> list[URLRelevanceHuggingfaceCycleInfo]: - cycle_infos = [] - pending_urls: list[URLWithHTML] = await self.adb_client.get_urls_with_html_data_and_no_relevancy_metadata() - for url_with_html in pending_urls: - cycle_info = URLRelevanceHuggingfaceCycleInfo( - url_with_html=url_with_html - ) - cycle_infos.append(cycle_info) - return cycle_infos diff --git a/core/classes/URLRelevanceHuggingfaceTaskOperator.py b/core/classes/URLRelevanceHuggingfaceTaskOperator.py new file mode 100644 index 00000000..2d54a856 --- /dev/null +++ b/core/classes/URLRelevanceHuggingfaceTaskOperator.py @@ -0,0 +1,73 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.URLMetadataInfo import URLMetadataInfo +from collector_db.DTOs.URLWithHTML import URLWithHTML +from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource, TaskType +from core.DTOs.task_data_objects.URLRelevanceHuggingfaceTDO import URLRelevanceHuggingfaceTDO +from core.classes.TaskOperatorBase import TaskOperatorBase +from hugging_face.HuggingFaceInterface import HuggingFaceInterface + + +class URLRelevanceHuggingfaceTaskOperator(TaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + huggingface_interface: HuggingFaceInterface + ): + super().__init__(adb_client) + self.huggingface_interface = huggingface_interface + + @property + def task_type(self): + return TaskType.RELEVANCY + + async def meets_task_prerequisites(self): + return await self.adb_client.has_pending_urls_with_html_data_and_without_metadata_type() + + async def inner_task_logic(self): + # Get pending urls from Source Collector + # with HTML data and without Relevancy Metadata + tdos = await self.get_pending_url_info( + without_metadata_attribute=URLMetadataAttributeType.RELEVANT + ) + url_ids = [tdo.url_with_html.url_id for tdo in tdos] + await self.link_urls_to_task(url_ids=url_ids) + # Pipe into Huggingface + await self.add_huggingface_relevancy(tdos) + + # Put results into Database + await self.put_results_into_database(tdos) + + async def put_results_into_database(self, tdos): + url_metadatas = [] + for tdo in tdos: + url_metadata = URLMetadataInfo( + url_id=tdo.url_with_html.url_id, + attribute=URLMetadataAttributeType.RELEVANT, + value=str(tdo.relevant), + validation_status=ValidationStatus.PENDING_VALIDATION, + validation_source=ValidationSource.MACHINE_LEARNING + ) + url_metadatas.append(url_metadata) + await self.adb_client.add_url_metadatas(url_metadatas) + + async def add_huggingface_relevancy(self, tdos: list[URLRelevanceHuggingfaceTDO]): + urls_with_html = [tdo.url_with_html for tdo in tdos] + results = self.huggingface_interface.get_url_relevancy(urls_with_html) + for tdo, result in zip(tdos, results): + tdo.relevant = result + + async def get_pending_url_info( + self, + without_metadata_attribute: URLMetadataAttributeType + ) -> list[URLRelevanceHuggingfaceTDO]: + tdos = [] + pending_urls: list[URLWithHTML] = await self.adb_client.get_urls_with_html_data_and_without_metadata_type( + without_metadata_type=without_metadata_attribute + ) + for url_with_html in pending_urls: + tdo = URLRelevanceHuggingfaceTDO( + url_with_html=url_with_html + ) + tdos.append(tdo) + return tdos diff --git a/core/enums.py b/core/enums.py index 69505406..605e49e5 100644 --- a/core/enums.py +++ b/core/enums.py @@ -9,4 +9,42 @@ class BatchStatus(Enum): class LabelStudioTaskStatus(Enum): PENDING = "pending" - COMPLETED = "completed" \ No newline at end of file + COMPLETED = "completed" + +class RecordType(Enum): + ACCIDENT_REPORTS = "Accident Reports" + ARREST_RECORDS = "Arrest Records" + CALLS_FOR_SERVICE = "Calls for Service" + CAR_GPS = "Car GPS" + CITATIONS = "Citations" + DISPATCH_LOGS = "Dispatch Logs" + DISPATCH_RECORDINGS = "Dispatch Recordings" + FIELD_CONTACTS = "Field Contacts" + INCIDENT_REPORTS = "Incident Reports" + MISC_POLICE_ACTIVITY = "Misc Police Activity" + OFFICER_INVOLVED_SHOOTINGS = "Officer Involved Shootings" + STOPS = "Stops" + SURVEYS = "Surveys" + USE_OF_FORCE_REPORTS = "Use of Force Reports" + VEHICLE_PURSUITS = "Vehicle Pursuits" + COMPLAINTS_AND_MISCONDUCT = "Complaints & Misconduct" + DAILY_ACTIVITY_LOGS = "Daily Activity Logs" + TRAINING_AND_HIRING_INFO = "Training & Hiring Info" + PERSONNEL_RECORDS = "Personnel Records" + ANNUAL_AND_MONTHLY_REPORTS = "Annual & Monthly Reports" + BUDGETS_AND_FINANCES = "Budgets & Finances" + CONTACT_INFO_AND_AGENCY_META = "Contact Info & Agency Meta" + GEOGRAPHIC = "Geographic" + LIST_OF_DATA_SOURCES = "List of Data Sources" + POLICIES_AND_CONTRACTS = "Policies & Contracts" + CRIME_MAPS_AND_REPORTS = "Crime Maps & Reports" + CRIME_STATISTICS = "Crime Statistics" + MEDIA_BULLETINS = "Media Bulletins" + RECORDS_REQUEST_INFO = "Records Request Info" + RESOURCES = "Resources" + SEX_OFFENDER_REGISTRY = "Sex Offender Registry" + WANTED_PERSONS = "Wanted Persons" + BOOKING_REPORTS = "Booking Reports" + COURT_CASES = "Court Cases" + INCARCERATION_RECORDS = "Incarceration Records" + OTHER = "Other" diff --git a/html_tag_collector/RootURLCache.py b/html_tag_collector/RootURLCache.py index be670475..e306b6e1 100644 --- a/html_tag_collector/RootURLCache.py +++ b/html_tag_collector/RootURLCache.py @@ -26,7 +26,7 @@ async def save_to_cache(self, url: str, title: str): self.cache[url] = title await self.adb_client.add_to_root_url_cache(url=url, page_title=title) - async def get_from_cache(self, url: str): + async def get_from_cache(self, url: str) -> Optional[str]: if self.cache is None: self.cache = await self.adb_client.load_root_url_cache() diff --git a/html_tag_collector/url_cache.json b/html_tag_collector/url_cache.json deleted file mode 100644 index d4a340e1..00000000 --- a/html_tag_collector/url_cache.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "http://www.example.com": "Example Domain", - "http://www.google.com": "Google", - "https://books.toscrape.com": "\n All products | Books to Scrape - Sandbox\n" -} \ No newline at end of file diff --git a/html_tag_collector/urls.json b/html_tag_collector/urls.json deleted file mode 100644 index 79574f93..00000000 --- a/html_tag_collector/urls.json +++ /dev/null @@ -1,17 +0,0 @@ -[{ - "id": 1, - "url": "https://pdap.io", - "label": "Label" -}, { - "id": 2, - "url": "https://pdapio.io", - "label": "Label" -}, { - "id": 3, - "url": "https://pdap.dev", - "label": "Label" -}, { - "id": 4, - "url": "https://pdap.io/404test", - "label": "Label" -}] diff --git a/hugging_face/HuggingFaceInterface.py b/hugging_face/HuggingFaceInterface.py index efb54b75..4e37e9c4 100644 --- a/hugging_face/HuggingFaceInterface.py +++ b/hugging_face/HuggingFaceInterface.py @@ -6,7 +6,7 @@ class HuggingFaceInterface: def __init__(self): - self.pipe = pipeline("text-classification", model="PDAP/url-relevance") + self.relevance_pipe = pipeline("text-classification", model="PDAP/url-relevance") def get_url_relevancy( self, @@ -14,7 +14,7 @@ def get_url_relevancy( threshold: float = 0.5 ) -> list[bool]: urls = [url_with_html.url for url_with_html in urls_with_html] - results: list[dict] = self.pipe(urls) + results: list[dict] = self.relevance_pipe(urls) bool_results = [] for result in results: diff --git a/hugging_face/__init__.py b/hugging_face/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hugging_face/example/__init__.py b/hugging_face/example/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hugging_face/testing/__init__.py b/hugging_face/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hugging_face/url_record_type_labeling/__init__.py b/hugging_face/url_record_type_labeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hugging_face/url_relevance/__init__.py b/hugging_face/url_relevance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_api_logic/DeepSeekRecordClassifier.py b/llm_api_logic/DeepSeekRecordClassifier.py new file mode 100644 index 00000000..67f6fa09 --- /dev/null +++ b/llm_api_logic/DeepSeekRecordClassifier.py @@ -0,0 +1,33 @@ +import json +import os + +from openai import AsyncOpenAI + +from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo +from core.enums import RecordType +from llm_api_logic.LLMRecordClassifierBase import RecordClassifierBase + +class DeepSeekRecordClassifier(RecordClassifierBase): + + + @property + def api_key(self): + return os.getenv("DEEPSEEK_API_KEY") + + @property + def model_name(self): + return "deepseek-chat" + + @property + def base_url(self): + return "https://api.deepseek.com" + + @property + def response_format(self): + return { + 'type': 'json_object' + } + + @property + def completions_func(self) -> callable: + return AsyncOpenAI.chat.completions.create \ No newline at end of file diff --git a/llm_api_logic/LLMRecordClassifierBase.py b/llm_api_logic/LLMRecordClassifierBase.py new file mode 100644 index 00000000..85142aea --- /dev/null +++ b/llm_api_logic/LLMRecordClassifierBase.py @@ -0,0 +1,76 @@ +import json +from abc import ABC, abstractmethod +from typing import Any + +from openai import AsyncOpenAI + +from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo +from llm_api_logic.RecordTypeStructuredOutput import RecordTypeStructuredOutput +from llm_api_logic.constants import RECORD_CLASSIFICATION_QUERY_CONTENT +from llm_api_logic.helpers import dictify_html_info + + +class RecordClassifierBase(ABC): + + def __init__(self): + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + @property + @abstractmethod + def api_key(self) -> str: + pass + + @property + @abstractmethod + def model_name(self) -> str: + pass + + @property + @abstractmethod + def base_url(self) -> str: + pass + + @property + @abstractmethod + def response_format(self) -> dict | RecordTypeStructuredOutput: + pass + + @property + @abstractmethod + def completions_func(self) -> callable: + pass + + def build_query_messages(self, content_infos: list[URLHTMLContentInfo]) -> list[dict[str, str]]: + insert_content = dictify_html_info(content_infos) + return [ + { + "role": "system", + "content": RECORD_CLASSIFICATION_QUERY_CONTENT + }, + { + "role": "user", + "content": str(insert_content) + } + ] + + @abstractmethod + def post_process_response(self, response: Any) -> str: + pass + + async def classify_url(self, content_infos: list[URLHTMLContentInfo]) -> str: + func = self.completions_func + response = await func( + model=self.model_name, + messages=self.build_query_messages(content_infos), + #stream=False, # Note that this is set for DeepSeek, but may not be needed for it + response_format=self.response_format + ) + return self.post_process_response(response) + + result_str = response.choices[0].message.content + + result_dict = json.loads(result_str) + return result_dict["record_type"] \ No newline at end of file diff --git a/llm_api_logic/OpenAIRecordClassifier.py b/llm_api_logic/OpenAIRecordClassifier.py new file mode 100644 index 00000000..fc20a0e2 --- /dev/null +++ b/llm_api_logic/OpenAIRecordClassifier.py @@ -0,0 +1,34 @@ +from typing import Any + +from openai.types.chat import ParsedChatCompletion + +from llm_api_logic.LLMRecordClassifierBase import RecordClassifierBase +from llm_api_logic.RecordTypeStructuredOutput import RecordTypeStructuredOutput +from util.helper_functions import get_from_env + + +class OpenAIRecordClassifier(RecordClassifierBase): + + @property + def api_key(self): + return get_from_env("OPENAI_API_KEY") + + @property + def model_name(self): + return "gpt-4o-mini-2024-07-18" + + @property + def base_url(self): + return None + + @property + def response_format(self): + return RecordTypeStructuredOutput + + @property + def completions_func(self) -> callable: + return self.client.beta.chat.completions.parse + + def post_process_response(self, response: ParsedChatCompletion) -> str: + output: RecordTypeStructuredOutput = response.choices[0].message.parsed + return output.record_type.value \ No newline at end of file diff --git a/llm_api_logic/RecordTypeStructuredOutput.py b/llm_api_logic/RecordTypeStructuredOutput.py new file mode 100644 index 00000000..a5993ae9 --- /dev/null +++ b/llm_api_logic/RecordTypeStructuredOutput.py @@ -0,0 +1,13 @@ +""" +Used per the guidance given in Open AI's documentation on structured outputs: +https://platform.openai.com/docs/guides/structured-outputs +""" + +from pydantic import BaseModel + +from core.enums import RecordType + + + +class RecordTypeStructuredOutput(BaseModel): + record_type: RecordType \ No newline at end of file diff --git a/llm_api_logic/__init__.py b/llm_api_logic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_api_logic/constants.py b/llm_api_logic/constants.py new file mode 100644 index 00000000..55133abf --- /dev/null +++ b/llm_api_logic/constants.py @@ -0,0 +1,48 @@ +RECORD_CLASSIFICATION_QUERY_CONTENT = """ + You will be provided with structured data from a web page and determine + the record type. + + The record types are as follows + + "Accident Reports": Records of vehicle accidents. + "Arrest Records": Records of each arrest made in the agency's jurisdiction. + "Calls for Service": Records of officers initiating activity or responding to requests for police response. Often called "Dispatch Logs" or "Incident Reports" when published. + "Car GPS": Records of police car location. Not generally posted online. + "Citations": Records of low-level criminal offenses where a police officer issued a citation instead of an arrest. + "Dispatch Logs": Records of calls or orders made by police dispatchers. + "Dispatch Recordings": Audio feeds and/or archives of municipal dispatch channels. + "Field Contacts": Reports of contact between police and civilians. May include uses of force, incidents, arrests, or contacts where nothing notable happened. + "Incident Reports": Reports made by police officers after responding to a call which may or may not be criminal in nature. Not generally posted online. + "Misc Police Activity": Records or descriptions of police activity not covered by other record types. + "Officer Involved Shootings": Case files of gun violence where a police officer was involved, typically as the shooter. Detailed, often containing references to records like Media Bulletins and Use of Force Reports. + "Stops": Records of pedestrian or traffic stops made by police. + "Surveys": Information captured from a sample of some population, like incarcerated people or magistrate judges. Often generated independently. + "Use of Force Reports": Records of use of force against civilians by police officers. + "Vehicle Pursuits": Records of cases where police pursued a person fleeing in a vehicle. + "Complaints & Misconduct": Records, statistics, or summaries of complaints and misconduct investigations into law enforcement officers. + "Daily Activity Logs": Officer-created reports or time sheets of what happened on a shift. Not generally posted online. + "Training & Hiring Info": Records and descriptions of additional training for police officers. + "Personnel Records": Records of hiring and firing, certification, discipline, and other officer-specific events. Not generally posted online. + "Annual & Monthly Reports": Often in PDF form, featuring summaries or high-level updates about the police force. Can contain versions of other record types, especially summaries. + "Budgets & Finances": Budgets, finances, grants, or other financial documents. + "Contact Info & Agency Meta": Information about organizational structure, including department structure and contact info. + "Geographic": Maps or geographic data about how land is divided up into municipal sectors, zones, and jurisdictions. + "List of Data Sources": Places on the internet, often data portal homepages, where many links to potential data sources can be found. + "Policies & Contracts": Policies or contracts related to agency procedure. + "Crime Maps & Reports": Records of individual crimes in map or table form for a given jurisdiction. + "Crime Statistics": Summarized information about crime in a given jurisdiction. + "Media Bulletins": Press releases, blotters, or blogs intended to broadly communicate alerts, requests, or other timely information. + "Records Request Info": Portals, forms, policies, or other resources for making public records requests. + "Resources": Agency-provided information or guidance about services, prices, best practices, etc. + "Sex Offender Registry": Index of people registered, usually by law, with the government as sex offenders. + "Wanted Persons": Names, descriptions, images, and associated information about people with outstanding arrest warrants. + "Booking Reports": Records of booking or intake into corrections institutions. + "Court Cases": Records such as dockets about individual court cases. + "Incarceration Records": Records of current inmates, often with full names and features for notification upon inmate release. + "Other": Other record types not otherwise described. + + Output the record type in the following JSON format: + { + "record_type": "" + } + """ diff --git a/llm_api_logic/helpers.py b/llm_api_logic/helpers.py new file mode 100644 index 00000000..3d5bde11 --- /dev/null +++ b/llm_api_logic/helpers.py @@ -0,0 +1,8 @@ +from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo + + +def dictify_html_info(html_infos: list[URLHTMLContentInfo]) -> dict[str, str]: + d = {} + for html_info in html_infos: + d[html_info.content_type.value] = html_info.content + return d diff --git a/local_database/DataDumper/dump.sh b/local_database/DataDumper/dump.sh index 6f1954c4..fd63c65f 100644 --- a/local_database/DataDumper/dump.sh +++ b/local_database/DataDumper/dump.sh @@ -1,6 +1,5 @@ #!/bin/bash set -e - # Variables (customize these or pass them as environment variables) DB_HOST=${DUMP_HOST:-"postgres_container"} DB_USER=${DUMP_USER:-"your_user"} @@ -8,12 +7,9 @@ DB_PORT=${DUMP_PORT:-"5432"} # Default to 5432 if not provided DB_PASSWORD=${DUMP_PASSWORD:-"your_password"} DB_NAME=${DUMP_NAME:-"your_database"} DUMP_FILE=${DUMP_FILE:-"/dump/db_dump.sql"} - # Export password for pg_dump export PGPASSWORD=$DB_PASSWORD - # Dump the database echo "Dumping database $DB_NAME from $DB_HOST:$DB_PORT..." pg_dump -h $DB_HOST -p $DB_PORT -U $DB_USER -d $DB_NAME --no-owner --no-acl -F c -f $DUMP_FILE - echo "Dump completed. File saved to $DUMP_FILE." \ No newline at end of file diff --git a/local_database/DataDumper/restore.sh b/local_database/DataDumper/restore.sh index d2046fb0..ff62349e 100644 --- a/local_database/DataDumper/restore.sh +++ b/local_database/DataDumper/restore.sh @@ -1,6 +1,5 @@ #!/bin/bash set -e - # Variables (customize these or pass them as environment variables) DB_HOST=${RESTORE_HOST:-"postgres_container"} DB_USER=${RESTORE_USER:-"your_user"} @@ -8,15 +7,11 @@ DB_PORT=${RESTORE_PORT:-"5432"} # Default to 5432 if not provided DB_PASSWORD=${RESTORE_PASSWORD:-"your_password"} NEW_DB_NAME=${RESTORE_DB_NAME:-"new_database"} # Name of the database to restore into DUMP_FILE=${DUMP_FILE:-"/dump/db_dump.sql"} - MAINTENANCE_DB="postgres" - # Export password for pg_restore export PGPASSWORD=$DB_PASSWORD - CONNECTION_STRING="postgresql://$DB_USER:$DB_PASSWORD@$DB_HOST:$DB_PORT/$NEW_DB_NAME" MAINT_CONNECTION_STRING="postgresql://$DB_USER:$DB_PASSWORD@$DB_HOST:$DB_PORT/$MAINTENANCE_DB" - echo "Checking if database $NEW_DB_NAME exists on $DB_HOST:$DB_PORT..." psql -d $MAINT_CONNECTION_STRING -tc "SELECT 1 FROM pg_database WHERE datname = '$NEW_DB_NAME';" | grep -q 1 && { echo "Database $NEW_DB_NAME exists. Dropping it..." @@ -25,16 +20,13 @@ psql -d $MAINT_CONNECTION_STRING -tc "SELECT 1 FROM pg_database WHERE datname = # Drop the database psql -d $MAINT_CONNECTION_STRING -c "DROP DATABASE $NEW_DB_NAME;" } - # Create the new database echo "Creating new database $NEW_DB_NAME on $DB_HOST:$DB_PORT..." psql -d $MAINT_CONNECTION_STRING -c "CREATE DATABASE $NEW_DB_NAME;" || { echo "Failed to create database $NEW_DB_NAME. It might already exist." exit 1 } - # Restore the dump into the new database echo "Restoring dump from $DUMP_FILE into database $NEW_DB_NAME..." pg_restore -d $CONNECTION_STRING --no-owner --no-acl -F c $DUMP_FILE - -echo "Database restoration completed." +echo "Database restoration completed." \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c05cfbfa..2cc28614 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,4 +45,4 @@ PyJWT~=2.10.1 # Tests pytest-timeout~=2.3.1 - +openai~=1.60.1 diff --git a/tests/helpers/DBDataCreator.py b/tests/helpers/DBDataCreator.py index c7fce247..0041fad5 100644 --- a/tests/helpers/DBDataCreator.py +++ b/tests/helpers/DBDataCreator.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.BatchInfo import BatchInfo @@ -9,7 +9,7 @@ from collector_db.DTOs.URLInfo import URLInfo from collector_db.DTOs.URLMetadataInfo import URLMetadataInfo from collector_db.DatabaseClient import DatabaseClient -from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource +from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource, TaskType from collector_manager.enums import CollectorType from core.enums import BatchStatus from tests.helpers.simple_test_data_functions import generate_test_urls @@ -21,7 +21,7 @@ class DBDataCreator: """ def __init__(self, db_client: DatabaseClient = DatabaseClient()): self.db_client = db_client - self.adb_client = AsyncDatabaseClient() + self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() def batch(self): return self.db_client.insert_batch( @@ -34,6 +34,12 @@ def batch(self): ) ) + async def task(self, url_ids: Optional[list[int]] = None) -> int: + task_id = await self.adb_client.initiate_task(task_type=TaskType.HTML) + if url_ids is not None: + await self.adb_client.link_urls_to_task(task_id=task_id, url_ids=url_ids) + return task_id + def urls(self, batch_id: int, url_count: int) -> InsertURLsInfo: raw_urls = generate_test_urls(url_count) url_infos: List[URLInfo] = [] @@ -99,12 +105,19 @@ async def metadata( ) ) - async def error_info(self, url_ids: list[int]): + async def error_info( + self, + url_ids: list[int], + task_id: Optional[int] = None + ): + if task_id is None: + task_id = await self.task() error_infos = [] for url_id in url_ids: url_error_info = URLErrorPydanticInfo( url_id=url_id, error="test error", + task_id=task_id ) error_infos.append(url_error_info) await self.adb_client.add_url_error_infos(error_infos) diff --git a/tests/helpers/assert_functions.py b/tests/helpers/assert_functions.py new file mode 100644 index 00000000..ef379d3e --- /dev/null +++ b/tests/helpers/assert_functions.py @@ -0,0 +1,7 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.models import Task + + +async def assert_database_has_no_tasks(adb_client: AsyncDatabaseClient): + tasks = await adb_client.get_all(Task) + assert len(tasks) == 0 \ No newline at end of file diff --git a/tests/manual/core/lifecycle/test_auto_googler_lifecycle.py b/tests/manual/core/lifecycle/test_auto_googler_lifecycle.py index 2489d17f..c962e1e7 100644 --- a/tests/manual/core/lifecycle/test_auto_googler_lifecycle.py +++ b/tests/manual/core/lifecycle/test_auto_googler_lifecycle.py @@ -1,12 +1,12 @@ import os import dotenv -from tests.automated.core.helpers.common_test_procedures import run_collector_and_wait_for_completion import api.dependencies from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.enums import CollectorType from core.enums import BatchStatus +from test_automated.integration.core.helpers.common_test_procedures import run_collector_and_wait_for_completion def test_auto_googler_collector_lifecycle(test_core): diff --git a/tests/manual/core/lifecycle/test_ckan_lifecycle.py b/tests/manual/core/lifecycle/test_ckan_lifecycle.py index 10802c77..575dedfa 100644 --- a/tests/manual/core/lifecycle/test_ckan_lifecycle.py +++ b/tests/manual/core/lifecycle/test_ckan_lifecycle.py @@ -1,10 +1,10 @@ -from tests.automated.core.helpers.common_test_procedures import run_collector_and_wait_for_completion import api.dependencies from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.enums import CollectorType from core.enums import BatchStatus from source_collectors.ckan.search_terms import group_search, package_search, organization_search +from test_automated.integration.core.helpers.common_test_procedures import run_collector_and_wait_for_completion def test_ckan_lifecycle(test_core): diff --git a/tests/manual/core/lifecycle/test_muckrock_lifecycles.py b/tests/manual/core/lifecycle/test_muckrock_lifecycles.py index d92fa0be..b688b0a8 100644 --- a/tests/manual/core/lifecycle/test_muckrock_lifecycles.py +++ b/tests/manual/core/lifecycle/test_muckrock_lifecycles.py @@ -1,10 +1,10 @@ -from tests.automated.core.helpers.common_test_procedures import run_collector_and_wait_for_completion -from tests.automated.core.helpers.constants import ALLEGHENY_COUNTY_TOWN_NAMES, ALLEGHENY_COUNTY_MUCKROCK_ID import api.dependencies from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.enums import CollectorType from core.enums import BatchStatus +from test_automated.integration.core.helpers.common_test_procedures import run_collector_and_wait_for_completion +from test_automated.integration.core.helpers.constants import ALLEGHENY_COUNTY_MUCKROCK_ID, ALLEGHENY_COUNTY_TOWN_NAMES def test_muckrock_simple_search_collector_lifecycle(test_core): diff --git a/tests/manual/html_collector/test_html_tag_collector_integration.py b/tests/manual/html_collector/test_html_tag_collector_integration.py index cb803e96..1673ca42 100644 --- a/tests/manual/html_collector/test_html_tag_collector_integration.py +++ b/tests/manual/html_collector/test_html_tag_collector_integration.py @@ -3,7 +3,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.URLInfo import URLInfo -from core.classes.URLHTMLCycler import URLHTMLCycler +from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator from helpers.DBDataCreator import DBDataCreator from html_tag_collector.ResponseParser import HTMLResponseParser from html_tag_collector.RootURLCache import RootURLCache @@ -43,14 +43,14 @@ async def test_url_html_cycle_live_data( """ Tests the cycle on whatever exists in the DB """ - cycler = URLHTMLCycler( + operator = URLHTMLTaskOperator( adb_client=AsyncDatabaseClient(), url_request_interface=URLRequestInterface(), html_parser=HTMLResponseParser( root_url_cache=RootURLCache() ) ) - await cycler.cycle() + await operator.run_task() @pytest.mark.asyncio async def test_url_html_cycle( @@ -64,11 +64,11 @@ async def test_url_html_cycle( db_client.insert_urls(url_infos=url_infos, batch_id=batch_id) - cycler = URLHTMLCycler( + operator = URLHTMLTaskOperator( adb_client=AsyncDatabaseClient(), url_request_interface=URLRequestInterface(), html_parser=HTMLResponseParser( root_url_cache=RootURLCache() ) ) - await cycler.cycle() \ No newline at end of file + await operator.run_task() \ No newline at end of file diff --git a/tests/manual/llm_api_logic/__init__.py b/tests/manual/llm_api_logic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/manual/llm_api_logic/test_deepseek_record_classifier.py b/tests/manual/llm_api_logic/test_deepseek_record_classifier.py new file mode 100644 index 00000000..b0a6c1fb --- /dev/null +++ b/tests/manual/llm_api_logic/test_deepseek_record_classifier.py @@ -0,0 +1,25 @@ +import pytest + +from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo +from llm_api_logic.DeepSeekRecordClassifier import DeepSeekRecordClassifier + + +@pytest.mark.asyncio +async def test_deepseek_record_classifier(): + from collector_db.DTOs.URLHTMLContentInfo import HTMLContentType as hct + + d = { + hct.TITLE: "Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police – City of Acworth, GA", + hct.DESCRIPTION: "At the Thursday, November 2 regular city council meeting, Chief Evans administered the oath of office and swearing in of Corporal Cody Lumpkin. Corporal Lumpkin was surrounded by his family and members of the Acworth Police Department for the occasion. Corporal Lumpkin began employment with the Acworth Police Department on June 8,", + hct.H3: ["Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police"], + hct.H4: ["Share this on Social Media"], + hct.DIV: "PHONE DIRECTORY RESOURCES Search for: Search Button NEWS DEPARTMENTS GOVERNANCE & DEVELOPMENT Administration Development Clerks Office Court Services DDA, Tourism, and Historic Preservation OPERATIONS Parks, Recreation, and Community Resources Power, Public Works, and Stormwater SUPPORT SERVICES Customer Service Human Resources Finance Information Technology PUBLIC SAFETY Acworth Police RESIDENTS Public Art Master Plan Application for Boards & Commissions Board of Aldermen Customer Service Parks, Recreation, and Community Resources Historic Acworth Master Fee Schedule E-News Sign Up Online Payments BUSINESS Bids & Projects E-Verify Permits, Applications, & Ordinances City Code of Ordinances Master Fee Schedule Start a Business EVENTS VISIT ACWORTH NEWS DEPARTMENTS GOVERNANCE & DEVELOPMENT Administration Development Clerks Office Court Services DDA, Tourism, and Historic Preservation OPERATIONS Parks, Recreation, and Community Resources Power, Public Works, and Stormwater SUPPORT SERVICES Customer Service Human Resources Finance Information Technology PUBLIC SAFETY Acworth Police RESIDENTS Public Art Master Plan Application for Boards & Commissions Board of Aldermen Customer Service Parks, Recreation, and Community Resources Historic Acworth Master Fee Schedule E-News Sign Up Online Payments BUSINESS Bids & Projects E-Verify Permits, Applications, & Ordinances City Code of Ordinances Master Fee Schedule Start a Business EVENTS VISIT ACWORTH Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police Published On: November 3, 2023 At the Thursday, November 2 regular city council meeting, Chief Evans administered the oath of office and swearing in of Corporal Cody Lumpkin.  Corporal Lumpkin was surrounded by his family and members of the Acworth Police Department for the occasion.  Corporal Lumpkin began employment with the Acworth Police Department on June 8 , 2015, and has served as a patrol officer in addition to time spent time in Special Operations prior to his recent promotion. Share this on Social Media 4415 Center Street, Acworth GA 30101 Phone Directory Contact Us © 2025 City of Acworth Acworth is located in the foothills of the North Georgia mountains and is nestled along the banks of Lake Acworth and Lake Allatoona, hence its nickname “The Lake City.” The city boasts a rich history, a charming downtown, abundant outdoor recreational activities, a vibrant restaurant scene, and an active festival and events calendar. Acworth is one of the best, family-friendly destinations in the Atlanta region. Come discover why You’re Welcome in Acworth! ESS | Webmail | Handbook | Peak | Laserfiche | Login ", + } + content_infos = [] + for content_type, value in d.items(): + content_info = URLHTMLContentInfo(content_type=content_type, content=value) + content_infos.append(content_info) + + classifier = DeepSeekRecordClassifier() + result = await classifier.classify_url(content_infos) + print(result) \ No newline at end of file diff --git a/tests/manual/llm_api_logic/test_openai_record_classifier.py b/tests/manual/llm_api_logic/test_openai_record_classifier.py new file mode 100644 index 00000000..72d474d2 --- /dev/null +++ b/tests/manual/llm_api_logic/test_openai_record_classifier.py @@ -0,0 +1,26 @@ +import pytest + +from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo +from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier + + +@pytest.mark.asyncio +async def test_openai_record_classifier(): + from collector_db.DTOs.URLHTMLContentInfo import HTMLContentType as hct + + d = { + hct.TITLE: "Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police – City of Acworth, GA", + hct.DESCRIPTION: "At the Thursday, November 2 regular city council meeting, Chief Evans administered the oath of office and swearing in of Corporal Cody Lumpkin. Corporal Lumpkin was surrounded by his family and members of the Acworth Police Department for the occasion. Corporal Lumpkin began employment with the Acworth Police Department on June 8,", + hct.H3: ["Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police"], + hct.H4: ["Share this on Social Media"], + hct.DIV: "PHONE DIRECTORY RESOURCES Search for: Search Button NEWS DEPARTMENTS GOVERNANCE & DEVELOPMENT Administration Development Clerks Office Court Services DDA, Tourism, and Historic Preservation OPERATIONS Parks, Recreation, and Community Resources Power, Public Works, and Stormwater SUPPORT SERVICES Customer Service Human Resources Finance Information Technology PUBLIC SAFETY Acworth Police RESIDENTS Public Art Master Plan Application for Boards & Commissions Board of Aldermen Customer Service Parks, Recreation, and Community Resources Historic Acworth Master Fee Schedule E-News Sign Up Online Payments BUSINESS Bids & Projects E-Verify Permits, Applications, & Ordinances City Code of Ordinances Master Fee Schedule Start a Business EVENTS VISIT ACWORTH NEWS DEPARTMENTS GOVERNANCE & DEVELOPMENT Administration Development Clerks Office Court Services DDA, Tourism, and Historic Preservation OPERATIONS Parks, Recreation, and Community Resources Power, Public Works, and Stormwater SUPPORT SERVICES Customer Service Human Resources Finance Information Technology PUBLIC SAFETY Acworth Police RESIDENTS Public Art Master Plan Application for Boards & Commissions Board of Aldermen Customer Service Parks, Recreation, and Community Resources Historic Acworth Master Fee Schedule E-News Sign Up Online Payments BUSINESS Bids & Projects E-Verify Permits, Applications, & Ordinances City Code of Ordinances Master Fee Schedule Start a Business EVENTS VISIT ACWORTH Oath of Office for Newly Promoted Corporal Lumpkin with Acworth Police Published On: November 3, 2023 At the Thursday, November 2 regular city council meeting, Chief Evans administered the oath of office and swearing in of Corporal Cody Lumpkin.  Corporal Lumpkin was surrounded by his family and members of the Acworth Police Department for the occasion.  Corporal Lumpkin began employment with the Acworth Police Department on June 8 , 2015, and has served as a patrol officer in addition to time spent time in Special Operations prior to his recent promotion. Share this on Social Media 4415 Center Street, Acworth GA 30101 Phone Directory Contact Us © 2025 City of Acworth Acworth is located in the foothills of the North Georgia mountains and is nestled along the banks of Lake Acworth and Lake Allatoona, hence its nickname “The Lake City.” The city boasts a rich history, a charming downtown, abundant outdoor recreational activities, a vibrant restaurant scene, and an active festival and events calendar. Acworth is one of the best, family-friendly destinations in the Atlanta region. Come discover why You’re Welcome in Acworth! ESS | Webmail | Handbook | Peak | Laserfiche | Login ", + } + content_infos = [] + for content_type, value in d.items(): + content_info = URLHTMLContentInfo(content_type=content_type, content=value) + content_infos.append(content_info) + + classifier = OpenAIRecordClassifier() + result = await classifier.classify_url(content_infos) + print(type(result)) + print(result) \ No newline at end of file diff --git a/tests/manual/source_collectors/test_muckrock_collectors.py b/tests/manual/source_collectors/test_muckrock_collectors.py index 00e1d57e..4689dbab 100644 --- a/tests/manual/source_collectors/test_muckrock_collectors.py +++ b/tests/manual/source_collectors/test_muckrock_collectors.py @@ -7,7 +7,7 @@ from source_collectors.muckrock.classes.MuckrockCollector import MuckrockSimpleSearchCollector, \ MuckrockCountyLevelSearchCollector, MuckrockAllFOIARequestsCollector from source_collectors.muckrock.schemas import MuckrockURLInfoSchema -from test_automated.integration.core.helpers import ALLEGHENY_COUNTY_MUCKROCK_ID, ALLEGHENY_COUNTY_TOWN_NAMES +from test_automated.integration.core.helpers.constants import ALLEGHENY_COUNTY_MUCKROCK_ID, ALLEGHENY_COUNTY_TOWN_NAMES def test_muckrock_simple_search_collector(): diff --git a/tests/test_alembic/helpers.py b/tests/test_alembic/helpers.py index d66854f2..d6b2bea4 100644 --- a/tests/test_alembic/helpers.py +++ b/tests/test_alembic/helpers.py @@ -1,7 +1,9 @@ +from typing import Optional + from sqlalchemy import text from sqlalchemy.orm import Session -from tests.test_alembic.AlembicRunner import AlembicRunner +from tests.test_alembic.AlembicRunner import AlembicRunner def get_enum_values(enum_name: str, session: Session) -> list[str]: @@ -9,12 +11,24 @@ def get_enum_values(enum_name: str, session: Session) -> list[str]: def table_creation_check( alembic_runner: AlembicRunner, - table_name: str, - start_revision: str, - end_revision: str + tables: list[str], + end_revision: str, + start_revision: Optional[str] = None, + ): - alembic_runner.upgrade(start_revision) - assert table_name not in alembic_runner.inspector.get_table_names() + if start_revision is not None: + alembic_runner.upgrade(start_revision) + for table_name in tables: + assert table_name not in alembic_runner.inspector.get_table_names() alembic_runner.upgrade(end_revision) alembic_runner.reflect() - assert table_name in alembic_runner.inspector.get_table_names() \ No newline at end of file + for table_name in tables: + assert table_name in alembic_runner.inspector.get_table_names() + +def columns_in_table( + alembic_runner: AlembicRunner, + table_name: str, + columns_to_check: list[str], +) -> bool: + current_columns = [col["name"] for col in alembic_runner.inspector.get_columns(table_name)] + return all(column in current_columns for column in columns_to_check) diff --git a/tests/test_alembic/test_revisions.py b/tests/test_alembic/test_revisions.py index 75df5f0c..22a83496 100644 --- a/tests/test_alembic/test_revisions.py +++ b/tests/test_alembic/test_revisions.py @@ -15,6 +15,7 @@ from sqlalchemy import text +from tests.test_alembic.helpers import columns_in_table from tests.test_alembic.helpers import get_enum_values, table_creation_check @@ -298,7 +299,39 @@ def test_add_in_label_studio_metadata_status(alembic_runner): def test_create_metadata_annotation_table(alembic_runner): table_creation_check( alembic_runner, - "metadata_annotations", + ["metadata_annotations"], start_revision="108dac321086", end_revision="dcd158092de0" + ) + +def test_add_task_tables_and_linking_logic(alembic_runner): + alembic_runner.upgrade("dcd158092de0") + assert not columns_in_table( + alembic_runner, + table_name="url_error_info", + columns_to_check=["task_id"], + ) + assert not columns_in_table( + alembic_runner, + table_name="url_metadata", + columns_to_check=["notes"], + ) + table_creation_check( + alembic_runner, + tables=[ + "tasks", + "task_errors", + "link_task_urls" + ], + end_revision="072b32a45b1c" + ) + assert columns_in_table( + alembic_runner, + table_name="url_error_info", + columns_to_check=["task_id"], + ) + assert columns_in_table( + alembic_runner, + table_name="url_metadata", + columns_to_check=["notes"], ) \ No newline at end of file diff --git a/tests/test_automated/integration/api/helpers/RequestValidator.py b/tests/test_automated/integration/api/helpers/RequestValidator.py index 7a0e9a6a..220b6645 100644 --- a/tests/test_automated/integration/api/helpers/RequestValidator.py +++ b/tests/test_automated/integration/api/helpers/RequestValidator.py @@ -5,15 +5,17 @@ from starlette.testclient import TestClient from collector_db.DTOs.BatchInfo import BatchInfo +from collector_db.DTOs.TaskInfo import TaskInfo +from collector_db.enums import TaskType from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO from collector_manager.enums import CollectorType from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.DTOs.GetNextURLForRelevanceAnnotationResponse import GetNextURLForRelevanceAnnotationResponse +from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsByBatchResponse import GetURLsByBatchResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo -from core.DTOs.LabelStudioExportResponseInfo import LabelStudioExportResponseInfo from core.DTOs.MessageCountResponse import MessageCountResponse from core.DTOs.MessageResponse import MessageResponse from core.DTOs.RelevanceAnnotationInfo import RelevanceAnnotationPostInfo @@ -160,12 +162,6 @@ def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: ) return GetBatchLogsResponse(**data) - def export_batch_to_label_studio(self, batch_id: int) -> LabelStudioExportResponseInfo: - data = self.post( - url=f"/label-studio/export-batch/{batch_id}" - ) - return LabelStudioExportResponseInfo(**data) - def abort_batch(self, batch_id: int) -> MessageResponse: data = self.post( url=f"/batch/{batch_id}/abort" @@ -201,4 +197,30 @@ def get_urls(self, page: int = 1, errors: bool = False) -> GetURLsResponseInfo: url=f"/url", params={"page": page, "errors": errors} ) - return GetURLsResponseInfo(**data) \ No newline at end of file + return GetURLsResponseInfo(**data) + + def get_task_info(self, task_id: int) -> TaskInfo: + data = self.get( + url=f"/task/{task_id}" + ) + return TaskInfo(**data) + + def get_tasks( + self, + page: int = 1, + task_type: Optional[TaskType] = None, + task_status: Optional[BatchStatus] = None + ) -> GetTasksResponse: + params = {"page": page} + update_if_not_none( + target=params, + source={ + "task_type": task_type.value if task_type else None, + "task_status": task_status.value if task_status else None + } + ) + data = self.get( + url=f"/task", + params=params + ) + return GetTasksResponse(**data) \ No newline at end of file diff --git a/tests/test_automated/integration/api/test_task.py b/tests/test_automated/integration/api/test_task.py new file mode 100644 index 00000000..d6e13b1f --- /dev/null +++ b/tests/test_automated/integration/api/test_task.py @@ -0,0 +1,41 @@ +import pytest + +from collector_db.enums import TaskType +from tests.test_automated.integration.api.conftest import APITestHelper + + +async def task_setup(ath: APITestHelper) -> int: + iui = ath.db_data_creator.urls(batch_id=ath.db_data_creator.batch(), url_count=3) + url_ids = [url.url_id for url in iui.url_mappings] + + task_id = await ath.db_data_creator.task(url_ids=url_ids) + await ath.db_data_creator.error_info(url_ids=[url_ids[0]], task_id=task_id) + + return task_id + +@pytest.mark.asyncio +async def test_get_task_info(api_test_helper): + ath = api_test_helper + + task_id = await task_setup(ath) + + task_info = ath.request_validator.get_task_info(task_id=task_id) + + assert len(task_info.urls) == 3 + assert len(task_info.url_errors) == 1 + + assert task_info.task_type == TaskType.HTML + +@pytest.mark.asyncio +async def test_get_tasks(api_test_helper): + ath = api_test_helper + for i in range(2): + await task_setup(ath) + + response = ath.request_validator.get_tasks(page=1, task_type=None, task_status=None) + + assert len(response.tasks) == 2 + for task in response.tasks: + assert task.type == TaskType.HTML + assert task.url_count == 3 + assert task.url_error_count == 1 diff --git a/tests/test_automated/integration/collector_db/test_database_structure.py b/tests/test_automated/integration/collector_db/test_database_structure.py index 926a6ed8..272f3de2 100644 --- a/tests/test_automated/integration/collector_db/test_database_structure.py +++ b/tests/test_automated/integration/collector_db/test_database_structure.py @@ -325,4 +325,4 @@ def test_root_url(db_data_creator: DBDataCreator): engine=db_data_creator.db_client.engine ) - table_tester.run_column_tests() \ No newline at end of file + table_tester.run_column_tests() diff --git a/tests/test_automated/integration/collector_db/test_db_client.py b/tests/test_automated/integration/collector_db/test_db_client.py index feadf57f..fa3b7110 100644 --- a/tests/test_automated/integration/collector_db/test_db_client.py +++ b/tests/test_automated/integration/collector_db/test_db_client.py @@ -136,12 +136,14 @@ async def test_add_url_error_info(db_data_creator: DBDataCreator): url_ids = [url_mapping.url_id for url_mapping in url_mappings] adb_client = AsyncDatabaseClient() + task_id = await db_data_creator.task() error_infos = [] for url_mapping in url_mappings: uei = URLErrorPydanticInfo( url_id=url_mapping.url_id, error="test error", + task_id=task_id ) error_infos.append(uei) @@ -167,7 +169,9 @@ async def test_get_urls_with_html_data_and_no_relevancy_metadata( url_ids = [url_info.url_id for url_info in url_mappings] await db_data_creator.html_data(url_ids) await db_data_creator.metadata([url_ids[0]]) - results = await db_data_creator.adb_client.get_urls_with_html_data_and_no_relevancy_metadata() + results = await db_data_creator.adb_client.get_urls_with_html_data_and_without_metadata_type( + without_metadata_type=URLMetadataAttributeType.RELEVANT + ) permitted_url_ids = [url_id for url_id in url_ids if url_id != url_ids[0]] assert len(results) == 2 diff --git a/tests/test_automated/integration/tasks/__init__.py b/tests/test_automated/integration/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_automated/integration/tasks/test_example_task.py b/tests/test_automated/integration/tasks/test_example_task.py new file mode 100644 index 00000000..f6f56521 --- /dev/null +++ b/tests/test_automated/integration/tasks/test_example_task.py @@ -0,0 +1,87 @@ +import types + +import pytest + +from collector_db.enums import TaskType +from core.classes.TaskOperatorBase import TaskOperatorBase +from core.enums import BatchStatus +from tests.helpers.DBDataCreator import DBDataCreator + +class ExampleTaskOperator(TaskOperatorBase): + + @property + def task_type(self) -> TaskType: + # Use TaskType.HTML so we don't have to add a test enum value to the db + return TaskType.HTML + + def inner_task_logic(self): + raise NotImplementedError + + async def meets_task_prerequisites(self): + return True + +@pytest.mark.asyncio +async def test_example_task_success(db_data_creator: DBDataCreator): + batch_id = db_data_creator.batch() + url_mappings = db_data_creator.urls( + batch_id=batch_id, + url_count=3 + ).url_mappings + url_ids = [url_info.url_id for url_info in url_mappings] + + async def mock_inner_task_logic(self): + # Add link to 3 urls + await self.adb_client.link_urls_to_task(task_id=self.task_id, url_ids=url_ids) + self.tasks_linked = True + + operator = ExampleTaskOperator(adb_client=db_data_creator.adb_client) + operator.inner_task_logic = types.MethodType(mock_inner_task_logic, operator) + + await operator.run_task() + + # Get Task Info + task_info = await db_data_creator.adb_client.get_task_info(task_id=operator.task_id) + + # Check that 3 urls were linked to the task + assert len(task_info.urls) == 3 + + # Check that error info is empty + assert task_info.error_info is None + + # Check that the task was marked as complete + assert task_info.task_status == BatchStatus.COMPLETE + + # Check that the task type is HTML + assert task_info.task_type == TaskType.HTML + + + # Check that updated_at is not null + assert task_info.updated_at is not None + +@pytest.mark.asyncio +async def test_example_task_failure(db_data_creator: DBDataCreator): + operator = ExampleTaskOperator(adb_client=db_data_creator.adb_client) + + def mock_inner_task_logic(self): + raise ValueError("test error") + + operator.inner_task_logic = types.MethodType(mock_inner_task_logic, operator) + await operator.run_task() + + # Get Task Info + task_info = await db_data_creator.adb_client.get_task_info(task_id=operator.task_id) + + # Check that there are no URLs associated + assert len(task_info.urls) == 0 + + # Check that the task was marked as errored + assert task_info.task_status == BatchStatus.ERROR + + # Check that the task type is HTML + assert task_info.task_type == TaskType.HTML + + # Check error + assert "test error" in task_info.error_info + + + diff --git a/tests/test_automated/integration/tasks/test_url_html_task.py b/tests/test_automated/integration/tasks/test_url_html_task.py new file mode 100644 index 00000000..7674113f --- /dev/null +++ b/tests/test_automated/integration/tasks/test_url_html_task.py @@ -0,0 +1,104 @@ +import types +from typing import Optional + +import pytest + +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.enums import TaskType +from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator +from core.enums import BatchStatus +from tests.helpers.DBDataCreator import DBDataCreator +from tests.helpers.assert_functions import assert_database_has_no_tasks +from html_tag_collector.DataClassTags import ResponseHTMLInfo +from html_tag_collector.ResponseParser import HTMLResponseParser +from html_tag_collector.RootURLCache import RootURLCache +from html_tag_collector.URLRequestInterface import URLRequestInterface, URLResponseInfo + + +@pytest.mark.asyncio +async def test_url_html_task(db_data_creator: DBDataCreator): + + mock_html_content = "" + mock_content_type = "text/html" + + async def mock_make_requests(self, urls: list[str]) -> list[URLResponseInfo]: + results = [] + for idx, url in enumerate(urls): + if idx == 2: + results.append( + URLResponseInfo( + success=False, + exception=ValueError("test error"), + content_type=mock_content_type + )) + else: + results.append(URLResponseInfo( + html=mock_html_content, success=True, content_type=mock_content_type)) + return results + + async def mock_parse(self, url: str, html_content: str, content_type: str) -> ResponseHTMLInfo: + assert html_content == mock_html_content + assert content_type == mock_content_type + return ResponseHTMLInfo( + url=url, + title="fake title", + description="fake description", + ) + + async def mock_get_from_cache(self, url: str) -> Optional[str]: + return None + + # Add mock methods or mock classes + url_request_interface = URLRequestInterface() + url_request_interface.make_requests = types.MethodType(mock_make_requests, url_request_interface) + + mock_root_url_cache = RootURLCache() + mock_root_url_cache.get_from_cache = types.MethodType(mock_get_from_cache, mock_root_url_cache) + + html_parser = HTMLResponseParser( + root_url_cache=mock_root_url_cache + ) + html_parser.parse = types.MethodType(mock_parse, html_parser) + + operator = URLHTMLTaskOperator( + adb_client=AsyncDatabaseClient(), + url_request_interface=url_request_interface, + html_parser=html_parser + ) + await operator.run_task() + + # Check that, because no URLs were created, the task did not run + await assert_database_has_no_tasks(db_data_creator.adb_client) + + batch_id = db_data_creator.batch() + url_mappings = db_data_creator.urls(batch_id=batch_id, url_count=3).url_mappings + url_ids = [url_info.url_id for url_info in url_mappings] + + await operator.run_task() + + + # Check in database that + # - task is listed as complete + # - task type is listed as 'HTML' + # - task has 3 urls + # - task has one errored url with error "ValueError" + task_info = await db_data_creator.adb_client.get_task_info( + task_id=operator.task_id + ) + + assert task_info.error_info is None + assert task_info.task_status == BatchStatus.COMPLETE + assert task_info.task_type == TaskType.HTML + + assert len(task_info.urls) == 3 + assert len(task_info.url_errors) == 1 + assert task_info.url_errors[0].error == "test error" + + adb = db_data_creator.adb_client + # Check that both success urls have two rows of HTML data + hci = await adb.get_html_content_info(url_id=task_info.urls[0].id) + assert len(hci) == 2 + hci = await adb.get_html_content_info(url_id=task_info.urls[1].id) + assert len(hci) == 2 + + # Check that errored url has error info diff --git a/tests/test_automated/integration/tasks/test_url_record_type_task.py b/tests/test_automated/integration/tasks/test_url_record_type_task.py new file mode 100644 index 00000000..ee624dae --- /dev/null +++ b/tests/test_automated/integration/tasks/test_url_record_type_task.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock + +import pytest + +from collector_db.enums import TaskType +from collector_db.models import URLMetadata +from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator +from core.enums import RecordType, BatchStatus +from tests.helpers.DBDataCreator import DBDataCreator +from tests.helpers.assert_functions import assert_database_has_no_tasks +from llm_api_logic.DeepSeekRecordClassifier import DeepSeekRecordClassifier + +@pytest.mark.asyncio +async def test_url_record_type_task(db_data_creator: DBDataCreator): + + mock_classifier = MagicMock(spec=DeepSeekRecordClassifier) + mock_classifier.classify_url.side_effect = [RecordType.ACCIDENT_REPORTS, "Error"] + mock_classifier.model_name = "test_notes" + + operator = URLRecordTypeTaskOperator( + adb_client=db_data_creator.adb_client, + classifier=mock_classifier + ) + await operator.run_task() + + # No task should have been created due to not meeting prerequisites + await assert_database_has_no_tasks(db_data_creator.adb_client) + + batch_id = db_data_creator.batch() + iui = db_data_creator.urls(batch_id=batch_id, url_count=2) + url_ids = [iui.url_mappings[0].url_id, iui.url_mappings[1].url_id] + await db_data_creator.html_data(url_ids) + + await operator.run_task() + + # Task should have been created + task_info = await db_data_creator.adb_client.get_task_info(task_id=operator.task_id) + assert task_info.error_info is None + assert task_info.task_status == BatchStatus.COMPLETE + + response = await db_data_creator.adb_client.get_tasks() + tasks = response.tasks + assert len(tasks) == 1 + task = tasks[0] + assert task.type == TaskType.RECORD_TYPE + assert task.url_count == 2 + assert task.url_error_count == 1 + + # Get metadata + metadata_results = await db_data_creator.adb_client.get_all(URLMetadata) + for metadata_row in metadata_results: + assert metadata_row.notes == "test_notes" + assert metadata_row.value == RecordType.ACCIDENT_REPORTS.value + diff --git a/tests/test_automated/integration/cycles/test_url_relevancy_huggingface_cycle.py b/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py similarity index 77% rename from tests/test_automated/integration/cycles/test_url_relevancy_huggingface_cycle.py rename to tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py index 064eff51..abf86cda 100644 --- a/tests/test_automated/integration/cycles/test_url_relevancy_huggingface_cycle.py +++ b/tests/test_automated/integration/tasks/test_url_relevancy_huggingface_task.py @@ -5,18 +5,15 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.URLWithHTML import URLWithHTML from collector_db.enums import ValidationStatus, ValidationSource -from collector_db.models import URLMetadata -from core.classes.URLRelevanceHuggingfaceCycler import URLRelevanceHuggingfaceCycler +from collector_db.models import URLMetadata, Task +from core.classes.URLRelevanceHuggingfaceTaskOperator import URLRelevanceHuggingfaceTaskOperator +from tests.helpers.assert_functions import assert_database_has_no_tasks from hugging_face.HuggingFaceInterface import HuggingFaceInterface @pytest.mark.asyncio -async def test_url_relevancy_huggingface_cycle(db_data_creator): - batch_id = db_data_creator.batch() - url_mappings = db_data_creator.urls(batch_id=batch_id, url_count=3).url_mappings - url_ids = [url_info.url_id for url_info in url_mappings] - await db_data_creator.html_data(url_ids) - await db_data_creator.metadata([url_ids[0]]) +async def test_url_relevancy_huggingface_task(db_data_creator): + def num_to_bool(num: int) -> bool: if num == 0: @@ -38,11 +35,21 @@ def mock_get_url_relevancy( mock_hf_interface = MagicMock(spec=HuggingFaceInterface) mock_hf_interface.get_url_relevancy = mock_get_url_relevancy - cycler = URLRelevanceHuggingfaceCycler( + task_operator = URLRelevanceHuggingfaceTaskOperator( adb_client=AsyncDatabaseClient(), huggingface_interface=mock_hf_interface ) - await cycler.cycle() + await task_operator.run_task() + + await assert_database_has_no_tasks(db_data_creator.adb_client) + + batch_id = db_data_creator.batch() + url_mappings = db_data_creator.urls(batch_id=batch_id, url_count=3).url_mappings + url_ids = [url_info.url_id for url_info in url_mappings] + await db_data_creator.html_data(url_ids) + await db_data_creator.metadata([url_ids[0]]) + + await task_operator.run_task() results = await db_data_creator.adb_client.get_all(URLMetadata)