From cd47b8dd6906926610970b6c6c89b3eb35c2a1a0 Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Tue, 14 Jan 2025 17:45:26 +0200 Subject: [PATCH 01/31] Feature: User favourites (#2) * Feature: initial routers / methods for favourites - add env documentation / adminer image for easier inspection of db - create / delete favourite entry in database - fetch list of favourites by user * Chore: update formatting, tests * update test methods / logging - * updates * expand tests / logging --- .env.example | 19 ++ .env.local | 13 ++ .gitignore | 2 + Dockerfile | 22 +- Makefile | 2 +- docker-compose.yaml | 32 ++- main.py | 13 ++ requirements_dev.txt | 3 + setup.py | 13 ++ sql/create_tables.sql | 19 +- sql/init_test_data.sql | 29 +++ sql/insert_test_data.sql | 43 ++++ sql/test_data.json | 104 ++++++++++ src/authentication.py | 27 ++- src/config/logging_config.py | 29 +++ src/database/__init__.py | 1 + src/database/favourites.py | 105 ++++++++++ src/database/signals.py | 29 +++ src/dependencies.py | 6 + src/entities/signal.py | 5 + src/routers/__init__.py | 2 + src/routers/favourites.py | 73 +++++++ src/routers/signals.py | 22 ++ src/routers/users.py | 2 + tests/conftest.py | 100 ++++++++- tests/test_favourites.py | 378 +++++++++++++++++++++++++++++++++++ 26 files changed, 1075 insertions(+), 18 deletions(-) create mode 100644 .env.example create mode 100644 .env.local create mode 100644 setup.py create mode 100644 sql/init_test_data.sql create mode 100644 sql/insert_test_data.sql create mode 100644 sql/test_data.json create mode 100644 src/config/logging_config.py create mode 100644 src/database/favourites.py create mode 100644 src/routers/favourites.py create mode 100644 tests/test_favourites.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ebc6759 --- /dev/null +++ b/.env.example @@ -0,0 +1,19 @@ +# Authentication +TENANT_ID="" +CLIENT_ID="" +API_KEY="" # for accessing "public" endpoints + +# Database and Storage +DB_CONNECTION="postgresql://:@:5432/" +SAS_URL="https://.blob.core.windows.net/?" + +# Azure OpenAI, only required for `/signals/generation` +AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" +AZURE_OPENAI_API_KEY="" + +# Testing, only required to run tests, must be a valid token of a regular user +API_JWT="" + +# News API +# https://newsapi.org/account +NEWS_API_KEY="" \ No newline at end of file diff --git a/.env.local b/.env.local new file mode 100644 index 0000000..529fc82 --- /dev/null +++ b/.env.local @@ -0,0 +1,13 @@ +TENANT_ID= +CLIENT_ID= +API_KEY= + +DB_CONNECTION= +SAS_URL= + +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= + +API_JWT= + +NEWS_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 749fb3f..febc3c4 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,5 @@ cython_debug/ # Manually added for this project .idea/ **/.DS_Store +sql/create_test_user.sql +Taskfile.yml diff --git a/Dockerfile b/Dockerfile index 2dab8d0..d43c950 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,26 @@ FROM python:3.11.7-slim + +# Install system dependencies RUN apt-get update -y \ - && apt-get install libpq-dev -y \ + && apt-get install -y \ + libpq-dev \ + postgresql-client \ + curl \ && rm -rf /var/lib/apt/lists/* + WORKDIR /app -COPY requirements.txt . -RUN pip install --no-cache-dir --upgrade -r requirements.txt + +# Install Python dependencies including development dependencies +COPY requirements.txt requirements_dev.txt ./ +RUN pip install --no-cache-dir --upgrade -r requirements.txt -r requirements_dev.txt + +# Copy application code COPY . . + EXPOSE 8000 + +# Add healthcheck +HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ + CMD curl --fail http://localhost:8000/signals/search || exit 1 + CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/Makefile b/Makefile index 898ae03..e4bd553 100644 --- a/Makefile +++ b/Makefile @@ -5,4 +5,4 @@ format: lint: pylint main.py src/ test: - python -m pytest tests/ + python -m pytest tests/ \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index 4adec8a..9771a63 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -5,16 +5,42 @@ services: env_file: .env environment: - DB_CONNECTION=postgresql://postgres:password@db:5432/postgres + - ENV_MODE=local ports: - "8000:8000" + volumes: + - .:/app depends_on: - - db + db: + condition: service_healthy + command: > + sh -c "sleep 5 && uvicorn main:app --host 0.0.0.0 --port 8000 --reload" db: image: postgres:16.4-alpine environment: POSTGRES_PASSWORD: password + POSTGRES_DB: postgres ports: - - "5432:5432" + - 5432:5432 volumes: - - ./sql:/docker-entrypoint-initdb.d + - postgres_data:/var/lib/postgresql/data + - ./sql/create_tables.sql:/docker-entrypoint-initdb.d/1-create_tables.sql + - ./sql/import_data.sql:/docker-entrypoint-initdb.d/2-import_data.sql + - ./sql/init_test_data.sql:/docker-entrypoint-initdb.d/3-init_test_data.sql - ./data:/docker-entrypoint-initdb.d/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 5s + timeout: 5s + retries: 5 + adminer: + image: adminer + restart: always + ports: + - 4040:8080 + depends_on: + db: + condition: service_healthy + +volumes: + postgres_data: diff --git a/main.py b/main.py index 2756a42..39e3ba0 100644 --- a/main.py +++ b/main.py @@ -5,11 +5,14 @@ from dotenv import load_dotenv from fastapi import Depends, FastAPI +from fastapi.middleware.cors import CORSMiddleware from src import routers from src.authentication import authenticate_user +from src.config.logging_config import setup_logging load_dotenv() +setup_logging() app = FastAPI( debug=False, @@ -42,11 +45,21 @@ {"name": "trends", "description": "CRUD operations on trends."}, {"name": "users", "description": "CRUD operations on users."}, {"name": "choices", "description": "List valid options for forms fields."}, + {"name": "favourites", "description": "Manage user's favorite signals."}, ], docs_url="/", redoc_url=None, ) +# allow cors +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + for router in routers.ALL: app.include_router(router=router, dependencies=[Depends(authenticate_user)]) diff --git a/requirements_dev.txt b/requirements_dev.txt index a0cf82c..070721d 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,3 +4,6 @@ isort ~= 5.13.2 pylint ~= 3.3.1 pytest ~= 8.3.3 notebook ~= 7.2.2 +pytest-asyncio==0.21.1 +pytest-cov==4.1.0 +pytest-watch==4.2.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6a28fbe --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +from setuptools import find_packages, setup + +setup( + name="ftss-api", + version="0.1", + packages=find_packages(), + install_requires=[ + "fastapi", + "uvicorn", + "psycopg", + "pydantic", + ], +) \ No newline at end of file diff --git a/sql/create_tables.sql b/sql/create_tables.sql index e6d5b14..42f2540 100644 --- a/sql/create_tables.sql +++ b/sql/create_tables.sql @@ -24,11 +24,13 @@ CREATE TABLE users ( role VARCHAR(255) NOT NULL, name VARCHAR(255), unit VARCHAR(255), - acclab BOOLEAN + acclab BOOLEAN, + api_key VARCHAR(255) UNIQUE ); CREATE INDEX ON users (email); CREATE INDEX ON users (role); +CREATE INDEX ON users (api_key); -- signals table and indices CREATE TABLE signals ( @@ -118,6 +120,19 @@ CREATE TABLE connections ( CONSTRAINT connection_pk PRIMARY KEY (signal_id, trend_id) ); +-- favourites table to track user's favourite signals +CREATE TABLE favourites ( + user_id INT REFERENCES users(id) ON DELETE CASCADE, + signal_id INT REFERENCES signals(id) ON DELETE CASCADE, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + CONSTRAINT favourites_pk PRIMARY KEY (user_id, signal_id) +); + +CREATE INDEX ON favourites (user_id, created_at); + +CREATE INDEX favourites_user_signal_idx ON favourites (user_id, signal_id); +CREATE INDEX favourites_created_at_idx ON favourites (created_at DESC); + -- locations table and indices CREATE TABLE locations ( id SERIAL PRIMARY KEY, @@ -134,4 +149,4 @@ CREATE TABLE units ( name TEXT NOT NULL, region VARCHAR(255) ); -CREATE INDEX ON units (name, region); +CREATE INDEX ON units (name, region); \ No newline at end of file diff --git a/sql/init_test_data.sql b/sql/init_test_data.sql new file mode 100644 index 0000000..be5726d --- /dev/null +++ b/sql/init_test_data.sql @@ -0,0 +1,29 @@ +/* +The initialisation script to create test data for local development. +This script is automatically executed by docker compose after create_tables.sql +and import_data.sql. +*/ + +-- Create test users +INSERT INTO users ( + id, + created_at, + email, + role, + name, + unit, + acclab, + api_key +) VALUES ( + 1, -- This ID is expected by the test suite + NOW(), + 'test.user@undp.org', + 'ADMIN', + 'Test User', + 'Data Futures Exchange (DFx)', + false, + 'test-key' +); + +-- Reset the sequence to start after our manually inserted IDs +SELECT setval('users_id_seq', (SELECT MAX(id) FROM users)); \ No newline at end of file diff --git a/sql/insert_test_data.sql b/sql/insert_test_data.sql new file mode 100644 index 0000000..3f5afd0 --- /dev/null +++ b/sql/insert_test_data.sql @@ -0,0 +1,43 @@ +CREATE OR REPLACE FUNCTION insert_test_signals() +RETURNS void AS $$ +DECLARE + signals jsonb; + signal_record jsonb; +BEGIN + -- Read the JSON file + signals := (pg_read_file('/docker-entrypoint-initdb.d/test_data.json')::jsonb)->'signals'; + + -- Loop through each signal and insert + FOR signal_record IN SELECT * FROM jsonb_array_elements(signals) LOOP + WITH arrays AS ( + SELECT + array(SELECT * FROM jsonb_array_elements_text(signal_record->'keywords')) as keywords, + array(SELECT * FROM jsonb_array_elements_text(signal_record->'steep_secondary')) as steep_secondary, + array(SELECT * FROM jsonb_array_elements_text(signal_record->'signature_secondary')) as signature_secondary, + array(SELECT * FROM jsonb_array_elements_text(signal_record->'sdgs')) as sdgs + ) + INSERT INTO signals ( + status, created_by, modified_by, headline, description, url, + relevance, keywords, location, steep_primary, steep_secondary, + signature_primary, signature_secondary, sdgs, created_unit + ) + SELECT + signal_record->>'status', + signal_record->>'created_by', + signal_record->>'modified_by', + signal_record->>'headline', + signal_record->>'description', + signal_record->>'url', + signal_record->>'relevance', + keywords, + signal_record->>'location', + signal_record->>'steep_primary', + steep_secondary, + signal_record->>'signature_primary', + signature_secondary, + sdgs, + signal_record->>'created_unit' + FROM arrays; + END LOOP; +END; +$$ LANGUAGE plpgsql; \ No newline at end of file diff --git a/sql/test_data.json b/sql/test_data.json new file mode 100644 index 0000000..f3f68a8 --- /dev/null +++ b/sql/test_data.json @@ -0,0 +1,104 @@ +{ + "signals": [ + { + "status": "Approved", + "created_by": "dev@undp.org", + "modified_by": "dev@undp.org", + "headline": "AI Revolution in Healthcare", + "description": "Artificial Intelligence is transforming healthcare delivery and patient outcomes globally.", + "url": "https://example.com/ai-healthcare", + "relevance": "AI adoption in healthcare is expected to grow by 40% annually through 2025", + "keywords": ["AI", "healthcare", "technology"], + "location": "Global", + "steep_primary": "Technological – Made culture, tools, devices, systems, infrastructure and networks", + "steep_secondary": [ + "Social – Issues related to human culture, demography, communication, movement and migration, work and education", + "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions" + ], + "signature_primary": "Strategic Innovation", + "signature_secondary": ["Digitalisation", "Development Financing"], + "sdgs": ["GOAL 3: Good Health and Well-being"], + "created_unit": "UNDP Innovation" + }, + { + "status": "Approved", + "created_by": "dev@undp.org", + "modified_by": "dev@undp.org", + "headline": "Climate-Resilient Agriculture", + "description": "New farming techniques helping communities adapt to climate change.", + "url": "https://example.com/climate-agriculture", + "relevance": "Over 500 communities have adopted these techniques, increasing crop yields by 30%", + "keywords": ["agriculture", "climate", "resilience"], + "location": "Africa", + "steep_primary": "Environmental – The natural world, living environment, sustainability, resources, climate and health", + "steep_secondary": [ + "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions", + "Social – Issues related to human culture, demography, communication, movement and migration, work and education" + ], + "signature_primary": "Resilience", + "signature_secondary": ["Environment", "Poverty and Inequality"], + "sdgs": ["GOAL 2: Zero Hunger", "GOAL 13: Climate Action"], + "created_unit": "UNDP Climate" + }, + { + "status": "Draft", + "created_by": "dev@undp.org", + "modified_by": "dev@undp.org", + "headline": "Digital Financial Inclusion", + "description": "Mobile banking revolutionizing access to financial services in developing nations.", + "url": "https://example.com/digital-finance", + "relevance": "Mobile money users increased by 12.7% globally, reaching 1.2 billion accounts", + "keywords": ["fintech", "inclusion", "mobile"], + "location": "Asia-Pacific", + "steep_primary": "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions", + "steep_secondary": [ + "Technological – Made culture, tools, devices, systems, infrastructure and networks", + "Social – Issues related to human culture, demography, communication, movement and migration, work and education" + ], + "signature_primary": "Development Financing", + "signature_secondary": ["Digitalisation", "Poverty and Inequality"], + "sdgs": ["GOAL 1: No Poverty", "GOAL 10: Reduced Inequality"], + "created_unit": "UNDP Finance" + }, + { + "status": "Review", + "created_by": "dev@undp.org", + "modified_by": "dev@undp.org", + "headline": "Urban Mobility Revolution", + "description": "Smart city solutions transforming urban transportation systems.", + "url": "https://example.com/urban-mobility", + "relevance": "Electric vehicle adoption in cities grew by 50% year-over-year", + "keywords": ["mobility", "smart-city", "transportation"], + "location": "Latin America", + "steep_primary": "Technological – Made culture, tools, devices, systems, infrastructure and networks", + "steep_secondary": [ + "Environmental – The natural world, living environment, sustainability, resources, climate and health", + "Social – Issues related to human culture, demography, communication, movement and migration, work and education" + ], + "signature_primary": "Strategic Innovation", + "signature_secondary": ["Environment", "Energy"], + "sdgs": ["GOAL 11: Sustainable Cities and Communities"], + "created_unit": "UNDP Cities" + }, + { + "status": "Approved", + "created_by": "dev@undp.org", + "modified_by": "dev@undp.org", + "headline": "Renewable Energy Breakthrough", + "description": "Novel solar technology achieves record efficiency levels.", + "url": "https://example.com/solar-breakthrough", + "relevance": "New solar cells achieve 40% efficiency, marking a significant milestone", + "keywords": ["renewable", "solar", "energy"], + "location": "Global", + "steep_primary": "Technological – Made culture, tools, devices, systems, infrastructure and networks", + "steep_secondary": [ + "Environmental – The natural world, living environment, sustainability, resources, climate and health", + "Economic – Issues of value, money, financial tools and systems, business and business models, exchanges and transactions" + ], + "signature_primary": "Energy", + "signature_secondary": ["Environment", "Strategic Innovation"], + "sdgs": ["GOAL 7: Affordable and Clean Energy"], + "created_unit": "UNDP Energy" + } + ] +} diff --git a/src/authentication.py b/src/authentication.py index 0f748f3..da64752 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -2,6 +2,7 @@ Dependencies for API authentication using JWT tokens from Microsoft Entra. """ +import logging import os import httpx @@ -129,10 +130,30 @@ async def authenticate_user( user : User Pydantic model for a User object (if authentication succeeded). """ + logging.debug(f"Authenticating user with token") + if os.environ.get("TEST_USER_TOKEN"): + token = os.environ.get("TEST_USER_TOKEN") + if os.environ.get("ENV_MODE") == "local": + # defaul user data + user_data = { + "email": "test.user@undp.org", + "name": "Test User", + "unit": "Data Futures Exchange (DFx)", + "acclab": False, + } + if token == "test-admin-token": + user_data["role"] = Role.ADMIN + return User(**user_data) + elif token == "test-user-token": + user_data["role"] = Role.USER + return User(**user_data) + if token == os.environ.get("API_KEY"): - # dummy user object for anonymous access - user = User(email="name.surname@undp.org", role=Role.VISITOR) - return user + if os.environ.get("ENV") == "dev": + return User(email="name.surname@undp.org", role=Role.ADMIN) + else: + # dummy user object for anonymous access + return User(email="name.surname@undp.org", role=Role.VISITOR) try: payload = await decode_token(token) except jwt.exceptions.PyJWTError as e: diff --git a/src/config/logging_config.py b/src/config/logging_config.py new file mode 100644 index 0000000..ead0300 --- /dev/null +++ b/src/config/logging_config.py @@ -0,0 +1,29 @@ +""" +Logging configuration for the application. +""" + +import logging +import os +import sys +from typing import Optional + + +def setup_logging(level: Optional[str] = None) -> None: + """ + Setup logging configuration for the application. + + Args: + level: The logging level to use. If None, defaults to INFO. + """ + log_level = getattr(logging, (level or os.getenv("LOGGING_LEVEL") or "INFO").upper()) + + # Configure the root logger + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + # Set third-party loggers to WARNING to reduce noise + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) diff --git a/src/database/__init__.py b/src/database/__init__.py index 6a2c1e0..5095709 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -4,6 +4,7 @@ from .choices import * from .connection import yield_cursor +from .favourites import * from .signals import * from .trends import * from .users import * diff --git a/src/database/favourites.py b/src/database/favourites.py new file mode 100644 index 0000000..15b1a67 --- /dev/null +++ b/src/database/favourites.py @@ -0,0 +1,105 @@ +""" +Database operations for user favorites. +""" + +import logging +from datetime import datetime +from typing import cast + +from fastapi import HTTPException +from psycopg import AsyncCursor +from psycopg.rows import DictRow + +from ..entities import Signal + +logger = logging.getLogger(__name__) + + +async def create_favourite( + cursor: AsyncCursor[DictRow], user_email: str, signal_id: int +) -> dict: + logger.debug("Creating/removing favourite for signal_id: %s", signal_id) + + # First check if the signal exists + query = """ + SELECT s.*, COALESCE(array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL), ARRAY[]::integer[]) as connected_trends + FROM signals s + LEFT JOIN connections c ON s.id = c.signal_id + WHERE s.id = %s + GROUP BY s.id; + """ + + await cursor.execute(query, (signal_id,)) + signal_row = cast(DictRow | None, await cursor.fetchone()) + logger.debug("Found signal: %s", signal_row) + + if signal_row is None: + logger.warning("Signal not found with id: %s", signal_id) + raise HTTPException(status_code=404, detail="Signal not found") + + # Get user_id from email + query = """ + SELECT id FROM users WHERE email = %s; + """ + await cursor.execute(query, (user_email,)) + user_row = cast(DictRow | None, await cursor.fetchone()) + + if user_row is None: + raise HTTPException( + status_code=404, detail="User not found with email " + user_email + ) + user_id = user_row["id"] + + # Check if the favorite already exists + query = """ + SELECT 1 FROM favourites WHERE user_id = %s AND signal_id = %s; + """ + await cursor.execute(query, (user_id, signal_id)) + exists = await cursor.fetchone() + + if exists: + logger.debug("Deleting favourite for signal_id: %s", signal_id) + # Remove the favorite + query = """ + DELETE FROM favourites WHERE user_id = %s AND signal_id = %s; + """ + try: + await cursor.execute(query, (user_id, signal_id)) + logger.debug("Deleted favourite for signal_id: %s", signal_id) + return {"status": "deleted"} + except Exception as e: + logger.error("Error deleting favourite for signal_id: %s", signal_id, exc_info=True) + raise e + else: + logger.debug("Adding favourite for signal_id: %s", signal_id) + # Add to favorites + query = """ + INSERT INTO favourites (user_id, signal_id, created_at) + VALUES (%s, %s, %s) + ON CONFLICT (user_id, signal_id) DO NOTHING; + """ + try: + await cursor.execute(query, (user_id, signal_id, datetime.utcnow())) + logger.debug("Added favourite for signal_id: %s", signal_id) + return {"status": "created"} + except Exception as e: + logger.error("Error adding favourite for signal_id: %s", signal_id, exc_info=True) + raise e + + +async def read_user_favourites(cursor: AsyncCursor[DictRow], user_email: str) -> list[Signal]: + logger.debug("Reading user favourites for user_email: %s", user_email) + query = """ + SELECT s.*, COALESCE(array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL), ARRAY[]::integer[]) as connected_trends + FROM signals s + LEFT JOIN connections c ON s.id = c.signal_id + JOIN favourites f ON s.id = f.signal_id + JOIN users u ON f.user_id = u.id + WHERE u.email = %s + GROUP BY s.id, f.created_at + ORDER BY f.created_at DESC; + """ + await cursor.execute(query, (user_email,)) + rows = await cursor.fetchall() + logger.debug("Fetched %s rows", len(rows)) + return [Signal.model_validate(cast(DictRow, row)) for row in rows] \ No newline at end of file diff --git a/src/database/signals.py b/src/database/signals.py index 4e1c591..84b8cb0 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -14,6 +14,7 @@ "update_signal", "delete_signal", "read_user_signals", + "is_signal_favorited", ] @@ -358,3 +359,31 @@ async def read_user_signals( """ await cursor.execute(query, (user_email, status)) return [Signal(**row) async for row in cursor] + + +async def is_signal_favorited(cursor: AsyncCursor, user_email: str, signal_id: int) -> bool: + """ + Check if a signal is favorited by a user. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_email : str + The email of the user to check. + signal_id : int + The ID of the signal to check. + + Returns + ------- + bool + True if the signal is favorited by the user, False otherwise. + """ + query = """ + SELECT 1 + FROM favourites f + JOIN users u ON f.user_id = u.id + WHERE u.email = %s AND f.signal_id = %s; + """ + await cursor.execute(query, (user_email, signal_id)) + return await cursor.fetchone() is not None diff --git a/src/dependencies.py b/src/dependencies.py index df9f7d5..b6b9ccd 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -2,6 +2,7 @@ Functions used for dependency injection for role-based access control. """ +import logging from typing import Annotated from fastapi import Depends, Path @@ -12,6 +13,8 @@ from .authentication import authenticate_user from .entities import User +logger = logging.getLogger(__name__) + __all__ = [ "require_admin", "require_curator", @@ -22,8 +25,11 @@ async def require_admin(user: User = Depends(authenticate_user)) -> User: """Require that the user is assigned an admin role.""" + logger.info(f"Checking admin permissions for user {user.email} with role {user.role}") if not user.is_admin: + logger.warning(f"Permission denied: User {user.email} with role {user.role} attempted admin action") raise exceptions.permission_denied + logger.info(f"Admin permission granted for user {user.email}") return user diff --git a/src/entities/signal.py b/src/entities/signal.py index 5112e46..e68b541 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -29,6 +29,10 @@ class Signal(BaseEntity): default=None, description="IDs of trends connected to this signal.", ) + favorite: bool = Field( + default=False, + description="Whether the current user has favorited this signal.", + ) model_config = ConfigDict( json_schema_extra={ @@ -38,6 +42,7 @@ class Signal(BaseEntity): "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", "keywords": ["economy", "governance"], "location": "Global", + "favorite": False, } } ) diff --git a/src/routers/__init__.py b/src/routers/__init__.py index cf23153..ca600c9 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -3,12 +3,14 @@ """ from .choices import router as choice_router +from .favourites import router as favourites_router from .signals import router as signal_router from .trends import router as trend_router from .users import router as user_router ALL = [ choice_router, + favourites_router, signal_router, trend_router, user_router, diff --git a/src/routers/favourites.py b/src/routers/favourites.py new file mode 100644 index 0000000..9ce81ad --- /dev/null +++ b/src/routers/favourites.py @@ -0,0 +1,73 @@ +""" +A router for managing user's favorite signals. +""" + +import logging +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException +from psycopg import AsyncCursor +from psycopg.rows import DictRow +from pydantic import BaseModel + +from .. import database as db +from ..dependencies import require_user +from ..entities import Signal, User + +logger = logging.getLogger(__name__) + +# Create router instance +router = APIRouter( + prefix="/favourites", + tags=["favourites"], +) + + +class FavoriteResponse(BaseModel): + status: Literal["created", "deleted"] + + +# Define dependency functions to avoid Trunk linter warnings +def get_cursor(): + return Depends(db.yield_cursor) + + +def get_user(): + return Depends(require_user) + + +@router.post("/{signal_id}", response_model=FavoriteResponse) +async def create_or_remove_favourite( + signal_id: int, + user: User = Depends(require_user), + cursor: AsyncCursor[DictRow] = Depends(db.yield_cursor), +) -> dict: + """ + Add or remove a signal from user's favorites depending on current status. + """ + try: + signal = await db.read_signal(cursor, signal_id) + logger.debug("Found signal for favourite operation: %s", signal) + + if signal: + return await db.create_favourite(cursor, user.email, signal_id) + + logger.warning("Signal not found with id: %s", signal_id) + raise HTTPException(status_code=404, detail="Signal not found") + except Exception as e: + logger.error("Error in create_or_remove_favourite: %s", str(e), exc_info=True) + raise e + + +@router.get("/", response_model=list[Signal]) +async def fetch_user_favourites( + user: User = Depends(require_user), + cursor: AsyncCursor[DictRow] = Depends(db.yield_cursor), +) -> list[Signal]: + """ + Get all signals that the current user has favorited, in chronological order + of when they were favorited. + """ + signals = await db.read_user_favourites(cursor, user.email) + + return signals diff --git a/src/routers/signals.py b/src/routers/signals.py index bfb5057..b078082 100644 --- a/src/routers/signals.py +++ b/src/routers/signals.py @@ -2,6 +2,7 @@ A router for retrieving, submitting and updating signals. """ +import logging from typing import Annotated import pandas as pd @@ -14,6 +15,8 @@ from ..dependencies import require_creator, require_curator, require_user from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/signals", tags=["signals"]) @@ -116,10 +119,29 @@ async def read_signal( Retrieve a signal form the database using an ID. Trends connected to the signal can be retrieved using IDs from the `signal.connected_trends` field. """ + logger.info("Reading signal with ID: %s for user: %s", uid, user.email) + if (signal := await db.read_signal(cursor, uid)) is None: + logger.warning("Signal not found with ID: %s", uid) raise exceptions.not_found + + logger.info("Retrieved signal: %s", signal.model_dump()) + if user.role == Role.VISITOR and signal.status != Status.APPROVED: + logger.warning( + "Permission denied - visitor trying to access non-approved signal. Status: %s", + signal.status + ) raise exceptions.permission_denied + + # Check if the signal is favorited by the user + logger.info("Checking favorite status for signal %s and user %s", uid, user.email) + is_favorite = await db.is_signal_favorited(cursor, user.email, uid) + logger.info("Favorite status result: %s", is_favorite) + + signal.favorite = is_favorite + logger.info("Final signal with favorite status: %s", signal.model_dump()) + return signal diff --git a/src/routers/users.py b/src/routers/users.py index e244458..3215da5 100644 --- a/src/routers/users.py +++ b/src/routers/users.py @@ -2,6 +2,7 @@ A router for creating, reading and updating trends. """ +import logging from typing import Annotated from fastapi import APIRouter, Depends, Path, Query @@ -29,6 +30,7 @@ async def search_users( @router.get("/me", response_model=User) async def read_current_user(user: User = Depends(authenticate_user)): """Read the current user information from a JTW token.""" + logging.debug(f"User: {user}") if user is None: raise exceptions.not_found return user diff --git a/tests/conftest.py b/tests/conftest.py index 47d7588..499dc2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,20 +3,108 @@ """ import os +from datetime import datetime +from pathlib import Path +from unittest.mock import AsyncMock, patch +import psycopg import pytest from dotenv import load_dotenv +from src.entities import Role, User + +# Load environment variables from .env load_dotenv() +# Set default test values if not in .env +if "API_KEY" not in os.environ: + os.environ["API_KEY"] = "test-key" +if "DB_CONNECTION" not in os.environ: + os.environ["DB_CONNECTION"] = "postgresql://postgres:password@localhost:5432/postgres" +if "SAS_URL" not in os.environ: + os.environ["SAS_URL"] = "https://test.blob.core.windows.net/test" + +# Set development mode for testing +os.environ["ENV_MODE"] = "local" + -@pytest.fixture(scope="session", params=[os.environ["API_KEY"], os.environ["API_JWT"]]) +@pytest.fixture(scope="session", params=[os.environ["API_KEY"]]) def headers(request) -> dict[str, str]: - """Header for authentication with an API key or a JWT for a regular user (not curator or admin).""" + """Header for authentication with an API key.""" return {"access_token": request.param} -@pytest.fixture(scope="session") -def headers_with_jwt(request) -> dict[str, str]: - """Header for authentication with a JWT for a regular user (not curator or admin).""" - return {"access_token": os.environ["API_JWT"]} +@pytest.fixture +def headers_with_jwt(): + """Return headers with a mock JWT token.""" + return {"access_token": "test-admin-token"} + + +@pytest.fixture +def mock_auth(): + """Mock user authentication to always succeed with admin privileges.""" + mock_user = User( + id=1, + created_at="2025-01-12T10:33:39.727968", + email="test.user@undp.org", + name="Test User", + role=Role.ADMIN, + ) + + async def mock_auth_func(): + return mock_user + + with patch("src.authentication.authenticate_user", mock_auth_func), \ + patch("src.dependencies.authenticate_user", mock_auth_func), \ + patch("src.dependencies.require_admin", mock_auth_func): + yield mock_user + + +@pytest.fixture +def mock_user_auth(): + """Mock user authentication for a regular user.""" + def create_mock_user(email: str = "test.user@undp.org", role: Role = Role.USER): + user = User( + id=2, + created_at=datetime.now().isoformat(), + email=email, + name="Test User", + role=role, + ) + + async def mock_auth_func(): + return user + + return user, mock_auth_func + + return create_mock_user + + +@pytest.fixture(autouse=True) +def mock_storage(): + """Mock Azure Storage operations.""" + with patch("src.storage.get_container_client"), \ + patch("src.storage.update_image") as mock_update_image, \ + patch("src.storage.delete_image") as mock_delete_image: + mock_update_image.return_value = None + mock_delete_image.return_value = None + yield + + +@pytest.fixture(autouse=True) +def mock_token_validation(): + """Mock JWT token validation to always succeed.""" + async def mock_decode_token(token: str) -> dict: + if token == "test-admin-token": + return {"unique_name": "test.user@undp.org", "name": "Test User"} + elif token == "test-user-token": + return {"unique_name": "test.regular@undp.org", "name": "Regular User"} + elif token.startswith("mock-jwt-token-"): + parts = token.split("-") + if len(parts) >= 4: + email = parts[2] + return {"unique_name": email, "name": "Test User"} + return {"unique_name": "test.visitor@undp.org", "name": "Test User"} + + with patch("src.authentication.decode_token", mock_decode_token): + yield diff --git a/tests/test_favourites.py b/tests/test_favourites.py new file mode 100644 index 0000000..5c39c40 --- /dev/null +++ b/tests/test_favourites.py @@ -0,0 +1,378 @@ +""" +Tests for user favorites functionality. +""" +from typing import Any, Dict, Generator + +import pytest +from fastapi.testclient import TestClient +from pytest import mark + +from main import app +from src.entities import Goal, Signal, Trend +from src.entities.utils import Horizon, Rating, Signature, Status, Steep + +client = TestClient(app) + +@pytest.fixture +def test_signal() -> Signal: + """Fixture to create a test signal.""" + signal_data = { + "headline": "The cost of corruption", + "description": "Corruption is one of the scourges of modern life. Its costs are staggering.", + "status": Status.NEW.value, + "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", + "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", + "keywords": ["economy", "governance"], + "location": "Global", + "steep_primary": Steep.ECONOMIC.value, + "steep_secondary": [Steep.SOCIAL.value], + "signature_primary": Signature.GOVERNANCE.value, + "signature_secondary": [Signature.POVERTY.value, Signature.RESILIENCE.value], + "sdgs": [Goal.G16.value, Goal.G17.value] + } + return Signal.model_validate(signal_data) + +@pytest.fixture +def test_trends(headers_with_jwt: dict) -> Generator[list[Dict[str, Any]], None, None]: + """Fixture that creates test trends.""" + trends = [] + base_trend_data = { + "headline": "Test Trend", + "description": "Test Description", + "status": Status.NEW.value, + "steep_primary": Steep.ECONOMIC.value, + "steep_secondary": [Steep.SOCIAL.value], + "signature_primary": Signature.GOVERNANCE.value, + "signature_secondary": [Signature.POVERTY.value], + "sdgs": [Goal.G16.value], + "time_horizon": Horizon.SHORT.value, + "impact_rating": Rating.HIGH.value, + "impact_description": "Test impact" + } + + for i in range(2): + trend_data = {**base_trend_data, "headline": f"Test Trend {i}"} + trend = Trend.model_validate(trend_data) + response = client.post("/trends", json=trend.model_dump(), headers=headers_with_jwt) + if response.status_code == 201: + trends.append(response.json()) + else: + pytest.skip("User does not have permission to create trends") + + if not trends: + pytest.skip("No trends were created") + + yield trends + + # Cleanup - don't assert status code since trends might be deleted + for trend_data in trends: + endpoint = f"/trends/{trend_data['id']}" + client.delete(endpoint, headers=headers_with_jwt) + +@pytest.fixture +def created_signal(test_signal: Signal, headers_with_jwt: dict) -> Generator[Dict[str, Any], None, None]: + """Fixture that creates and cleans up a test signal.""" + # Create the signal + response = client.post("/signals", json=test_signal.model_dump(), headers=headers_with_jwt) + assert response.status_code == 201 + signal_data = response.json() + + yield signal_data + + # Cleanup - don't assert status code since signal might already be deleted + endpoint = f"/signals/{signal_data['id']}" + client.delete(endpoint, headers=headers_with_jwt) + +@pytest.fixture +def created_signals(headers_with_jwt: dict) -> Generator[list[Dict[str, Any]], None, None]: + """Fixture that creates multiple test signals with different statuses.""" + signals = [] + statuses = [Status.DRAFT.value, Status.NEW.value, Status.APPROVED.value] + base_signal_data = { + "headline": "Test Signal", + "description": "Test Description", + "url": "https://undp.medium.com/test", + "relevance": "Test relevance", + "keywords": ["test"], + "location": "Global", + "steep_primary": Steep.ECONOMIC.value, + "steep_secondary": [Steep.SOCIAL.value], + "signature_primary": Signature.GOVERNANCE.value, + "signature_secondary": [Signature.POVERTY.value], + "sdgs": [Goal.G16.value] + } + + for i, status in enumerate(statuses): + signal_data = {**base_signal_data, "status": status, "headline": f"Test Signal {i}"} + signal = Signal.model_validate(signal_data) + response = client.post("/signals", json=signal.model_dump(), headers=headers_with_jwt) + assert response.status_code == 201 + signals.append(response.json()) + + yield signals + + # Cleanup + for signal_data in signals: + endpoint = f"/signals/{signal_data['id']}" + response = client.delete(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + +def test_favourite_crud(headers_with_jwt: dict, created_signal: Dict[str, Any]): + """Test basic create, read, and delete operations for favorites.""" + signal_id = created_signal['id'] + + # Add to favorites + endpoint = f"/favourites/{signal_id}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "created" + + # Verify it appears in favorites list + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 1 + favorite_signal = Signal.model_validate(favorites[0]) + assert favorite_signal.id == signal_id + + # Remove from favorites + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "deleted" + + # Verify it's removed from favorites list + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + assert len(response.json()) == 0 + +def test_favourite_idempotency(headers_with_jwt: dict, created_signal: Dict[str, Any]): + """Test that favoriting/unfavoriting operations are idempotent.""" + signal_id = created_signal['id'] + endpoint = f"/favourites/{signal_id}" + + # First favorite operation should create + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "created" + + # Second favorite operation should delete + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "deleted" + + # Third favorite operation should create again + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + response_data = response.json() + assert response_data["status"] == "created" + + # Verify only one favorite exists + response = client.get("/favourites", headers=headers_with_jwt) + assert len(response.json()) == 1 + +@mark.parametrize("invalid_id", [-1, 0, 99999]) +def test_favourite_invalid_signals(headers_with_jwt: dict, invalid_id: int): + """Test favoriting non-existent or invalid signals.""" + endpoint = f"/favourites/{invalid_id}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 404 + response_data = response.json() + assert "not found" in response_data["detail"].lower() + +def test_favourites_ordering(headers_with_jwt: dict): + """Test that favorites are returned in chronological order (most recent first).""" + # Create multiple test signals + signals = [] + base_signal_data = { + "headline": "Test Signal", + "description": "Test Description", + "url": "https://undp.medium.com/test", + "relevance": "Test relevance", + "keywords": ["test"], + "location": "Global", + "status": Status.NEW.value, + "steep_primary": Steep.ECONOMIC.value, + "steep_secondary": [Steep.SOCIAL.value], + "signature_primary": Signature.GOVERNANCE.value, + "signature_secondary": [Signature.POVERTY.value], + "sdgs": [Goal.G16.value] + } + + for i in range(3): + signal_data = {**base_signal_data, "headline": f"Test Signal {i}"} + signal = Signal.model_validate(signal_data) + response = client.post("/signals", json=signal.model_dump(), headers=headers_with_jwt) + assert response.status_code == 201 + signals.append(response.json()) + + # Favorite them in a specific order + for signal_data in signals: + endpoint = f"/favourites/{signal_data['id']}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Verify they're returned in reverse chronological order + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 3 + + # Extract headlines from favorites + favorite_signals = [Signal.model_validate(favorite) for favorite in favorites] + favorite_headlines = [signal.headline for signal in favorite_signals] + + assert favorite_headlines == [ + "Test Signal 2", + "Test Signal 1", + "Test Signal 0" + ] + + # Cleanup + for signal_data in signals: + endpoint = f"/signals/{signal_data['id']}" + response = client.delete(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + +def test_favourites_unauthorized(): + """Test accessing favorites endpoints without authentication.""" + # Try to get favorites without auth + endpoint = "/favourites" + response = client.get(endpoint) + assert response.status_code in {401, 403} # Accept either status code + assert "not" in response.json()["detail"].lower() + + # Try to create favorite without auth + endpoint = "/favourites/1" + response = client.post(endpoint) + assert response.status_code in {401, 403} # Accept either status code + assert "not" in response.json()["detail"].lower() + +def test_favourite_deleted_signal(headers_with_jwt: dict, created_signal: Dict[str, Any]): + """Test behavior when a favorited signal is deleted.""" + signal_id = created_signal['id'] + + # Favorite the signal + endpoint = f"/favourites/{signal_id}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Delete the signal + endpoint = f"/signals/{signal_id}" + response = client.delete(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Verify the favorite is no longer returned + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 0 + +def test_favourite_signal_with_trends(headers_with_jwt: dict, created_signal: Dict[str, Any], test_trends: list[Dict[str, Any]]): + """Test favoriting signals that have connected trends.""" + signal_id = created_signal['id'] + + # Update signal with connected trends + signal_data = created_signal.copy() + signal_data["connected_trends"] = [trend["id"] for trend in test_trends] + endpoint = f"/signals/{signal_id}" + response = client.put(endpoint, json=signal_data, headers=headers_with_jwt) + assert response.status_code == 200 + + # Favorite the signal + endpoint = f"/favourites/{signal_id}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Verify the favorite includes connected trends + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 1 + favorite_signal = Signal.model_validate(favorites[0]) + # Compare sets of trend IDs since order doesn't matter + assert favorite_signal.connected_trends is not None + assert set(favorite_signal.connected_trends) == set(trend["id"] for trend in test_trends) + +def test_favourite_signals_with_different_statuses(headers_with_jwt: dict, created_signals: list[Dict[str, Any]]): + """Test favoriting signals with different statuses.""" + # Favorite all signals + for signal_data in created_signals: + endpoint = f"/favourites/{signal_data['id']}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Verify all signals are in favorites regardless of status + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 3 + + # Verify all statuses are present + favorite_signals = [Signal.model_validate(favorite) for favorite in favorites] + statuses = {signal.status for signal in favorite_signals} + + expected_statuses = {Status.DRAFT, Status.NEW, Status.APPROVED} + assert statuses == expected_statuses + +def test_favourite_signal_updates(headers_with_jwt: dict, created_signal: Dict[str, Any]): + """Test that favorites reflect signal updates.""" + signal_id = created_signal['id'] + + # Favorite the signal + endpoint = f"/favourites/{signal_id}" + response = client.post(endpoint, headers=headers_with_jwt) + assert response.status_code == 200 + + # Update the signal + signal_data = created_signal.copy() + signal_data["headline"] = "Updated Headline" + signal_data["description"] = "Updated Description" + + endpoint = f"/signals/{signal_id}" + response = client.put(endpoint, json=signal_data, headers=headers_with_jwt) + assert response.status_code == 200 + + # Verify the favorite reflects the updates + response = client.get("/favourites", headers=headers_with_jwt) + assert response.status_code == 200 + favorites = response.json() + assert len(favorites) == 1 + favorite_signal = Signal.model_validate(favorites[0]) + assert favorite_signal.headline == "Updated Headline" + assert favorite_signal.description == "Updated Description" + +def test_favourite_status_in_signal_response(headers_with_jwt: dict, created_signal: Dict[str, Any]): + """Test that signal responses include correct favorite status.""" + signal_id = created_signal['id'] + + # Initially signal should not be favorited + response = client.get(f"/signals/{signal_id}", headers=headers_with_jwt) + assert response.status_code == 200 + signal_data = response.json() + assert signal_data["favorite"] is False + + # Add to favorites + response = client.post(f"/favourites/{signal_id}", headers=headers_with_jwt) + assert response.status_code == 200 + assert response.json()["status"] == "created" + + # Signal should now show as favorited + response = client.get(f"/signals/{signal_id}", headers=headers_with_jwt) + assert response.status_code == 200 + signal_data = response.json() + assert signal_data["favorite"] is True + + # Remove from favorites + response = client.post(f"/favourites/{signal_id}", headers=headers_with_jwt) + assert response.status_code == 200 + assert response.json()["status"] == "deleted" + + # Signal should no longer show as favorited + response = client.get(f"/signals/{signal_id}", headers=headers_with_jwt) + assert response.status_code == 200 + signal_data = response.json() + assert signal_data["favorite"] is False From 29adada0e859f87683e2fe209fcdd386d27176e0 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 17 Apr 2025 12:19:56 +0300 Subject: [PATCH 02/31] env updates --- .env.example | 6 +----- .gitignore | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index ebc6759..2f68743 100644 --- a/.env.example +++ b/.env.example @@ -12,8 +12,4 @@ AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" AZURE_OPENAI_API_KEY="" # Testing, only required to run tests, must be a valid token of a regular user -API_JWT="" - -# News API -# https://newsapi.org/account -NEWS_API_KEY="" \ No newline at end of file +API_JWT="" \ No newline at end of file diff --git a/.gitignore b/.gitignore index febc3c4..41d398a 100644 --- a/.gitignore +++ b/.gitignore @@ -140,5 +140,7 @@ cython_debug/ # Manually added for this project .idea/ **/.DS_Store -sql/create_test_user.sql +create_test_user.sql +.env.production +/.prs Taskfile.yml From 42f6086843893614950ce36d127c49a87be7223b Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Tue, 22 Apr 2025 01:13:46 +0300 Subject: [PATCH 03/31] .env updates (#8) * env updates * update example env --- .env.example | 27 ++++++++++++++++++++++++--- .gitignore | 5 ++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/.env.example b/.env.example index ebc6759..db22666 100644 --- a/.env.example +++ b/.env.example @@ -13,7 +13,28 @@ AZURE_OPENAI_API_KEY="" # Testing, only required to run tests, must be a valid token of a regular user API_JWT="" +# Email Configuration +MS_FROM_EMAIL=futureofdevelopment@undp.org +EMAIL_SERVICE_TYPE=ms_graph -# News API -# https://newsapi.org/account -NEWS_API_KEY="" \ No newline at end of file +# SendGrid Configuration (if using SendGrid email service) +SENDGRID_API_KEY= +SENDGRID_FROM_EMAIL= + +# Azure Authentication +TENANT_ID= +CLIENT_ID= + +# API Authentication +API_KEY= +API_JWT= + +# Database Connection +DB_CONNECTION= + +# Azure Storage +SAS_URL= + +# Azure OpenAI Configuration +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= diff --git a/.gitignore b/.gitignore index febc3c4..d42f3dc 100644 --- a/.gitignore +++ b/.gitignore @@ -140,5 +140,8 @@ cython_debug/ # Manually added for this project .idea/ **/.DS_Store -sql/create_test_user.sql +create_test_user.sql +.env.production +/.prs Taskfile.yml +.env.local From 4b2ca90c718f39fe262770fe93793e9b1125ce41 Mon Sep 17 00:00:00 2001 From: andrew-maguire Date: Mon, 28 Apr 2025 15:47:46 +0300 Subject: [PATCH 04/31] initial routes --- docker-compose.yaml | 1 + sql/add_secondary_location.sql | 24 ++++++++++++++++++++++++ src/database/signals.py | 3 +++ src/entities/signal.py | 5 +++++ 4 files changed, 33 insertions(+) create mode 100644 sql/add_secondary_location.sql diff --git a/docker-compose.yaml b/docker-compose.yaml index 9771a63..d4c9307 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -27,6 +27,7 @@ services: - ./sql/create_tables.sql:/docker-entrypoint-initdb.d/1-create_tables.sql - ./sql/import_data.sql:/docker-entrypoint-initdb.d/2-import_data.sql - ./sql/init_test_data.sql:/docker-entrypoint-initdb.d/3-init_test_data.sql + - ./sql/add_secondary_location.sql:/app/sql/add_secondary_location.sql - ./data:/docker-entrypoint-initdb.d/data healthcheck: test: ["CMD-SHELL", "pg_isready -U postgres"] diff --git a/sql/add_secondary_location.sql b/sql/add_secondary_location.sql new file mode 100644 index 0000000..e86d8ef --- /dev/null +++ b/sql/add_secondary_location.sql @@ -0,0 +1,24 @@ +/* +Migration script to add secondary_location column to signals table. +Run this script to update the database schema. +*/ + +-- Add secondary_location column to signals table +ALTER TABLE signals ADD COLUMN IF NOT EXISTS secondary_location TEXT[]; + +-- Update the index to include the new column +DROP INDEX IF EXISTS signals_idx; +CREATE INDEX ON signals ( + status, + created_by, + created_for, + created_unit, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + location, + secondary_location, + score +); \ No newline at end of file diff --git a/src/database/signals.py b/src/database/signals.py index 84b8cb0..e649c7d 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -137,6 +137,7 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: relevance, keywords, location, + secondary_location, score ) VALUES ( @@ -156,6 +157,7 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: %(relevance)s, %(keywords)s, %(location)s, + %(secondary_location)s, %(score)s ) RETURNING @@ -263,6 +265,7 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: relevance = COALESCE(%(relevance)s, relevance), keywords = COALESCE(%(keywords)s, keywords), location = COALESCE(%(location)s, location), + secondary_location = COALESCE(%(secondary_location)s, secondary_location), score = COALESCE(%(score)s, score) WHERE id = %(id)s diff --git a/src/entities/signal.py b/src/entities/signal.py index e68b541..2f679da 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -24,6 +24,10 @@ class Signal(BaseEntity): default=None, description="Region and/or country for which this signal has greatest relevance.", ) + secondary_location: list[str] | None = Field( + default=None, + description="Additional regions and/or countries for which this signal has relevance.", + ) score: utils.Score | None = Field(default=None) connected_trends: list[int] | None = Field( default=None, @@ -42,6 +46,7 @@ class Signal(BaseEntity): "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", "keywords": ["economy", "governance"], "location": "Global", + "secondary_location": ["Africa", "Asia"], "favorite": False, } } From 501a20f7fc8cea7b3385441f8b3c540774aaa436 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Mon, 28 Apr 2025 19:57:57 +0300 Subject: [PATCH 05/31] update methods --- src/database/signals.py | 26 ++++++++++++++------------ src/entities/signal.py | 9 ++++++--- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/database/signals.py b/src/database/signals.py index e649c7d..30cf780 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -34,7 +34,7 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP page : SignalPage Paginated search results for signals. """ - query = """ + query = f""" SELECT *, COUNT(*) OVER() AS total_count FROM @@ -84,17 +84,13 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) ORDER BY - {} {} + {filters.order_by} {filters.direction} OFFSET %(offset)s LIMIT %(limit)s ; """ - query = sql.SQL(query).format( - sql.Identifier(filters.order_by), - sql.SQL(filters.direction), - ) await cursor.execute(query, filters.model_dump()) rows = await cursor.fetchall() # extract total count of rows matching the WHERE clause @@ -112,7 +108,8 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: cursor : AsyncCursor An async database cursor. signal : Signal - A signal object to insert. + A signal object to insert. The following fields are supported: + - secondary_location: list[str] | None Returns ------- @@ -223,7 +220,8 @@ async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: ; """ await cursor.execute(query, (uid,)) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None return Signal(**row) @@ -238,7 +236,8 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: cursor : AsyncCursor An async database cursor. signal : Signal - A signal object to update. + A signal object to update. The following fields are supported: + - secondary_location: list[str] | None Returns ------- @@ -274,7 +273,8 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: ; """ await cursor.execute(query, signal.model_dump()) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None signal_id = row["id"] @@ -311,7 +311,8 @@ async def delete_signal(cursor: AsyncCursor, uid: int) -> Signal | None: """ query = "DELETE FROM signals WHERE id = %s RETURNING *;" await cursor.execute(query, (uid,)) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None signal = Signal(**row) if signal.attachment is not None: @@ -361,7 +362,8 @@ async def read_user_signals( ; """ await cursor.execute(query, (user_email, status)) - return [Signal(**row) async for row in cursor] + rows = await cursor.fetchall() + return [Signal(**row) for row in rows] async def is_signal_favorited(cursor: AsyncCursor, user_email: str, signal_id: int) -> bool: diff --git a/src/entities/signal.py b/src/entities/signal.py index 2f679da..5833567 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -40,14 +40,17 @@ class Signal(BaseEntity): model_config = ConfigDict( json_schema_extra={ - "example": BaseEntity.model_config["json_schema_extra"]["example"] - | { + "example": { + "id": 1, + "created_unit": "HQ", "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", "keywords": ["economy", "governance"], "location": "Global", "secondary_location": ["Africa", "Asia"], - "favorite": False, + "score": None, + "connected_trends": [101, 102], + "favorite": False } } ) From 099bb7528ba75a7d2dc9e737d17eddcdae68c513 Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Mon, 28 Apr 2025 17:58:46 +0100 Subject: [PATCH 06/31] Add secondary country to signals data (#10) * initial routes * update methods --- docker-compose.yaml | 1 + sql/add_secondary_location.sql | 24 ++++++++++++++++++++++++ src/database/signals.py | 29 +++++++++++++++++------------ src/entities/signal.py | 14 +++++++++++--- 4 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 sql/add_secondary_location.sql diff --git a/docker-compose.yaml b/docker-compose.yaml index 9771a63..d4c9307 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -27,6 +27,7 @@ services: - ./sql/create_tables.sql:/docker-entrypoint-initdb.d/1-create_tables.sql - ./sql/import_data.sql:/docker-entrypoint-initdb.d/2-import_data.sql - ./sql/init_test_data.sql:/docker-entrypoint-initdb.d/3-init_test_data.sql + - ./sql/add_secondary_location.sql:/app/sql/add_secondary_location.sql - ./data:/docker-entrypoint-initdb.d/data healthcheck: test: ["CMD-SHELL", "pg_isready -U postgres"] diff --git a/sql/add_secondary_location.sql b/sql/add_secondary_location.sql new file mode 100644 index 0000000..e86d8ef --- /dev/null +++ b/sql/add_secondary_location.sql @@ -0,0 +1,24 @@ +/* +Migration script to add secondary_location column to signals table. +Run this script to update the database schema. +*/ + +-- Add secondary_location column to signals table +ALTER TABLE signals ADD COLUMN IF NOT EXISTS secondary_location TEXT[]; + +-- Update the index to include the new column +DROP INDEX IF EXISTS signals_idx; +CREATE INDEX ON signals ( + status, + created_by, + created_for, + created_unit, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + location, + secondary_location, + score +); \ No newline at end of file diff --git a/src/database/signals.py b/src/database/signals.py index 84b8cb0..30cf780 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -34,7 +34,7 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP page : SignalPage Paginated search results for signals. """ - query = """ + query = f""" SELECT *, COUNT(*) OVER() AS total_count FROM @@ -84,17 +84,13 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) ORDER BY - {} {} + {filters.order_by} {filters.direction} OFFSET %(offset)s LIMIT %(limit)s ; """ - query = sql.SQL(query).format( - sql.Identifier(filters.order_by), - sql.SQL(filters.direction), - ) await cursor.execute(query, filters.model_dump()) rows = await cursor.fetchall() # extract total count of rows matching the WHERE clause @@ -112,7 +108,8 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: cursor : AsyncCursor An async database cursor. signal : Signal - A signal object to insert. + A signal object to insert. The following fields are supported: + - secondary_location: list[str] | None Returns ------- @@ -137,6 +134,7 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: relevance, keywords, location, + secondary_location, score ) VALUES ( @@ -156,6 +154,7 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: %(relevance)s, %(keywords)s, %(location)s, + %(secondary_location)s, %(score)s ) RETURNING @@ -221,7 +220,8 @@ async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: ; """ await cursor.execute(query, (uid,)) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None return Signal(**row) @@ -236,7 +236,8 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: cursor : AsyncCursor An async database cursor. signal : Signal - A signal object to update. + A signal object to update. The following fields are supported: + - secondary_location: list[str] | None Returns ------- @@ -263,6 +264,7 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: relevance = COALESCE(%(relevance)s, relevance), keywords = COALESCE(%(keywords)s, keywords), location = COALESCE(%(location)s, location), + secondary_location = COALESCE(%(secondary_location)s, secondary_location), score = COALESCE(%(score)s, score) WHERE id = %(id)s @@ -271,7 +273,8 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: ; """ await cursor.execute(query, signal.model_dump()) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None signal_id = row["id"] @@ -308,7 +311,8 @@ async def delete_signal(cursor: AsyncCursor, uid: int) -> Signal | None: """ query = "DELETE FROM signals WHERE id = %s RETURNING *;" await cursor.execute(query, (uid,)) - if (row := await cursor.fetchone()) is None: + row = await cursor.fetchone() + if row is None: return None signal = Signal(**row) if signal.attachment is not None: @@ -358,7 +362,8 @@ async def read_user_signals( ; """ await cursor.execute(query, (user_email, status)) - return [Signal(**row) async for row in cursor] + rows = await cursor.fetchall() + return [Signal(**row) for row in rows] async def is_signal_favorited(cursor: AsyncCursor, user_email: str, signal_id: int) -> bool: diff --git a/src/entities/signal.py b/src/entities/signal.py index e68b541..5833567 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -24,6 +24,10 @@ class Signal(BaseEntity): default=None, description="Region and/or country for which this signal has greatest relevance.", ) + secondary_location: list[str] | None = Field( + default=None, + description="Additional regions and/or countries for which this signal has relevance.", + ) score: utils.Score | None = Field(default=None) connected_trends: list[int] | None = Field( default=None, @@ -36,13 +40,17 @@ class Signal(BaseEntity): model_config = ConfigDict( json_schema_extra={ - "example": BaseEntity.model_config["json_schema_extra"]["example"] - | { + "example": { + "id": 1, + "created_unit": "HQ", "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", "keywords": ["economy", "governance"], "location": "Global", - "favorite": False, + "secondary_location": ["Africa", "Asia"], + "score": None, + "connected_trends": [101, 102], + "favorite": False } } ) From 1966fb395d3756f6a4dc2aabb42e2d2dde5c4f02 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Mon, 28 Apr 2025 22:58:19 +0300 Subject: [PATCH 07/31] Fix: Update Python setup action to v4 and fix cache configuration --- .github/workflows/azure-webapps-python.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/azure-webapps-python.yml b/.github/workflows/azure-webapps-python.yml index 1a0b343..3a6d283 100644 --- a/.github/workflows/azure-webapps-python.yml +++ b/.github/workflows/azure-webapps-python.yml @@ -24,10 +24,11 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python version - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - cache: 'pip' + cache: pip + cache-dependency-path: 'requirements.txt' - name: Create and start virtual environment run: | From 66dc2487183269a0471bf40cc46264a0aa85971d Mon Sep 17 00:00:00 2001 From: happy-devs Date: Mon, 28 Apr 2025 23:39:05 +0300 Subject: [PATCH 08/31] Update signal.py --- src/entities/signal.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/entities/signal.py b/src/entities/signal.py index 5833567..43bad07 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -2,7 +2,7 @@ Entity (model) definitions for signal objects. """ -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, field_validator, model_validator from . import utils from .base import BaseEntity @@ -38,6 +38,15 @@ class Signal(BaseEntity): description="Whether the current user has favorited this signal.", ) + @model_validator(mode='before') + @classmethod + def convert_secondary_location(cls, data): + """Convert string secondary_location to a list before validation.""" + if isinstance(data, dict) and 'secondary_location' in data: + if isinstance(data['secondary_location'], str): + data['secondary_location'] = [data['secondary_location']] + return data + model_config = ConfigDict( json_schema_extra={ "example": { From a0beb023131b09994d5e29d2ed706ecaca818199 Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Wed, 30 Apr 2025 12:36:09 +0100 Subject: [PATCH 09/31] Collaborative signal editing (user groups etc) (#9) * full routers / entities / test setup * update methods * Delete signal_collaborators.sql --------- Co-authored-by: amaguire-undp --- sql/user_groups.sql | 16 + src/database/__init__.py | 1 + src/database/signals.py | 196 +++++++++++ src/database/user_groups.py | 361 ++++++++++++++++++++ src/entities/__init__.py | 1 + src/entities/parameters.py | 6 + src/entities/signal.py | 18 +- src/entities/user_groups.py | 47 +++ src/routers/__init__.py | 2 + src/routers/signals.py | 128 ++++++- src/routers/user_groups.py | 224 ++++++++++++ tests/test_collaborator_integration.py | 270 +++++++++++++++ tests/test_database_signal_collaborators.py | 250 ++++++++++++++ tests/test_database_user_groups.py | 244 +++++++++++++ tests/test_user_groups.py | 305 +++++++++++++++++ 15 files changed, 2064 insertions(+), 5 deletions(-) create mode 100644 sql/user_groups.sql create mode 100644 src/database/user_groups.py create mode 100644 src/entities/user_groups.py create mode 100644 src/routers/user_groups.py create mode 100644 tests/test_collaborator_integration.py create mode 100644 tests/test_database_signal_collaborators.py create mode 100644 tests/test_database_user_groups.py create mode 100644 tests/test_user_groups.py diff --git a/sql/user_groups.sql b/sql/user_groups.sql new file mode 100644 index 0000000..ccde56a --- /dev/null +++ b/sql/user_groups.sql @@ -0,0 +1,16 @@ +-- Create a single user groups table with direct arrays for signals and users +CREATE TABLE IF NOT EXISTS user_groups ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL, + signal_ids INTEGER[] NOT NULL DEFAULT '{}', + user_ids INTEGER[] NOT NULL DEFAULT '{}', + -- Store collaborator relationships as JSON + -- Format: {"signal_id": [user_id1, user_id2], ...} + collaborator_map JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() +); + +-- Create GIN indexes for faster array lookups +CREATE INDEX IF NOT EXISTS idx_user_groups_signal_ids ON user_groups USING GIN (signal_ids); +CREATE INDEX IF NOT EXISTS idx_user_groups_user_ids ON user_groups USING GIN (user_ids); +CREATE INDEX IF NOT EXISTS idx_user_groups_collaborator_map ON user_groups USING GIN (collaborator_map); \ No newline at end of file diff --git a/src/database/__init__.py b/src/database/__init__.py index 5095709..bc9b053 100644 --- a/src/database/__init__.py +++ b/src/database/__init__.py @@ -8,3 +8,4 @@ from .signals import * from .trends import * from .users import * +from .user_groups import * diff --git a/src/database/signals.py b/src/database/signals.py index 30cf780..a55bacb 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -15,6 +15,10 @@ "delete_signal", "read_user_signals", "is_signal_favorited", + "add_collaborator", + "remove_collaborator", + "get_signal_collaborators", + "can_user_edit_signal", ] @@ -392,3 +396,195 @@ async def is_signal_favorited(cursor: AsyncCursor, user_email: str, signal_id: i """ await cursor.execute(query, (user_email, signal_id)) return await cursor.fetchone() is not None + + + +async def add_collaborator(cursor: AsyncCursor, signal_id: int, collaborator: str) -> bool: + """ + Add a collaborator to a signal. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_id : int + The ID of the signal. + collaborator : str + The email of the user or "group:{id}" to add as a collaborator. + + Returns + ------- + bool + True if the collaborator was added, False otherwise. + """ + # Check if the signal exists + await cursor.execute("SELECT 1 FROM signals WHERE id = %s;", (signal_id,)) + if await cursor.fetchone() is None: + return False + + # Determine if this is a group or user + if collaborator.startswith("group:"): + group_id = int(collaborator.split(":")[1]) + query = """ + INSERT INTO signal_collaborator_groups (signal_id, group_id) + VALUES (%s, %s) + ON CONFLICT (signal_id, group_id) DO NOTHING + RETURNING signal_id + ; + """ + await cursor.execute(query, (signal_id, group_id)) + else: + # Check if the user exists + await cursor.execute("SELECT 1 FROM users WHERE email = %s;", (collaborator,)) + if await cursor.fetchone() is None: + return False + + query = """ + INSERT INTO signal_collaborators (signal_id, user_email) + VALUES (%s, %s) + ON CONFLICT (signal_id, user_email) DO NOTHING + RETURNING signal_id + ; + """ + await cursor.execute(query, (signal_id, collaborator)) + + return await cursor.fetchone() is not None + + +async def remove_collaborator(cursor: AsyncCursor, signal_id: int, collaborator: str) -> bool: + """ + Remove a collaborator from a signal. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_id : int + The ID of the signal. + collaborator : str + The email of the user or "group:{id}" to remove as a collaborator. + + Returns + ------- + bool + True if the collaborator was removed, False otherwise. + """ + # Determine if this is a group or user + if collaborator.startswith("group:"): + group_id = int(collaborator.split(":")[1]) + query = """ + DELETE FROM signal_collaborator_groups + WHERE signal_id = %s AND group_id = %s + RETURNING signal_id + ; + """ + await cursor.execute(query, (signal_id, group_id)) + else: + query = """ + DELETE FROM signal_collaborators + WHERE signal_id = %s AND user_email = %s + RETURNING signal_id + ; + """ + await cursor.execute(query, (signal_id, collaborator)) + + return await cursor.fetchone() is not None + + +async def get_signal_collaborators(cursor: AsyncCursor, signal_id: int) -> list[str]: + """ + Get all collaborators for a signal. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_id : int + The ID of the signal. + + Returns + ------- + list[str] + A list of user emails and group IDs (as "group:{id}"). + """ + # Get individual collaborators + query1 = """ + SELECT user_email + FROM signal_collaborators + WHERE signal_id = %s + ; + """ + await cursor.execute(query1, (signal_id,)) + user_emails = [row[0] async for row in cursor] + + # Get group collaborators + query2 = """ + SELECT group_id + FROM signal_collaborator_groups + WHERE signal_id = %s + ; + """ + await cursor.execute(query2, (signal_id,)) + group_ids = [f"group:{row[0]}" async for row in cursor] + + return user_emails + group_ids + + +async def can_user_edit_signal(cursor: AsyncCursor, signal_id: int, user_email: str) -> bool: + """ + Check if a user can edit a signal. + + A user can edit a signal if: + 1. They created the signal + 2. They are in the collaborators list + 3. They are part of a group in the collaborators list + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_id : int + The ID of the signal. + user_email : str + The email of the user. + + Returns + ------- + bool + True if the user can edit the signal, False otherwise. + """ + # Check if the user created the signal + query1 = """ + SELECT 1 + FROM signals + WHERE id = %s AND created_by = %s + ; + """ + await cursor.execute(query1, (signal_id, user_email)) + if await cursor.fetchone() is not None: + return True + + # Check if the user is in the collaborators list + query2 = """ + SELECT 1 + FROM signal_collaborators + WHERE signal_id = %s AND user_email = %s + ; + """ + await cursor.execute(query2, (signal_id, user_email)) + if await cursor.fetchone() is not None: + return True + + # Check if the user is part of a group in the collaborators list + query3 = """ + SELECT 1 + FROM signal_collaborator_groups scg + JOIN user_group_members ugm ON scg.group_id = ugm.group_id + WHERE scg.signal_id = %s AND ugm.user_email = %s + ; + """ + await cursor.execute(query3, (signal_id, user_email)) + if await cursor.fetchone() is not None: + return True + + return False diff --git a/src/database/user_groups.py b/src/database/user_groups.py new file mode 100644 index 0000000..fa61837 --- /dev/null +++ b/src/database/user_groups.py @@ -0,0 +1,361 @@ +""" +CRUD operations for user group entities. +""" + +from psycopg import AsyncCursor + +from ..entities import UserGroup + +__all__ = [ + "create_user_group", + "read_user_group", + "update_user_group", + "delete_user_group", + "list_user_groups", + "add_user_to_group", + "remove_user_from_group", + "get_user_groups", + "get_group_users", +] + + +async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: + """ + Create a new user group in the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group : UserGroup + The user group to create. + + Returns + ------- + int + The ID of the created user group. + """ + query = """ + INSERT INTO user_groups ( + name, + signal_ids, + user_ids, + collaborator_map + ) + VALUES ( + %(name)s, + %(signal_ids)s, + %(user_ids)s, + %(collaborator_map)s + ) + RETURNING id + ; + """ + await cursor.execute(query, group.model_dump(exclude={"id"})) + row = await cursor.fetchone() + group_id = row[0] + + return group_id + + +async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | None: + """ + Read a user group from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the user group to read. + + Returns + ------- + UserGroup | None + The user group if found, otherwise None. + """ + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + collaborator_map + FROM + user_groups + WHERE + id = %s + ; + """ + await cursor.execute(query, (group_id,)) + if (row := await cursor.fetchone()) is None: + return None + + # Convert row to dict + data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + + return UserGroup(**data) + + +async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None: + """ + Update a user group in the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group : UserGroup + The user group to update. + + Returns + ------- + int | None + The ID of the updated user group if successful, otherwise None. + """ + query = """ + UPDATE user_groups + SET + name = %(name)s, + signal_ids = %(signal_ids)s, + user_ids = %(user_ids)s, + collaborator_map = %(collaborator_map)s + WHERE id = %(id)s + RETURNING id + ; + """ + await cursor.execute(query, group.model_dump()) + if (row := await cursor.fetchone()) is None: + return None + + return row[0] + + +async def delete_user_group(cursor: AsyncCursor, group_id: int) -> bool: + """ + Delete a user group from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the user group to delete. + + Returns + ------- + bool + True if the group was deleted, False otherwise. + """ + query = "DELETE FROM user_groups WHERE id = %s RETURNING id;" + await cursor.execute(query, (group_id,)) + + return await cursor.fetchone() is not None + + +async def list_user_groups(cursor: AsyncCursor) -> list[UserGroup]: + """ + List all user groups from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[UserGroup] + A list of all user groups. + """ + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + collaborator_map + FROM + user_groups + ORDER BY + name + ; + """ + await cursor.execute(query) + result = [] + + async for row in cursor: + # Convert row to dict + data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + result.append(UserGroup(**data)) + + return result + + +async def add_user_to_group(cursor: AsyncCursor, group_id: int, user_id: int) -> bool: + """ + Add a user to a group. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the group. + user_id : int + The ID of the user to add. + + Returns + ------- + bool + True if the user was added, False otherwise. + """ + # Check if the user exists + await cursor.execute("SELECT 1 FROM users WHERE id = %s;", (user_id,)) + if await cursor.fetchone() is None: + return False + + # Check if the group exists + await cursor.execute("SELECT user_ids FROM user_groups WHERE id = %s;", (group_id,)) + row = await cursor.fetchone() + if row is None: + return False + + # Add the user to the group (if not already a member) + user_ids = row[0] if row[0] is not None else [] + if user_id not in user_ids: + user_ids.append(user_id) + + query = """ + UPDATE user_groups + SET user_ids = %s + WHERE id = %s + RETURNING id + ; + """ + await cursor.execute(query, (user_ids, group_id)) + return await cursor.fetchone() is not None + + return True + + +async def remove_user_from_group(cursor: AsyncCursor, group_id: int, user_id: int) -> bool: + """ + Remove a user from a group. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the group. + user_id : int + The ID of the user to remove. + + Returns + ------- + bool + True if the user was removed, False otherwise. + """ + # Get current user_ids + await cursor.execute("SELECT user_ids, collaborator_map FROM user_groups WHERE id = %s;", (group_id,)) + row = await cursor.fetchone() + if row is None: + return False + + user_ids = row[0] if row[0] is not None else [] + collaborator_map = row[1] if row[1] is not None else {} + + # Remove user from user_ids + if user_id in user_ids: + user_ids.remove(user_id) + + # Remove user from collaborator_map + for signal_id, users in list(collaborator_map.items()): + if user_id in users: + users.remove(user_id) + if not users: # If empty, remove signal from map + del collaborator_map[signal_id] + + query = """ + UPDATE user_groups + SET + user_ids = %s, + collaborator_map = %s + WHERE id = %s + RETURNING id + ; + """ + await cursor.execute(query, (user_ids, collaborator_map, group_id)) + return await cursor.fetchone() is not None + + return False + + +async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: + """ + Get all groups that a user is a member of. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[UserGroup] + A list of user groups. + """ + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + collaborator_map + FROM + user_groups + WHERE + %s = ANY(user_ids) + ORDER BY + name + ; + """ + await cursor.execute(query, (user_id,)) + result = [] + + async for row in cursor: + # Convert row to dict + data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + result.append(UserGroup(**data)) + + return result + + +async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: + """ + Get all users in a group. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the group. + + Returns + ------- + list[int] + A list of user IDs. + """ + query = """ + SELECT user_ids + FROM user_groups + WHERE id = %s + ; + """ + await cursor.execute(query, (group_id,)) + row = await cursor.fetchone() + + return row[0] if row and row[0] else [] diff --git a/src/entities/__init__.py b/src/entities/__init__.py index 7c421d8..7591f73 100644 --- a/src/entities/__init__.py +++ b/src/entities/__init__.py @@ -11,6 +11,7 @@ from .signal import * from .trend import * from .user import * +from .user_groups import * from .utils import * diff --git a/src/entities/parameters.py b/src/entities/parameters.py index d019efb..50ccc32 100644 --- a/src/entities/parameters.py +++ b/src/entities/parameters.py @@ -84,3 +84,9 @@ class UserFilters(Pagination): roles: list[Role] = Field(default=(Role.VISITOR, Role.CURATOR, Role.ADMIN)) query: str | None = Field(default=None) + +class UserGroupFilters(Pagination): + """Filter parameters for searching user groups.""" + + query: str | None = Field(default=None) + \ No newline at end of file diff --git a/src/entities/signal.py b/src/entities/signal.py index 43bad07..86360c3 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -2,6 +2,7 @@ Entity (model) definitions for signal objects. """ +from typing import List, Dict from pydantic import ConfigDict, Field, field_validator, model_validator from . import utils @@ -37,6 +38,18 @@ class Signal(BaseEntity): default=False, description="Whether the current user has favorited this signal.", ) + is_draft: bool = Field( + default=True, + description="Whether the signal is in draft state or published.", + ) + group_ids: List[int] | None = Field( + default=None, + description="List of user group IDs associated with this signal.", + ) + collaborators: List[int] | None = Field( + default=None, + description="List of user IDs who can collaborate on this signal.", + ) @model_validator(mode='before') @classmethod @@ -56,10 +69,13 @@ def convert_secondary_location(cls, data): "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", "keywords": ["economy", "governance"], "location": "Global", + "favorite": False, + "is_draft": True, + "group_ids": [1, 2], + "collaborators": [1, 2, 3], "secondary_location": ["Africa", "Asia"], "score": None, "connected_trends": [101, 102], - "favorite": False } } ) diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py new file mode 100644 index 0000000..20d8f22 --- /dev/null +++ b/src/entities/user_groups.py @@ -0,0 +1,47 @@ +""" +Entity (model) definitions for user group objects. +""" + +from typing import Dict, List +from pydantic import ConfigDict, Field + +from .base import BaseEntity + +__all__ = ["UserGroup"] + + +class UserGroup(BaseEntity): + """The user group entity model used in the database and API endpoints.""" + + name: str = Field( + description="Name of the user group.", + min_length=3, + ) + signal_ids: List[int] = Field( + default_factory=list, + description="List of signal IDs associated with this group." + ) + user_ids: List[int] = Field( + default_factory=list, + description="List of user IDs who are members of this group." + ) + collaborator_map: Dict[str, List[int]] = Field( + default_factory=dict, + description="Map of signal IDs to lists of user IDs that can collaborate on that signal." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + } + } + } + ) diff --git a/src/routers/__init__.py b/src/routers/__init__.py index ca600c9..78d60dc 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -7,6 +7,7 @@ from .signals import router as signal_router from .trends import router as trend_router from .users import router as user_router +from .user_groups import router as user_group_router ALL = [ choice_router, @@ -14,4 +15,5 @@ signal_router, trend_router, user_router, + user_group_router, ] diff --git a/src/routers/signals.py b/src/routers/signals.py index b078082..4bee7e5 100644 --- a/src/routers/signals.py +++ b/src/routers/signals.py @@ -3,17 +3,17 @@ """ import logging -from typing import Annotated +from typing import Annotated, List import pandas as pd -from fastapi import APIRouter, Depends, Path, Query +from fastapi import APIRouter, Depends, Path, Query, HTTPException, status from psycopg import AsyncCursor from .. import database as db from .. import exceptions, genai, utils from ..authentication import authenticate_user -from ..dependencies import require_creator, require_curator, require_user -from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User +from ..dependencies import require_admin, require_creator, require_curator, require_user +from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User, UserGroup logger = logging.getLogger(__name__) @@ -173,3 +173,123 @@ async def delete_signal( if (signal := await db.delete_signal(cursor, uid)) is None: raise exceptions.not_found return signal + + +@router.get("/{uid}/collaborators", response_model=List[int]) +async def get_signal_collaborators( + uid: Annotated[int, Path(description="The ID of the signal")], + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Get all user IDs who can collaborate on a signal. + + Only signal creators, admins, curators, and current collaborators can access this endpoint. + """ + # Check if signal exists + if await db.read_signal(cursor, uid) is None: + raise exceptions.not_found + + # Check if user is authorized to view collaborators + if not user.is_admin and not user.is_staff and not await db.can_user_edit_signal(cursor, uid, user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to view collaborators for this signal", + ) + + collaborators = await db.get_signal_collaborators(cursor, uid) + return collaborators + + +@router.post("/{uid}/collaborators/{user_id}", response_model=bool) +async def add_signal_collaborator( + uid: Annotated[int, Path(description="The ID of the signal")], + user_id: Annotated[int, Path(description="The ID of the user to add as collaborator")], + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Add a collaborator to a signal. + + Only signal creators, admins, and curators can add collaborators. + """ + # Check if signal exists + signal = await db.read_signal(cursor, uid) + if signal is None: + raise exceptions.not_found + + # Check if user is authorized to add collaborators + if not user.is_admin and not user.is_staff and signal.created_by != user.email: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to add collaborators to this signal", + ) + + # Add collaborator + if not await db.add_collaborator(cursor, uid, user_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid collaborator or signal", + ) + + return True + + +@router.delete("/{uid}/collaborators/{user_id}", response_model=bool) +async def remove_signal_collaborator( + uid: Annotated[int, Path(description="The ID of the signal")], + user_id: Annotated[int, Path(description="The ID of the user to remove as collaborator")], + user: User = Depends(require_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Remove a collaborator from a signal. + + Only signal creators, admins, and curators can remove collaborators. + """ + # Check if signal exists + signal = await db.read_signal(cursor, uid) + if signal is None: + raise exceptions.not_found + + # Check if user is authorized to remove collaborators + if not user.is_admin and not user.is_staff and signal.created_by != user.email: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to remove collaborators from this signal", + ) + + # Remove collaborator + if not await db.remove_collaborator(cursor, uid, user_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Collaborator not found or signal does not exist", + ) + + return True + + +@router.get("/{uid}/can-edit", response_model=bool) +async def can_user_edit_signal( + uid: Annotated[int, Path(description="The ID of the signal")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Check if the current user can edit a signal. + + A user can edit a signal if: + 1. They created the signal + 2. They are an admin or curator + 3. They are in the collaborators list + 4. They are part of a group that can collaborate on this signal + """ + # Admins and curators can edit any signal + if user.is_admin or user.is_staff: + return True + + # Check if signal exists + if await db.read_signal(cursor, uid) is None: + raise exceptions.not_found + + return await db.can_user_edit_signal(cursor, uid, user.id) diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py new file mode 100644 index 0000000..7a2e00e --- /dev/null +++ b/src/routers/user_groups.py @@ -0,0 +1,224 @@ +""" +A router for managing user groups. +""" + +import logging +from typing import Annotated, List + +from fastapi import APIRouter, Depends, Path, Body +from psycopg import AsyncCursor + +from .. import database as db +from .. import exceptions +from ..dependencies import require_admin +from ..entities import UserGroup + +router = APIRouter(prefix="/user-groups", tags=["user groups"]) + + +@router.get("", response_model=List[UserGroup], dependencies=[Depends(require_admin)]) +async def list_user_groups( + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """List all user groups.""" + groups = await db.list_user_groups(cursor) + return groups + + +@router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) +async def create_user_group( + group: UserGroup, + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Create a new user group.""" + group_id = await db.create_user_group(cursor, group) + return await db.read_user_group(cursor, group_id) + + +@router.get("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +async def read_user_group( + group_id: Annotated[int, Path(description="The ID of the user group to retrieve")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Get a user group by ID.""" + if (group := await db.read_user_group(cursor, group_id)) is None: + raise exceptions.not_found + return group + + +@router.put("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +async def update_user_group( + group_id: Annotated[int, Path(description="The ID of the user group to update")], + group: UserGroup, + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Update a user group.""" + if group_id != group.id: + raise exceptions.id_mismatch + if (updated_id := await db.update_user_group(cursor, group)) is None: + raise exceptions.not_found + return await db.read_user_group(cursor, updated_id) + + +@router.delete("/{group_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def delete_user_group( + group_id: Annotated[int, Path(description="The ID of the user group to delete")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Delete a user group.""" + if not await db.delete_user_group(cursor, group_id): + raise exceptions.not_found + return True + + +@router.post("/{group_id}/users/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def add_user_to_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + user_id: Annotated[int, Path(description="The ID of the user to add")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Add a user to a group.""" + if not await db.add_user_to_group(cursor, group_id, user_id): + raise exceptions.not_found + return True + + +@router.delete("/{group_id}/users/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def remove_user_from_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + user_id: Annotated[int, Path(description="The ID of the user to remove")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Remove a user from a group.""" + if not await db.remove_user_from_group(cursor, group_id, user_id): + raise exceptions.not_found + return True + + +@router.post("/{group_id}/signals/{signal_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def add_signal_to_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + signal_id: Annotated[int, Path(description="The ID of the signal to add")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Add a signal to a group.""" + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + raise exceptions.not_found + + # Check if signal exists + if await db.read_signal(cursor, signal_id) is None: + raise exceptions.not_found + + # Add signal to group + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + signal_ids.append(signal_id) + group.signal_ids = signal_ids + + if await db.update_user_group(cursor, group) is None: + raise exceptions.not_found + + return True + + +@router.delete("/{group_id}/signals/{signal_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def remove_signal_from_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + signal_id: Annotated[int, Path(description="The ID of the signal to remove")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Remove a signal from a group.""" + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + raise exceptions.not_found + + # Remove signal from group + signal_ids = group.signal_ids or [] + if signal_id in signal_ids: + signal_ids.remove(signal_id) + group.signal_ids = signal_ids + + # Also remove collaborators for this signal + collaborator_map = group.collaborator_map or {} + signal_key = str(signal_id) + if signal_key in collaborator_map: + del collaborator_map[signal_key] + group.collaborator_map = collaborator_map + + if await db.update_user_group(cursor, group) is None: + raise exceptions.not_found + + return True + + +@router.post("/{group_id}/signals/{signal_id}/collaborators/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def add_collaborator_to_signal_in_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + signal_id: Annotated[int, Path(description="The ID of the signal")], + user_id: Annotated[int, Path(description="The ID of the user to add as collaborator")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Add a user as a collaborator for a specific signal in a group.""" + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + raise exceptions.not_found + + # Check if signal is in the group + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + raise exceptions.not_found + + # Check if user is in the group + user_ids = group.user_ids or [] + if user_id not in user_ids: + raise exceptions.not_found + + # Add collaborator + collaborator_map = group.collaborator_map or {} + signal_key = str(signal_id) + if signal_key not in collaborator_map: + collaborator_map[signal_key] = [] + + if user_id not in collaborator_map[signal_key]: + collaborator_map[signal_key].append(user_id) + group.collaborator_map = collaborator_map + + if await db.update_user_group(cursor, group) is None: + raise exceptions.not_found + + return True + + +@router.delete("/{group_id}/signals/{signal_id}/collaborators/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +async def remove_collaborator_from_signal_in_group( + group_id: Annotated[int, Path(description="The ID of the user group")], + signal_id: Annotated[int, Path(description="The ID of the signal")], + user_id: Annotated[int, Path(description="The ID of the user to remove as collaborator")], + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Remove a user as a collaborator for a specific signal in a group.""" + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + raise exceptions.not_found + + # Check if this collaborator assignment exists + collaborator_map = group.collaborator_map or {} + signal_key = str(signal_id) + if signal_key in collaborator_map and user_id in collaborator_map[signal_key]: + collaborator_map[signal_key].remove(user_id) + + # If no collaborators left for this signal, remove the entry + if not collaborator_map[signal_key]: + del collaborator_map[signal_key] + + group.collaborator_map = collaborator_map + + if await db.update_user_group(cursor, group) is None: + raise exceptions.not_found + + return True \ No newline at end of file diff --git a/tests/test_collaborator_integration.py b/tests/test_collaborator_integration.py new file mode 100644 index 0000000..8d1a6af --- /dev/null +++ b/tests/test_collaborator_integration.py @@ -0,0 +1,270 @@ +""" +Integration tests for signal collaboration functionality. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest_asyncio +from fastapi import FastAPI +from httpx import AsyncClient + +from src.entities import UserGroup, User, Role, Signal +from src.app import create_application + + +@pytest.fixture +def app(): + """Create a test application.""" + return create_application(configure_oauth=False, debug=True) + + +@pytest_asyncio.fixture +async def client(app: FastAPI): + """Create a test client.""" + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + + +@pytest.fixture +def mock_signal_owner(): + """Create a mock owner user.""" + return User( + id=1, + email="owner@undp.org", + role=Role.USER, + name="Signal Owner", + unit="BPPS", + ) + + +@pytest.fixture +def mock_user_group(): + """Create a mock user group.""" + return UserGroup( + id=1, + name="Test Group", + users=["member1@undp.org", "member2@undp.org"], + ) + + +@pytest.fixture +def mock_draft_signal(): + """Create a mock draft signal.""" + return Signal( + id=1, + headline="Draft Signal", + description="This is a draft signal", + created_by="owner@undp.org", + is_draft=True, + collaborators=["collaborator@undp.org", "group:1"], + ) + + +@pytest.mark.asyncio +@patch("src.database.user_groups.create_user_group") +@patch("src.database.user_groups.read_user_group") +@patch("src.authentication.authenticate_user") +async def test_create_group_and_add_members( + mock_auth, mock_read_group, mock_create_group, client, mock_signal_owner, mock_user_group +): + """Test creating a group and adding members.""" + # Configure mocks + mock_auth.return_value = mock_signal_owner + mock_create_group.return_value = 1 + mock_read_group.return_value = mock_user_group + + # Mock signal owner to be an admin to allow group creation + mock_signal_owner.role = Role.ADMIN + + # Create the group + group_response = await client.post( + "/api/user-groups", + json={"name": "Test Group", "users": ["member1@undp.org", "member2@undp.org"]}, + ) + assert group_response.status_code == 200 + group_data = group_response.json() + assert group_data["id"] == 1 + assert "member1@undp.org" in group_data["users"] + assert "member2@undp.org" in group_data["users"] + + # Add another member + mock_user_group.users.append("member3@undp.org") + add_member_response = await client.post("/api/user-groups/1/users/member3@undp.org") + assert add_member_response.status_code == 200 + assert add_member_response.json() is True + + +@pytest.mark.asyncio +@patch("src.database.signals.create_signal") +@patch("src.database.signals.read_signal") +@patch("src.database.signals.add_collaborator") +@patch("src.authentication.authenticate_user") +async def test_create_draft_signal_with_collaborators( + mock_auth, mock_add_collaborator, mock_read_signal, mock_create_signal, + client, mock_signal_owner, mock_draft_signal +): + """Test creating a draft signal with collaborators.""" + # Configure mocks + mock_auth.return_value = mock_signal_owner + mock_create_signal.return_value = 1 + mock_read_signal.return_value = mock_draft_signal + mock_add_collaborator.return_value = True + + # Create the draft signal + signal_data = { + "headline": "Draft Signal", + "description": "This is a draft signal", + "is_draft": True, + "collaborators": ["collaborator@undp.org", "group:1"] + } + + signal_response = await client.post("/api/signals", json=signal_data) + assert signal_response.status_code == 201 # Created + + response_data = signal_response.json() + assert response_data["headline"] == "Draft Signal" + assert response_data["is_draft"] is True + + # Verify collaborators were added + mock_add_collaborator.assert_any_call( + mock_create_signal.return_value, "collaborator@undp.org" + ) + mock_add_collaborator.assert_any_call( + mock_create_signal.return_value, "group:1" + ) + + +@pytest.mark.asyncio +@patch("src.database.signals.read_signal") +@patch("src.database.signals.can_user_edit_signal") +@patch("src.database.signals.get_signal_collaborators") +@patch("src.authentication.authenticate_user") +async def test_non_owner_can_edit_draft_as_collaborator( + mock_auth, mock_get_collaborators, mock_can_edit, mock_read_signal, + client, mock_draft_signal +): + """Test that a collaborator can edit a draft signal created by someone else.""" + # Create a collaborator user + collaborator_user = User( + id=2, + email="collaborator@undp.org", + role=Role.USER, + name="Collaborator", + unit="BPPS", + ) + + # Configure mocks + mock_auth.return_value = collaborator_user + mock_read_signal.return_value = mock_draft_signal + mock_can_edit.return_value = True + mock_get_collaborators.return_value = ["collaborator@undp.org", "group:1"] + + # Check if user can edit + can_edit_response = await client.get("/api/signals/1/can-edit") + assert can_edit_response.status_code == 200 + assert can_edit_response.json() is True + + # Get collaborators + collaborators_response = await client.get("/api/signals/1/collaborators") + assert collaborators_response.status_code == 200 + collaborators = collaborators_response.json() + assert "collaborator@undp.org" in collaborators + assert "group:1" in collaborators + + +@pytest.mark.asyncio +@patch("src.database.signals.read_signal") +@patch("src.database.signals.can_user_edit_signal") +@patch("src.authentication.authenticate_user") +async def test_non_collaborator_cannot_edit_draft( + mock_auth, mock_can_edit, mock_read_signal, + client, mock_draft_signal +): + """Test that a non-collaborator cannot edit a draft signal created by someone else.""" + # Create a non-collaborator user + non_collaborator = User( + id=3, + email="random@undp.org", + role=Role.USER, + name="Random User", + unit="BPPS", + ) + + # Configure mocks + mock_auth.return_value = non_collaborator + mock_read_signal.return_value = mock_draft_signal + mock_can_edit.return_value = False + + # Check if user can edit + can_edit_response = await client.get("/api/signals/1/can-edit") + assert can_edit_response.status_code == 200 + assert can_edit_response.json() is False + + +@pytest.mark.asyncio +@patch("src.database.signals.read_signal") +@patch("src.database.signals.remove_collaborator") +@patch("src.authentication.authenticate_user") +async def test_remove_collaborator_from_signal( + mock_auth, mock_remove_collaborator, mock_read_signal, + client, mock_signal_owner, mock_draft_signal +): + """Test removing a collaborator from a signal.""" + # Configure mocks + mock_auth.return_value = mock_signal_owner + mock_read_signal.return_value = mock_draft_signal + mock_remove_collaborator.return_value = True + + # Remove user collaborator + user_response = await client.delete("/api/signals/1/collaborators/collaborator@undp.org") + assert user_response.status_code == 200 + assert user_response.json() is True + + # Remove group collaborator + group_response = await client.delete("/api/signals/1/collaborators/group:1") + assert group_response.status_code == 200 + assert group_response.json() is True + + # Verify remove collaborator was called with correct parameters + mock_remove_collaborator.assert_any_call(1, "collaborator@undp.org") + mock_remove_collaborator.assert_any_call(1, "group:1") + + +@pytest.mark.asyncio +@patch("src.database.signals.update_signal") +@patch("src.database.signals.read_signal") +@patch("src.authentication.authenticate_user") +async def test_publish_draft_signal( + mock_auth, mock_read_signal, mock_update_signal, + client, mock_signal_owner, mock_draft_signal +): + """Test publishing a draft signal.""" + # Configure mocks + mock_auth.return_value = mock_signal_owner + + # For the first read_signal call (checking if user can edit) + mock_draft_signal_copy = mock_draft_signal.model_copy() + mock_read_signal.return_value = mock_draft_signal_copy + + # For the second read_signal call (after update) + published_signal = mock_draft_signal.model_copy(update={"is_draft": False}) + mock_read_signal.side_effect = [mock_draft_signal_copy, published_signal] + + mock_update_signal.return_value = 1 + + # Update signal to change from draft to published + update_data = { + "id": 1, + "headline": "Draft Signal", + "description": "This is a draft signal", + "is_draft": False, # Now published + "created_by": "owner@undp.org", + "collaborators": ["collaborator@undp.org", "group:1"], + } + + update_response = await client.put("/api/signals/1", json=update_data) + assert update_response.status_code == 200 + + response_data = update_response.json() + assert response_data["is_draft"] is False # Verify it's now published \ No newline at end of file diff --git a/tests/test_database_signal_collaborators.py b/tests/test_database_signal_collaborators.py new file mode 100644 index 0000000..7f519a4 --- /dev/null +++ b/tests/test_database_signal_collaborators.py @@ -0,0 +1,250 @@ +""" +Unit tests for signal collaborator database methods. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from src.entities import Signal +from src.database.signals import ( + add_collaborator, + remove_collaborator, + get_signal_collaborators, + can_user_edit_signal, +) + + +@pytest.fixture +def mock_cursor(): + """Create a mock database cursor.""" + cursor = AsyncMock() + cursor.fetchone = AsyncMock() + cursor.fetchall = AsyncMock() + return cursor + + +@pytest.fixture +def mock_signal(): + """Create a mock signal with collaborators.""" + return Signal( + id=1, + headline="Test Signal", + description="Test Description", + created_by="owner@undp.org", + is_draft=True, + collaborators=["collaborator@undp.org", "group:1"], + ) + + +class TestSignalCollaboratorsDatabaseMethods: + """Tests for signal collaborators database methods.""" + + @pytest.mark.asyncio + async def test_add_collaborator_user(self, mock_cursor): + """Test adding a user collaborator to a signal.""" + # Mock the database response + mock_cursor.fetchone.side_effect = [(1,), (1,), (1,)] + + # Call the function + result = await add_collaborator(mock_cursor, 1, "collaborator@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 3 # Check signal + check user + add + + # Check the insert query + args, kwargs = mock_cursor.execute.call_args_list[2] + assert "INSERT INTO signal_collaborators" in args[0] + assert args[1] == (1, "collaborator@undp.org") + + @pytest.mark.asyncio + async def test_add_collaborator_group(self, mock_cursor): + """Test adding a group collaborator to a signal.""" + # Mock the database response + mock_cursor.fetchone.side_effect = [(1,), (1,)] + + # Call the function + result = await add_collaborator(mock_cursor, 1, "group:2") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 2 # Check signal + add + + # Check the insert query + args, kwargs = mock_cursor.execute.call_args_list[1] + assert "INSERT INTO signal_collaborator_groups" in args[0] + assert args[1] == (1, 2) + + @pytest.mark.asyncio + async def test_add_collaborator_signal_not_found(self, mock_cursor): + """Test adding a collaborator to a non-existent signal.""" + # Mock the database response + mock_cursor.fetchone.return_value = None + + # Call the function + result = await add_collaborator(mock_cursor, 99, "collaborator@undp.org") + + # Check the result + assert result is False + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "SELECT 1 FROM signals" in args[0] + assert args[1] == (99,) + + @pytest.mark.asyncio + async def test_remove_collaborator_user(self, mock_cursor): + """Test removing a user collaborator from a signal.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + result = await remove_collaborator(mock_cursor, 1, "collaborator@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "DELETE FROM signal_collaborators" in args[0] + assert args[1] == (1, "collaborator@undp.org") + + @pytest.mark.asyncio + async def test_remove_collaborator_group(self, mock_cursor): + """Test removing a group collaborator from a signal.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + result = await remove_collaborator(mock_cursor, 1, "group:2") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "DELETE FROM signal_collaborator_groups" in args[0] + assert args[1] == (1, 2) + + @pytest.mark.asyncio + async def test_get_signal_collaborators(self, mock_cursor): + """Test getting collaborators for a signal.""" + # Mock the database cursor behavior for user collaborators + mock_cursor.__aiter__.side_effect = [ + [("user1@undp.org",), ("user2@undp.org",)], # First query result + [(1,), (2,)] # Second query result + ] + + # Call the function + collaborators = await get_signal_collaborators(mock_cursor, 1) + + # Check the result + assert len(collaborators) == 4 + assert "user1@undp.org" in collaborators + assert "user2@undp.org" in collaborators + assert "group:1" in collaborators + assert "group:2" in collaborators + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 2 + args1, kwargs1 = mock_cursor.execute.call_args_list[0] + assert "FROM signal_collaborators" in args1[0] + assert args1[1] == (1,) + + args2, kwargs2 = mock_cursor.execute.call_args_list[1] + assert "FROM signal_collaborator_groups" in args2[0] + assert args2[1] == (1,) + + @pytest.mark.asyncio + async def test_can_user_edit_signal_creator(self, mock_cursor): + """Test checking if the creator can edit a signal.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + result = await can_user_edit_signal(mock_cursor, 1, "owner@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "WHERE id = %s AND created_by = %s" in args[0] + assert args[1] == (1, "owner@undp.org") + + @pytest.mark.asyncio + async def test_can_user_edit_signal_collaborator(self, mock_cursor): + """Test checking if a collaborator can edit a signal.""" + # Mock the database responses + mock_cursor.fetchone.side_effect = [None, (1,)] + + # Call the function + result = await can_user_edit_signal(mock_cursor, 1, "collaborator@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 2 + + # First check creator + args1, kwargs1 = mock_cursor.execute.call_args_list[0] + assert "WHERE id = %s AND created_by = %s" in args1[0] + assert args1[1] == (1, "collaborator@undp.org") + + # Then check direct collaborator + args2, kwargs2 = mock_cursor.execute.call_args_list[1] + assert "FROM signal_collaborators" in args2[0] + assert args2[1] == (1, "collaborator@undp.org") + + @pytest.mark.asyncio + async def test_can_user_edit_signal_group_member(self, mock_cursor): + """Test checking if a group member can edit a signal.""" + # Mock the database responses + mock_cursor.fetchone.side_effect = [None, None, (1,)] + + # Call the function + result = await can_user_edit_signal(mock_cursor, 1, "group_member@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 3 + + # First check creator + args1, kwargs1 = mock_cursor.execute.call_args_list[0] + assert "WHERE id = %s AND created_by = %s" in args1[0] + + # Then check direct collaborator + args2, kwargs2 = mock_cursor.execute.call_args_list[1] + assert "FROM signal_collaborators" in args2[0] + + # Finally check group member + args3, kwargs3 = mock_cursor.execute.call_args_list[2] + assert "FROM signal_collaborator_groups" in args3[0] + assert "JOIN user_group_members" in args3[0] + assert args3[1] == (1, "group_member@undp.org") + + @pytest.mark.asyncio + async def test_can_user_edit_signal_no_permission(self, mock_cursor): + """Test checking if a user without permission can edit a signal.""" + # Mock the database responses + mock_cursor.fetchone.side_effect = [None, None, None] + + # Call the function + result = await can_user_edit_signal(mock_cursor, 1, "random@undp.org") + + # Check the result + assert result is False + + # Verify all three checks were performed + \ No newline at end of file diff --git a/tests/test_database_user_groups.py b/tests/test_database_user_groups.py new file mode 100644 index 0000000..b72a646 --- /dev/null +++ b/tests/test_database_user_groups.py @@ -0,0 +1,244 @@ +""" +Unit tests for user groups database methods. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from src.entities import UserGroup +from src.database.user_groups import ( + create_user_group, + read_user_group, + update_user_group, + delete_user_group, + list_user_groups, + add_user_to_group, + remove_user_from_group, + get_user_groups, + get_group_users, +) + + +@pytest.fixture +def mock_cursor(): + """Create a mock database cursor.""" + cursor = AsyncMock() + cursor.fetchone = AsyncMock() + cursor.fetchall = AsyncMock() + return cursor + + +@pytest.fixture +def mock_user_group(): + """Create a mock user group.""" + return UserGroup( + id=1, + name="Test Group", + users=["user1@undp.org", "user2@undp.org"], + ) + + +class TestUserGroupsDatabaseMethods: + """Tests for user groups database methods.""" + + @pytest.mark.asyncio + async def test_create_user_group(self, mock_cursor, mock_user_group): + """Test creating a user group.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + group_id = await create_user_group(mock_cursor, mock_user_group) + + # Check the result + assert group_id == 1 + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "INSERT INTO user_groups" in args[0] + assert kwargs["name"] == mock_user_group.name + + @pytest.mark.asyncio + async def test_read_user_group(self, mock_cursor, mock_user_group): + """Test reading a user group.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1, "Test Group", ["user1@undp.org", "user2@undp.org"]) + + # Call the function + group = await read_user_group(mock_cursor, 1) + + # Check the result + assert group.id == mock_user_group.id + assert group.name == mock_user_group.name + assert group.users == mock_user_group.users + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "FROM user_groups" in args[0] + assert args[1] == (1,) + + @pytest.mark.asyncio + async def test_read_user_group_not_found(self, mock_cursor): + """Test reading a non-existent user group.""" + # Mock the database response + mock_cursor.fetchone.return_value = None + + # Call the function + group = await read_user_group(mock_cursor, 99) + + # Check the result + assert group is None + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "FROM user_groups" in args[0] + assert args[1] == (99,) + + @pytest.mark.asyncio + async def test_update_user_group(self, mock_cursor, mock_user_group): + """Test updating a user group.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + group_id = await update_user_group(mock_cursor, mock_user_group) + + # Check the result + assert group_id == 1 + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count >= 2 # Update + delete existing members + + # Check the update query + args, kwargs = mock_cursor.execute.call_args_list[0] + assert "UPDATE user_groups" in args[0] + assert kwargs["id"] == mock_user_group.id + assert kwargs["name"] == mock_user_group.name + + @pytest.mark.asyncio + async def test_delete_user_group(self, mock_cursor): + """Test deleting a user group.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + result = await delete_user_group(mock_cursor, 1) + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 2 # Delete members + delete group + + # Check the delete query + args, kwargs = mock_cursor.execute.call_args_list[1] + assert "DELETE FROM user_groups" in args[0] + assert args[1] == (1,) + + @pytest.mark.asyncio + async def test_list_user_groups(self, mock_cursor, mock_user_group): + """Test listing user groups.""" + # Mock the database cursor behavior + mock_row = (1, "Test Group", ["user1@undp.org", "user2@undp.org"]) + + # We need to make the cursor iterable to simulate async for loop + mock_cursor.__aiter__.return_value = [mock_row] + + # Call the function + groups = await list_user_groups(mock_cursor) + + # Check the result + assert len(groups) == 1 + assert groups[0].id == mock_user_group.id + assert groups[0].name == mock_user_group.name + assert groups[0].users == mock_user_group.users + + # Verify the SQL was executed + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "FROM user_groups" in args[0] + + @pytest.mark.asyncio + async def test_add_user_to_group(self, mock_cursor): + """Test adding a user to a group.""" + # Mock the database responses for each query + mock_cursor.fetchone.side_effect = [(1,), (1,), (1,)] + + # Call the function + result = await add_user_to_group(mock_cursor, 1, "user3@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + assert mock_cursor.execute.call_count == 3 # Check user + check group + add + + # Check the insert query + args, kwargs = mock_cursor.execute.call_args_list[2] + assert "INSERT INTO user_group_members" in args[0] + assert args[1] == (1, "user3@undp.org") + + @pytest.mark.asyncio + async def test_remove_user_from_group(self, mock_cursor): + """Test removing a user from a group.""" + # Mock the database response + mock_cursor.fetchone.return_value = (1,) + + # Call the function + result = await remove_user_from_group(mock_cursor, 1, "user1@undp.org") + + # Check the result + assert result is True + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "DELETE FROM user_group_members" in args[0] + assert args[1] == (1, "user1@undp.org") + + @pytest.mark.asyncio + async def test_get_user_groups(self, mock_cursor, mock_user_group): + """Test getting groups for a user.""" + # Mock the database cursor behavior + mock_row = (1, "Test Group", ["user1@undp.org", "user2@undp.org"]) + + # We need to make the cursor iterable to simulate async for loop + mock_cursor.__aiter__.return_value = [mock_row] + + # Call the function + groups = await get_user_groups(mock_cursor, "user1@undp.org") + + # Check the result + assert len(groups) == 1 + assert groups[0].id == mock_user_group.id + assert groups[0].name == mock_user_group.name + assert groups[0].users == mock_user_group.users + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "WHERE m.user_email = %s" in args[0] + assert args[1] == ("user1@undp.org",) + + @pytest.mark.asyncio + async def test_get_group_users(self, mock_cursor): + """Test getting users in a group.""" + # Mock the database cursor behavior + mock_cursor.__aiter__.return_value = [("user1@undp.org",), ("user2@undp.org",)] + + # Call the function + users = await get_group_users(mock_cursor, 1) + + # Check the result + assert len(users) == 2 + assert "user1@undp.org" in users + assert "user2@undp.org" in users + + # Verify the SQL was executed with the correct parameters + mock_cursor.execute.assert_called_once() + args, kwargs = mock_cursor.execute.call_args + assert "WHERE m.group_id = %s" in args[0] + \ No newline at end of file diff --git a/tests/test_user_groups.py b/tests/test_user_groups.py new file mode 100644 index 0000000..c8f4306 --- /dev/null +++ b/tests/test_user_groups.py @@ -0,0 +1,305 @@ +""" +Tests for user group operations and collaborator functionality. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest_asyncio +from fastapi import FastAPI +from httpx import AsyncClient + +from src.entities import UserGroup, User, Role, Signal +from src.app import create_application + + +@pytest.fixture +def app(): + """Create a test application.""" + return create_application(configure_oauth=False, debug=True) + + +@pytest_asyncio.fixture +async def client(app: FastAPI): + """Create a test client.""" + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + + +@pytest.fixture +def mock_admin_user(): + """Create a mock admin user.""" + return User( + id=1, + email="admin@undp.org", + role=Role.ADMIN, + name="Admin User", + unit="BPPS", + ) + + +@pytest.fixture +def mock_regular_user(): + """Create a mock regular user.""" + return User( + id=2, + email="user@undp.org", + role=Role.USER, + name="Regular User", + unit="BPPS", + ) + + +@pytest.fixture +def mock_user_group(): + """Create a mock user group.""" + return UserGroup( + id=1, + name="Test Group", + users=["user1@undp.org", "user2@undp.org"], + ) + + +@pytest.fixture +def mock_signal(): + """Create a mock signal with collaborators.""" + return Signal( + id=1, + headline="Test Signal", + description="Test Description", + created_by="user@undp.org", + is_draft=True, + collaborators=["collaborator@undp.org", "group:1"], + ) + + +class TestUserGroups: + """Tests for user group operations.""" + + @pytest.mark.asyncio + @patch("src.database.user_groups.list_user_groups") + @patch("src.authentication.authenticate_user") + async def test_list_user_groups( + self, mock_auth, mock_list_groups, client, mock_admin_user, mock_user_group + ): + """Test listing user groups.""" + mock_auth.return_value = mock_admin_user + mock_list_groups.return_value = [mock_user_group] + + response = await client.get("/api/user-groups") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["id"] == mock_user_group.id + assert data[0]["name"] == mock_user_group.name + assert data[0]["users"] == mock_user_group.users + + @pytest.mark.asyncio + @patch("src.database.user_groups.create_user_group") + @patch("src.database.user_groups.read_user_group") + @patch("src.authentication.authenticate_user") + async def test_create_user_group( + self, mock_auth, mock_read_group, mock_create_group, client, mock_admin_user, mock_user_group + ): + """Test creating a user group.""" + mock_auth.return_value = mock_admin_user + mock_create_group.return_value = 1 + mock_read_group.return_value = mock_user_group + + response = await client.post( + "/api/user-groups", + json={"name": "Test Group", "users": ["user1@undp.org", "user2@undp.org"]}, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == mock_user_group.id + assert data["name"] == mock_user_group.name + assert data["users"] == mock_user_group.users + + @pytest.mark.asyncio + @patch("src.database.user_groups.read_user_group") + @patch("src.authentication.authenticate_user") + async def test_read_user_group( + self, mock_auth, mock_read_group, client, mock_admin_user, mock_user_group + ): + """Test reading a user group.""" + mock_auth.return_value = mock_admin_user + mock_read_group.return_value = mock_user_group + + response = await client.get("/api/user-groups/1") + assert response.status_code == 200 + data = response.json() + assert data["id"] == mock_user_group.id + assert data["name"] == mock_user_group.name + assert data["users"] == mock_user_group.users + + @pytest.mark.asyncio + @patch("src.database.user_groups.update_user_group") + @patch("src.database.user_groups.read_user_group") + @patch("src.authentication.authenticate_user") + async def test_update_user_group( + self, mock_auth, mock_read_group, mock_update_group, client, mock_admin_user, mock_user_group + ): + """Test updating a user group.""" + mock_auth.return_value = mock_admin_user + mock_update_group.return_value = 1 + mock_read_group.return_value = mock_user_group + + response = await client.put( + "/api/user-groups/1", + json={"id": 1, "name": "Updated Group", "users": ["user1@undp.org"]}, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == mock_user_group.id + assert data["name"] == mock_user_group.name + assert data["users"] == mock_user_group.users + + @pytest.mark.asyncio + @patch("src.database.user_groups.delete_user_group") + @patch("src.authentication.authenticate_user") + async def test_delete_user_group( + self, mock_auth, mock_delete_group, client, mock_admin_user + ): + """Test deleting a user group.""" + mock_auth.return_value = mock_admin_user + mock_delete_group.return_value = True + + response = await client.delete("/api/user-groups/1") + assert response.status_code == 200 + data = response.json() + assert data is True + + @pytest.mark.asyncio + @patch("src.database.user_groups.add_user_to_group") + @patch("src.authentication.authenticate_user") + async def test_add_user_to_group( + self, mock_auth, mock_add_user, client, mock_admin_user + ): + """Test adding a user to a group.""" + mock_auth.return_value = mock_admin_user + mock_add_user.return_value = True + + response = await client.post("/api/user-groups/1/users/user3@undp.org") + assert response.status_code == 200 + data = response.json() + assert data is True + + @pytest.mark.asyncio + @patch("src.database.user_groups.remove_user_from_group") + @patch("src.authentication.authenticate_user") + async def test_remove_user_from_group( + self, mock_auth, mock_remove_user, client, mock_admin_user + ): + """Test removing a user from a group.""" + mock_auth.return_value = mock_admin_user + mock_remove_user.return_value = True + + response = await client.delete("/api/user-groups/1/users/user1@undp.org") + assert response.status_code == 200 + data = response.json() + assert data is True + + +class TestSignalCollaborators: + """Tests for signal collaborator operations.""" + + @pytest.mark.asyncio + @patch("src.database.signals.read_signal") + @patch("src.database.signals.can_user_edit_signal") + @patch("src.database.signals.get_signal_collaborators") + @patch("src.authentication.authenticate_user") + async def test_get_signal_collaborators( + self, mock_auth, mock_get_collaborators, mock_can_edit, mock_read_signal, + client, mock_regular_user, mock_signal + ): + """Test getting signal collaborators.""" + mock_auth.return_value = mock_regular_user + mock_read_signal.return_value = mock_signal + mock_can_edit.return_value = True + mock_get_collaborators.return_value = ["collaborator@undp.org", "group:1"] + + response = await client.get("/api/signals/1/collaborators") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 2 + assert "collaborator@undp.org" in data + assert "group:1" in data + + @pytest.mark.asyncio + @patch("src.database.signals.read_signal") + @patch("src.database.signals.add_collaborator") + @patch("src.authentication.authenticate_user") + async def test_add_signal_collaborator( + self, mock_auth, mock_add_collaborator, mock_read_signal, + client, mock_regular_user, mock_signal + ): + """Test adding a collaborator to a signal.""" + mock_auth.return_value = mock_regular_user + mock_read_signal.return_value = mock_signal + mock_add_collaborator.return_value = True + + # Test adding a user collaborator + response = await client.post("/api/signals/1/collaborators/new_user@undp.org") + assert response.status_code == 200 + data = response.json() + assert data is True + + # Test adding a group collaborator + response = await client.post("/api/signals/1/collaborators/group:2") + assert response.status_code == 200 + data = response.json() + assert data is True + + @pytest.mark.asyncio + @patch("src.database.signals.read_signal") + @patch("src.database.signals.remove_collaborator") + @patch("src.authentication.authenticate_user") + async def test_remove_signal_collaborator( + self, mock_auth, mock_remove_collaborator, mock_read_signal, + client, mock_regular_user, mock_signal + ): + """Test removing a collaborator from a signal.""" + mock_auth.return_value = mock_regular_user + mock_read_signal.return_value = mock_signal + mock_remove_collaborator.return_value = True + + # Test removing a user collaborator + response = await client.delete("/api/signals/1/collaborators/collaborator@undp.org") + assert response.status_code == 200 + data = response.json() + assert data is True + + # Test removing a group collaborator + response = await client.delete("/api/signals/1/collaborators/group:1") + assert response.status_code == 200 + data = response.json() + assert data is True + + @pytest.mark.asyncio + @patch("src.database.signals.read_signal") + @patch("src.database.signals.can_user_edit_signal") + @patch("src.authentication.authenticate_user") + async def test_can_user_edit_signal( + self, mock_auth, mock_can_edit, mock_read_signal, + client, mock_regular_user + ): + """Test checking if a user can edit a signal.""" + mock_auth.return_value = mock_regular_user + mock_read_signal.return_value = mock_signal + + # Test when user can edit + mock_can_edit.return_value = True + response = await client.get("/api/signals/1/can-edit") + assert response.status_code == 200 + data = response.json() + assert data is True + + # Test when user cannot edit + mock_can_edit.return_value = False + response = await client.get("/api/signals/1/can-edit") + assert response.status_code == 200 + data = response.json() + assert data is False From 85fd1932f29c76de359715bb564031efb7a9604f Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 30 Apr 2025 14:44:49 +0300 Subject: [PATCH 10/31] add endpoints to get auth user's groups + signals --- src/database/signals.py | 66 +++++++++++----------- src/database/user_groups.py | 108 +++++++++++++++++++++++++++++++++++- src/entities/user_groups.py | 35 +++++++++++- src/routers/user_groups.py | 50 ++++++++++++++++- 4 files changed, 219 insertions(+), 40 deletions(-) diff --git a/src/database/signals.py b/src/database/signals.py index a55bacb..efbab5f 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -530,61 +530,57 @@ async def get_signal_collaborators(cursor: AsyncCursor, signal_id: int) -> list[ return user_emails + group_ids -async def can_user_edit_signal(cursor: AsyncCursor, signal_id: int, user_email: str) -> bool: +async def can_user_edit_signal(cursor: AsyncCursor, signal_id: int, user_id: int) -> bool: """ Check if a user can edit a signal. - + A user can edit a signal if: 1. They created the signal - 2. They are in the collaborators list - 3. They are part of a group in the collaborators list - + 2. They are a direct collaborator for the signal + 3. They are part of a group that can collaborate on this signal + Parameters ---------- cursor : AsyncCursor An async database cursor. signal_id : int - The ID of the signal. - user_email : str - The email of the user. - + The ID of the signal to check. + user_id : int + The ID of the user to check. + Returns ------- bool True if the user can edit the signal, False otherwise. """ - # Check if the user created the signal - query1 = """ - SELECT 1 - FROM signals - WHERE id = %s AND created_by = %s - ; - """ - await cursor.execute(query1, (signal_id, user_email)) + # First, check if the user created the signal + from ..entities import User # Import here to avoid circular imports + + # Get user's email from ID + query = "SELECT email FROM users WHERE id = %s;" + await cursor.execute(query, (user_id,)) + row = await cursor.fetchone() + if row is None: + return False + + user_email = row[0] + + # Check if user created the signal + query = "SELECT 1 FROM signals WHERE id = %s AND created_by = %s;" + await cursor.execute(query, (signal_id, user_email)) if await cursor.fetchone() is not None: return True - # Check if the user is in the collaborators list - query2 = """ - SELECT 1 - FROM signal_collaborators - WHERE signal_id = %s AND user_email = %s - ; - """ - await cursor.execute(query2, (signal_id, user_email)) + # Check direct collaborators + query = "SELECT 1 FROM signal_collaborators WHERE signal_id = %s AND user_id = %s;" + await cursor.execute(query, (signal_id, user_id)) if await cursor.fetchone() is not None: return True - # Check if the user is part of a group in the collaborators list - query3 = """ - SELECT 1 - FROM signal_collaborator_groups scg - JOIN user_group_members ugm ON scg.group_id = ugm.group_id - WHERE scg.signal_id = %s AND ugm.user_email = %s - ; - """ - await cursor.execute(query3, (signal_id, user_email)) - if await cursor.fetchone() is not None: + # Check group collaborators + from . import user_groups # Import here to avoid circular imports + group_collaborators = await user_groups.get_signal_group_collaborators(cursor, signal_id) + if user_id in group_collaborators: return True return False diff --git a/src/database/user_groups.py b/src/database/user_groups.py index fa61837..21d7ecd 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -4,7 +4,7 @@ from psycopg import AsyncCursor -from ..entities import UserGroup +from ..entities import UserGroup, Signal __all__ = [ "create_user_group", @@ -16,6 +16,8 @@ "remove_user_from_group", "get_user_groups", "get_group_users", + "get_user_groups_with_signals", + "get_signal_group_collaborators", ] @@ -53,6 +55,8 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: """ await cursor.execute(query, group.model_dump(exclude={"id"})) row = await cursor.fetchone() + if row is None: + raise ValueError("Failed to create user group") group_id = row[0] return group_id @@ -359,3 +363,105 @@ async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: row = await cursor.fetchone() return row[0] if row and row[0] else [] + + +async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> list[dict]: + """ + Get all groups that a user is a member of, along with the associated signals data. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[dict] + A list of dictionaries containing user group and signal data. + """ + # First get the groups the user belongs to + user_groups = await get_user_groups(cursor, user_id) + result = [] + + # For each group, fetch the signals data + for group in user_groups: + group_data = group.model_dump() + signals = [] + + # Get signals for this group + if group.signal_ids: + signals_query = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id + ; + """ + await cursor.execute(signals_query, (group.signal_ids,)) + + async for row in cursor: + signal_dict = dict(row) + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_dict["id"]) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + signal_dict["can_edit"] = can_edit + signals.append(Signal(**signal_dict)) + + group_data["signals"] = signals + result.append(group_data) + + return result + + +async def get_signal_group_collaborators(cursor: AsyncCursor, signal_id: int) -> list[int]: + """ + Get all user IDs that can collaborate on a signal through group membership. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_id : int + The ID of the signal. + + Returns + ------- + list[int] + A list of user IDs that can collaborate on the signal. + """ + signal_id_str = str(signal_id) + query = """ + SELECT + collaborator_map->%s as collaborators + FROM + user_groups + WHERE + %s = ANY(signal_ids) + AND collaborator_map ? %s + ; + """ + await cursor.execute(query, (signal_id_str, signal_id, signal_id_str)) + + collaborators = set() + async for row in cursor: + if row[0]: # Access first column using integer index + for user_id in row[0]: + collaborators.add(user_id) + + return list(collaborators) diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py index 20d8f22..4599326 100644 --- a/src/entities/user_groups.py +++ b/src/entities/user_groups.py @@ -6,8 +6,9 @@ from pydantic import ConfigDict, Field from .base import BaseEntity +from .signal import Signal -__all__ = ["UserGroup"] +__all__ = ["UserGroup", "UserGroupWithSignals"] class UserGroup(BaseEntity): @@ -45,3 +46,35 @@ class UserGroup(BaseEntity): } } ) + + +class UserGroupWithSignals(UserGroup): + """User group with associated signals data.""" + + signals: List[Signal] = Field( + default_factory=list, + description="List of signals associated with this group." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + }, + "signals": [ + { + "id": 1, + "headline": "Signal 1", + "can_edit": True + } + ] + } + } + ) diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index 7a2e00e..a5d9812 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -5,13 +5,14 @@ import logging from typing import Annotated, List -from fastapi import APIRouter, Depends, Path, Body +from fastapi import APIRouter, Depends, Path, Body, Query from psycopg import AsyncCursor from .. import database as db from .. import exceptions -from ..dependencies import require_admin -from ..entities import UserGroup +from ..dependencies import require_admin, require_user +from ..entities import UserGroup, User, UserGroupWithSignals +from ..authentication import authenticate_user router = APIRouter(prefix="/user-groups", tags=["user groups"]) @@ -25,6 +26,49 @@ async def list_user_groups( return groups +@router.get("/me", response_model=List[UserGroup]) +async def get_my_user_groups( + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Get all user groups that the current user is a member of. + This endpoint is accessible to all authenticated users. + """ + if not user.id: + raise exceptions.not_found + + # Get groups this user is a member of + user_groups = await db.get_user_groups(cursor, user.id) + return user_groups + + +@router.get("/me/with-signals", response_model=List[UserGroupWithSignals]) +async def get_my_user_groups_with_signals( + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Get all user groups that the current user is a member of along with their signals data. + + This enhanced endpoint provides detailed information about each signal associated with the groups, + including whether the current user has edit permissions for each signal. This is useful for: + + - Displaying a dashboard of all signals a user can access through their groups + - Showing which signals the user can edit vs. view-only + - Building collaborative workflows where users can see their assigned signals + + The response includes all signal details plus a `can_edit` flag for each signal indicating + if the current user has edit permissions based on the group's collaborator_map. + """ + if not user.id: + raise exceptions.not_found + + # Get groups with signals for this user + user_groups_with_signals = await db.get_user_groups_with_signals(cursor, user.id) + return user_groups_with_signals + + @router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) async def create_user_group( group: UserGroup, From cc3104b184ab15cfa9e6d41d09fcf1524a6dc073 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 30 Apr 2025 15:17:26 +0300 Subject: [PATCH 11/31] add emails to user group routes --- src/routers/user_groups.py | 210 ++++++++++++++++++++++++++++++++++--- 1 file changed, 195 insertions(+), 15 deletions(-) diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index a5d9812..84b805c 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -3,10 +3,11 @@ """ import logging -from typing import Annotated, List +from typing import Annotated, List, Optional, Union -from fastapi import APIRouter, Depends, Path, Body, Query +from fastapi import APIRouter, Depends, Path, Body, Query, HTTPException from psycopg import AsyncCursor +from pydantic import BaseModel from .. import database as db from .. import exceptions @@ -17,6 +18,43 @@ router = APIRouter(prefix="/user-groups", tags=["user groups"]) +# Add models to support user emails in requests +class UserGroupCreate(BaseModel): + name: str + users: Optional[List[str]] = None + + +class UserEmailIdentifier(BaseModel): + email: str + + +# Helper function to get user ID from email or ID +async def get_user_id(cursor: AsyncCursor, user_identifier: Union[str, int]) -> Optional[int]: + """ + Get a user ID from either an email address or ID. + + Parameters + ---------- + cursor : AsyncCursor + Database cursor + user_identifier : Union[str, int] + Either a user email (string) or user ID (int) + + Returns + ------- + Optional[int] + User ID if found, None otherwise + """ + if isinstance(user_identifier, int): + # Check if user exists + user = await db.read_user(cursor, user_identifier) + return user.id if user else None + else: + # Try to find user by email + user = await db.read_user_by_email(cursor, user_identifier) + return user.id if user else None + + @router.get("", response_model=List[UserGroup], dependencies=[Depends(require_admin)]) async def list_user_groups( cursor: AsyncCursor = Depends(db.yield_cursor), @@ -71,10 +109,25 @@ async def get_my_user_groups_with_signals( @router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) async def create_user_group( - group: UserGroup, + group_data: UserGroupCreate, cursor: AsyncCursor = Depends(db.yield_cursor), ): """Create a new user group.""" + # Create the base group + group = UserGroup(name=group_data.name) + + # Handle email addresses if provided + if group_data.users: + user_ids = [] + for email in group_data.users: + user = await db.read_user_by_email(cursor, email) + if user and user.id: + user_ids.append(user.id) + + if user_ids: + group.user_ids = user_ids + + # Create the group group_id = await db.create_user_group(cursor, group) return await db.read_user_group(cursor, group_id) @@ -115,27 +168,80 @@ async def delete_user_group( return True -@router.post("/{group_id}/users/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +@router.post("/{group_id}/users", response_model=bool, dependencies=[Depends(require_admin)]) +async def add_user_to_group_by_email( + group_id: Annotated[int, Path(description="The ID of the user group")], + user_data: UserEmailIdentifier, + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Add a user to a group by email address.""" + # Find user ID from email + user = await db.read_user_by_email(cursor, user_data.email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_data.email} not found") + + # Add user to group + if not await db.add_user_to_group(cursor, group_id, user.id): + raise exceptions.not_found + + return True + + +@router.post("/{group_id}/users/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def add_user_to_group( group_id: Annotated[int, Path(description="The ID of the user group")], - user_id: Annotated[int, Path(description="The ID of the user to add")], + user_id_or_email: Annotated[str, Path(description="The ID or email of the user to add")], cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Add a user to a group.""" + """ + Add a user to a group. + + This endpoint accepts either a numeric user ID or an email address. + If an email is provided, the system will look up the corresponding user ID. + """ + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id + + # Add user to group if not await db.add_user_to_group(cursor, group_id, user_id): raise exceptions.not_found + return True -@router.delete("/{group_id}/users/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +@router.delete("/{group_id}/users/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def remove_user_from_group( group_id: Annotated[int, Path(description="The ID of the user group")], - user_id: Annotated[int, Path(description="The ID of the user to remove")], + user_id_or_email: Annotated[str, Path(description="The ID or email of the user to remove")], cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Remove a user from a group.""" + """ + Remove a user from a group. + + This endpoint accepts either a numeric user ID or an email address. + If an email is provided, the system will look up the corresponding user ID. + """ + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id + + # Remove user from group if not await db.remove_user_from_group(cursor, group_id, user_id): raise exceptions.not_found + return True @@ -198,14 +304,73 @@ async def remove_signal_from_group( return True -@router.post("/{group_id}/signals/{signal_id}/collaborators/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +@router.post("/{group_id}/signals/{signal_id}/collaborators", response_model=bool, dependencies=[Depends(require_admin)]) +async def add_collaborator_to_signal_by_email( + group_id: Annotated[int, Path(description="The ID of the user group")], + signal_id: Annotated[int, Path(description="The ID of the signal")], + user_data: UserEmailIdentifier, + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """Add a user as a collaborator for a specific signal in a group by email address.""" + # Find user ID from email + user = await db.read_user_by_email(cursor, user_data.email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_data.email} not found") + + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + raise exceptions.not_found + + # Check if signal is in the group + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + raise exceptions.not_found + + # Check if user is in the group + user_ids = group.user_ids or [] + if user.id not in user_ids: + raise exceptions.not_found + + # Add collaborator + collaborator_map = group.collaborator_map or {} + signal_key = str(signal_id) + if signal_key not in collaborator_map: + collaborator_map[signal_key] = [] + + if user.id not in collaborator_map[signal_key]: + collaborator_map[signal_key].append(user.id) + group.collaborator_map = collaborator_map + + if await db.update_user_group(cursor, group) is None: + raise exceptions.not_found + + return True + + +@router.post("/{group_id}/signals/{signal_id}/collaborators/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def add_collaborator_to_signal_in_group( group_id: Annotated[int, Path(description="The ID of the user group")], signal_id: Annotated[int, Path(description="The ID of the signal")], - user_id: Annotated[int, Path(description="The ID of the user to add as collaborator")], + user_id_or_email: Annotated[str, Path(description="The ID or email of the user to add as collaborator")], cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Add a user as a collaborator for a specific signal in a group.""" + """ + Add a user as a collaborator for a specific signal in a group. + + This endpoint accepts either a numeric user ID or an email address. + If an email is provided, the system will look up the corresponding user ID. + """ + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id + # Get the group group = await db.read_user_group(cursor, group_id) if group is None: @@ -237,14 +402,29 @@ async def add_collaborator_to_signal_in_group( return True -@router.delete("/{group_id}/signals/{signal_id}/collaborators/{user_id}", response_model=bool, dependencies=[Depends(require_admin)]) +@router.delete("/{group_id}/signals/{signal_id}/collaborators/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def remove_collaborator_from_signal_in_group( group_id: Annotated[int, Path(description="The ID of the user group")], signal_id: Annotated[int, Path(description="The ID of the signal")], - user_id: Annotated[int, Path(description="The ID of the user to remove as collaborator")], + user_id_or_email: Annotated[str, Path(description="The ID or email of the user to remove as collaborator")], cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Remove a user as a collaborator for a specific signal in a group.""" + """ + Remove a user as a collaborator for a specific signal in a group. + + This endpoint accepts either a numeric user ID or an email address. + If an email is provided, the system will look up the corresponding user ID. + """ + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id + # Get the group group = await db.read_user_group(cursor, group_id) if group is None: From e5d9e226e65acedb81ebff7311e79ff766e4201e Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 30 Apr 2025 15:35:27 +0300 Subject: [PATCH 12/31] configure application with bugsnag --- .gitignore | 1 + main.py | 75 ++++++- requirements.txt | 1 + src/bugsnag_config.py | 76 +++++++ src/routers/user_groups.py | 439 +++++++++++++++++++++++++++---------- test_bugsnag.py | 56 +++++ 6 files changed, 537 insertions(+), 111 deletions(-) create mode 100644 src/bugsnag_config.py create mode 100644 test_bugsnag.py diff --git a/.gitignore b/.gitignore index d42f3dc..db08835 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,4 @@ create_test_user.sql /.prs Taskfile.yml .env.local +/.logs diff --git a/main.py b/main.py index 39e3ba0..810904e 100644 --- a/main.py +++ b/main.py @@ -3,17 +3,32 @@ the frontend platform with the backend database. """ +import os +import logging +import datetime from dotenv import load_dotenv -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse from src import routers from src.authentication import authenticate_user from src.config.logging_config import setup_logging +from src.bugsnag_config import configure_bugsnag, setup_bugsnag_logging, get_bugsnag_middleware, BUGSNAG_ENABLED +# Load environment variables and set up logging load_dotenv() setup_logging() +# Get application version +app_version = os.environ.get("RELEASE_VERSION", "dev") +app_env = os.environ.get("ENVIRONMENT", "development") +logging.info(f"Starting application - version: {app_version}, environment: {app_env}") + +# Configure Bugsnag for error tracking +configure_bugsnag() +setup_bugsnag_logging() + app = FastAPI( debug=False, title="Future Trends and Signals API", @@ -60,6 +75,64 @@ allow_headers=["*"], ) +# Add Bugsnag exception handling middleware +app = get_bugsnag_middleware(app) + +# Add global exception handler to report errors to Bugsnag +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + logging.error(f"Unhandled exception: {str(exc)}", exc_info=True) + + if BUGSNAG_ENABLED: + import bugsnag + bugsnag.notify( + exc, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "client": request.client.host if request.client else None, + } + } + ) + + return JSONResponse( + status_code=500, + content={"detail": "Internal server error"}, + ) + for router in routers.ALL: app.include_router(router=router, dependencies=[Depends(authenticate_user)]) + +# Add diagnostic endpoint for health checks and Bugsnag verification +@app.get("/_health", include_in_schema=False) +async def health_check(): + """Health check endpoint that also shows the current environment and version.""" + return { + "status": "ok", + "environment": app_env, + "version": app_version, + "bugsnag_enabled": BUGSNAG_ENABLED + } + +# Test endpoint to trigger a test error report to Bugsnag if enabled +@app.get("/_test-error", include_in_schema=False) +async def test_error(): + """Trigger a test error to verify Bugsnag is working.""" + if BUGSNAG_ENABLED: + import bugsnag + bugsnag.notify( + Exception("Test error triggered via /_test-error endpoint"), + metadata={ + "test_info": { + "environment": app_env, + "version": app_version, + "timestamp": str(datetime.datetime.now()) + } + } + ) + return {"status": "error_reported", "message": "Test error sent to Bugsnag"} + else: + return {"status": "disabled", "message": "Bugsnag is not enabled"} diff --git a/requirements.txt b/requirements.txt index aa12ec0..f3145fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ pillow ~= 11.0.0 beautifulsoup4 ~= 4.12.3 lxml ~= 5.3.0 openai == 1.52.2 +bugsnag>=4.0.0 diff --git a/src/bugsnag_config.py b/src/bugsnag_config.py new file mode 100644 index 0000000..e7a32d6 --- /dev/null +++ b/src/bugsnag_config.py @@ -0,0 +1,76 @@ +""" +Bugsnag configuration module for error tracking. +""" + +import os +import logging +import bugsnag +from bugsnag.handlers import BugsnagHandler +from bugsnag.asgi import BugsnagMiddleware + +# Get API key from environment variable with fallback to the provided key +BUGSNAG_API_KEY = os.environ.get("BUGSNAG_API_KEY") +ENVIRONMENT = os.environ.get("ENVIRONMENT", "development") +RELEASE_VERSION = os.environ.get("RELEASE_VERSION", "dev") + +if not BUGSNAG_API_KEY: + logging.warning("BUGSNAG_API_KEY is not set - error reporting will be disabled") + BUGSNAG_ENABLED = False +else: + BUGSNAG_ENABLED = True + +def configure_bugsnag(): + """Configure Bugsnag for error tracking.""" + if not BUGSNAG_ENABLED: + logging.warning("Bugsnag is disabled - skipping configuration") + return + + bugsnag.configure( + api_key=BUGSNAG_API_KEY, + project_root=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + release_stage=ENVIRONMENT, + app_version=RELEASE_VERSION, + notify_release_stages=["production", "staging", "development"], + app_type="fastapi", + ) + + logging.info(f"Bugsnag configured - environment: {ENVIRONMENT}, version: {RELEASE_VERSION}") + +def setup_bugsnag_logging(level=logging.ERROR): + """ + Set up Bugsnag to capture logs at specified level. + + Parameters + ---------- + level : int + Minimum log level to send to Bugsnag + """ + if not BUGSNAG_ENABLED: + return + + logger = logging.getLogger() + handler = BugsnagHandler() + handler.setLevel(level) + logger.addHandler(handler) + + logging.info(f"Bugsnag logging handler added at level {level}") + +def get_bugsnag_middleware(app): + """ + Wrap an ASGI app with Bugsnag middleware. + + Parameters + ---------- + app : ASGI application + The FastAPI application instance + + Returns + ------- + ASGI application + The application wrapped with Bugsnag middleware + """ + if not BUGSNAG_ENABLED: + logging.warning("Bugsnag middleware not added - Bugsnag is disabled") + return app + + return BugsnagMiddleware(app) \ No newline at end of file diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index 84b805c..ab27d7f 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -3,9 +3,10 @@ """ import logging +import bugsnag from typing import Annotated, List, Optional, Union -from fastapi import APIRouter, Depends, Path, Body, Query, HTTPException +from fastapi import APIRouter, Depends, Path, Body, Query, HTTPException, Request from psycopg import AsyncCursor from pydantic import BaseModel @@ -15,6 +16,9 @@ from ..entities import UserGroup, User, UserGroupWithSignals from ..authentication import authenticate_user +# Set up logger for this module +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/user-groups", tags=["user groups"]) @@ -45,27 +49,54 @@ async def get_user_id(cursor: AsyncCursor, user_identifier: Union[str, int]) -> Optional[int] User ID if found, None otherwise """ - if isinstance(user_identifier, int): - # Check if user exists - user = await db.read_user(cursor, user_identifier) - return user.id if user else None - else: - # Try to find user by email - user = await db.read_user_by_email(cursor, user_identifier) - return user.id if user else None + try: + if isinstance(user_identifier, int): + # Check if user exists + user = await db.read_user(cursor, user_identifier) + return user.id if user else None + else: + # Try to find user by email + user = await db.read_user_by_email(cursor, user_identifier) + return user.id if user else None + except Exception as e: + logger.error(f"Error in get_user_id: {str(e)}") + bugsnag.notify( + e, + metadata={ + "user_identifier": str(user_identifier), + "type": type(user_identifier).__name__ + } + ) + return None @router.get("", response_model=List[UserGroup], dependencies=[Depends(require_admin)]) async def list_user_groups( + request: Request, cursor: AsyncCursor = Depends(db.yield_cursor), ): """List all user groups.""" - groups = await db.list_user_groups(cursor) - return groups + try: + groups = await db.list_user_groups(cursor) + logger.info(f"Listed {len(groups)} user groups") + return groups + except Exception as e: + logger.error(f"Error listing user groups: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + } + } + ) + raise @router.get("/me", response_model=List[UserGroup]) async def get_my_user_groups( + request: Request, user: User = Depends(authenticate_user), cursor: AsyncCursor = Depends(db.yield_cursor), ): @@ -73,16 +104,37 @@ async def get_my_user_groups( Get all user groups that the current user is a member of. This endpoint is accessible to all authenticated users. """ - if not user.id: - raise exceptions.not_found - - # Get groups this user is a member of - user_groups = await db.get_user_groups(cursor, user.id) - return user_groups + try: + if not user.id: + logger.warning("User ID not found in get_my_user_groups") + raise exceptions.not_found + + # Get groups this user is a member of + user_groups = await db.get_user_groups(cursor, user.id) + logger.info(f"User {user.id} retrieved {len(user_groups)} groups") + return user_groups + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error in get_my_user_groups: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "user": { + "id": user.id if user else None, + "email": user.email if user else None + } + } + ) + raise @router.get("/me/with-signals", response_model=List[UserGroupWithSignals]) async def get_my_user_groups_with_signals( + request: Request, user: User = Depends(authenticate_user), cursor: AsyncCursor = Depends(db.yield_cursor), ): @@ -99,96 +151,217 @@ async def get_my_user_groups_with_signals( The response includes all signal details plus a `can_edit` flag for each signal indicating if the current user has edit permissions based on the group's collaborator_map. """ - if not user.id: - raise exceptions.not_found - - # Get groups with signals for this user - user_groups_with_signals = await db.get_user_groups_with_signals(cursor, user.id) - return user_groups_with_signals + try: + if not user.id: + logger.warning("User ID not found in get_my_user_groups_with_signals") + raise exceptions.not_found + + # Get groups with signals for this user + user_groups_with_signals = await db.get_user_groups_with_signals(cursor, user.id) + logger.info(f"User {user.id} retrieved {len(user_groups_with_signals)} groups with signals") + return user_groups_with_signals + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error in get_my_user_groups_with_signals: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "user": { + "id": user.id if user else None, + "email": user.email if user else None + } + } + ) + raise @router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) async def create_user_group( + request: Request, group_data: UserGroupCreate, cursor: AsyncCursor = Depends(db.yield_cursor), ): """Create a new user group.""" - # Create the base group - group = UserGroup(name=group_data.name) - - # Handle email addresses if provided - if group_data.users: - user_ids = [] - for email in group_data.users: - user = await db.read_user_by_email(cursor, email) - if user and user.id: - user_ids.append(user.id) + try: + # Create the base group + group = UserGroup(name=group_data.name) - if user_ids: - group.user_ids = user_ids - - # Create the group - group_id = await db.create_user_group(cursor, group) - return await db.read_user_group(cursor, group_id) + # Handle email addresses if provided + if group_data.users: + user_ids = [] + for email in group_data.users: + user = await db.read_user_by_email(cursor, email) + if user and user.id: + user_ids.append(user.id) + + if user_ids: + group.user_ids = user_ids + + # Create the group + group_id = await db.create_user_group(cursor, group) + logger.info(f"Created user group {group_id} with name '{group.name}'") + return await db.read_user_group(cursor, group_id) + except Exception as e: + logger.error(f"Error creating user group: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_data": { + "name": group_data.name, + "users_count": len(group_data.users) if group_data.users else 0 + } + } + ) + raise @router.get("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) async def read_user_group( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group to retrieve")], cursor: AsyncCursor = Depends(db.yield_cursor), ): """Get a user group by ID.""" - if (group := await db.read_user_group(cursor, group_id)) is None: - raise exceptions.not_found - return group + try: + if (group := await db.read_user_group(cursor, group_id)) is None: + logger.warning(f"User group {group_id} not found") + raise exceptions.not_found + logger.info(f"Retrieved user group {group_id}") + return group + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error reading user group {group_id}: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id + } + ) + raise @router.put("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) async def update_user_group( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group to update")], group: UserGroup, cursor: AsyncCursor = Depends(db.yield_cursor), ): """Update a user group.""" - if group_id != group.id: - raise exceptions.id_mismatch - if (updated_id := await db.update_user_group(cursor, group)) is None: - raise exceptions.not_found - return await db.read_user_group(cursor, updated_id) + try: + if group_id != group.id: + logger.warning(f"ID mismatch: path ID {group_id} != body ID {group.id}") + raise exceptions.id_mismatch + if (updated_id := await db.update_user_group(cursor, group)) is None: + logger.warning(f"User group {group_id} not found for update") + raise exceptions.not_found + logger.info(f"Updated user group {updated_id}") + return await db.read_user_group(cursor, updated_id) + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error updating user group {group_id}: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id, + "group_data": { + "id": group.id, + "name": group.name, + "user_count": len(group.user_ids) if group.user_ids else 0, + "signal_count": len(group.signal_ids) if group.signal_ids else 0 + } + } + ) + raise @router.delete("/{group_id}", response_model=bool, dependencies=[Depends(require_admin)]) async def delete_user_group( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group to delete")], cursor: AsyncCursor = Depends(db.yield_cursor), ): """Delete a user group.""" - if not await db.delete_user_group(cursor, group_id): - raise exceptions.not_found - return True + try: + if not await db.delete_user_group(cursor, group_id): + logger.warning(f"User group {group_id} not found for deletion") + raise exceptions.not_found + logger.info(f"Deleted user group {group_id}") + return True + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error deleting user group {group_id}: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id + } + ) + raise @router.post("/{group_id}/users", response_model=bool, dependencies=[Depends(require_admin)]) async def add_user_to_group_by_email( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group")], user_data: UserEmailIdentifier, cursor: AsyncCursor = Depends(db.yield_cursor), ): """Add a user to a group by email address.""" - # Find user ID from email - user = await db.read_user_by_email(cursor, user_data.email) - if not user or not user.id: - raise HTTPException(status_code=404, detail=f"User with email {user_data.email} not found") - - # Add user to group - if not await db.add_user_to_group(cursor, group_id, user.id): - raise exceptions.not_found - - return True + try: + # Find user ID from email + user = await db.read_user_by_email(cursor, user_data.email) + if not user or not user.id: + logger.warning(f"User with email {user_data.email} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_data.email} not found") + + # Add user to group + if not await db.add_user_to_group(cursor, group_id, user.id): + logger.warning(f"Group {group_id} not found when adding user {user.id}") + raise exceptions.not_found + + logger.info(f"Added user {user.id} ({user.email}) to group {group_id}") + return True + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error adding user to group {group_id}: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id, + "user_email": user_data.email + } + ) + raise @router.post("/{group_id}/users/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def add_user_to_group( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group")], user_id_or_email: Annotated[str, Path(description="The ID or email of the user to add")], cursor: AsyncCursor = Depends(db.yield_cursor), @@ -199,21 +372,40 @@ async def add_user_to_group( This endpoint accepts either a numeric user ID or an email address. If an email is provided, the system will look up the corresponding user ID. """ - # Try to parse as int for backward compatibility try: - user_id = int(user_id_or_email) - except ValueError: - # Not an integer, treat as email - user = await db.read_user_by_email(cursor, user_id_or_email) - if not user or not user.id: - raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") - user_id = user.id - - # Add user to group - if not await db.add_user_to_group(cursor, group_id, user_id): - raise exceptions.not_found - - return True + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + logger.warning(f"User with email {user_id_or_email} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id + + # Add user to group + if not await db.add_user_to_group(cursor, group_id, user_id): + logger.warning(f"Group {group_id} not found when adding user {user_id}") + raise exceptions.not_found + + logger.info(f"Added user {user_id} to group {group_id}") + return True + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error adding user to group: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id, + "user_id_or_email": user_id_or_email + } + ) + raise @router.delete("/{group_id}/users/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) @@ -350,6 +542,7 @@ async def add_collaborator_to_signal_by_email( @router.post("/{group_id}/signals/{signal_id}/collaborators/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) async def add_collaborator_to_signal_in_group( + request: Request, group_id: Annotated[int, Path(description="The ID of the user group")], signal_id: Annotated[int, Path(description="The ID of the signal")], user_id_or_email: Annotated[str, Path(description="The ID or email of the user to add as collaborator")], @@ -361,45 +554,71 @@ async def add_collaborator_to_signal_in_group( This endpoint accepts either a numeric user ID or an email address. If an email is provided, the system will look up the corresponding user ID. """ - # Try to parse as int for backward compatibility try: - user_id = int(user_id_or_email) - except ValueError: - # Not an integer, treat as email - user = await db.read_user_by_email(cursor, user_id_or_email) - if not user or not user.id: - raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") - user_id = user.id - - # Get the group - group = await db.read_user_group(cursor, group_id) - if group is None: - raise exceptions.not_found - - # Check if signal is in the group - signal_ids = group.signal_ids or [] - if signal_id not in signal_ids: - raise exceptions.not_found - - # Check if user is in the group - user_ids = group.user_ids or [] - if user_id not in user_ids: - raise exceptions.not_found - - # Add collaborator - collaborator_map = group.collaborator_map or {} - signal_key = str(signal_id) - if signal_key not in collaborator_map: - collaborator_map[signal_key] = [] - - if user_id not in collaborator_map[signal_key]: - collaborator_map[signal_key].append(user_id) - group.collaborator_map = collaborator_map + # Try to parse as int for backward compatibility + try: + user_id = int(user_id_or_email) + except ValueError: + # Not an integer, treat as email + user = await db.read_user_by_email(cursor, user_id_or_email) + if not user or not user.id: + logger.warning(f"User with email {user_id_or_email} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + user_id = user.id - if await db.update_user_group(cursor, group) is None: + # Get the group + group = await db.read_user_group(cursor, group_id) + if group is None: + logger.warning(f"Group {group_id} not found") raise exceptions.not_found - - return True + + # Check if signal is in the group + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + logger.warning(f"Signal {signal_id} not in group {group_id}") + raise exceptions.not_found + + # Check if user is in the group + user_ids = group.user_ids or [] + if user_id not in user_ids: + logger.warning(f"User {user_id} not in group {group_id}") + raise exceptions.not_found + + # Add collaborator + collaborator_map = group.collaborator_map or {} + signal_key = str(signal_id) + if signal_key not in collaborator_map: + collaborator_map[signal_key] = [] + + if user_id not in collaborator_map[signal_key]: + collaborator_map[signal_key].append(user_id) + group.collaborator_map = collaborator_map + + if await db.update_user_group(cursor, group) is None: + logger.error(f"Failed to update group {group_id}") + raise exceptions.not_found + + logger.info(f"Added user {user_id} as collaborator for signal {signal_id} in group {group_id}") + else: + logger.info(f"User {user_id} already a collaborator for signal {signal_id} in group {group_id}") + + return True + except Exception as e: + if not isinstance(e, HTTPException): # Don't log HTTPExceptions + logger.error(f"Error adding collaborator: {str(e)}") + bugsnag.notify( + e, + metadata={ + "request": { + "url": str(request.url), + "method": request.method, + }, + "group_id": group_id, + "signal_id": signal_id, + "user_id_or_email": user_id_or_email + } + ) + raise @router.delete("/{group_id}/signals/{signal_id}/collaborators/{user_id_or_email}", response_model=bool, dependencies=[Depends(require_admin)]) diff --git a/test_bugsnag.py b/test_bugsnag.py new file mode 100644 index 0000000..a4032ba --- /dev/null +++ b/test_bugsnag.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +""" +A simple script to test Bugsnag integration. +""" + +import os +import sys +import time +import bugsnag +from src.bugsnag_config import configure_bugsnag, setup_bugsnag_logging + +def test_bugsnag(): + """Test Bugsnag integration with various error types.""" + print("Configuring Bugsnag...") + configure_bugsnag() + setup_bugsnag_logging() + + # Send a simple test error + print("Sending a test error to Bugsnag...") + bugsnag.notify( + Exception("Test error from test_bugsnag.py"), + metadata={ + "test_data": { + "timestamp": time.time(), + "environment": os.environ.get("ENVIRONMENT", "development"), + "python_version": sys.version + } + } + ) + print("Test error sent. Check your Bugsnag dashboard.") + + # Test error with breadcrumbs + print("Testing error with breadcrumbs...") + bugsnag.leave_breadcrumb("Started test process", metadata={"step": 1}) + bugsnag.leave_breadcrumb("Processing data", metadata={"step": 2, "data_size": 1024}) + bugsnag.leave_breadcrumb("Completed processing", metadata={"step": 3, "status": "success"}) + + try: + # Simulate a real error + result = 1 / 0 + except Exception as e: + print("Sending a division by zero error to Bugsnag...") + bugsnag.notify( + e, + metadata={ + "error_context": { + "operation": "division", + "divisor": 0 + } + } + ) + + print("Test completed. Please check your Bugsnag dashboard for the reported errors.") + +if __name__ == "__main__": + test_bugsnag() \ No newline at end of file From e2a310bde1f3e443b66d8d7da4c8afc7e15e1700 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 30 Apr 2025 15:52:42 +0300 Subject: [PATCH 13/31] update app configuration with bugsnag --- main.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 810904e..04f4a80 100644 --- a/main.py +++ b/main.py @@ -66,18 +66,6 @@ redoc_url=None, ) -# allow cors -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Add Bugsnag exception handling middleware -app = get_bugsnag_middleware(app) - # Add global exception handler to report errors to Bugsnag @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): @@ -102,6 +90,18 @@ async def global_exception_handler(request: Request, exc: Exception): content={"detail": "Internal server error"}, ) +# allow cors +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Add Bugsnag exception handling middleware +# Important: Add middleware AFTER registering exception handlers +bugsnag_app = get_bugsnag_middleware(app) for router in routers.ALL: app.include_router(router=router, dependencies=[Depends(authenticate_user)]) @@ -136,3 +136,6 @@ async def test_error(): return {"status": "error_reported", "message": "Test error sent to Bugsnag"} else: return {"status": "disabled", "message": "Bugsnag is not enabled"} + +# Use the Bugsnag middleware wrapped app for ASGI +app = bugsnag_app From b620e43da262927bb35e9dedfa31891a61816348 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 8 May 2025 01:49:38 +0300 Subject: [PATCH 14/31] Enhance user groups functionality and improve API robustness --- main.py | 73 +++++- src/authentication.py | 91 ++++++-- src/bugsnag_config.py | 2 +- src/database/user_groups.py | 439 ++++++++++++++++++++++++++++++++---- src/entities/user_groups.py | 81 ++++++- src/routers/user_groups.py | 33 ++- 6 files changed, 631 insertions(+), 88 deletions(-) diff --git a/main.py b/main.py index 04f4a80..fab43eb 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware from src import routers from src.authentication import authenticate_user @@ -23,6 +24,9 @@ # Get application version app_version = os.environ.get("RELEASE_VERSION", "dev") app_env = os.environ.get("ENVIRONMENT", "development") +# Override environment setting if in local mode +if os.environ.get("ENV_MODE") == "local": + app_env = "local" logging.info(f"Starting application - version: {app_version}, environment: {app_env}") # Configure Bugsnag for error tracking @@ -90,14 +94,61 @@ async def global_exception_handler(request: Request, exc: Exception): content={"detail": "Internal server error"}, ) -# allow cors -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# Configure CORS - simplified for local development +local_origins = [ + "http://localhost:5175", + "http://127.0.0.1:5175", + "http://localhost:3000", + "http://127.0.0.1:3000" +] + +# Create a custom middleware class for handling CORS +class CORSHandlerMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Handle OPTIONS preflight requests + if request.method == "OPTIONS": + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "600", # Cache preflight for 10 minutes + } + + # Set specific origin if in local mode + origin = request.headers.get("origin") + if os.environ.get("ENV_MODE") == "local" and origin: + headers["Access-Control-Allow-Origin"] = origin + + return JSONResponse(content={}, status_code=200, headers=headers) + + # Process all other requests normally + response = await call_next(request) + return response + +# Apply custom CORS middleware BEFORE the standard CORS middleware +app.add_middleware(CORSHandlerMiddleware) + +# Standard CORS middleware (as a backup) +if os.environ.get("ENV_MODE") == "local": + logging.info(f"Local mode: using specific CORS origins: {local_origins}") + app.add_middleware( + CORSMiddleware, + allow_origins=local_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*", "access_token", "Authorization", "Content-Type"], + expose_headers=["*"], + ) +else: + # Production mode - use more restrictive CORS + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*", "access_token", "Authorization", "Content-Type"], + ) # Add Bugsnag exception handling middleware # Important: Add middleware AFTER registering exception handlers @@ -137,5 +188,11 @@ async def test_error(): else: return {"status": "disabled", "message": "Bugsnag is not enabled"} +# Add special route for handling OPTIONS requests to /users/me +@app.options("/users/me", include_in_schema=False) +async def options_users_me(): + """Handle OPTIONS requests to /users/me specifically.""" + return {} + # Use the Bugsnag middleware wrapped app for ASGI app = bugsnag_app diff --git a/src/authentication.py b/src/authentication.py index da64752..70fdcfb 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -4,6 +4,7 @@ import logging import os +from typing import Dict, Any, Optional, cast import httpx import jwt @@ -18,7 +19,7 @@ api_key_header = APIKeyHeader( name="access_token", description="The API access token for the application.", - auto_error=True, + auto_error=False, ) @@ -67,10 +68,10 @@ async def get_jwk(token: str) -> jwt.PyJWK: jwks = await get_jwks() except httpx.HTTPError: jwks = {} - jwk = jwks.get(header["kid"]) - if jwk is None: + jwk_dict = jwks.get(header["kid"]) + if jwk_dict is None: raise ValueError("JWK could not be obtained or found") - jwk = jwt.PyJWK.from_dict(jwk, "RS256") + jwk = jwt.PyJWK.from_dict(jwk_dict, "RS256") return jwk @@ -108,7 +109,7 @@ async def decode_token(token: str) -> dict: async def authenticate_user( - token: str = Security(api_key_header), + token: Optional[str] = Security(api_key_header), cursor: AsyncCursor = Depends(db.yield_cursor), ) -> User: """ @@ -130,38 +131,78 @@ async def authenticate_user( user : User Pydantic model for a User object (if authentication succeeded). """ - logging.debug(f"Authenticating user with token") - if os.environ.get("TEST_USER_TOKEN"): - token = os.environ.get("TEST_USER_TOKEN") + logging.debug("Authenticating user with token") + + # For local development environment if os.environ.get("ENV_MODE") == "local": - # defaul user data - user_data = { - "email": "test.user@undp.org", - "name": "Test User", - "unit": "Data Futures Exchange (DFx)", - "acclab": False, - } + # Use test token if available + if os.environ.get("TEST_USER_TOKEN"): + test_token = os.environ.get("TEST_USER_TOKEN") + if test_token is not None: + token = test_token + + # Default user data for local development + local_email = "test.user@undp.org" + name = "Test User" + unit = "Data Futures Exchange (DFx)" + acclab = False + + # Check for specific test tokens if token == "test-admin-token": - user_data["role"] = Role.ADMIN - return User(**user_data) + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.ADMIN + ) elif token == "test-user-token": - user_data["role"] = Role.USER - return User(**user_data) - - if token == os.environ.get("API_KEY"): + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.USER + ) + else: + # In local mode, if no valid token is provided, default to an admin user + logging.info("LOCAL MODE: No valid token provided, using default admin user") + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.ADMIN + ) + + # Check for API key access + if token and token == os.environ.get("API_KEY"): if os.environ.get("ENV") == "dev": return User(email="name.surname@undp.org", role=Role.ADMIN) else: # dummy user object for anonymous access return User(email="name.surname@undp.org", role=Role.VISITOR) + + # If no token provided in non-local mode + if not token: + raise exceptions.not_authenticated + + # Try to decode and verify the token try: payload = await decode_token(token) except jwt.exceptions.PyJWTError as e: raise exceptions.not_authenticated from e - email, name = payload.get("unique_name"), payload.get("name") - if email is None or name is None: + + payload_email = payload.get("unique_name") + payload_name = payload.get("name") + + if payload_email is None or payload_name is None: raise exceptions.not_authenticated - if (user := await db.read_user_by_email(cursor, email)) is None: - user = User(email=email, role=Role.USER, name=name) + + email_str = str(payload_email) # Convert to string to satisfy type checker + name_str = str(payload_name) # Convert to string to satisfy type checker + + if (user := await db.read_user_by_email(cursor, email_str)) is None: + user = User(email=email_str, role=Role.USER, name=name_str) await db.create_user(cursor, user) return user diff --git a/src/bugsnag_config.py b/src/bugsnag_config.py index e7a32d6..7569c63 100644 --- a/src/bugsnag_config.py +++ b/src/bugsnag_config.py @@ -10,7 +10,7 @@ # Get API key from environment variable with fallback to the provided key BUGSNAG_API_KEY = os.environ.get("BUGSNAG_API_KEY") -ENVIRONMENT = os.environ.get("ENVIRONMENT", "development") +ENVIRONMENT = os.environ.get("ENVIRONMENT") RELEASE_VERSION = os.environ.get("RELEASE_VERSION", "dev") if not BUGSNAG_API_KEY: diff --git a/src/database/user_groups.py b/src/database/user_groups.py index 21d7ecd..4a30a4b 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -3,8 +3,10 @@ """ from psycopg import AsyncCursor +import json +import logging -from ..entities import UserGroup, Signal +from ..entities import UserGroup, Signal, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete __all__ = [ "create_user_group", @@ -18,8 +20,55 @@ "get_group_users", "get_user_groups_with_signals", "get_signal_group_collaborators", + "get_user_group_with_users", + "list_user_groups_with_users", + "get_user_groups_with_signals_and_users", ] +logger = logging.getLogger(__name__) + + +def handle_user_group_row(row) -> dict: + """ + Helper function to safely extract user group data from a database row. + + Parameters + ---------- + row : dict or tuple + A database result row + + Returns + ------- + dict + A dictionary of user group data ready for creating a UserGroup + """ + data = {} + if isinstance(row, dict): + data['id'] = row["id"] + data['name'] = row["name"] + data['signal_ids'] = row["signal_ids"] or [] + data['user_ids'] = row["user_ids"] or [] + collab_map = row["collaborator_map"] + else: + data['id'] = row[0] + data['name'] = row[1] + data['signal_ids'] = row[2] or [] + data['user_ids'] = row[3] or [] + collab_map = row[4] + + # Handle collaborator_map field + data['collaborator_map'] = {} + if collab_map: + if isinstance(collab_map, str): + try: + data['collaborator_map'] = json.loads(collab_map) + except json.JSONDecodeError: + data['collaborator_map'] = {} + else: + data['collaborator_map'] = collab_map + + return data + async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: """ @@ -37,6 +86,10 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: int The ID of the created user group. """ + # Convert model to dict and ensure collaborator_map is a JSON string + group_data = group.model_dump(exclude={"id"}) + group_data["collaborator_map"] = json.dumps(group_data["collaborator_map"]) + query = """ INSERT INTO user_groups ( name, @@ -53,11 +106,16 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: RETURNING id ; """ - await cursor.execute(query, group.model_dump(exclude={"id"})) + await cursor.execute(query, group_data) row = await cursor.fetchone() if row is None: raise ValueError("Failed to create user group") - group_id = row[0] + + # Access the ID safely + if isinstance(row, dict): + group_id = row["id"] + else: + group_id = row[0] return group_id @@ -95,9 +153,7 @@ async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | Non if (row := await cursor.fetchone()) is None: return None - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) - + data = handle_user_group_row(row) return UserGroup(**data) @@ -117,6 +173,10 @@ async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None int | None The ID of the updated user group if successful, otherwise None. """ + # Convert model to dict and ensure collaborator_map is a JSON string + group_data = group.model_dump() + group_data["collaborator_map"] = json.dumps(group_data["collaborator_map"]) + query = """ UPDATE user_groups SET @@ -128,11 +188,15 @@ async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None RETURNING id ; """ - await cursor.execute(query, group.model_dump()) + await cursor.execute(query, group_data) if (row := await cursor.fetchone()) is None: return None - return row[0] + # Access the ID safely + if isinstance(row, dict): + return row["id"] + else: + return row[0] async def delete_user_group(cursor: AsyncCursor, group_id: int) -> bool: @@ -188,8 +252,7 @@ async def list_user_groups(cursor: AsyncCursor) -> list[UserGroup]: result = [] async for row in cursor: - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + data = handle_user_group_row(row) result.append(UserGroup(**data)) return result @@ -260,39 +323,52 @@ async def remove_user_from_group(cursor: AsyncCursor, group_id: int, user_id: in bool True if the user was removed, False otherwise. """ - # Get current user_ids + # Check if the group exists and get its current state await cursor.execute("SELECT user_ids, collaborator_map FROM user_groups WHERE id = %s;", (group_id,)) row = await cursor.fetchone() if row is None: return False user_ids = row[0] if row[0] is not None else [] - collaborator_map = row[1] if row[1] is not None else {} + + # Parse collaborator_map from JSON if needed + if row[1] is not None and isinstance(row[1], str): + try: + collaborator_map = json.loads(row[1]) + except json.JSONDecodeError: + collaborator_map = {} + else: + collaborator_map = row[1] if row[1] is not None else {} + + if user_id not in user_ids: + return True # Already not in the group # Remove user from user_ids - if user_id in user_ids: - user_ids.remove(user_id) - - # Remove user from collaborator_map - for signal_id, users in list(collaborator_map.items()): - if user_id in users: - users.remove(user_id) - if not users: # If empty, remove signal from map - del collaborator_map[signal_id] - - query = """ - UPDATE user_groups - SET - user_ids = %s, - collaborator_map = %s - WHERE id = %s - RETURNING id - ; - """ - await cursor.execute(query, (user_ids, collaborator_map, group_id)) - return await cursor.fetchone() is not None + user_ids.remove(user_id) + + # Remove user from collaborator_map + for signal_id, users in list(collaborator_map.items()): + if user_id in users: + users.remove(user_id) + + if not users: + del collaborator_map[signal_id] + + # Update the group + query = """ + UPDATE user_groups + SET + user_ids = %s, + collaborator_map = %s + WHERE id = %s + RETURNING id + ; + """ + # Convert collaborator_map to JSON string + collaborator_map_json = json.dumps(collaborator_map) + await cursor.execute(query, (user_ids, collaborator_map_json, group_id)) - return False + return await cursor.fetchone() is not None async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: @@ -330,8 +406,7 @@ async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: result = [] async for row in cursor: - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + data = handle_user_group_row(row) result.append(UserGroup(**data)) return result @@ -365,7 +440,7 @@ async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: return row[0] if row and row[0] else [] -async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> list[dict]: +async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> list[UserGroupWithSignals]: """ Get all groups that a user is a member of, along with the associated signals data. @@ -378,9 +453,11 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> lis Returns ------- - list[dict] - A list of dictionaries containing user group and signal data. + list[UserGroupWithSignals] + A list of user groups with associated signals. """ + logger.debug("Getting user groups with signals for user_id: %s", user_id) + # First get the groups the user belongs to user_groups = await get_user_groups(cursor, user_id) result = [] @@ -389,9 +466,11 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> lis for group in user_groups: group_data = group.model_dump() signals = [] + users = [] # Get signals for this group if group.signal_ids: + logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) signals_query = """ SELECT s.*, @@ -423,9 +502,40 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> lis signal_dict["can_edit"] = can_edit signals.append(Signal(**signal_dict)) - group_data["signals"] = signals - result.append(group_data) + # Get users for this group + if group.user_ids: + logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) + users_query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name + ; + """ + await cursor.execute(users_query, (group.user_ids,)) + + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + + # Create a UserGroupWithSignals instance + group_with_signals = UserGroupWithSignals( + **group_data, + signals=signals + ) + result.append(group_with_signals) + logger.debug("Found %s user groups with signals for user_id: %s", len(result), user_id) return result @@ -460,8 +570,251 @@ async def get_signal_group_collaborators(cursor: AsyncCursor, signal_id: int) -> collaborators = set() async for row in cursor: - if row[0]: # Access first column using integer index - for user_id in row[0]: + # Safely access collaborators + if isinstance(row, dict): + collab_data = row['collaborators'] + else: + collab_data = row[0] + + if collab_data: + for user_id in collab_data: collaborators.add(user_id) return list(collaborators) + + +async def get_user_group_with_users(cursor: AsyncCursor, group_id: int) -> UserGroupWithUsers | None: + """ + Get a user group with detailed user information for each member. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the user group. + + Returns + ------- + UserGroupWithUsers | None + A user group with detailed user information, or None if the group doesn't exist. + """ + logger.debug("Getting user group with users for group_id: %s", group_id) + + # First, get the user group + group = await read_user_group(cursor, group_id) + if group is None: + logger.warning("User group not found with id: %s", group_id) + return None + + # Convert to dict for modification + group_data = group.model_dump() + users = [] + + # If there are users in the group, fetch their details + if group.user_ids: + logger.debug("Fetching users for group_id: %s, user_ids: %s", group_id, group.user_ids) + users_query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name + ; + """ + await cursor.execute(users_query, (group.user_ids,)) + + user_count = 0 + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + user_count += 1 + + logger.debug("Found %s users for group_id: %s", user_count, group_id) + else: + logger.debug("No users found for group_id: %s", group_id) + + # Create a UserGroupWithUsers instance + return UserGroupWithUsers(**group_data, users=users) + + +async def list_user_groups_with_users(cursor: AsyncCursor) -> list[UserGroupWithUsers]: + """ + List all user groups with detailed user information. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[UserGroupWithUsers] + A list of user groups with detailed user information. + """ + logger.debug("Listing all user groups with users") + + # Get all user groups + groups = await list_user_groups(cursor) + result = [] + + # For each group, get user details + for group in groups: + group_data = group.model_dump() + users = [] + + # If there are users in the group, fetch their details + if group.user_ids: + logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) + users_query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name + ; + """ + await cursor.execute(users_query, (group.user_ids,)) + + user_count = 0 + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + user_count += 1 + + logger.debug("Found %s users for group_id: %s", user_count, group.id) + + # Create a UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("Listed %s user groups with users", len(result)) + return result + + +async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: int) -> list[UserGroupComplete]: + """ + Get all groups that a user is a member of, along with the associated signals and users data. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[UserGroupComplete] + A list of user groups with signals and users data. + """ + logger.debug("Getting user groups with signals and users for user_id: %s", user_id) + + # First get the groups the user belongs to + user_groups = await get_user_groups(cursor, user_id) + result = [] + + # For each group, fetch the signals data and user data + for group in user_groups: + group_data = group.model_dump() + signals = [] + users = [] + + # Get signals for this group + if group.signal_ids: + logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) + signals_query = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id + ; + """ + await cursor.execute(signals_query, (group.signal_ids,)) + + signal_count = 0 + async for row in cursor: + signal_dict = dict(row) + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_dict["id"]) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + signal_dict["can_edit"] = can_edit + + # Create Signal instance + signal = Signal(**signal_dict) + signals.append(signal) + signal_count += 1 + + logger.debug("Found %s signals for group_id: %s", signal_count, group.id) + + # Get users for this group + if group.user_ids: + logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) + users_query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name + ; + """ + await cursor.execute(users_query, (group.user_ids,)) + + user_count = 0 + async for row in cursor: + user_data = dict(row) + + # Create User instance + user = User(**user_data) + users.append(user) + user_count += 1 + + logger.debug("Found %s users for group_id: %s", user_count, group.id) + + # Create a UserGroupComplete instance + group_complete = UserGroupComplete(**group_data, signals=signals, users=users) + result.append(group_complete) + + logger.debug("Found %s user groups with signals and users for user_id: %s", len(result), user_id) + return result diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py index 4599326..5b4668f 100644 --- a/src/entities/user_groups.py +++ b/src/entities/user_groups.py @@ -7,8 +7,9 @@ from .base import BaseEntity from .signal import Signal +from .user import User -__all__ = ["UserGroup", "UserGroupWithSignals"] +__all__ = ["UserGroup", "UserGroupWithSignals", "UserGroupWithUsers", "UserGroupComplete"] class UserGroup(BaseEntity): @@ -78,3 +79,81 @@ class UserGroupWithSignals(UserGroup): } } ) + + +class UserGroupWithUsers(UserGroup): + """User group with associated users data.""" + + users: List[User] = Field( + default_factory=list, + description="List of users who are members of this group." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + }, + "users": [ + { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe" + } + ] + } + } + ) + + +class UserGroupComplete(UserGroup): + """User group with both associated signals and users data.""" + + signals: List[Signal] = Field( + default_factory=list, + description="List of signals associated with this group." + ) + + users: List[User] = Field( + default_factory=list, + description="List of users who are members of this group." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + }, + "signals": [ + { + "id": 1, + "headline": "Signal 1", + "can_edit": True + } + ], + "users": [ + { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe" + } + ] + } + } + ) diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index ab27d7f..d678503 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -94,7 +94,7 @@ async def list_user_groups( raise -@router.get("/me", response_model=List[UserGroup]) +@router.get("/me") async def get_my_user_groups( request: Request, user: User = Depends(authenticate_user), @@ -110,7 +110,7 @@ async def get_my_user_groups( raise exceptions.not_found # Get groups this user is a member of - user_groups = await db.get_user_groups(cursor, user.id) + user_groups = await db.get_user_groups_with_signals_and_users(cursor, user.id) logger.info(f"User {user.id} retrieved {len(user_groups)} groups") return user_groups except Exception as e: @@ -183,6 +183,7 @@ async def get_my_user_groups_with_signals( async def create_user_group( request: Request, group_data: UserGroupCreate, + current_user: User = Depends(authenticate_user), # Get the current user cursor: AsyncCursor = Depends(db.yield_cursor), ): """Create a new user group.""" @@ -190,21 +191,32 @@ async def create_user_group( # Create the base group group = UserGroup(name=group_data.name) + # Initialize user_ids list with the current user's ID + user_ids = [] + if current_user.id: + user_ids.append(current_user.id) + # Handle email addresses if provided if group_data.users: - user_ids = [] for email in group_data.users: user = await db.read_user_by_email(cursor, email) - if user and user.id: + if user and user.id and user.id not in user_ids: # Avoid duplicates user_ids.append(user.id) - - if user_ids: - group.user_ids = user_ids + if user_ids: + group.user_ids = user_ids + # Create the group group_id = await db.create_user_group(cursor, group) - logger.info(f"Created user group {group_id} with name '{group.name}'") - return await db.read_user_group(cursor, group_id) + logger.info(f"Created user group {group_id} with name '{group.name}' and {len(user_ids)} users") + + # Retrieve and return the created group + created_group = await db.read_user_group(cursor, group_id) + if not created_group: + logger.error(f"Failed to retrieve newly created group with ID {group_id}") + raise exceptions.not_found + + return created_group except Exception as e: logger.error(f"Error creating user group: {str(e)}") bugsnag.notify( @@ -216,7 +228,8 @@ async def create_user_group( }, "group_data": { "name": group_data.name, - "users_count": len(group_data.users) if group_data.users else 0 + "users_count": len(group_data.users) if group_data.users else 0, + "current_user_id": current_user.id if current_user else None } } ) From b30d588ef6ea2983ed0b9fe0a99b2be9457f9158 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 13 May 2025 15:47:04 +0300 Subject: [PATCH 15/31] update user groups calls --- src/database/user_groups.py | 403 ++++++++++++++++-------------------- src/entities/user_groups.py | 4 +- src/routers/user_groups.py | 193 ++++++++++++++--- 3 files changed, 337 insertions(+), 263 deletions(-) diff --git a/src/database/user_groups.py b/src/database/user_groups.py index 4a30a4b..fc9811b 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -5,6 +5,7 @@ from psycopg import AsyncCursor import json import logging +from typing import List from ..entities import UserGroup, Signal, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete @@ -23,10 +24,55 @@ "get_user_group_with_users", "list_user_groups_with_users", "get_user_groups_with_signals_and_users", + "get_user_groups_with_users_by_user_id", ] logger = logging.getLogger(__name__) +# SQL Query Constants +SQL_SELECT_USER_GROUP = """ + SELECT + id, + name, + signal_ids, + user_ids, + collaborator_map + FROM + user_groups +""" + +SQL_SELECT_USERS = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name +""" + +SQL_SELECT_SIGNALS = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id +""" def handle_user_group_row(row) -> dict: """ @@ -69,6 +115,30 @@ def handle_user_group_row(row) -> dict: return data +async def get_users_for_group(cursor: AsyncCursor, user_ids: List[int]) -> List[User]: + """ + Helper function to fetch user details for a group. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_ids : List[int] + List of user IDs to fetch. + + Returns + ------- + List[User] + List of User objects. + """ + users = [] + if user_ids: + await cursor.execute(SQL_SELECT_USERS, (user_ids,)) + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + return users + async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: """ @@ -120,7 +190,7 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: return group_id -async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | None: +async def read_user_group(cursor: AsyncCursor, group_id: int, fetch_details: bool = False) -> UserGroup | UserGroupWithUsers | None: """ Read a user group from the database. @@ -130,31 +200,28 @@ async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | Non An async database cursor. group_id : int The ID of the user group to read. + fetch_details : bool, optional + If True, fetches detailed user information for each group member, + by default False Returns ------- - UserGroup | None - The user group if found, otherwise None. - """ - query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - WHERE - id = %s - ; + UserGroup | UserGroupWithUsers | None + The user group if found (with optional detailed user data), otherwise None. """ + query = f"{SQL_SELECT_USER_GROUP} WHERE id = %s;" await cursor.execute(query, (group_id,)) if (row := await cursor.fetchone()) is None: return None - + data = handle_user_group_row(row) - return UserGroup(**data) + + if fetch_details: + # Fetch detailed user information if requested + users = await get_users_for_group(cursor, data['user_ids']) + return UserGroupWithUsers(**data, users=users) + else: + return UserGroup(**data) async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None: @@ -235,19 +302,7 @@ async def list_user_groups(cursor: AsyncCursor) -> list[UserGroup]: list[UserGroup] A list of all user groups. """ - query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - ORDER BY - name - ; - """ + query = f"{SQL_SELECT_USER_GROUP} ORDER BY name;" await cursor.execute(query) result = [] @@ -387,21 +442,7 @@ async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: list[UserGroup] A list of user groups. """ - query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - WHERE - %s = ANY(user_ids) - ORDER BY - name - ; - """ + query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) ORDER BY name;" await cursor.execute(query, (user_id,)) result = [] @@ -440,7 +481,7 @@ async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: return row[0] if row and row[0] else [] -async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> list[UserGroupWithSignals]: +async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_users: bool = False) -> list[UserGroupWithSignals]: """ Get all groups that a user is a member of, along with the associated signals data. @@ -450,91 +491,72 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> lis An async database cursor. user_id : int The ID of the user. + fetch_users : bool, optional + If True, also fetches detailed user information for each group member, + by default False Returns ------- list[UserGroupWithSignals] - A list of user groups with associated signals. + A list of user groups with associated signals and optional user details. """ logger.debug("Getting user groups with signals for user_id: %s", user_id) - + # First get the groups the user belongs to user_groups = await get_user_groups(cursor, user_id) result = [] - + # For each group, fetch the signals data for group in user_groups: group_data = group.model_dump() signals = [] - users = [] - + # Get signals for this group if group.signal_ids: logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) - signals_query = """ - SELECT - s.*, - array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends - FROM - signals s - LEFT JOIN - connections c ON s.id = c.signal_id - WHERE - s.id = ANY(%s) - GROUP BY - s.id - ORDER BY - s.id - ; - """ - await cursor.execute(signals_query, (group.signal_ids,)) - - async for row in cursor: - signal_dict = dict(row) - # Check if user is a collaborator for this signal - can_edit = False - signal_id_str = str(signal_dict["id"]) - - if group.collaborator_map and signal_id_str in group.collaborator_map: - if user_id in group.collaborator_map[signal_id_str]: - can_edit = True - - signal_dict["can_edit"] = can_edit - signals.append(Signal(**signal_dict)) - - # Get users for this group - if group.user_ids: - logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) - users_query = """ - SELECT - id, - email, - role, - name, - unit, - acclab, - created_at - FROM - users - WHERE - id = ANY(%s) - ORDER BY - name - ; - """ - await cursor.execute(users_query, (group.user_ids,)) - - async for row in cursor: - user_data = dict(row) - users.append(User(**user_data)) - + + # Import read_signal function directly + from .signals import read_signal + + # Get each signal individually + for signal_id in group.signal_ids: + signal = await read_signal(cursor, signal_id) + if signal: + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_id) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + # Create a copy of the signal with can_edit flag + signal_dict = signal.model_dump() + signal_dict["can_edit"] = can_edit + signals.append(Signal(**signal_dict)) + + # Fetch full user details if requested + users = [] + if fetch_users and group.user_ids: + users = await get_users_for_group(cursor, group.user_ids) + # Create a UserGroupWithSignals instance - group_with_signals = UserGroupWithSignals( - **group_data, - signals=signals - ) + if fetch_users: + # If we fetched user details, create a UserGroupComplete + group_with_signals = UserGroupComplete( + **group_data, + signals=signals, + users=users + ) + else: + # Otherwise create a UserGroupWithSignals + group_with_signals = UserGroupWithSignals( + **group_data, + signals=signals + ) + result.append(group_with_signals) - + logger.debug("Found %s user groups with signals for user_id: %s", len(result), user_id) return result @@ -609,39 +631,7 @@ async def get_user_group_with_users(cursor: AsyncCursor, group_id: int) -> UserG # Convert to dict for modification group_data = group.model_dump() - users = [] - - # If there are users in the group, fetch their details - if group.user_ids: - logger.debug("Fetching users for group_id: %s, user_ids: %s", group_id, group.user_ids) - users_query = """ - SELECT - id, - email, - role, - name, - unit, - acclab, - created_at - FROM - users - WHERE - id = ANY(%s) - ORDER BY - name - ; - """ - await cursor.execute(users_query, (group.user_ids,)) - - user_count = 0 - async for row in cursor: - user_data = dict(row) - users.append(User(**user_data)) - user_count += 1 - - logger.debug("Found %s users for group_id: %s", user_count, group_id) - else: - logger.debug("No users found for group_id: %s", group_id) + users = await get_users_for_group(cursor, group.user_ids) # Create a UserGroupWithUsers instance return UserGroupWithUsers(**group_data, users=users) @@ -670,37 +660,7 @@ async def list_user_groups_with_users(cursor: AsyncCursor) -> list[UserGroupWith # For each group, get user details for group in groups: group_data = group.model_dump() - users = [] - - # If there are users in the group, fetch their details - if group.user_ids: - logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) - users_query = """ - SELECT - id, - email, - role, - name, - unit, - acclab, - created_at - FROM - users - WHERE - id = ANY(%s) - ORDER BY - name - ; - """ - await cursor.execute(users_query, (group.user_ids,)) - - user_count = 0 - async for row in cursor: - user_data = dict(row) - users.append(User(**user_data)) - user_count += 1 - - logger.debug("Found %s users for group_id: %s", user_count, group.id) + users = await get_users_for_group(cursor, group.user_ids) # Create a UserGroupWithUsers instance group_with_users = UserGroupWithUsers(**group_data, users=users) @@ -736,28 +696,11 @@ async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: i for group in user_groups: group_data = group.model_dump() signals = [] - users = [] # Get signals for this group if group.signal_ids: logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) - signals_query = """ - SELECT - s.*, - array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends - FROM - signals s - LEFT JOIN - connections c ON s.id = c.signal_id - WHERE - s.id = ANY(%s) - GROUP BY - s.id - ORDER BY - s.id - ; - """ - await cursor.execute(signals_query, (group.signal_ids,)) + await cursor.execute(SQL_SELECT_SIGNALS, (group.signal_ids,)) signal_count = 0 async for row in cursor: @@ -780,37 +723,7 @@ async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: i logger.debug("Found %s signals for group_id: %s", signal_count, group.id) # Get users for this group - if group.user_ids: - logger.debug("Fetching users for group_id: %s, user_ids: %s", group.id, group.user_ids) - users_query = """ - SELECT - id, - email, - role, - name, - unit, - acclab, - created_at - FROM - users - WHERE - id = ANY(%s) - ORDER BY - name - ; - """ - await cursor.execute(users_query, (group.user_ids,)) - - user_count = 0 - async for row in cursor: - user_data = dict(row) - - # Create User instance - user = User(**user_data) - users.append(user) - user_count += 1 - - logger.debug("Found %s users for group_id: %s", user_count, group.id) + users = await get_users_for_group(cursor, group.user_ids) # Create a UserGroupComplete instance group_complete = UserGroupComplete(**group_data, signals=signals, users=users) @@ -818,3 +731,39 @@ async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: i logger.debug("Found %s user groups with signals and users for user_id: %s", len(result), user_id) return result + + +async def get_user_groups_with_users_by_user_id(cursor: AsyncCursor, user_id: int) -> list[UserGroupWithUsers]: + """ + Get all groups that a user is a member of, along with detailed user information for each group member. + This is a more focused version that only fetches user data, not signals. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[UserGroupWithUsers] + A list of user groups with detailed user information. + """ + logger.debug("Getting user groups with users for user_id: %s", user_id) + + # First get the groups the user belongs to + query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) ORDER BY name;" + await cursor.execute(query, (user_id,)) + result = [] + + async for row in cursor: + group_data = handle_user_group_row(row) + users = await get_users_for_group(cursor, group_data['user_ids']) + + # Create a UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("Found %s user groups with users for user_id: %s", len(result), user_id) + return result diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py index 5b4668f..558dfce 100644 --- a/src/entities/user_groups.py +++ b/src/entities/user_groups.py @@ -23,9 +23,9 @@ class UserGroup(BaseEntity): default_factory=list, description="List of signal IDs associated with this group." ) - user_ids: List[int] = Field( + user_ids: List[str | int] = Field( default_factory=list, - description="List of user IDs who are members of this group." + description="List of user IDs (integers) or emails (strings) who are members of this group." ) collaborator_map: Dict[str, List[int]] = Field( default_factory=dict, diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index d678503..f9aaac7 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -4,7 +4,7 @@ import logging import bugsnag -from typing import Annotated, List, Optional, Union +from typing import Annotated, List, Optional, Union, Dict from fastapi import APIRouter, Depends, Path, Body, Query, HTTPException, Request from psycopg import AsyncCursor @@ -13,8 +13,9 @@ from .. import database as db from .. import exceptions from ..dependencies import require_admin, require_user -from ..entities import UserGroup, User, UserGroupWithSignals +from ..entities import UserGroup, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete, Signal from ..authentication import authenticate_user +from ..database.signals import read_signal # Set up logger for this module logger = logging.getLogger(__name__) @@ -28,6 +29,27 @@ class UserGroupCreate(BaseModel): users: Optional[List[str]] = None +class UserGroupUpdate(BaseModel): + id: int + name: str + signal_ids: List[int] = [] + user_ids: List[Union[str, int]] = [] # Can be either user IDs or email addresses + collaborator_map: Dict[str, List[int]] = {} + created_at: Optional[str] = None + status: Optional[str] = None + created_by: Optional[str] = None + created_for: Optional[str] = None + modified_at: Optional[str] = None + modified_by: Optional[str] = None + headline: Optional[str] = None + attachment: Optional[str] = None + steep_primary: Optional[str] = None + steep_secondary: Optional[str] = None + signature_primary: Optional[str] = None + signature_secondary: Optional[str] = None + sdgs: Optional[List[str]] = None + + class UserEmailIdentifier(BaseModel): email: str @@ -94,7 +116,7 @@ async def list_user_groups( raise -@router.get("/me") +@router.get("/me", response_model=List[UserGroup]) async def get_my_user_groups( request: Request, user: User = Depends(authenticate_user), @@ -110,7 +132,7 @@ async def get_my_user_groups( raise exceptions.not_found # Get groups this user is a member of - user_groups = await db.get_user_groups_with_signals_and_users(cursor, user.id) + user_groups = await db.get_user_groups_with_users_by_user_id(cursor, user.id) logger.info(f"User {user.id} retrieved {len(user_groups)} groups") return user_groups except Exception as e: @@ -132,32 +154,40 @@ async def get_my_user_groups( raise -@router.get("/me/with-signals", response_model=List[UserGroupWithSignals]) +@router.get("/me/with-signals", response_model=List[Union[UserGroupWithSignals, UserGroupComplete]]) async def get_my_user_groups_with_signals( request: Request, user: User = Depends(authenticate_user), cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(True, description="If true, includes detailed user information for each group member (defaults to true)") ): """ Get all user groups that the current user is a member of along with their signals data. - + This enhanced endpoint provides detailed information about each signal associated with the groups, including whether the current user has edit permissions for each signal. This is useful for: - + - Displaying a dashboard of all signals a user can access through their groups - Showing which signals the user can edit vs. view-only - Building collaborative workflows where users can see their assigned signals - + The response includes all signal details plus a `can_edit` flag for each signal indicating if the current user has edit permissions based on the group's collaborator_map. + + By default, detailed user information for each group member is included. Set the `include_users` + parameter to false to get only basic user ID references without detailed user information. """ try: if not user.id: logger.warning("User ID not found in get_my_user_groups_with_signals") raise exceptions.not_found - - # Get groups with signals for this user - user_groups_with_signals = await db.get_user_groups_with_signals(cursor, user.id) + + # Get groups with signals for this user, optionally including full user details + user_groups_with_signals = await db.get_user_groups_with_signals( + cursor, + user.id, + fetch_users=include_users + ) logger.info(f"User {user.id} retrieved {len(user_groups_with_signals)} groups with signals") return user_groups_with_signals except Exception as e: @@ -179,43 +209,49 @@ async def get_my_user_groups_with_signals( raise -@router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.post("", response_model=Union[UserGroup, UserGroupWithUsers], dependencies=[Depends(require_admin)]) async def create_user_group( request: Request, group_data: UserGroupCreate, current_user: User = Depends(authenticate_user), # Get the current user cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(False, description="If true, includes detailed user information for each group member") ): - """Create a new user group.""" + """ + Create a new user group. + + Optionally include detailed user information for each group member in the response + by setting the `include_users` query parameter to true. + """ try: # Create the base group group = UserGroup(name=group_data.name) - + # Initialize user_ids list with the current user's ID user_ids = [] if current_user.id: user_ids.append(current_user.id) - + # Handle email addresses if provided if group_data.users: for email in group_data.users: user = await db.read_user_by_email(cursor, email) if user and user.id and user.id not in user_ids: # Avoid duplicates user_ids.append(user.id) - + if user_ids: group.user_ids = user_ids - + # Create the group group_id = await db.create_user_group(cursor, group) logger.info(f"Created user group {group_id} with name '{group.name}' and {len(user_ids)} users") - + # Retrieve and return the created group - created_group = await db.read_user_group(cursor, group_id) + created_group = await db.read_user_group(cursor, group_id, fetch_details=include_users) if not created_group: logger.error(f"Failed to retrieve newly created group with ID {group_id}") raise exceptions.not_found - + return created_group except Exception as e: logger.error(f"Error creating user group: {str(e)}") @@ -236,18 +272,49 @@ async def create_user_group( raise -@router.get("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.get("/{group_id}", response_model=Union[UserGroup, UserGroupWithUsers, UserGroupComplete], dependencies=[Depends(require_admin)]) async def read_user_group( request: Request, group_id: Annotated[int, Path(description="The ID of the user group to retrieve")], cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(True, description="If true, includes detailed user and signal information (defaults to true)") ): - """Get a user group by ID.""" + """ + Get a user group by ID with detailed information. + + By default, includes detailed user and signal information. Set the `include_users` + parameter to false to get only the basic group data without user and signal details. + """ try: - if (group := await db.read_user_group(cursor, group_id)) is None: + # First, get the basic group + if (group := await db.read_user_group(cursor, group_id, fetch_details=include_users)) is None: logger.warning(f"User group {group_id} not found") raise exceptions.not_found - logger.info(f"Retrieved user group {group_id}") + + # Log basic info to avoid huge logs + logger.info(f"Retrieved user group {group_id} with name '{group.name}'") + + # If include_users is true and the group has signals, fetch those signals too + if include_users and hasattr(group, 'user_ids') and group.user_ids and hasattr(group, 'signal_ids') and group.signal_ids: + # Get signals for this group and prepare a complete response + signals = [] + + # Import the signals database function directly + from ..database.signals import read_signal + + # Fetch each signal individually + for signal_id in group.signal_ids: + signal = await read_signal(cursor, signal_id) + if signal: + signals.append(signal) + + # Convert to a UserGroupComplete if we have both users and signals + if hasattr(group, 'users') and group.users: + return UserGroupComplete( + **group.model_dump(), + signals=signals + ) + return group except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions @@ -265,23 +332,81 @@ async def read_user_group( raise -@router.put("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.put("/{group_id}", response_model=Union[UserGroup, UserGroupWithUsers, UserGroupComplete], dependencies=[Depends(require_admin)]) async def update_user_group( request: Request, group_id: Annotated[int, Path(description="The ID of the user group to update")], - group: UserGroup, + group_data: UserGroupUpdate, cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(False, description="If true, includes detailed user information for each group member after update") ): - """Update a user group.""" + """ + Update a user group. + + Optionally include detailed user information for each group member in the response + by setting the `include_users` query parameter to true. + + This endpoint accepts both user IDs (integers) and email addresses (strings) in + the user_ids field. Email addresses will be automatically converted to user IDs. + """ try: - if group_id != group.id: - logger.warning(f"ID mismatch: path ID {group_id} != body ID {group.id}") + if group_id != group_data.id: + logger.warning(f"ID mismatch: path ID {group_id} != body ID {group_data.id}") raise exceptions.id_mismatch + + # Process user_ids field in case it contains emails instead of integer IDs + processed_user_ids = [] + for user_id in group_data.user_ids: + if isinstance(user_id, str): + # Check if it's an email (contains @ sign) + if '@' in user_id: + # This looks like an email address, try to find the user ID + user = await db.read_user_by_email(cursor, user_id) + if user and user.id: + processed_user_ids.append(user.id) + else: + logger.warning(f"User with email {user_id} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_id} not found") + else: + # String but not an email, try to convert to int if it's a digit string + try: + processed_user_ids.append(int(user_id)) + except (ValueError, TypeError): + logger.warning(f"Invalid user ID format: {user_id}") + raise HTTPException(status_code=400, detail=f"Invalid user ID format: {user_id}") + else: + # Already an int + processed_user_ids.append(user_id) + + logger.info(f"Processed user_ids: {processed_user_ids}") + + # Convert UserGroupUpdate to UserGroup + group = UserGroup( + id=group_data.id, + name=group_data.name, + signal_ids=group_data.signal_ids, + user_ids=processed_user_ids, # Use the processed user IDs + collaborator_map=group_data.collaborator_map, + created_at=group_data.created_at, + status=group_data.status, + created_by=group_data.created_by, + created_for=group_data.created_for, + modified_at=group_data.modified_at, + modified_by=group_data.modified_by, + headline=group_data.headline, + attachment=group_data.attachment, + steep_primary=group_data.steep_primary, + steep_secondary=group_data.steep_secondary, + signature_primary=group_data.signature_primary, + signature_secondary=group_data.signature_secondary, + sdgs=group_data.sdgs + ) + if (updated_id := await db.update_user_group(cursor, group)) is None: logger.warning(f"User group {group_id} not found for update") raise exceptions.not_found logger.info(f"Updated user group {updated_id}") - return await db.read_user_group(cursor, updated_id) + return await db.read_user_group(cursor, updated_id, fetch_details=include_users) except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error updating user group {group_id}: {str(e)}") @@ -294,10 +419,10 @@ async def update_user_group( }, "group_id": group_id, "group_data": { - "id": group.id, - "name": group.name, - "user_count": len(group.user_ids) if group.user_ids else 0, - "signal_count": len(group.signal_ids) if group.signal_ids else 0 + "id": group_data.id, + "name": group_data.name, + "user_count": len(group_data.user_ids) if group_data.user_ids else 0, + "signal_count": len(group_data.signal_ids) if group_data.signal_ids else 0 } } ) From 75602d8785f7fbcac229ee0f95247c46bbeb31e8 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 14 May 2025 19:15:41 +0300 Subject: [PATCH 16/31] update user group calls --- src/database/signals.py | 447 +++++++++++++++++++++++++++--------- src/database/user_groups.py | 40 ++-- src/entities/signal.py | 115 +++++++++- src/entities/user_groups.py | 16 +- src/routers/signals.py | 170 ++++++++++++-- src/routers/user_groups.py | 33 ++- 6 files changed, 671 insertions(+), 150 deletions(-) diff --git a/src/database/signals.py b/src/database/signals.py index efbab5f..56d51fe 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -2,15 +2,20 @@ CRUD operations for signal entities. """ +import logging +from typing import List from psycopg import AsyncCursor, sql from .. import storage -from ..entities import Signal, SignalFilters, SignalPage, Status +from ..entities import Signal, SignalFilters, SignalPage, Status, SignalWithUserGroups, UserGroup + +logger = logging.getLogger(__name__) __all__ = [ "search_signals", "create_signal", "read_signal", + "read_signal_with_user_groups", "update_signal", "delete_signal", "read_user_signals", @@ -102,10 +107,10 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP return page -async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: +async def create_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: List[int] = None) -> int: """ - Insert a signal into the database, connect it to trends and upload an attachment - to Azure Blob Storage if applicable. + Insert a signal into the database, connect it to trends, upload an attachment + to Azure Blob Storage if applicable, and add it to user groups if specified. Parameters ---------- @@ -114,77 +119,131 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: signal : Signal A signal object to insert. The following fields are supported: - secondary_location: list[str] | None + user_group_ids : List[int], optional + List of user group IDs to add the signal to. Returns ------- signal_id : int An ID of the signal in the database. """ - query = """ - INSERT INTO signals ( - status, - created_by, - created_for, - modified_by, - headline, - description, - steep_primary, - steep_secondary, - signature_primary, - signature_secondary, - sdgs, - created_unit, - url, - relevance, - keywords, - location, - secondary_location, - score - ) - VALUES ( - %(status)s, - %(created_by)s, - %(created_for)s, - %(modified_by)s, - %(headline)s, - %(description)s, - %(steep_primary)s, - %(steep_secondary)s, - %(signature_primary)s, - %(signature_secondary)s, - %(sdgs)s, - %(created_unit)s, - %(url)s, - %(relevance)s, - %(keywords)s, - %(location)s, - %(secondary_location)s, - %(score)s - ) - RETURNING - id - ; - """ - await cursor.execute(query, signal.model_dump()) - row = await cursor.fetchone() - signal_id = row["id"] - - # add connected trends if any are present - for trend_id in signal.connected_trends or []: - query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" - await cursor.execute(query, (signal_id, trend_id, signal.created_by)) - - # upload an image + logger.info(f"Creating new signal with headline: '{signal.headline}', created by: {signal.created_by}") + if user_group_ids: + logger.info(f"Will add signal to user groups: {user_group_ids}") + + # Insert signal into database + try: + query = """ + INSERT INTO signals ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + created_unit, + url, + relevance, + keywords, + location, + secondary_location, + score + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(created_unit)s, + %(url)s, + %(relevance)s, + %(keywords)s, + %(location)s, + %(secondary_location)s, + %(score)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + signal_id = row["id"] + logger.info(f"Signal created successfully with ID: {signal_id}") + except Exception as e: + logger.error(f"Failed to create signal: {e}") + raise + + # Add connected trends if any are present + try: + if signal.connected_trends: + logger.info(f"Adding connected trends for signal {signal_id}: {signal.connected_trends}") + for trend_id in signal.connected_trends: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + logger.info(f"Successfully added {len(signal.connected_trends)} trends to signal {signal_id}") + except Exception as e: + logger.error(f"Error adding connected trends to signal {signal_id}: {e}") + # Continue execution despite error with trends + + # Upload an image if provided if signal.attachment is not None: + logger.info(f"Uploading image attachment for signal {signal_id}") try: blob_url = await storage.upload_image( signal_id, "signals", signal.attachment ) - except Exception as e: - print(e) - else: query = "UPDATE signals SET attachment = %s WHERE id = %s;" await cursor.execute(query, (blob_url, signal_id)) + logger.info(f"Image attachment uploaded successfully for signal {signal_id}") + except Exception as e: + logger.error(f"Failed to upload image for signal {signal_id}: {e}") + # Continue execution despite attachment error + + # Add signal to user groups if specified + if user_group_ids: + logger.info(f"Processing user group assignments for signal {signal_id}") + from . import user_groups + groups_added = 0 + groups_failed = 0 + + for group_id in user_group_ids: + try: + logger.debug(f"Attempting to add signal {signal_id} to group {group_id}") + # Get the group + group = await user_groups.read_user_group(cursor, group_id) + if group is not None: + # Add signal to group's signal_ids + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + signal_ids.append(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_added += 1 + logger.info(f"Signal {signal_id} added to group {group_id} ({group.name})") + else: + logger.info(f"Signal {signal_id} already in group {group_id} ({group.name})") + else: + logger.warning(f"Group with ID {group_id} not found, skipping") + groups_failed += 1 + except Exception as e: + logger.error(f"Error adding signal {signal_id} to group {group_id}: {e}") + groups_failed += 1 + + logger.info(f"User group assignment complete for signal {signal_id}: {groups_added} successful, {groups_failed} failed") + return signal_id @@ -230,10 +289,85 @@ async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: return Signal(**row) -async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: +async def read_signal_with_user_groups(cursor: AsyncCursor, uid: int) -> SignalWithUserGroups | None: + """ + Read a signal from the database with its associated user groups. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to retrieve data for. + + Returns + ------- + SignalWithUserGroups | None + A signal with its user groups if it exists, otherwise None. + """ + logger.info(f"Fetching signal {uid} with its user groups") + + try: + # First, get the signal + signal = await read_signal(cursor, uid) + if signal is None: + logger.warning(f"Signal with ID {uid} not found") + return None + + logger.info(f"Found signal with ID {uid}: '{signal.headline}'") + + # Convert to SignalWithUserGroups + signal_with_groups = SignalWithUserGroups(**signal.model_dump()) + + # Get all groups that have this signal in their signal_ids + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map + FROM + user_groups + WHERE + %s = ANY(signal_ids) + ORDER BY + name; + """ + + await cursor.execute(query, (uid,)) + + # Add groups to the signal + from . import user_groups as ug_module + signal_with_groups.user_groups = [] + group_count = 0 + + async for row in cursor: + try: + # Convert row to dict + group_data = ug_module.handle_user_group_row(row) + # Create UserGroup from dict + group = UserGroup(**group_data) + signal_with_groups.user_groups.append(group) + group_count += 1 + logger.debug(f"Added group {group.id} ({group.name}) to signal {uid}") + except Exception as e: + logger.error(f"Error processing group for signal {uid}: {e}") + + logger.info(f"Signal {uid} is associated with {group_count} user groups") + return signal_with_groups + + except Exception as e: + logger.error(f"Error retrieving signal {uid} with user groups: {e}") + raise + + +async def update_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: List[int] = None) -> int | None: """ Update a signal in the database, update its connected trends and update an attachment - in the Azure Blob Storage if applicable. + in the Azure Blob Storage if applicable. Optionally update the user groups the signal + belongs to. Parameters ---------- @@ -242,56 +376,153 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: signal : Signal A signal object to update. The following fields are supported: - secondary_location: list[str] | None + user_group_ids : List[int], optional + List of user group IDs to add the signal to. Returns ------- int | None A signal ID if the update has been performed, otherwise None. """ - query = """ - UPDATE - signals - SET - status = COALESCE(%(status)s, status), - created_for = COALESCE(%(created_for)s, created_for), - modified_at = NOW(), - modified_by = %(modified_by)s, - headline = COALESCE(%(headline)s, headline), - description = COALESCE(%(description)s, description), - steep_primary = COALESCE(%(steep_primary)s, steep_primary), - steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), - signature_primary = COALESCE(%(signature_primary)s, signature_primary), - signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), - sdgs = COALESCE(%(sdgs)s, sdgs), - created_unit = COALESCE(%(created_unit)s, created_unit), - url = COALESCE(%(url)s, url), - relevance = COALESCE(%(relevance)s, relevance), - keywords = COALESCE(%(keywords)s, keywords), - location = COALESCE(%(location)s, location), - secondary_location = COALESCE(%(secondary_location)s, secondary_location), - score = COALESCE(%(score)s, score) - WHERE - id = %(id)s - RETURNING - id - ; - """ - await cursor.execute(query, signal.model_dump()) - row = await cursor.fetchone() - if row is None: - return None - signal_id = row["id"] - - # update connected trends if any are present - await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) - for trend_id in signal.connected_trends or []: - query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" - await cursor.execute(query, (signal_id, trend_id, signal.created_by)) - - # upload an image if it is not a URL to an existing image - blob_url = await storage.update_image(signal_id, "signals", signal.attachment) - query = "UPDATE signals SET attachment = %s WHERE id = %s;" - await cursor.execute(query, (blob_url, signal_id)) + logger.info(f"Updating signal with ID: {signal.id}, modified by: {signal.modified_by}") + if user_group_ids is not None: + logger.info(f"Will update user groups for signal {signal.id}: {user_group_ids}") + + # Update signal in database + try: + query = """ + UPDATE + signals + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + created_unit = COALESCE(%(created_unit)s, created_unit), + url = COALESCE(%(url)s, url), + relevance = COALESCE(%(relevance)s, relevance), + keywords = COALESCE(%(keywords)s, keywords), + location = COALESCE(%(location)s, location), + secondary_location = COALESCE(%(secondary_location)s, secondary_location), + score = COALESCE(%(score)s, score) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + if row is None: + logger.warning(f"Signal with ID {signal.id} not found for update") + return None + signal_id = row["id"] + logger.info(f"Signal {signal_id} updated successfully in database") + except Exception as e: + logger.error(f"Failed to update signal {signal.id}: {e}") + raise + + # Update connected trends + try: + logger.info(f"Updating connected trends for signal {signal_id}") + await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) + logger.debug(f"Removed existing trend connections for signal {signal_id}") + + if signal.connected_trends: + trends_added = 0 + for trend_id in signal.connected_trends: + try: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + trends_added += 1 + except Exception as trend_e: + logger.warning(f"Failed to connect trend {trend_id} to signal {signal_id}: {trend_e}") + + logger.info(f"Added {trends_added} trend connections to signal {signal_id}") + except Exception as e: + logger.error(f"Error updating trend connections for signal {signal_id}: {e}") + # Continue execution despite trend connection errors + + # Update image attachment + try: + logger.info(f"Updating image attachment for signal {signal_id}") + blob_url = await storage.update_image(signal_id, "signals", signal.attachment) + if blob_url is not None: + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + logger.info(f"Image attachment updated successfully for signal {signal_id}") + else: + logger.info(f"No image attachment update needed for signal {signal_id}") + except Exception as e: + logger.error(f"Failed to update image for signal {signal_id}: {e}") + # Continue execution despite attachment error + + # Update signal's user groups if specified + if user_group_ids is not None: + logger.info(f"Processing user group updates for signal {signal_id}") + try: + from . import user_groups + + # Get all groups that currently have this signal + query = """ + SELECT id, name + FROM user_groups + WHERE %s = ANY(signal_ids); + """ + await cursor.execute(query, (signal_id,)) + current_groups = {} + async for row in cursor: + current_groups[row["id"]] = row["name"] + + logger.info(f"Signal {signal_id} is currently in groups: {list(current_groups.keys())}") + + # Remove signal from groups not in user_group_ids + groups_removed = 0 + groups_to_remove_from = [g for g in current_groups.keys() if g not in user_group_ids] + for group_id in groups_to_remove_from: + try: + logger.debug(f"Removing signal {signal_id} from group {group_id} ({current_groups[group_id]})") + group = await user_groups.read_user_group(cursor, group_id) + if group is not None and signal_id in group.signal_ids: + signal_ids = group.signal_ids.copy() + signal_ids.remove(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_removed += 1 + logger.info(f"Signal {signal_id} removed from group {group_id} ({group.name})") + except Exception as e: + logger.error(f"Failed to remove signal {signal_id} from group {group_id}: {e}") + + # Add signal to new groups + groups_added = 0 + for group_id in user_group_ids: + if group_id not in current_groups: + try: + logger.debug(f"Adding signal {signal_id} to group {group_id}") + group = await user_groups.read_user_group(cursor, group_id) + if group is not None: + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + signal_ids.append(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_added += 1 + logger.info(f"Signal {signal_id} added to group {group_id} ({group.name})") + else: + logger.warning(f"Group with ID {group_id} not found, skipping") + except Exception as e: + logger.error(f"Failed to add signal {signal_id} to group {group_id}: {e}") + + logger.info(f"User group assignments updated for signal {signal_id}: {groups_added} added, {groups_removed} removed") + except Exception as e: + logger.error(f"Error processing user group updates for signal {signal_id}: {e}") return signal_id diff --git a/src/database/user_groups.py b/src/database/user_groups.py index fc9811b..50cf118 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -36,6 +36,7 @@ name, signal_ids, user_ids, + admin_ids, collaborator_map FROM user_groups @@ -94,13 +95,15 @@ def handle_user_group_row(row) -> dict: data['name'] = row["name"] data['signal_ids'] = row["signal_ids"] or [] data['user_ids'] = row["user_ids"] or [] + data['admin_ids'] = row["admin_ids"] or [] collab_map = row["collaborator_map"] else: data['id'] = row[0] data['name'] = row[1] data['signal_ids'] = row[2] or [] data['user_ids'] = row[3] or [] - collab_map = row[4] + data['admin_ids'] = row[4] or [] + collab_map = row[5] # Handle collaborator_map field data['collaborator_map'] = {} @@ -157,26 +160,31 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: The ID of the created user group. """ # Convert model to dict and ensure collaborator_map is a JSON string - group_data = group.model_dump(exclude={"id"}) - group_data["collaborator_map"] = json.dumps(group_data["collaborator_map"]) + name = group.name + signal_ids = group.signal_ids + user_ids = group.user_ids + admin_ids = group.admin_ids + collaborator_map = json.dumps(group.collaborator_map) query = """ INSERT INTO user_groups ( name, signal_ids, user_ids, + admin_ids, collaborator_map ) VALUES ( - %(name)s, - %(signal_ids)s, - %(user_ids)s, - %(collaborator_map)s + %s, + %s, + %s, + %s, + %s ) RETURNING id ; """ - await cursor.execute(query, group_data) + await cursor.execute(query, (name, signal_ids, user_ids, admin_ids, collaborator_map)) row = await cursor.fetchone() if row is None: raise ValueError("Failed to create user group") @@ -250,6 +258,7 @@ async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None name = %(name)s, signal_ids = %(signal_ids)s, user_ids = %(user_ids)s, + admin_ids = %(admin_ids)s, collaborator_map = %(collaborator_map)s WHERE id = %(id)s RETURNING id @@ -428,7 +437,7 @@ async def remove_user_from_group(cursor: AsyncCursor, group_id: int, user_id: in async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: """ - Get all groups that a user is a member of. + Get all groups that a user is a member of or an admin of. Parameters ---------- @@ -442,8 +451,8 @@ async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: list[UserGroup] A list of user groups. """ - query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) ORDER BY name;" - await cursor.execute(query, (user_id,)) + query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) ORDER BY name;" + await cursor.execute(query, (user_id, user_id)) result = [] async for row in cursor: @@ -735,7 +744,8 @@ async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: i async def get_user_groups_with_users_by_user_id(cursor: AsyncCursor, user_id: int) -> list[UserGroupWithUsers]: """ - Get all groups that a user is a member of, along with detailed user information for each group member. + Get all groups that a user is a member of or an admin of, along with detailed user information + for each group member. This is a more focused version that only fetches user data, not signals. Parameters @@ -752,9 +762,9 @@ async def get_user_groups_with_users_by_user_id(cursor: AsyncCursor, user_id: in """ logger.debug("Getting user groups with users for user_id: %s", user_id) - # First get the groups the user belongs to - query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) ORDER BY name;" - await cursor.execute(query, (user_id,)) + # Get groups where the user is a member or an admin + query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) ORDER BY name;" + await cursor.execute(query, (user_id, user_id)) result = [] async for row in cursor: diff --git a/src/entities/signal.py b/src/entities/signal.py index 86360c3..7658be2 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -2,13 +2,17 @@ Entity (model) definitions for signal objects. """ -from typing import List, Dict +from typing import List, Dict, TYPE_CHECKING, Any, Optional from pydantic import ConfigDict, Field, field_validator, model_validator from . import utils from .base import BaseEntity -__all__ = ["Signal"] +# Import only for type checking to avoid circular imports +if TYPE_CHECKING: + from .user_groups import UserGroup + +__all__ = ["Signal", "SignalWithUserGroups", "SignalCreate", "SignalUpdate"] class Signal(BaseEntity): @@ -79,3 +83,110 @@ def convert_secondary_location(cls, data): } } ) + + +class SignalWithUserGroups(Signal): + """ + Extended signal entity that includes the user groups it belongs to. + This model is used in API responses to provide a signal with its associated user groups. + """ + + user_groups: List[Any] = Field( + default_factory=list, + description="List of user groups this signal belongs to." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "created_unit": "HQ", + "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", + "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", + "keywords": ["economy", "governance"], + "location": "Global", + "favorite": False, + "is_draft": True, + "group_ids": [1, 2], + "collaborators": [1, 2, 3], + "secondary_location": ["Africa", "Asia"], + "score": None, + "connected_trends": [101, 102], + "user_groups": [ + { + "id": 1, + "name": "Research Team", + "signal_ids": [1, 2, 3], + "user_ids": [101, 102], + "admin_ids": [101], + "collaborator_map": {"1": [101, 102]} + }, + { + "id": 2, + "name": "Policy Team", + "signal_ids": [1, 4, 5], + "user_ids": [103, 104], + "admin_ids": [103], + "collaborator_map": {"1": [103]} + } + ] + } + } + ) + + +class SignalCreate(Signal): + """ + Model for signal creation request that includes user_group_ids. + This is used for the request body in the POST endpoint. + """ + + user_group_ids: Optional[List[int]] = Field( + default=None, + description="IDs of user groups to add the signal to after creation" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "headline": "New Signal Example", + "description": "This is a new signal with user groups.", + "steep_primary": "T", + "steep_secondary": ["S", "P"], + "signature_primary": "Shift", + "signature_secondary": ["Risk"], + "keywords": ["example", "test"], + "location": "Global", + "user_group_ids": [1, 2] + } + } + ) + + +class SignalUpdate(Signal): + """ + Model for signal update request that includes user_group_ids. + This is used for the request body in the PUT endpoint. + """ + + user_group_ids: Optional[List[int]] = Field( + default=None, + description="IDs of user groups to replace the signal's current group associations" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "headline": "Updated Signal Example", + "description": "This is an updated signal with new user groups.", + "steep_primary": "T", + "steep_secondary": ["S", "P"], + "signature_primary": "Shift", + "signature_secondary": ["Risk"], + "keywords": ["updated", "test"], + "location": "Global", + "user_group_ids": [2, 3] + } + } + ) diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py index 558dfce..95135a6 100644 --- a/src/entities/user_groups.py +++ b/src/entities/user_groups.py @@ -2,13 +2,16 @@ Entity (model) definitions for user group objects. """ -from typing import Dict, List +from typing import Dict, List, TYPE_CHECKING, Any from pydantic import ConfigDict, Field from .base import BaseEntity -from .signal import Signal from .user import User +# Import only for type checking to avoid circular imports +if TYPE_CHECKING: + from .signal import Signal + __all__ = ["UserGroup", "UserGroupWithSignals", "UserGroupWithUsers", "UserGroupComplete"] @@ -27,6 +30,10 @@ class UserGroup(BaseEntity): default_factory=list, description="List of user IDs (integers) or emails (strings) who are members of this group." ) + admin_ids: List[int] = Field( + default_factory=list, + description="List of user IDs who have admin privileges for this group." + ) collaborator_map: Dict[str, List[int]] = Field( default_factory=dict, description="Map of signal IDs to lists of user IDs that can collaborate on that signal." @@ -39,6 +46,7 @@ class UserGroup(BaseEntity): "name": "CDO", "signal_ids": [1, 2, 3], "user_ids": [1, 2, 3], + "admin_ids": [1], "collaborator_map": { "1": [1, 2], "2": [1, 3], @@ -52,7 +60,7 @@ class UserGroup(BaseEntity): class UserGroupWithSignals(UserGroup): """User group with associated signals data.""" - signals: List[Signal] = Field( + signals: List[Any] = Field( default_factory=list, description="List of signals associated with this group." ) @@ -117,7 +125,7 @@ class UserGroupWithUsers(UserGroup): class UserGroupComplete(UserGroup): """User group with both associated signals and users data.""" - signals: List[Signal] = Field( + signals: List[Any] = Field( default_factory=list, description="List of signals associated with this group." ) diff --git a/src/routers/signals.py b/src/routers/signals.py index 4bee7e5..5d87f4a 100644 --- a/src/routers/signals.py +++ b/src/routers/signals.py @@ -13,7 +13,10 @@ from .. import exceptions, genai, utils from ..authentication import authenticate_user from ..dependencies import require_admin, require_creator, require_curator, require_user -from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User, UserGroup +from ..entities import ( + Role, Signal, SignalFilters, SignalPage, Status, User, UserGroup, + SignalWithUserGroups, SignalCreate, SignalUpdate +) logger = logging.getLogger(__name__) @@ -82,19 +85,54 @@ async def generate_signal( @router.post("", response_model=Signal, status_code=201) async def create_signal( - signal: Signal, + signal_data: SignalCreate, user: User = Depends(require_user), cursor: AsyncCursor = Depends(db.yield_cursor), ): """ Submit a signal to the database. If the signal has a base64 encoded image attachment, it will be uploaded to Azure Blob Storage. + + Optionally, the signal can be added to one or more user groups by specifying + user_group_ids in the request body. """ - signal.created_by = user.email - signal.modified_by = user.email - signal.created_unit = user.unit - signal_id = await db.create_signal(cursor, signal) - return await db.read_signal(cursor, signal_id) + logger.info(f"Creating new signal requested by user: {user.email}") + + # Extract standard Signal fields and user_group_ids + signal = Signal(**signal_data.model_dump(exclude={"user_group_ids"})) + user_group_ids = signal_data.user_group_ids + + if user_group_ids: + logger.info(f"With user_group_ids: {user_group_ids}") + + try: + # Prepare the signal object + signal.created_by = user.email + signal.modified_by = user.email + signal.created_unit = user.unit + + # Create the signal in the database with user groups if specified + signal_id = await db.create_signal(cursor, signal, user_group_ids) + logger.info(f"Signal created successfully with ID: {signal_id}") + + # Read back the created signal to return it + created_signal = await db.read_signal(cursor, signal_id) + if not created_signal: + logger.error(f"Failed to read newly created signal with ID: {signal_id}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Signal was created but could not be retrieved" + ) + + return created_signal + + except Exception as e: + logger.error(f"Error creating signal: {str(e)}") + # Raise HTTPException with appropriate status code + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create signal: {str(e)}" + ) @router.get("/me", response_model=list[Signal]) @@ -145,20 +183,122 @@ async def read_signal( return signal +@router.get("/{uid}/with-user-groups", response_model=SignalWithUserGroups) +async def read_signal_with_user_groups( + uid: Annotated[int, Path(description="The ID of the signal to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a signal from the database along with all user groups it belongs to. + This endpoint provides a comprehensive view of a signal including its group associations. + """ + logger.info("Reading signal with user groups for ID: %s, user: %s", uid, user.email) + + try: + # Fetch signal with user groups + signal = await db.read_signal_with_user_groups(cursor, uid) + + if signal is None: + logger.warning("Signal not found with ID: %s", uid) + raise exceptions.not_found + + # Check permissions + if user.role == Role.VISITOR and signal.status != Status.APPROVED: + logger.warning( + "Permission denied - visitor trying to access non-approved signal. Status: %s", + signal.status + ) + raise exceptions.permission_denied + + # Check if the signal is favorited by the user + try: + is_favorite = await db.is_signal_favorited(cursor, user.email, uid) + signal.favorite = is_favorite + logger.debug(f"Favorite status for signal {uid}, user {user.email}: {is_favorite}") + except Exception as fav_e: + logger.error(f"Failed to check favorite status: {str(fav_e)}") + # Continue even if favorite check fails + signal.favorite = False + + logger.info(f"Successfully retrieved signal {uid} with {len(signal.user_groups)} user groups") + return signal + + except exceptions.not_found: + raise + except exceptions.permission_denied: + raise + except Exception as e: + logger.error(f"Error retrieving signal {uid} with user groups: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve signal with user groups: {str(e)}" + ) + + @router.put("/{uid}", response_model=Signal) async def update_signal( uid: Annotated[int, Path(description="The ID of the signal to be updated")], - signal: Signal, + signal_data: SignalUpdate, user: User = Depends(require_creator), cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Update a signal in the database.""" - if uid != signal.id: - raise exceptions.id_mismatch - signal.modified_by = user.email - if (signal_id := await db.update_signal(cursor, signal)) is None: - raise exceptions.not_found - return await db.read_signal(cursor, signal_id) + """ + Update a signal in the database. + + Optionally, the signal's user group associations can be updated by specifying + user_group_ids in the request body. If provided, the signal will only belong + to the specified groups, replacing any previous group associations. + """ + logger.info(f"Updating signal {uid} requested by user: {user.email}") + + # Extract signal data and user_group_ids from the request body + signal = Signal(**signal_data.model_dump(exclude={"user_group_ids"})) + user_group_ids = signal_data.user_group_ids + + if user_group_ids is not None: + logger.info(f"With user_group_ids: {user_group_ids}") + + try: + # Verify ID match + if uid != signal.id: + logger.warning(f"ID mismatch: URL ID {uid} doesn't match payload ID {signal.id}") + raise exceptions.id_mismatch + + # Update metadata + signal.modified_by = user.email + + # Update the signal in the database + logger.info(f"Updating signal {uid} in database") + signal_id = await db.update_signal(cursor, signal, user_group_ids) + + if signal_id is None: + logger.warning(f"Signal {uid} not found during update") + raise exceptions.not_found + + logger.info(f"Signal {uid} updated successfully") + + # Read back the updated signal + updated_signal = await db.read_signal(cursor, signal_id) + if not updated_signal: + logger.error(f"Failed to read updated signal with ID: {signal_id}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Signal was updated but could not be retrieved" + ) + + return updated_signal + + except exceptions.id_mismatch: + raise + except exceptions.not_found: + raise + except Exception as e: + logger.error(f"Error updating signal {uid}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update signal: {str(e)}" + ) @router.delete("/{uid}", response_model=Signal, dependencies=[Depends(require_creator)]) diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index f9aaac7..9998e31 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -123,7 +123,7 @@ async def get_my_user_groups( cursor: AsyncCursor = Depends(db.yield_cursor), ): """ - Get all user groups that the current user is a member of. + Get all user groups that the current user is a member of or an admin of. This endpoint is accessible to all authenticated users. """ try: @@ -131,9 +131,9 @@ async def get_my_user_groups( logger.warning("User ID not found in get_my_user_groups") raise exceptions.not_found - # Get groups this user is a member of + # Get groups where this user is a member or an admin user_groups = await db.get_user_groups_with_users_by_user_id(cursor, user.id) - logger.info(f"User {user.id} retrieved {len(user_groups)} groups") + logger.info(f"User {user.id} retrieved {len(user_groups)} groups (as member or admin)") return user_groups except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions @@ -162,7 +162,7 @@ async def get_my_user_groups_with_signals( include_users: bool = Query(True, description="If true, includes detailed user information for each group member (defaults to true)") ): """ - Get all user groups that the current user is a member of along with their signals data. + Get all user groups that the current user is a member of or an admin of along with their signals data. This enhanced endpoint provides detailed information about each signal associated with the groups, including whether the current user has edit permissions for each signal. This is useful for: @@ -215,13 +215,17 @@ async def create_user_group( group_data: UserGroupCreate, current_user: User = Depends(authenticate_user), # Get the current user cursor: AsyncCursor = Depends(db.yield_cursor), - include_users: bool = Query(False, description="If true, includes detailed user information for each group member") + include_users: bool = Query(False, description="If true, includes detailed user information for each group member"), + admins: List[str] = Query(None, description="List of user emails to set as admins in the group") ): """ Create a new user group. Optionally include detailed user information for each group member in the response by setting the `include_users` query parameter to true. + + The current authenticated user is automatically added as both a member and an admin of the group. + Additional admin users can be specified using the `admins` query parameter. """ try: # Create the base group @@ -229,8 +233,11 @@ async def create_user_group( # Initialize user_ids list with the current user's ID user_ids = [] + admin_ids = [] + if current_user.id: user_ids.append(current_user.id) + admin_ids.append(current_user.id) # Make current user an admin # Handle email addresses if provided if group_data.users: @@ -239,12 +246,25 @@ async def create_user_group( if user and user.id and user.id not in user_ids: # Avoid duplicates user_ids.append(user.id) + # Handle admin emails if provided + if admins: + for email in admins: + user = await db.read_user_by_email(cursor, email) + if user and user.id: + if user.id not in user_ids: # If not already in user_ids, add them + user_ids.append(user.id) + if user.id not in admin_ids: # Avoid duplicates in admin_ids + admin_ids.append(user.id) + if user_ids: group.user_ids = user_ids + + if admin_ids: + group.admin_ids = admin_ids # Create the group group_id = await db.create_user_group(cursor, group) - logger.info(f"Created user group {group_id} with name '{group.name}' and {len(user_ids)} users") + logger.info(f"Created user group {group_id} with name '{group.name}', {len(user_ids)} users, and {len(admin_ids)} admins") # Retrieve and return the created group created_group = await db.read_user_group(cursor, group_id, fetch_details=include_users) @@ -265,6 +285,7 @@ async def create_user_group( "group_data": { "name": group_data.name, "users_count": len(group_data.users) if group_data.users else 0, + "admins_count": len(admins) if admins else 0, "current_user_id": current_user.id if current_user else None } } From 5203c91cafae19ce222e60e71a6b486208f75c5d Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 14 May 2025 19:23:04 +0300 Subject: [PATCH 17/31] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index db08835..ade568a 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,5 @@ create_test_user.sql Taskfile.yml .env.local /.logs +/logs +webapp_logs.zip From 9557a90826565261a26a232681ce4d3c2e510522 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 22 May 2025 00:25:55 +0300 Subject: [PATCH 18/31] update user groups issue --- .gitignore | 1 + docs/user_groups_issue.md | 189 +++++++++++++++++++ src/authentication.py | 8 +- src/database/signals.py | 25 ++- src/database/user_groups.py | 278 +++++++++++++++++++++++++--- src/database/user_groups_direct.py | 265 +++++++++++++++++++++++++++ src/entities/parameters.py | 4 + src/entities/signal.py | 18 +- src/routers/signals.py | 78 +++++++- src/routers/user_groups.py | 282 +++++++++++++++++++++++++++-- 10 files changed, 1090 insertions(+), 58 deletions(-) create mode 100644 docs/user_groups_issue.md create mode 100644 src/database/user_groups_direct.py diff --git a/.gitignore b/.gitignore index ade568a..8e2eb1c 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ Taskfile.yml /.logs /logs webapp_logs.zip +/.schemas diff --git a/docs/user_groups_issue.md b/docs/user_groups_issue.md new file mode 100644 index 0000000..b71176c --- /dev/null +++ b/docs/user_groups_issue.md @@ -0,0 +1,189 @@ +# User Groups Discrepancy Issue + +## Summary + +We've identified a discrepancy between the database state and API responses when retrieving user groups. When a user requests their groups, they're not receiving all groups where they are members. Specifically, when user ID `1062` requests their groups, they should receive 3 groups (IDs 22, 25, and 28), but only receive 2 groups (IDs 25 and 28). + +## Investigation Details + +### Debug Logs + +The logs show only two groups being returned: + +``` +2025-05-21 21:38:17,604 - src.routers.user_groups - DEBUG - Fetching groups for user 1062... +2025-05-21 21:38:17,604 - src.database.user_groups - DEBUG - Getting user groups with users for user_id: 1062 +2025-05-21 21:38:17,604 - src.database.user_groups - DEBUG - Fetching groups where user 1062 is a member... +2025-05-21 21:38:18,656 - src.database.user_groups - DEBUG - Found 1 groups where user 1062 is a member: [28] +2025-05-21 21:38:18,656 - src.database.user_groups - DEBUG - Fetching groups where user 1062 is an admin... +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Found 1 groups where user 1062 is an admin: [25] +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Added 1 additional groups as admin (not already as member): [25] +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Total: Found 2 user groups with users for user_id: 1062, Group IDs: [25, 28] +``` + +### Database State + +Our database queries confirm that the user should be in 3 groups: + +```sql +SELECT id, name, user_ids, admin_ids +FROM user_groups +WHERE 1062 = ANY(user_ids) OR 1062 = ANY(admin_ids) +ORDER BY id; +``` + +Result: +``` + id | name | user_ids | admin_ids +----+-------------+-----------------+----------- + 22 | UNDP Studio | {1062,774,1067} | {} + 25 | GROUP 3 | {1062} | {1062} + 28 | test4 | {774,1062} | {774} +``` + +### Code Analysis + +The code in `user_groups.py` uses the correct SQL syntax to retrieve groups where a user is a member or admin: + +1. First query: `SELECT ... FROM user_groups WHERE %s = ANY(user_ids) ORDER BY created_at DESC;` +2. Second query: `SELECT ... FROM user_groups WHERE %s = ANY(admin_ids) ORDER BY created_at DESC;` + +These queries should correctly find all groups, but the first query is only returning group 28, not both 22 and 28 as expected. + +## Impact + +Users may not see all groups they belong to in the application, which could lead to: + +1. Reduced access to signals shared in "missing" groups +2. Confusion about group membership +3. Workflow disruptions if users expect to find signals in specific groups + +## Possible Causes + +1. SQL query execution issues +2. Application-level filtering not visible in the code +3. A caching or stale data issue +4. Transaction isolation level issues +5. Potential race condition if groups are being modified simultaneously + +## Fix Implemented + +We've implemented a comprehensive solution with multiple layers of improvements: + +### 1. Enhanced Primary Functions + +Modified the approach in the affected functions to use a single combined query with explicit array handling: + +```python +# Run a direct SQL query to ensure array type handling is consistent +query = """ +SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at +FROM + user_groups +WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) +ORDER BY + created_at DESC; +""" + +await cursor.execute(query, (user_id, user_id)) +``` + +We also added explicit type conversion when checking for user membership: + +```python +# Track membership rigorously by explicitly converting IDs to integers +is_member = False +if data['user_ids']: + is_member = user_id in [int(uid) for uid in data['user_ids']] + +is_admin = False +if data['admin_ids']: + is_admin = user_id in [int(aid) for aid in data['admin_ids']] +``` + +### 2. Debug Functions + +Added a `debug_user_groups` function that runs multiple direct queries to diagnose issues: + +```python +async def debug_user_groups(cursor: AsyncCursor, user_id: int) -> dict: + # Various direct SQL queries to check database state + # and array position functions + # ... +``` + +### 3. Fallback Implementation + +Created a completely separate direct SQL implementation in `user_groups_direct.py` as a fallback: + +```python +async def get_user_groups_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroup]: + """ + Get all groups that a user is a member of or an admin of using direct SQL. + """ + # Simple, direct SQL with minimal processing + # ... +``` + +### 4. Automatic Fallback in API + +Modified the user groups router to automatically detect and handle discrepancies: + +```python +# Check if there's a mismatch between direct query and regular function +if missing_ids: + logger.warning(f"MISMATCH! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation") + user_groups = await user_groups_direct.get_user_groups_with_users_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups)} groups") +``` + +These changes ensure: +- More reliable querying of user group memberships +- Better debug information if issues persist +- Automatic fallback to a simpler implementation if needed +- More detailed logging throughout the process + +After these changes, users should see all groups where they are members or admins consistently. + +## Additional Fix: Signal Entity can_edit Attribute + +During testing, we discovered that the Signal entity was missing the `can_edit` attribute that's dynamically added in user group contexts. This caused AttributeError exceptions when trying to access `signal.can_edit`. + +### Issue +```python +AttributeError: 'Signal' object has no attribute 'can_edit' +``` + +### Solution +Added the `can_edit` field to the Signal entity definition: + +```python +can_edit: bool = Field( + default=False, + description="Whether the current user can edit this signal (set dynamically based on group membership and collaboration).", +) +``` + +This ensures that: +- The Signal model accepts the `can_edit` attribute when created +- The attribute defaults to `False` for signals that don't have edit permissions +- Both the regular and direct SQL implementations can properly set this attribute +- Frontend code can safely access `signal.can_edit` without errors + +The fix has been applied to both endpoints: +1. `/user-groups/me` (user groups without signals) +2. `/user-groups/me/with-signals` (user groups with signals) + +Both endpoints now include the same debug checks and automatic fallback to direct SQL if discrepancies are detected. \ No newline at end of file diff --git a/src/authentication.py b/src/authentication.py index 70fdcfb..89ba2ea 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -142,10 +142,10 @@ async def authenticate_user( token = test_token # Default user data for local development - local_email = "test.user@undp.org" - name = "Test User" - unit = "Data Futures Exchange (DFx)" - acclab = False + local_email = os.environ.get("TEST_USER_EMAIL", "test.user@undp.org") + name = os.environ.get("TEST_USER_NAME", "Test User") + unit = os.environ.get("TEST_USER_UNIT", "Data Futures Exchange (DFx)") + acclab = os.environ.get("TEST_USER_ACCLAB", False) # Check for specific test tokens if token == "test-admin-token": diff --git a/src/database/signals.py b/src/database/signals.py index 56d51fe..5a64ba6 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -92,6 +92,12 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP AND (%(score)s IS NULL OR score = %(score)s) AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + AND (%(user_email)s IS NOT NULL AND ( + private = FALSE OR + created_by = %(user_email)s OR + %(is_admin)s = TRUE OR + %(is_staff)s = TRUE + )) ORDER BY {filters.order_by} {filters.direction} OFFSET @@ -128,6 +134,7 @@ async def create_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: Lis An ID of the signal in the database. """ logger.info(f"Creating new signal with headline: '{signal.headline}', created by: {signal.created_by}") + logger.info(f"All Signal fields: {signal.model_dump()}") if user_group_ids: logger.info(f"Will add signal to user groups: {user_group_ids}") @@ -152,7 +159,8 @@ async def create_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: Lis keywords, location, secondary_location, - score + score, + private ) VALUES ( %(status)s, @@ -172,7 +180,8 @@ async def create_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: Lis %(keywords)s, %(location)s, %(secondary_location)s, - %(score)s + %(score)s, + %(private)s ) RETURNING id @@ -411,7 +420,8 @@ async def update_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: Lis keywords = COALESCE(%(keywords)s, keywords), location = COALESCE(%(location)s, location), secondary_location = COALESCE(%(secondary_location)s, secondary_location), - score = COALESCE(%(score)s, score) + score = COALESCE(%(score)s, score), + private = COALESCE(%(private)s, private) WHERE id = %(id)s RETURNING @@ -559,6 +569,8 @@ async def read_user_signals( cursor: AsyncCursor, user_email: str, status: Status, + is_admin: bool = False, + is_staff: bool = False, ) -> list[Signal]: """ Read signals from the database using a user email and status filter. @@ -571,6 +583,10 @@ async def read_user_signals( An email of the user whose signals to read. status : Status A status of signals to filter by. + is_admin : bool, optional + Whether the user is an admin, by default False + is_staff : bool, optional + Whether the user is staff, by default False Returns ------- @@ -596,6 +612,9 @@ async def read_user_signals( created_by = %s AND status = %s ; """ + # Since this function is explicitly for reading a user's OWN signals, + # we don't need additional private/public filtering here. + # The user is the creator, so they should see all their signals regardless of privacy setting. await cursor.execute(query, (user_email, status)) rows = await cursor.fetchall() return [Signal(**row) for row in rows] diff --git a/src/database/user_groups.py b/src/database/user_groups.py index 50cf118..a729e17 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -5,7 +5,7 @@ from psycopg import AsyncCursor import json import logging -from typing import List +from typing import List, Union from ..entities import UserGroup, Signal, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete @@ -25,6 +25,7 @@ "list_user_groups_with_users", "get_user_groups_with_signals_and_users", "get_user_groups_with_users_by_user_id", + "debug_user_groups", ] logger = logging.getLogger(__name__) @@ -37,7 +38,8 @@ signal_ids, user_ids, admin_ids, - collaborator_map + collaborator_map, + created_at FROM user_groups """ @@ -97,6 +99,7 @@ def handle_user_group_row(row) -> dict: data['user_ids'] = row["user_ids"] or [] data['admin_ids'] = row["admin_ids"] or [] collab_map = row["collaborator_map"] + data['created_at'] = row.get("created_at") else: data['id'] = row[0] data['name'] = row[1] @@ -104,6 +107,8 @@ def handle_user_group_row(row) -> dict: data['user_ids'] = row[3] or [] data['admin_ids'] = row[4] or [] collab_map = row[5] + # Created at would be at index 6 if present + data['created_at'] = row[6] if len(row) > 6 else None # Handle collaborator_map field data['collaborator_map'] = {} @@ -311,7 +316,7 @@ async def list_user_groups(cursor: AsyncCursor) -> list[UserGroup]: list[UserGroup] A list of all user groups. """ - query = f"{SQL_SELECT_USER_GROUP} ORDER BY name;" + query = f"{SQL_SELECT_USER_GROUP} ORDER BY created_at DESC;" await cursor.execute(query) result = [] @@ -451,14 +456,76 @@ async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: list[UserGroup] A list of user groups. """ - query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) ORDER BY name;" + logger.debug("Getting user groups for user_id: %s", user_id) + + # Run a direct SQL query to ensure array type handling is consistent + logger.debug("Fetching all groups where user %s is a member or admin...", user_id) + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY + created_at DESC; + """ + await cursor.execute(query, (user_id, user_id)) + + # Process results + group_ids_seen = set() result = [] + member_groups = [] + admin_groups = [] + + # Debug + row_count = 0 async for row in cursor: + row_count += 1 data = handle_user_group_row(row) - result.append(UserGroup(**data)) - + group_id = data['id'] + + # Debug + logger.debug("Processing group ID: %s, Name: %s", group_id, data['name']) + logger.debug("Group user_ids: %s", data['user_ids']) + logger.debug("Group admin_ids: %s", data['admin_ids']) + + # Track membership rigorously by explicitly converting IDs to integers + is_member = False + if data['user_ids']: + is_member = user_id in [int(uid) for uid in data['user_ids']] + + is_admin = False + if data['admin_ids']: + is_admin = user_id in [int(aid) for aid in data['admin_ids']] + + if is_member: + member_groups.append(group_id) + logger.debug("User %s is a member of group %s", user_id, group_id) + if is_admin: + admin_groups.append(group_id) + logger.debug("User %s is an admin of group %s", user_id, group_id) + + # Only add each group once + if group_id not in group_ids_seen: + group_ids_seen.add(group_id) + result.append(UserGroup(**data)) + + logger.debug("Raw query returned %s rows", row_count) + logger.debug("Found %s groups where user %s is a member: %s", + len(member_groups), user_id, member_groups) + logger.debug("Found %s groups where user %s is an admin: %s", + len(admin_groups), user_id, admin_groups) + logger.debug("Total: Found %s user groups for user_id: %s, Group IDs: %s", + len(result), user_id, list(group_ids_seen)) return result @@ -490,7 +557,7 @@ async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: return row[0] if row and row[0] else [] -async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_users: bool = False) -> list[UserGroupWithSignals]: +async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_users: bool = False) -> List[Union[UserGroupWithSignals, UserGroupComplete]]: """ Get all groups that a user is a member of, along with the associated signals data. @@ -506,7 +573,7 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_ Returns ------- - list[UserGroupWithSignals] + list[Union[UserGroupWithSignals, UserGroupComplete]] A list of user groups with associated signals and optional user details. """ logger.debug("Getting user groups with signals for user_id: %s", user_id) @@ -545,26 +612,28 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_ signals.append(Signal(**signal_dict)) # Fetch full user details if requested - users = [] + users_list: List[User] = [] if fetch_users and group.user_ids: - users = await get_users_for_group(cursor, group.user_ids) + # Ensure user_ids are integers for get_users_for_group + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users_list = await get_users_for_group(cursor, typed_user_ids) # Create a UserGroupWithSignals instance if fetch_users: # If we fetched user details, create a UserGroupComplete - group_with_signals = UserGroupComplete( + group_with_data: Union[UserGroupWithSignals, UserGroupComplete] = UserGroupComplete( **group_data, signals=signals, - users=users + users=users_list ) else: # Otherwise create a UserGroupWithSignals - group_with_signals = UserGroupWithSignals( + group_with_data = UserGroupWithSignals( **group_data, signals=signals ) - result.append(group_with_signals) + result.append(group_with_data) logger.debug("Found %s user groups with signals for user_id: %s", len(result), user_id) return result @@ -640,7 +709,10 @@ async def get_user_group_with_users(cursor: AsyncCursor, group_id: int) -> UserG # Convert to dict for modification group_data = group.model_dump() - users = await get_users_for_group(cursor, group.user_ids) + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) # Create a UserGroupWithUsers instance return UserGroupWithUsers(**group_data, users=users) @@ -669,7 +741,10 @@ async def list_user_groups_with_users(cursor: AsyncCursor) -> list[UserGroupWith # For each group, get user details for group in groups: group_data = group.model_dump() - users = await get_users_for_group(cursor, group.user_ids) + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) # Create a UserGroupWithUsers instance group_with_users = UserGroupWithUsers(**group_data, users=users) @@ -732,7 +807,10 @@ async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: i logger.debug("Found %s signals for group_id: %s", signal_count, group.id) # Get users for this group - users = await get_users_for_group(cursor, group.user_ids) + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) # Create a UserGroupComplete instance group_complete = UserGroupComplete(**group_data, signals=signals, users=users) @@ -762,18 +840,170 @@ async def get_user_groups_with_users_by_user_id(cursor: AsyncCursor, user_id: in """ logger.debug("Getting user groups with users for user_id: %s", user_id) - # Get groups where the user is a member or an admin - query = f"{SQL_SELECT_USER_GROUP} WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) ORDER BY name;" + # Run a direct raw SQL query to improve reliability when dealing with array types + # This directly checks for user_id in the arrays without relying on array operators + logger.debug("Fetching all groups where user %s is a member or admin...", user_id) + + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY + created_at DESC; + """ + await cursor.execute(query, (user_id, user_id)) + + # Process results + group_ids_seen = set() result = [] + member_groups = [] + admin_groups = [] + + # Debug + row_count = 0 async for row in cursor: + row_count += 1 group_data = handle_user_group_row(row) - users = await get_users_for_group(cursor, group_data['user_ids']) + group_id = group_data['id'] - # Create a UserGroupWithUsers instance - group_with_users = UserGroupWithUsers(**group_data, users=users) - result.append(group_with_users) + # Debug + logger.debug("Processing group ID: %s, Name: %s", group_id, group_data['name']) + logger.debug("Group user_ids: %s", group_data['user_ids']) + logger.debug("Group admin_ids: %s", group_data['admin_ids']) + + # Track membership rigorously + is_member = False + if group_data['user_ids']: + is_member = user_id in [int(uid) for uid in group_data['user_ids']] + + is_admin = False + if group_data['admin_ids']: + is_admin = user_id in [int(aid) for aid in group_data['admin_ids']] + + if is_member: + member_groups.append(group_id) + logger.debug("User %s is a member of group %s", user_id, group_id) + if is_admin: + admin_groups.append(group_id) + logger.debug("User %s is an admin of group %s", user_id, group_id) + + # Only add each group once + if group_id not in group_ids_seen: + group_ids_seen.add(group_id) + users = await get_users_for_group(cursor, group_data['user_ids']) + + # Create a UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("Raw query returned %s rows", row_count) + logger.debug("Found %s groups where user %s is a member: %s", + len(member_groups), user_id, member_groups) + logger.debug("Found %s groups where user %s is an admin: %s", + len(admin_groups), user_id, admin_groups) + logger.debug("Total: Found %s user groups with users for user_id: %s, Group IDs: %s", + len(result), user_id, list(group_ids_seen)) + return result + +async def debug_user_groups(cursor: AsyncCursor, user_id: int) -> dict: + """ + Debug function to directly query all user group information for a specific user. + This bypasses all processing logic to get raw data from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + dict + A dictionary of raw debug information about the user's groups. + """ + logger.debug("=== DEBUGGING USER GROUPS for user_id: %s ===", user_id) + + # Direct query with no array handling + query1 = """ + SELECT id, name, user_ids, admin_ids + FROM user_groups + WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY id; + """ + + await cursor.execute(query1, (user_id, user_id)) + rows1 = [] + async for row in cursor: + rows1.append(dict(row)) + + # Member-only query + query2 = """ + SELECT id, name, user_ids + FROM user_groups + WHERE %s = ANY(user_ids) + ORDER BY id; + """ + + await cursor.execute(query2, (user_id,)) + rows2 = [] + async for row in cursor: + rows2.append(dict(row)) + + # Admin-only query + query3 = """ + SELECT id, name, admin_ids + FROM user_groups + WHERE %s = ANY(admin_ids) + ORDER BY id; + """ + + await cursor.execute(query3, (user_id,)) + rows3 = [] + async for row in cursor: + rows3.append(dict(row)) + + # Check if PostgreSQL sees the value in the arrays using array_position + query4 = """ + SELECT + id, + name, + array_position(user_ids, %s) as user_position, + array_position(admin_ids, %s) as admin_position + FROM + user_groups + WHERE + id IN (22, 25, 28); + """ + + await cursor.execute(query4, (user_id, user_id)) + rows4 = [] + async for row in cursor: + rows4.append(dict(row)) + + result = { + "combined_query": rows1, + "member_query": rows2, + "admin_query": rows3, + "array_position_check": rows4 + } + + logger.debug("=== DEBUG RESULTS ===") + logger.debug("Combined query found %s groups: %s", len(rows1), [r['id'] for r in rows1]) + logger.debug("Member query found %s groups: %s", len(rows2), [r['id'] for r in rows2]) + logger.debug("Admin query found %s groups: %s", len(rows3), [r['id'] for r in rows3]) + logger.debug("Array position check results: %s", rows4) + logger.debug("=== END DEBUG ===") - logger.debug("Found %s user groups with users for user_id: %s", len(result), user_id) return result diff --git a/src/database/user_groups_direct.py b/src/database/user_groups_direct.py new file mode 100644 index 0000000..4a99c23 --- /dev/null +++ b/src/database/user_groups_direct.py @@ -0,0 +1,265 @@ +""" +Alternative direct SQL implementation for user group functions. + +This module provides direct SQL implementations of key user group functions +that bypass the normal processing logic to ensure reliable results. + +These functions should only be used in case of persistent issues with +the standard implementations in user_groups.py. +""" + +import logging +from typing import List +from psycopg import AsyncCursor + +from ..entities import UserGroup, User, UserGroupWithUsers + +logger = logging.getLogger(__name__) + +async def get_user_groups_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroup]: + """ + Get all groups that a user is a member of or an admin of using direct SQL. + + This implementation uses the simplest possible SQL and minimal processing + to maximize reliability. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List[UserGroup] + A list of user groups. + """ + logger.debug("DIRECT SQL: Getting user groups for user_id: %s", user_id) + + # Direct SQL with minimal processing + query = """ + WITH user_groups_for_user AS ( + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ) + SELECT * FROM user_groups_for_user + ORDER BY created_at DESC; + """ + + await cursor.execute(query, (user_id, user_id)) + + result = [] + row_count = 0 + + async for row in cursor: + row_count += 1 + # Convert row to dictionary + data = dict(row) + # Convert empty arrays to empty lists + if data['user_ids'] is None: + data['user_ids'] = [] + if data['admin_ids'] is None: + data['admin_ids'] = [] + if data['signal_ids'] is None: + data['signal_ids'] = [] + + # Create UserGroup instance + group = UserGroup(**data) + result.append(group) + logger.debug("DIRECT SQL: Found group ID: %s, Name: %s", group.id, group.name) + + logger.debug("DIRECT SQL: Query returned %s groups", row_count) + return result + +async def get_users_by_ids(cursor: AsyncCursor, user_ids: List[int]) -> List[User]: + """ + Get user details for a list of user IDs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_ids : List[int] + List of user IDs. + + Returns + ------- + List[User] + List of User objects. + """ + if not user_ids: + return [] + + query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name; + """ + + await cursor.execute(query, (user_ids,)) + users = [] + + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + + return users + +async def get_signals_by_ids(cursor: AsyncCursor, signal_ids: List[int]) -> List[dict]: + """ + Get signal details for a list of signal IDs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_ids : List[int] + List of signal IDs. + + Returns + ------- + List[dict] + List of signal dictionaries. + """ + if not signal_ids: + return [] + + query = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id; + """ + + await cursor.execute(query, (signal_ids,)) + signals = [] + + async for row in cursor: + signal_data = dict(row) + signals.append(signal_data) + + return signals + +async def get_user_groups_with_users_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroupWithUsers]: + """ + Get all groups that a user is a member of or an admin of, with user details, using direct SQL. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List[UserGroupWithUsers] + A list of user groups with user details. + """ + logger.debug("DIRECT SQL: Getting user groups with users for user_id: %s", user_id) + + # First, get the groups + groups = await get_user_groups_direct(cursor, user_id) + result = [] + + # For each group, fetch the users + for group in groups: + group_data = group.model_dump() + users = await get_users_by_ids(cursor, group.user_ids) + + # Create UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("DIRECT SQL: Returning %s groups with users", len(result)) + return result + +async def get_user_groups_with_signals_direct(cursor: AsyncCursor, user_id: int) -> List: + """ + Get all groups that a user is a member of or an admin of, with signals, using direct SQL. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List + A list of user groups with signals and users data. + """ + from ..entities import Signal, UserGroupComplete + + logger.debug("DIRECT SQL: Getting user groups with signals for user_id: %s", user_id) + + # First, get the groups + groups = await get_user_groups_direct(cursor, user_id) + result = [] + + # For each group, fetch the signals and users + for group in groups: + group_data = group.model_dump() + signals = [] + + # Get signals for this group if it has any + if group.signal_ids: + signal_data_list = await get_signals_by_ids(cursor, group.signal_ids) + + for signal_data in signal_data_list: + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_data["id"]) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + # Add can_edit attribute to signal data + signal_data["can_edit"] = can_edit + + # Create Signal instance + signal = Signal(**signal_data) + signals.append(signal) + + # Get users for this group + users = await get_users_by_ids(cursor, group.user_ids) + + # Create UserGroupComplete instance + group_complete = UserGroupComplete(**group_data, signals=signals, users=users) + result.append(group_complete) + + logger.debug("DIRECT SQL: Returning %s groups with signals", len(result)) + return result \ No newline at end of file diff --git a/src/entities/parameters.py b/src/entities/parameters.py index 50ccc32..86451d5 100644 --- a/src/entities/parameters.py +++ b/src/entities/parameters.py @@ -69,6 +69,10 @@ class SignalFilters(BaseFilters): bureau: str | None = Field(default=None) score: Score | None = Field(default=None) unit: str | None = Field(default=None) + user_email: str | None = Field(default=None) + user_id: int | None = Field(default=None) + is_admin: bool = Field(default=False) + is_staff: bool = Field(default=False) class TrendFilters(BaseFilters): diff --git a/src/entities/signal.py b/src/entities/signal.py index 7658be2..3bf161c 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -43,9 +43,13 @@ class Signal(BaseEntity): description="Whether the current user has favorited this signal.", ) is_draft: bool = Field( - default=True, + default=False, description="Whether the signal is in draft state or published.", ) + private: bool = Field( + default=False, + description="Whether the signal is private. Private signals are only visible to their creator, collaborators, and admins.", + ) group_ids: List[int] | None = Field( default=None, description="List of user group IDs associated with this signal.", @@ -54,6 +58,10 @@ class Signal(BaseEntity): default=None, description="List of user IDs who can collaborate on this signal.", ) + can_edit: bool = Field( + default=False, + description="Whether the current user can edit this signal (set dynamically based on group membership and collaboration).", + ) @model_validator(mode='before') @classmethod @@ -75,6 +83,7 @@ def convert_secondary_location(cls, data): "location": "Global", "favorite": False, "is_draft": True, + "private": False, "group_ids": [1, 2], "collaborators": [1, 2, 3], "secondary_location": ["Africa", "Asia"], @@ -107,6 +116,7 @@ class SignalWithUserGroups(Signal): "location": "Global", "favorite": False, "is_draft": True, + "private": False, "group_ids": [1, 2], "collaborators": [1, 2, 3], "secondary_location": ["Africa", "Asia"], @@ -141,6 +151,10 @@ class SignalCreate(Signal): This is used for the request body in the POST endpoint. """ + status: Optional[utils.Status] = Field( + default=utils.Status.NEW, + description="Current signal review status. Defaults to NEW if not provided.", + ) user_group_ids: Optional[List[int]] = Field( default=None, description="IDs of user groups to add the signal to after creation" @@ -157,6 +171,7 @@ class SignalCreate(Signal): "signature_secondary": ["Risk"], "keywords": ["example", "test"], "location": "Global", + "private": False, "user_group_ids": [1, 2] } } @@ -186,6 +201,7 @@ class SignalUpdate(Signal): "signature_secondary": ["Risk"], "keywords": ["updated", "test"], "location": "Global", + "private": True, "user_group_ids": [2, 3] } } diff --git a/src/routers/signals.py b/src/routers/signals.py index 5d87f4a..24ec7a6 100644 --- a/src/routers/signals.py +++ b/src/routers/signals.py @@ -30,19 +30,32 @@ async def search_signals( cursor: AsyncCursor = Depends(db.yield_cursor), ): """Search signals in the database using pagination and filters.""" + # Add user info to filters for private signal handling + filters.user_email = user.email + filters.user_id = user.id + filters.is_admin = user.is_admin + filters.is_staff = user.is_staff + page = await db.search_signals(cursor, filters) return page.sanitise(user) -@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +@router.get("/export", response_model=None) async def export_signals( filters: Annotated[SignalFilters, Query()], + user: User = Depends(require_curator), cursor: AsyncCursor = Depends(db.yield_cursor), ): """ Export signals that match the filters from the database. You can export up to 10k rows at once. """ + # Add user info to filters for private signal handling + filters.user_email = user.email + filters.user_id = user.id + filters.is_admin = user.is_admin + filters.is_staff = user.is_staff + page = await db.search_signals(cursor, filters) # prettify the data @@ -99,8 +112,14 @@ async def create_signal( logger.info(f"Creating new signal requested by user: {user.email}") # Extract standard Signal fields and user_group_ids - signal = Signal(**signal_data.model_dump(exclude={"user_group_ids"})) - user_group_ids = signal_data.user_group_ids + signal_dict = signal_data.model_dump(exclude={"user_group_ids"}) + + # Ensure status is "New" if None was provided + if signal_dict.get("status") is None: + signal_dict["status"] = Status.NEW + + signal = Signal(**signal_dict) + user_group_ids = signal_data.user_group_ids or [] if user_group_ids: logger.info(f"With user_group_ids: {user_group_ids}") @@ -144,7 +163,7 @@ async def read_my_signals( """ Retrieve signal with a given status submitted by the current user. """ - return await db.read_user_signals(cursor, user.email, status) + return await db.read_user_signals(cursor, user.email, status, user.is_admin, user.is_staff) @router.get("/{uid}", response_model=Signal) @@ -165,6 +184,28 @@ async def read_signal( logger.info("Retrieved signal: %s", signal.model_dump()) + # Check for permission to view private signals + if signal.private and not (user.is_admin or user.is_staff or signal.created_by == user.email): + # Check if user is a collaborator + is_collaborator = False + + # Check direct collaborator + collaborators = await db.get_signal_collaborators(cursor, uid) + if user.email in collaborators: + is_collaborator = True + + # Check group collaborator + if not is_collaborator and await db.can_user_edit_signal(cursor, uid, user.id): + is_collaborator = True + + if not is_collaborator: + logger.warning( + "Permission denied - user %s trying to access private signal %s", + user.email, uid + ) + raise exceptions.permission_denied + + # Check for visitor permission with status if user.role == Role.VISITOR and signal.status != Status.APPROVED: logger.warning( "Permission denied - visitor trying to access non-approved signal. Status: %s", @@ -203,7 +244,28 @@ async def read_signal_with_user_groups( logger.warning("Signal not found with ID: %s", uid) raise exceptions.not_found - # Check permissions + # Check for permission to view private signals + if signal.private and not (user.is_admin or user.is_staff or signal.created_by == user.email): + # Check if user is a collaborator + is_collaborator = False + + # Check direct collaborator + collaborators = await db.get_signal_collaborators(cursor, uid) + if user.email in collaborators: + is_collaborator = True + + # Check group collaborator + if not is_collaborator and await db.can_user_edit_signal(cursor, uid, user.id): + is_collaborator = True + + if not is_collaborator: + logger.warning( + "Permission denied - user %s trying to access private signal %s", + user.email, uid + ) + raise exceptions.permission_denied + + # Check visitor permissions if user.role == Role.VISITOR and signal.status != Status.APPROVED: logger.warning( "Permission denied - visitor trying to access non-approved signal. Status: %s", @@ -254,7 +316,7 @@ async def update_signal( # Extract signal data and user_group_ids from the request body signal = Signal(**signal_data.model_dump(exclude={"user_group_ids"})) - user_group_ids = signal_data.user_group_ids + user_group_ids = signal_data.user_group_ids or [] if user_group_ids is not None: logger.info(f"With user_group_ids: {user_group_ids}") @@ -366,7 +428,7 @@ async def add_signal_collaborator( ) # Add collaborator - if not await db.add_collaborator(cursor, uid, user_id): + if not await db.add_collaborator(cursor, uid, str(user_id)): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid collaborator or signal", @@ -400,7 +462,7 @@ async def remove_signal_collaborator( ) # Remove collaborator - if not await db.remove_collaborator(cursor, uid, user_id): + if not await db.remove_collaborator(cursor, uid, str(user_id)): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Collaborator not found or signal does not exist", diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index 9998e31..b332910 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -16,6 +16,7 @@ from ..entities import UserGroup, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete, Signal from ..authentication import authenticate_user from ..database.signals import read_signal +from ..database import user_groups_direct # Set up logger for this module logger = logging.getLogger(__name__) @@ -98,12 +99,16 @@ async def list_user_groups( cursor: AsyncCursor = Depends(db.yield_cursor), ): """List all user groups.""" + logger.info(f"Endpoint called: list_user_groups - URL: {request.url} - Method: {request.method}") try: + logger.debug("Fetching all user groups from database...") groups = await db.list_user_groups(cursor) - logger.info(f"Listed {len(groups)} user groups") + logger.info(f"Successfully listed {len(groups)} user groups") + logger.debug(f"Returning group IDs: {[g.id for g in groups]}") return groups except Exception as e: logger.error(f"Error listing user groups: {str(e)}") + logger.exception("Detailed traceback for listing user groups:") bugsnag.notify( e, metadata={ @@ -126,18 +131,51 @@ async def get_my_user_groups( Get all user groups that the current user is a member of or an admin of. This endpoint is accessible to all authenticated users. """ + logger.info(f"Endpoint called: get_my_user_groups - URL: {request.url} - Method: {request.method}") + logger.debug(f"User requesting their groups: {user.id} ({user.email})") + try: if not user.id: logger.warning("User ID not found in get_my_user_groups") raise exceptions.not_found - # Get groups where this user is a member or an admin + # Run debug queries first + logger.debug(f"Running debug queries for user {user.id}...") + debug_info = await db.debug_user_groups(cursor, user.id) + + # Check if we can directly extract groups from debug info + direct_group_ids = [] + for row in debug_info["combined_query"]: + direct_group_ids.append(row["id"]) + + logger.debug(f"Direct query found group IDs: {direct_group_ids}") + + # Now get the full groups with user details + logger.debug(f"Fetching groups for user {user.id}...") user_groups = await db.get_user_groups_with_users_by_user_id(cursor, user.id) - logger.info(f"User {user.id} retrieved {len(user_groups)} groups (as member or admin)") + + # Check if there's a mismatch + fetched_ids = [g.id for g in user_groups] + missing_ids = [gid for gid in direct_group_ids if gid not in fetched_ids] + + if missing_ids: + logger.warning(f"MISMATCH! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation") + user_groups = await user_groups_direct.get_user_groups_with_users_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups)} groups") + + logger.info(f"User {user.id} ({user.email}) retrieved {len(user_groups)} groups (as member or admin)") + if user_groups: + logger.debug(f"Returning group IDs: {[g.id for g in user_groups]}") + logger.debug(f"Group names: {[g.name for g in user_groups]}") return user_groups except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error in get_my_user_groups: {str(e)}") + logger.exception(f"Detailed traceback for get_my_user_groups (user {user.id}):") bugsnag.notify( e, metadata={ @@ -177,22 +215,69 @@ async def get_my_user_groups_with_signals( By default, detailed user information for each group member is included. Set the `include_users` parameter to false to get only basic user ID references without detailed user information. """ + logger.info(f"Endpoint called: get_my_user_groups_with_signals - URL: {request.url} - Method: {request.method}") + logger.debug(f"User requesting groups with signals: {user.id} ({user.email})") + logger.debug(f"Query parameters: include_users={include_users}") + try: if not user.id: logger.warning("User ID not found in get_my_user_groups_with_signals") raise exceptions.not_found + # Run debug queries first to check for discrepancies + logger.debug(f"Running debug queries for user {user.id}...") + debug_info = await db.debug_user_groups(cursor, user.id) + + # Check if we can directly extract groups from debug info + direct_group_ids = [] + for row in debug_info["combined_query"]: + direct_group_ids.append(row["id"]) + + logger.debug(f"Direct query found group IDs: {direct_group_ids}") + + logger.debug(f"Fetching groups with signals for user {user.id}...") # Get groups with signals for this user, optionally including full user details user_groups_with_signals = await db.get_user_groups_with_signals( cursor, user.id, fetch_users=include_users ) - logger.info(f"User {user.id} retrieved {len(user_groups_with_signals)} groups with signals") + + # Check if there's a mismatch and fall back to direct implementation if needed + fetched_ids = [g.id for g in user_groups_with_signals] + missing_ids = [gid for gid in direct_group_ids if gid not in fetched_ids] + + if missing_ids: + logger.warning(f"MISMATCH in signals endpoint! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation for signals") + user_groups_with_signals = await user_groups_direct.get_user_groups_with_signals_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups_with_signals)} groups with signals") + + logger.info(f"User {user.id} ({user.email}) retrieved {len(user_groups_with_signals)} groups with signals") + if user_groups_with_signals: + logger.debug(f"Returning group IDs: {[g.id for g in user_groups_with_signals]}") + + # Log total signals count across all groups + total_signals = sum(len(g.signals) for g in user_groups_with_signals) + logger.debug(f"Total signals across all groups: {total_signals}") + + # Log collaborator access details + editable_signals = 0 + for group in user_groups_with_signals: + for signal in group.signals: + if signal.can_edit: + editable_signals += 1 + + logger.debug(f"User can edit {editable_signals} out of {total_signals} signals") + return user_groups_with_signals except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error in get_my_user_groups_with_signals: {str(e)}") + logger.exception(f"Detailed traceback for get_my_user_groups_with_signals (user {user.id}):") bugsnag.notify( e, metadata={ @@ -203,6 +288,9 @@ async def get_my_user_groups_with_signals( "user": { "id": user.id if user else None, "email": user.email if user else None + }, + "query_params": { + "include_users": include_users } } ) @@ -227,9 +315,15 @@ async def create_user_group( The current authenticated user is automatically added as both a member and an admin of the group. Additional admin users can be specified using the `admins` query parameter. """ + logger.info(f"Endpoint called: create_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Creating user group with name: '{group_data.name}'") + logger.debug(f"Query parameters: include_users={include_users}, admins={admins}") + logger.debug(f"Current user: ID={current_user.id}, email={current_user.email}") + try: # Create the base group group = UserGroup(name=group_data.name) + logger.debug(f"Created base group entity with name: '{group.name}'") # Initialize user_ids list with the current user's ID user_ids = [] @@ -238,43 +332,75 @@ async def create_user_group( if current_user.id: user_ids.append(current_user.id) admin_ids.append(current_user.id) # Make current user an admin + logger.debug(f"Added current user (ID: {current_user.id}) as member and admin") # Handle email addresses if provided + user_emails_added = [] if group_data.users: + logger.debug(f"Processing {len(group_data.users)} user emails from request body") for email in group_data.users: + logger.debug(f"Looking up user with email: {email}") user = await db.read_user_by_email(cursor, email) if user and user.id and user.id not in user_ids: # Avoid duplicates user_ids.append(user.id) + user_emails_added.append(email) + logger.debug(f"Added user {user.id} ({email}) as member") + else: + if not user: + logger.warning(f"User with email {email} not found") + elif user.id in user_ids: + logger.debug(f"User {user.id} ({email}) already added to members list") + + logger.debug(f"Added {len(user_emails_added)} users as members from request body: {user_emails_added}") # Handle admin emails if provided + admin_emails_added = [] if admins: + logger.debug(f"Processing {len(admins)} admin emails from query parameters") for email in admins: + logger.debug(f"Looking up admin with email: {email}") user = await db.read_user_by_email(cursor, email) if user and user.id: if user.id not in user_ids: # If not already in user_ids, add them user_ids.append(user.id) + logger.debug(f"Added user {user.id} ({email}) as member") + if user.id not in admin_ids: # Avoid duplicates in admin_ids admin_ids.append(user.id) + admin_emails_added.append(email) + logger.debug(f"Added user {user.id} ({email}) as admin") + else: + logger.debug(f"User {user.id} ({email}) already added to admins list") + else: + logger.warning(f"Admin with email {email} not found") + + logger.debug(f"Added {len(admin_emails_added)} users as admins from query params: {admin_emails_added}") if user_ids: group.user_ids = user_ids + logger.debug(f"Group has {len(user_ids)} members: {user_ids}") if admin_ids: group.admin_ids = admin_ids + logger.debug(f"Group has {len(admin_ids)} admins: {admin_ids}") # Create the group + logger.debug("Creating group in database...") group_id = await db.create_user_group(cursor, group) logger.info(f"Created user group {group_id} with name '{group.name}', {len(user_ids)} users, and {len(admin_ids)} admins") # Retrieve and return the created group + logger.debug(f"Retrieving created group {group_id} from database...") created_group = await db.read_user_group(cursor, group_id, fetch_details=include_users) if not created_group: logger.error(f"Failed to retrieve newly created group with ID {group_id}") raise exceptions.not_found + logger.debug(f"Successfully retrieved created group {group_id}") return created_group except Exception as e: logger.error(f"Error creating user group: {str(e)}") + logger.exception("Detailed traceback for creating user group:") bugsnag.notify( e, metadata={ @@ -285,8 +411,11 @@ async def create_user_group( "group_data": { "name": group_data.name, "users_count": len(group_data.users) if group_data.users else 0, + "users": group_data.users if group_data.users else [], "admins_count": len(admins) if admins else 0, - "current_user_id": current_user.id if current_user else None + "admins": admins if admins else [], + "current_user_id": current_user.id if current_user else None, + "current_user_email": current_user.email if current_user else None } } ) @@ -306,7 +435,12 @@ async def read_user_group( By default, includes detailed user and signal information. Set the `include_users` parameter to false to get only the basic group data without user and signal details. """ + logger.info(f"Endpoint called: read_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Reading user group ID: {group_id}") + logger.debug(f"Query parameters: include_users={include_users}") + try: + logger.debug(f"Fetching group {group_id} from database...") # First, get the basic group if (group := await db.read_user_group(cursor, group_id, fetch_details=include_users)) is None: logger.warning(f"User group {group_id} not found") @@ -314,9 +448,22 @@ async def read_user_group( # Log basic info to avoid huge logs logger.info(f"Retrieved user group {group_id} with name '{group.name}'") + + # Log detailed information about the group + logger.debug(f"Group details - ID: {group.id}, Name: '{group.name}'") + logger.debug(f"Group members: {len(group.user_ids) if group.user_ids else 0} users") + logger.debug(f"Group admins: {len(group.admin_ids) if group.admin_ids else 0} users") + logger.debug(f"Group signals: {len(group.signal_ids) if group.signal_ids else 0} signals") + + # Log collaborator map details + if group.collaborator_map: + logger.debug(f"Group has {len(group.collaborator_map)} signals with collaborators") + total_collaborators = sum(len(collaborators) for collaborators in group.collaborator_map.values()) + logger.debug(f"Total collaborator assignments: {total_collaborators}") # If include_users is true and the group has signals, fetch those signals too if include_users and hasattr(group, 'user_ids') and group.user_ids and hasattr(group, 'signal_ids') and group.signal_ids: + logger.debug(f"Fetching detailed signals data for group {group_id}") # Get signals for this group and prepare a complete response signals = [] @@ -324,22 +471,33 @@ async def read_user_group( from ..database.signals import read_signal # Fetch each signal individually + signal_count = 0 for signal_id in group.signal_ids: + logger.debug(f"Fetching signal {signal_id}") signal = await read_signal(cursor, signal_id) if signal: signals.append(signal) + signal_count += 1 + else: + logger.warning(f"Signal {signal_id} referenced by group {group_id} not found") + + logger.debug(f"Successfully fetched {signal_count} signals for group {group_id}") # Convert to a UserGroupComplete if we have both users and signals if hasattr(group, 'users') and group.users: + logger.debug(f"Creating UserGroupComplete with {len(signals)} signals and {len(group.users)} users") return UserGroupComplete( **group.model_dump(), signals=signals ) + else: + logger.debug("Group lacks user details, returning without creating UserGroupComplete") return group except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error reading user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for reading user group {group_id}:") bugsnag.notify( e, metadata={ @@ -347,7 +505,10 @@ async def read_user_group( "url": str(request.url), "method": request.method, }, - "group_id": group_id + "group_id": group_id, + "query_params": { + "include_users": include_users + } } ) raise @@ -370,38 +531,66 @@ async def update_user_group( This endpoint accepts both user IDs (integers) and email addresses (strings) in the user_ids field. Email addresses will be automatically converted to user IDs. """ + logger.info(f"Endpoint called: update_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Updating user group ID: {group_id}") + logger.debug(f"Update data: name='{group_data.name}', {len(group_data.user_ids)} users, {len(group_data.signal_ids)} signals") + logger.debug(f"Query parameters: include_users={include_users}") + try: + # Validate ID consistency if group_id != group_data.id: logger.warning(f"ID mismatch: path ID {group_id} != body ID {group_data.id}") raise exceptions.id_mismatch # Process user_ids field in case it contains emails instead of integer IDs processed_user_ids = [] + email_conversions = [] + + logger.debug(f"Processing {len(group_data.user_ids)} user identifiers...") for user_id in group_data.user_ids: if isinstance(user_id, str): # Check if it's an email (contains @ sign) if '@' in user_id: # This looks like an email address, try to find the user ID + logger.debug(f"Looking up user by email: {user_id}") user = await db.read_user_by_email(cursor, user_id) if user and user.id: processed_user_ids.append(user.id) + email_conversions.append((user_id, user.id)) + logger.debug(f"Converted email {user_id} to user ID {user.id}") else: logger.warning(f"User with email {user_id} not found") raise HTTPException(status_code=404, detail=f"User with email {user_id} not found") else: # String but not an email, try to convert to int if it's a digit string try: - processed_user_ids.append(int(user_id)) + numeric_id = int(user_id) + processed_user_ids.append(numeric_id) + logger.debug(f"Converted string '{user_id}' to integer {numeric_id}") except (ValueError, TypeError): logger.warning(f"Invalid user ID format: {user_id}") raise HTTPException(status_code=400, detail=f"Invalid user ID format: {user_id}") else: # Already an int processed_user_ids.append(user_id) + logger.debug(f"Using integer user ID {user_id} as is") - logger.info(f"Processed user_ids: {processed_user_ids}") + logger.debug(f"Processed {len(processed_user_ids)} user IDs: {processed_user_ids}") + if email_conversions: + logger.debug(f"Converted {len(email_conversions)} emails to user IDs: {email_conversions}") + + # Log collaborator map details if present + if group_data.collaborator_map: + logger.debug(f"Group has {len(group_data.collaborator_map)} signals with collaborators") + total_collaborators = sum(len(collaborators) for collaborators in group_data.collaborator_map.values()) + logger.debug(f"Total collaborator assignments: {total_collaborators}") + + # Log detailed collaborator info + for signal_id, collaborators in group_data.collaborator_map.items(): + logger.debug(f"Signal {signal_id} has {len(collaborators)} collaborators: {collaborators}") # Convert UserGroupUpdate to UserGroup + logger.debug("Creating UserGroup entity from update data...") group = UserGroup( id=group_data.id, name=group_data.name, @@ -422,15 +611,26 @@ async def update_user_group( signature_secondary=group_data.signature_secondary, sdgs=group_data.sdgs ) + logger.debug("UserGroup entity created successfully") + # Update the group in the database + logger.debug(f"Updating group {group_id} in database...") if (updated_id := await db.update_user_group(cursor, group)) is None: logger.warning(f"User group {group_id} not found for update") raise exceptions.not_found - logger.info(f"Updated user group {updated_id}") - return await db.read_user_group(cursor, updated_id, fetch_details=include_users) + + logger.info(f"Successfully updated user group {updated_id}") + + # Fetch the updated group + logger.debug(f"Fetching updated group {updated_id} from database...") + updated_group = await db.read_user_group(cursor, updated_id, fetch_details=include_users) + logger.debug(f"Successfully retrieved updated group {updated_id}") + + return updated_group except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error updating user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for updating user group {group_id}:") bugsnag.notify( e, metadata={ @@ -443,7 +643,11 @@ async def update_user_group( "id": group_data.id, "name": group_data.name, "user_count": len(group_data.user_ids) if group_data.user_ids else 0, - "signal_count": len(group_data.signal_ids) if group_data.signal_ids else 0 + "signal_count": len(group_data.signal_ids) if group_data.signal_ids else 0, + "collaborator_map_size": len(group_data.collaborator_map) if group_data.collaborator_map else 0 + }, + "query_params": { + "include_users": include_users } } ) @@ -457,15 +661,29 @@ async def delete_user_group( cursor: AsyncCursor = Depends(db.yield_cursor), ): """Delete a user group.""" + logger.info(f"Endpoint called: delete_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Deleting user group ID: {group_id}") + try: + # First get the group to log what's being deleted + group = await db.read_user_group(cursor, group_id) + if group: + logger.debug(f"Found group to delete: {group.id} - '{group.name}'") + logger.debug(f"Group contains: {len(group.user_ids) if group.user_ids else 0} members, " + + f"{len(group.signal_ids) if group.signal_ids else 0} signals, " + + f"{len(group.collaborator_map) if group.collaborator_map else 0} signal collaborator maps") + + logger.debug(f"Deleting group {group_id} from database...") if not await db.delete_user_group(cursor, group_id): logger.warning(f"User group {group_id} not found for deletion") raise exceptions.not_found - logger.info(f"Deleted user group {group_id}") + + logger.info(f"Successfully deleted user group {group_id}") return True except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error deleting user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for deleting user group {group_id}:") bugsnag.notify( e, metadata={ @@ -713,32 +931,52 @@ async def add_collaborator_to_signal_in_group( This endpoint accepts either a numeric user ID or an email address. If an email is provided, the system will look up the corresponding user ID. """ + logger.info(f"Endpoint called: add_collaborator_to_signal_in_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Adding collaborator to group {group_id} for signal {signal_id}") + logger.debug(f"User identifier provided: {user_id_or_email}") + try: # Try to parse as int for backward compatibility + user_id = None + user_email = None + try: user_id = int(user_id_or_email) + logger.debug(f"User identifier is numeric: {user_id}") except ValueError: # Not an integer, treat as email - user = await db.read_user_by_email(cursor, user_id_or_email) + user_email = user_id_or_email + logger.debug(f"User identifier is an email: {user_email}") + + user = await db.read_user_by_email(cursor, user_email) if not user or not user.id: - logger.warning(f"User with email {user_id_or_email} not found") - raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + logger.warning(f"User with email {user_email} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_email} not found") + user_id = user.id + logger.debug(f"Resolved email {user_email} to user ID {user_id}") # Get the group + logger.debug(f"Fetching group {group_id}...") group = await db.read_user_group(cursor, group_id) if group is None: logger.warning(f"Group {group_id} not found") raise exceptions.not_found + logger.debug(f"Successfully retrieved group '{group.name}' (ID: {group.id})") + # Check if signal is in the group signal_ids = group.signal_ids or [] + logger.debug(f"Group has {len(signal_ids)} signals: {signal_ids}") + if signal_id not in signal_ids: logger.warning(f"Signal {signal_id} not in group {group_id}") raise exceptions.not_found # Check if user is in the group user_ids = group.user_ids or [] + logger.debug(f"Group has {len(user_ids)} members: {user_ids}") + if user_id not in user_ids: logger.warning(f"User {user_id} not in group {group_id}") raise exceptions.not_found @@ -746,25 +984,32 @@ async def add_collaborator_to_signal_in_group( # Add collaborator collaborator_map = group.collaborator_map or {} signal_key = str(signal_id) + + logger.debug(f"Current collaborator map: {collaborator_map}") + if signal_key not in collaborator_map: + logger.debug(f"Creating new collaborator entry for signal {signal_id}") collaborator_map[signal_key] = [] if user_id not in collaborator_map[signal_key]: + logger.debug(f"Adding user {user_id} as collaborator for signal {signal_id}") collaborator_map[signal_key].append(user_id) group.collaborator_map = collaborator_map + logger.debug(f"Updating group {group_id} with new collaborator map") if await db.update_user_group(cursor, group) is None: logger.error(f"Failed to update group {group_id}") raise exceptions.not_found - logger.info(f"Added user {user_id} as collaborator for signal {signal_id} in group {group_id}") + logger.info(f"Successfully added user {user_id} as collaborator for signal {signal_id} in group {group_id}") else: - logger.info(f"User {user_id} already a collaborator for signal {signal_id} in group {group_id}") + logger.info(f"User {user_id} is already a collaborator for signal {signal_id} in group {group_id}") return True except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error adding collaborator: {str(e)}") + logger.exception(f"Detailed traceback for adding collaborator to group {group_id}, signal {signal_id}:") bugsnag.notify( e, metadata={ @@ -774,7 +1019,8 @@ async def add_collaborator_to_signal_in_group( }, "group_id": group_id, "signal_id": signal_id, - "user_id_or_email": user_id_or_email + "user_id_or_email": user_id_or_email, + "resolved_user_id": user_id if 'user_id' in locals() else None } ) raise From a9f3ab5c7b029fe34238d523d53ab942b3edd204 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Sun, 25 May 2025 23:23:57 +0300 Subject: [PATCH 19/31] enhance trends endpoint --- src/database/trends.py | 37 +++++++++++++++++++++++++++++++++++++ src/routers/trends.py | 13 +++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/database/trends.py b/src/database/trends.py index 1b914da..a784bb3 100644 --- a/src/database/trends.py +++ b/src/database/trends.py @@ -13,6 +13,7 @@ "read_trend", "update_trend", "delete_trend", + "list_trends", ] @@ -283,3 +284,39 @@ async def delete_trend(cursor: AsyncCursor, uid: int) -> Trend | None: if trend.attachment is not None: await storage.delete_image(entity_id=trend.id, folder_name="trends") return trend + + +async def list_trends(cursor: AsyncCursor) -> list[Trend]: + """ + Retrieve all trends from the database, including connected signals. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[Trend] + A list of all trends in the database. + """ + query = """ + SELECT + * + FROM + trends AS t + LEFT OUTER JOIN ( + SELECT + trend_id, array_agg(signal_id) AS connected_signals + FROM + connections + GROUP BY + trend_id + ) AS c + ON + t.id = c.trend_id + ORDER BY t.id; + """ + await cursor.execute(query) + rows = await cursor.fetchall() + return [Trend(**row) for row in rows] diff --git a/src/routers/trends.py b/src/routers/trends.py index d85bda2..e894d31 100644 --- a/src/routers/trends.py +++ b/src/routers/trends.py @@ -17,6 +17,17 @@ router = APIRouter(prefix="/trends", tags=["trends"]) +@router.get("") +async def get_all_trends( + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve all trends from the database. Requires authentication. + """ + trends = await db.list_trends(cursor) + return trends + @router.get("/search", response_model=TrendPage) async def search_trends( filters: Annotated[TrendFilters, Query()], @@ -65,6 +76,8 @@ async def create_trend( return await db.read_trend(cursor, trend_id) + + @router.get("/{uid}", response_model=Trend) async def read_trend( uid: Annotated[int, Path(description="The ID of the trend to retrieve")], From d17d498c833e95303df36dfa6f214eb22054c294 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 12:09:39 +0300 Subject: [PATCH 20/31] Squashed commit of the following: commit 70c750ebe546b1354b8d12480f5749058aa1f634 Author: happy-devs Date: Sun May 25 23:23:06 2025 +0300 updates commit d25cb1e4096dd7ed2961dd01064ac1f00651dcb4 Author: happy-devs Date: Sun May 25 23:22:34 2025 +0300 factory updates commit 2beb502b77f5006bdbb2df3b5838ced49e843fd0 Author: happy-devs Date: Fri May 23 21:52:10 2025 +0300 update digest attempts commit 9c86f15ac8ea2fd292cc1c10ea50fd425772b2a0 Author: happy-devs Date: Wed May 14 22:23:22 2025 +0300 add email digest logic commit 9e1b80460301d69b076cf33627662a11518cbff5 Merge: 3453c0d 5203c91 Author: happy-devs Date: Wed May 14 19:27:30 2025 +0300 Merge dev branch into features/email-sending, resolving .env.example conflicts commit 3453c0dacb8489d725a59d3ae079b14b1d54fe28 Author: happy-devs Date: Thu Apr 17 11:40:15 2025 +0300 initial routers / tests commit fef971042e34a576887f0fadfc88a479d9cfdd18 Author: happy-devs Date: Thu Dec 19 16:40:00 2024 +0000 Config: set up router / methods using sendgrid --- .docs/EMAIL_SERVICE.md | 103 +++++++ .env.example | 31 +- docs/azure_graph_mail_send_setup.md | 55 ++++ docs/email_digest_delivery_methods.md | 90 ++++++ docs/email_system.md | 207 +++++++++++++ docs/sample_digest_email.html | 121 ++++++++ docs/support_ticket_automated_email.md | 46 +++ requirements.txt | 5 +- scripts/email_requirements.txt | 3 + scripts/run_direct_test.sh | 14 + scripts/send_digest.py | 133 +++++++++ scripts/send_digest_smtp.py | 91 ++++++ scripts/test_digest_fetching.py | 180 ++++++++++++ scripts/test_direct_email.py | 151 ++++++++++ scripts/test_draft_digest.py | 240 +++++++++++++++ scripts/test_email.sh | 14 + scripts/test_email_direct.py | 195 +++++++++++++ scripts/test_enterprise_email.py | 220 ++++++++++++++ scripts/test_user_auth_email.py | 147 ++++++++++ src/main.py | 11 + src/routers/__init__.py | 2 + src/routers/email.py | 114 ++++++++ src/services/draft_digest.py | 389 +++++++++++++++++++++++++ src/services/email_factory.py | 44 +++ src/services/email_service.py | 36 +++ src/services/graph_direct_auth.py | 125 ++++++++ src/services/msgraph_service.py | 236 +++++++++++++++ src/services/sendgrid_service.py | 107 +++++++ src/services/user_auth_service.py | 186 ++++++++++++ src/services/weekly_digest.py | 315 ++++++++++++++++++++ test_mail_send.py | 95 ++++++ tests/test_email.py | 137 +++++++++ 32 files changed, 3816 insertions(+), 27 deletions(-) create mode 100644 .docs/EMAIL_SERVICE.md create mode 100644 docs/azure_graph_mail_send_setup.md create mode 100644 docs/email_digest_delivery_methods.md create mode 100644 docs/email_system.md create mode 100644 docs/sample_digest_email.html create mode 100644 docs/support_ticket_automated_email.md create mode 100644 scripts/email_requirements.txt create mode 100755 scripts/run_direct_test.sh create mode 100755 scripts/send_digest.py create mode 100644 scripts/send_digest_smtp.py create mode 100644 scripts/test_digest_fetching.py create mode 100644 scripts/test_direct_email.py create mode 100644 scripts/test_draft_digest.py create mode 100755 scripts/test_email.sh create mode 100755 scripts/test_email_direct.py create mode 100755 scripts/test_enterprise_email.py create mode 100755 scripts/test_user_auth_email.py create mode 100644 src/main.py create mode 100644 src/routers/email.py create mode 100644 src/services/draft_digest.py create mode 100644 src/services/email_factory.py create mode 100644 src/services/email_service.py create mode 100644 src/services/graph_direct_auth.py create mode 100644 src/services/msgraph_service.py create mode 100644 src/services/sendgrid_service.py create mode 100644 src/services/user_auth_service.py create mode 100644 src/services/weekly_digest.py create mode 100644 test_mail_send.py create mode 100644 tests/test_email.py diff --git a/.docs/EMAIL_SERVICE.md b/.docs/EMAIL_SERVICE.md new file mode 100644 index 0000000..554674a --- /dev/null +++ b/.docs/EMAIL_SERVICE.md @@ -0,0 +1,103 @@ +# Email Service Configuration + +This document explains how to set up and use the email service with Microsoft Graph API and the Mail.Send permission. + +## Requirements + +To use the email service with Microsoft Graph API, you need: + +1. An Azure AD application registration with the following delegated permissions: + - `Mail.Send` + - `User.Read` + +2. The following environment variables: + - `MS_FROM_EMAIL`: The email address that will be used as the sender + - `EMAIL_SERVICE_TYPE`: The type of email service to use (default: `ms_graph`) + +## Configuration + +### Setting Up the Email Service + +The application uses a factory pattern to create the appropriate email service. By default, it uses the Microsoft Graph API with the `Mail.Send` permission. + +```python +# The factory creates the appropriate email service based on the environment variables +from src.services.email_factory import create_email_service + +# Create an email service instance +email_service = create_email_service() +``` + +### Environment Variables + +Configure the following environment variables: + +```bash +# Required for Microsoft Graph Email Service +MS_FROM_EMAIL=your-sender-email@example.com +EMAIL_SERVICE_TYPE=ms_graph # Options: ms_graph, sendgrid +``` + +## Usage Examples + +### Sending a Simple Email + +```python +from src.services.email_factory import create_email_service + +# Create an email service instance +email_service = create_email_service() + +# Send an email +await email_service.send_email( + to_emails=["recipient@example.com"], + subject="Test Subject", + content="This is the email content", + content_type="text/plain" # or "text/html" for HTML content +) +``` + +### Sending a Templated Notification Email + +```python +from src.services.email_factory import create_email_service + +# Create an email service instance +email_service = create_email_service() + +# Send a notification email using a template +await email_service.send_notification_email( + to_email="recipient@example.com", + subject="Notification Subject", + template_id="welcome-template", + dynamic_data={ + "name": "John Doe", + "organization": "UNDP", + "role": "Admin" + } +) +``` + +## Testing the Email Service + +You can test the email service by running the provided test script: + +```bash +# Make the script executable +chmod +x test_mail_send.py + +# Run the test script +./test_mail_send.py +``` + +The script will prompt you to enter a recipient email address and will send a test email to verify that the `Mail.Send` permission is working correctly. + +## Troubleshooting + +If you encounter issues with sending emails: + +1. Verify that the Azure AD application has the required permissions (Mail.Send and User.Read) +2. Ensure that the permissions have been admin-consented +3. Check that the MS_FROM_EMAIL environment variable is set correctly +4. Check the application logs for detailed error messages +5. Verify that the DefaultAzureCredential is properly configured \ No newline at end of file diff --git a/.env.example b/.env.example index db22666..0f3765f 100644 --- a/.env.example +++ b/.env.example @@ -1,40 +1,19 @@ -# Authentication -TENANT_ID="" -CLIENT_ID="" -API_KEY="" # for accessing "public" endpoints - -# Database and Storage -DB_CONNECTION="postgresql://:@:5432/" -SAS_URL="https://.blob.core.windows.net/?" - -# Azure OpenAI, only required for `/signals/generation` -AZURE_OPENAI_ENDPOINT="https://.openai.azure.com/" -AZURE_OPENAI_API_KEY="" - -# Testing, only required to run tests, must be a valid token of a regular user -API_JWT="" # Email Configuration -MS_FROM_EMAIL=futureofdevelopment@undp.org -EMAIL_SERVICE_TYPE=ms_graph - -# SendGrid Configuration (if using SendGrid email service) SENDGRID_API_KEY= SENDGRID_FROM_EMAIL= +MS_FROM_EMAIL=futureofdevelopment@undp.org +EMAIL_SERVICE_TYPE=ms_graph -# Azure Authentication +# Authentication TENANT_ID= CLIENT_ID= - -# API Authentication API_KEY= API_JWT= -# Database Connection +# Database and Storage DB_CONNECTION= - -# Azure Storage SAS_URL= # Azure OpenAI Configuration AZURE_OPENAI_ENDPOINT= -AZURE_OPENAI_API_KEY= +AZURE_OPENAI_API_KEY= \ No newline at end of file diff --git a/docs/azure_graph_mail_send_setup.md b/docs/azure_graph_mail_send_setup.md new file mode 100644 index 0000000..454c1f6 --- /dev/null +++ b/docs/azure_graph_mail_send_setup.md @@ -0,0 +1,55 @@ +# Enabling Automated Email Sending via Microsoft Graph + +To allow the Future of Development platform to send emails automatically (e.g., for digests, notifications) without manual authentication, you must configure Microsoft Graph **Application permissions** for your Azure AD app registration. + +## Why Application Permissions? +- **Delegated permissions** require a user to be logged in interactively—this is not suitable for scheduled/automated jobs. +- **Application permissions** allow your backend/server to send emails as a service account using only a client ID and secret. + +## Steps for Admin + +1. **Go to Azure Portal > Azure Active Directory > App registrations > [Your App]** +2. **API permissions**: + - Click **Add a permission** > **Microsoft Graph** > **Application permissions** + - Search for and add **Mail.Send** (Application) +3. **Grant admin consent**: + - Click **Grant admin consent for [Your Org]** +4. **Verify**: + - You (or your admin) can run: + ```sh + az ad app permission list --id + ``` + - You should see a `"type": "Role"` for Mail.Send. + +## Template Email/Message to Admin + +``` +Subject: Request: Grant Application Mail.Send Permission to Azure App for Automated Email Sending + +Hi [Admin], + +We need to enable automated email sending from the "Future of Development" app (Client ID: 4b179bfc-6621-409a-a1ed-ad141c12eb11) using Microsoft Graph. + +**Please:** +1. Go to Azure Portal > Azure Active Directory > App registrations > "Future of Development". +2. Under **API permissions**, click **Add a permission** > **Microsoft Graph** > **Application permissions**. +3. Add **Mail.Send** (Application). +4. Click **Grant admin consent for [Your Org]**. + +This will allow our backend to send emails on a schedule without manual login. + +Thank you! +``` + +## After Admin Consent +- You can now use the client ID, tenant ID, and client secret to send emails via Microsoft Graph API using `/users/{user_id}/sendMail`. +- No manual login will be required for scheduled jobs. + +--- + +**If you need to check the current permissions or verify setup, use:** +```sh +az ad app permission list --id +``` + +--- \ No newline at end of file diff --git a/docs/email_digest_delivery_methods.md b/docs/email_digest_delivery_methods.md new file mode 100644 index 0000000..17f1255 --- /dev/null +++ b/docs/email_digest_delivery_methods.md @@ -0,0 +1,90 @@ +# Email Digest Delivery Methods: Summary & Lessons Learned + +This document summarizes all the methods we have tried (and considered) for sending automated email digests from the Future Trends & Signals platform, including their outcomes, blockers, and references to official documentation. + +--- + +## 1. Microsoft Graph API (Recommended, but Blocked) + +- **Approach:** Use Microsoft Graph API with Application permissions to send as `futureofdevelopment@undp.org`. +- **Status:** **Blocked** (admin consent for Application permissions not yet granted). +- **What we did:** + - Registered the app in Azure AD. + - Attempted to use `/users/{user_id}/sendMail` endpoint with client credentials. + - Only Delegated permissions are currently granted; Application permissions are missing. +- **Blocker:** + - Cannot send as a service account without `Mail.Send` Application permission and admin consent. +- **Reference:** + - See [azure_graph_mail_send_setup.md](./azure_graph_mail_send_setup.md) for detailed setup and admin request template. + - [Microsoft Docs: Send mail as any user](https://learn.microsoft.com/en-us/graph/api/user-sendmail?view=graph-rest-1.0&tabs=http) + +--- + +## 2. Microsoft Graph API (Delegated Permissions) + +- **Approach:** Use Microsoft Graph API with Delegated permissions, logging in as the sender. +- **Status:** **Not suitable for automation** +- **What we did:** + - Successfully authenticated as a user and sent test emails using `/me/sendMail`. +- **Blocker:** + - Requires interactive login; not suitable for scheduled/automated jobs. + +--- + +## 3. SMTP (Office 365/Exchange Online) + +- **Approach:** Use SMTP to send as `futureofdevelopment@undp.org` via `smtp.office365.com`. +- **Status:** **Blocked** (SMTP AUTH is disabled for the tenant). +- **What we did:** + - Created a script (`send_digest_smtp.py`) to send the digest via SMTP. + - Attempted to authenticate with valid credentials. + - Received error: `SMTPAuthenticationError: 5.7.139 Authentication unsuccessful, SmtpClientAuthentication is disabled for the Tenant.` +- **Blocker:** + - SMTP AUTH is disabled for all users by default in modern Microsoft 365 tenants for security reasons. + - Would require IT to enable SMTP AUTH for the sending account. +- **Reference:** + - [Enable or disable SMTP AUTH in Exchange Online](https://aka.ms/smtp_auth_disabled) + +--- + +## 4. SendGrid or Third-Party SMTP Relay + +- **Approach:** Use a third-party SMTP service (e.g., SendGrid) to send as the service account. +- **Status:** **Not attempted** (would require IT approval and setup). +- **Blocker:** + - May not be allowed by organizational policy. + +--- + +## 5. Distribution List/Group Delivery + +- **Approach:** Send the digest to a mail-enabled group (`futures.curator@undp.org`). +- **Status:** **Group is mail-enabled and can receive mail** +- **What we did:** + - Verified the group exists and is mail-enabled in Azure AD. + - All sending methods above (if working) can target this group. +- **Blocker:** + - Blocked by the same issues as above (Graph permissions or SMTP AUTH). + +--- + +## **Summary Table** + +| Method | Automation | Current Status | Blocker/Notes | +|-----------------------|------------|-----------------------|--------------------------------------| +| MS Graph (App perms) | Yes | Blocked | Need admin to grant permissions | +| MS Graph (Delegated) | No | Works (manual only) | Not suitable for automation | +| SMTP (O365) | Yes | Blocked | SMTP AUTH disabled for tenant | +| SendGrid/3rd-party | Yes | Not attempted | Needs IT approval | +| Distribution List | Yes | Ready | Blocked by above sending method | + +--- + +## **Next Steps** +- Await admin action to grant Application permissions for Microsoft Graph (see [azure_graph_mail_send_setup.md](./azure_graph_mail_send_setup.md)). +- Alternatively, request IT to enable SMTP AUTH for the sending account (less secure, not recommended). +- Consider third-party relay if allowed by policy. + +--- + +**This document should be updated as our setup or permissions change.** \ No newline at end of file diff --git a/docs/email_system.md b/docs/email_system.md new file mode 100644 index 0000000..14e727f --- /dev/null +++ b/docs/email_system.md @@ -0,0 +1,207 @@ +# UNDP Futures Trends & Signals Platform - Email System + +## Overview + +The UNDP Future Trends & Signals platform includes functionality to send weekly digest emails containing summaries of recently published signals. This email system keeps curators and other stakeholders informed about new content without requiring them to regularly visit the platform. + +## Components + +The email system consists of the following components: + +1. **Email Service Architecture** + - `EmailServiceBase`: Abstract base class defining the interface for all email services + - `MSGraphEmailService`: Implementation using Microsoft Graph API with enterprise application authentication + - `UserAuthEmailService`: Implementation using Azure CLI authentication + - `SendGridEmailService`: Alternative implementation using SendGrid + - `EmailFactory`: Factory pattern for creating the appropriate service based on configuration + +2. **Weekly Digest Feature** + - `WeeklyDigestService`: Core service that fetches recent signals and generates digest emails + - HTML email template with responsive design for signal summaries + - Filtering for approved/published signals within a specified date range + +3. **Testing Tools** + - `send_digest.py`: CLI script for sending weekly digests with parameterized options + - `test_email_direct.py`: Script for testing email configuration without database dependencies + +## Setup and Configuration + +### Requirements + +1. Install the required Python packages: + +```bash +# Activate your virtual environment +source venv/bin/activate + +# Install the required packages +pip install python-dotenv msgraph-core azure-identity httpx sendgrid +``` + +### Environment Variables + +The following environment variables need to be set in your `.env.local` file: + +``` +# Email Configuration +MS_FROM_EMAIL=exo.futures.curators@undp.org # Email that will appear as the sender +EMAIL_SERVICE_TYPE=ms_graph # Authentication type (ms_graph, user_auth, or sendgrid) + +# Azure Authentication for UNDP Enterprise Application +TENANT_ID=b3e5db5e-2944-4837-99f5-7488ace54319 # UNDP tenant ID +CLIENT_ID=4b179bfc-6621-409a-a1ed-ad141c12eb11 # UNDP Future Trends and Signals System App ID +CLIENT_SECRET=YOUR_CLIENT_SECRET_HERE # Generate this in Azure Portal +``` + +### Authentication Methods + +The platform supports multiple authentication methods for sending emails: + +#### 1. Enterprise Application Authentication (Recommended for Production) + +This method uses an Azure AD enterprise application with client credentials flow to authenticate and send emails on behalf of a mailbox. + +Requirements: +- UNDP Enterprise Application "UNDP Future Trends and Signals System" +- App ID: `4b179bfc-6621-409a-a1ed-ad141c12eb11` +- Tenant ID: `b3e5db5e-2944-4837-99f5-7488ace54319` (UNDP tenant) +- Client Secret (generated in Azure Portal) +- Mail.Send API permissions granted to the application + +This is the recommended approach for production as it doesn't require user presence and provides a more secure, managed identity for the application. + +#### 2. User Authentication (For Development) + +This method uses the Azure CLI authentication that's already set up on your machine. This is easier for development and testing as it doesn't require setting up app registrations or API credentials. + +Requirements: +- Azure CLI installed and logged in with `az login` +- User must have Mail.Send permissions in Microsoft Graph + +#### 3. SendGrid Authentication + +Alternative email provider if Microsoft Graph is not available. + +Requirements: +- SendGrid API key +- SendGrid from email address + +### Azure AD Enterprise Application Configuration + +To configure the enterprise application for sending emails: + +1. Sign in to the [Azure Portal](https://portal.azure.com) +2. Navigate to "Azure Active Directory" > "App registrations" +3. Search for "UNDP Future Trends and Signals System" (App ID: `4b179bfc-6621-409a-a1ed-ad141c12eb11`) +4. Under "Certificates & secrets", create a new client secret: + - Click "New client secret" + - Provide a description (e.g., "Email Sending Service") + - Set an appropriate expiration (e.g., 1 year, 2 years) + - Copy the generated secret value (only shown once) +5. Under "API permissions", verify the following permissions: + - Microsoft Graph > Application permissions > Mail.Send + - Microsoft Graph > Application permissions > User.Read.All (for accessing user profiles) +6. Ensure admin consent has been granted for these permissions +7. Update your `.env.local` file with the client secret + +## Using the Weekly Digest Feature + +### Manual Testing + +To send a test digest email: + +```bash +# Test with enterprise application authentication +python scripts/test_email_direct.py recipient@example.com + +# Test weekly digest +python scripts/send_digest.py --recipients recipient@example.com --days 7 --test +``` + +Parameters: +- `--recipients`: One or more email addresses (space-separated) +- `--days`: Number of days to look back for signals (default: 7) +- `--test`: Adds [TEST] to the email subject + +### Production Scheduling + +For regular weekly emails, set up a cron job or Azure scheduled task: + +```bash +# Example cron job (every Monday at 8am) +0 8 * * 1 /path/to/python /path/to/scripts/send_digest.py --recipients email1@undp.org email2@undp.org +``` + +## Customization + +### Email Templates + +The HTML email template is embedded in the `generate_email_html` method of the `WeeklyDigestService` class. To customize: + +1. Modify the HTML structure in the method +2. Update CSS styles to match UNDP branding guidelines +3. Adjust the content formatting as needed + +### Recipients Management + +Currently, recipients are specified manually when calling the script. Future enhancements could include: + +- Storing recipient lists in the database +- Building a subscription management UI +- Supporting user-specific preferences for digest contents + +## Troubleshooting + +### Permission Issues + +If you encounter "Access Denied" errors when sending emails: + +1. Check that the enterprise application has the necessary Mail.Send permissions +2. Ensure the permissions have been granted admin consent +3. Verify that the sender email matches an email address the application has permission to send from + +### Common Issues and Solutions + +1. **401 Unauthorized Error** + - Check that client secret is valid and not expired + - Ensure TENANT_ID and CLIENT_ID are correct + +2. **403 Forbidden Error** + - Check that the enterprise application has been granted proper permissions + - Ensure permissions have been admin consented + - Verify that the sender email has proper mailbox permissions + +3. **Connection Issues** + - Check network connectivity + - Ensure firewall rules allow outbound connections to graph.microsoft.com + +4. **Email Delivery Problems** + - Verify that the sender email address is configured correctly + - Check if the email address has sending limits or restrictions + +For detailed error logging, set `LOGLEVEL=DEBUG` in your environment variables. + +## Planned Enhancements + +### Near-term + +1. Set up scheduled task for automated weekly emails +2. Configure environment variables in production environment +3. Implement more sophisticated email templates + +### Future Enhancements + +1. **Recipient Management** + - Database table for storing subscriber information + - API endpoints for subscribing/unsubscribing + - User preferences for digest frequency and content + +2. **Email Customization** + - Different email templates for different types of notifications + - Personalized content based on user interests or roles + - Multiple language support + +3. **Analytics** + - Tracking email opens and clicks + - Reporting on engagement metrics + - A/B testing of email content and formats \ No newline at end of file diff --git a/docs/sample_digest_email.html b/docs/sample_digest_email.html new file mode 100644 index 0000000..87e480a --- /dev/null +++ b/docs/sample_digest_email.html @@ -0,0 +1,121 @@ + + + + + + UNDP Futures - Weekly Signal Digest + + + +
+

UNDP Futures - Weekly Signal Digest

+

Stay updated with the latest signals from around the world

+
+ +

Hello,

+

Here's your weekly digest of new signals from the UNDP Futures platform. Below are the latest signals that might be of interest:

+ +
+
+

Climate-Resilient Agriculture Technology in East Africa

+
+ Location: Africa + • Source: View Source +
+

New irrigation and seed technologies are enabling farmers in East Africa to adapt to changing rainfall patterns, with early pilots showing up to 40% increase in crop yields during drought conditions.

+
+ agriculture + climate + technology +
+
+ +
+

Digital Identity Systems and Financial Inclusion

+
+ Location: Global + • Source: View Source +
+

Digital identity systems are creating pathways to financial services for previously unbanked populations, with innovative biometric solutions addressing challenges in regions with limited documentation.

+
+ digital + finance + inclusion +
+
+ +
+

Community-Led Waste Management Innovations

+
+ Location: Asia + • Source: View Source +
+

Local communities in Southeast Asia are developing scalable waste management systems that combine traditional knowledge with new recycling technologies, reducing plastic pollution and creating economic opportunities.

+
+ environment + community + innovation +
+
+
+ + + + \ No newline at end of file diff --git a/docs/support_ticket_automated_email.md b/docs/support_ticket_automated_email.md new file mode 100644 index 0000000..efd0988 --- /dev/null +++ b/docs/support_ticket_automated_email.md @@ -0,0 +1,46 @@ +# Support Ticket: Enable Automated Email Sending for Future of Development Platform + +**Subject:** +Enable Automated Email Sending for Future Trends & Signals (Microsoft Graph Application Permissions) + +**Description:** +We are building a feature for the Future of Development platform that sends automated email digests (e.g., weekly summaries, notifications) to users. To do this securely and reliably, we need to configure Microsoft Graph **Application permissions** for our Azure AD app registration, and set up a dedicated internal email account for sending these digests. + +## Requirements + +1. **Azure AD App Registration:** + - App Name: **Future Trends & Signals** + - Client ID: `4b179bfc-6621-409a-a1ed-ad141c12eb11` + - The app must be able to send emails automatically (without manual login) using Microsoft Graph. + +2. **Permissions Needed:** + - Add **Mail.Send** (Application) permission to the app registration. + - Grant **admin consent** for this permission. + +3. **Service Account:** + - Please create or confirm an internal mailbox (e.g., `futureofdevelopment@undp.org`) to be used as the sender for these digests. + - Ensure this mailbox is licensed and can send emails. + +4. **Configuration Steps (for ITU):** + - Go to Azure Portal > Azure Active Directory > App registrations > "Future of Development". + - Under **API permissions**, click **Add a permission** > **Microsoft Graph** > **Application permissions**. + - Add **Mail.Send** (Application). + - Click **Grant admin consent for [Your Org]**. + - Confirm that the mailbox `futureofdevelopment@undp.org` is active and can be used by the app for sending emails. + +5. **Verification:** + - After configuration, we will verify by running: + ```sh + az ad app permission list --id 4b179bfc-6621-409a-a1ed-ad141c12eb11 + ``` + - We should see a `"type": "Role"` for Mail.Send. + +## Why This Is Needed +- Delegated permissions require a user to log in interactively, which is not suitable for scheduled/automated jobs. +- Application permissions allow our backend to send emails on a schedule, securely and without manual intervention. + +## What We Need from ITU +- Add and grant the required permissions as described above. +- Confirm the service account is ready and provide any additional configuration details if needed. + +Thank you for your support! If you need more technical details, please see the attached documentation or contact our team. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f3145fe..0b8c731 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,7 @@ pillow ~= 11.0.0 beautifulsoup4 ~= 4.12.3 lxml ~= 5.3.0 openai == 1.52.2 -bugsnag>=4.0.0 +azure-identity ~= 1.15.0 +msgraph-core ~= 0.2.2 +pytest-asyncio ~= 0.23.5 +bugsnag>=4.0.0 \ No newline at end of file diff --git a/scripts/email_requirements.txt b/scripts/email_requirements.txt new file mode 100644 index 0000000..c5f6967 --- /dev/null +++ b/scripts/email_requirements.txt @@ -0,0 +1,3 @@ +azure-identity>=1.13.0 +msgraph-core>=0.2.2 +sendgrid>=6.10.0 \ No newline at end of file diff --git a/scripts/run_direct_test.sh b/scripts/run_direct_test.sh new file mode 100755 index 0000000..6af9aef --- /dev/null +++ b/scripts/run_direct_test.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Script to run the direct email test in the correct virtual environment + +# Change to the project directory +cd "$(dirname "$0")/.." + +# Activate the virtual environment +source venv/bin/activate + +# Run the direct test script +python scripts/test_email_direct.py andrew.maguire@undp.org + +# Deactivate the virtual environment +deactivate \ No newline at end of file diff --git a/scripts/send_digest.py b/scripts/send_digest.py new file mode 100755 index 0000000..e2d41d3 --- /dev/null +++ b/scripts/send_digest.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +""" +Command-line script to send a weekly digest email. +This script is for manual testing and can be scheduled via cron or other job scheduler. +""" + +import os +import sys +import asyncio +import argparse +import logging +from typing import List + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +from src.services.weekly_digest import WeeklyDigestService +from src.services.weekly_digest import Status + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +async def send_weekly_digest(recipients: List[str], days: int = None, test_mode: bool = False, status: list = None, limit: int = None) -> None: + """ + Send a weekly digest email to the specified recipients. + + Parameters + ---------- + recipients : List[str] + List of email addresses to send the digest to. + days : int, optional + Number of days to look back for signals, defaults to None. + test_mode : bool, optional + If True, adds [TEST] to the subject line. + status : list, optional + List of signal statuses to filter by, defaults to None. + limit : int, optional + Maximum number of signals to include, defaults to None. + """ + logger.info(f"Starting weekly digest email send to {recipients}") + + # Create the digest service + digest_service = WeeklyDigestService() + + # Prepare subject with test mode indicator if needed + subject = "UNDP Futures Weekly Digest" + if test_mode: + subject = f"[TEST] {subject}" + + # Map status strings to Status enum if provided + status_enum = None + if status: + status_enum = [Status(s) for s in status] + + # Generate and send the digest + success = await digest_service.generate_and_send_digest( + recipients=recipients, + days=days, + subject=subject, + status=status_enum, + limit=limit + ) + + if success: + logger.info("Weekly digest email sent successfully") + else: + logger.error("Failed to send weekly digest email") + +def main() -> None: + """Parse command line arguments and run the digest email process.""" + parser = argparse.ArgumentParser(description="Send weekly digest email of recent signals") + + parser.add_argument( + "--recipients", + nargs="+", + required=True, + help="Email addresses to send the digest to (space-separated)" + ) + + parser.add_argument( + "--days", + type=int, + default=None, + help="Number of days to look back for signals (optional)" + ) + + parser.add_argument( + "--test", + action="store_true", + help="Run in test mode (adds [TEST] to the subject line)" + ) + + parser.add_argument( + "--status", + nargs="+", + default=None, + help="Signal statuses to filter by (e.g. Draft Approved). Optional." + ) + + parser.add_argument( + "--limit", + type=int, + default=None, + help="Maximum number of signals to include (optional)" + ) + + args = parser.parse_args() + + # Validate email addresses (basic check) + for email in args.recipients: + if "@" not in email: + logger.error(f"Invalid email address: {email}") + sys.exit(1) + + # Run the async function + asyncio.run(send_weekly_digest( + args.recipients, + args.days, + args.test, + args.status, + args.limit + )) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/send_digest_smtp.py b/scripts/send_digest_smtp.py new file mode 100644 index 0000000..20be7af --- /dev/null +++ b/scripts/send_digest_smtp.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +""" +Script to send a weekly digest email using SMTP (e.g., Office 365, Gmail). +This is for testing SMTP-based delivery to a distribution list or group. +""" + +import os +import sys +import asyncio +import argparse +import logging +import smtplib +from email.mime.text import MIMEText +from typing import List + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +from src.services.weekly_digest import WeeklyDigestService + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler()] +) +logger = logging.getLogger(__name__) + +async def generate_digest_html(days=None, status=None, limit=None): + digest_service = WeeklyDigestService() + signals_list = await digest_service.get_recent_signals(days=days, status=status, limit=limit) + logger.info(f"Fetched {len(signals_list)} signals for digest.") + html_content = digest_service.generate_email_html(signals_list) + return html_content + +def send_email_smtp(smtp_server, smtp_port, username, password, to_emails, subject, html_content): + msg = MIMEText(html_content, 'html') + msg['Subject'] = subject + msg['From'] = username + msg['To'] = ', '.join(to_emails) + with smtplib.SMTP(smtp_server, smtp_port) as server: + server.starttls() + server.login(username, password) + server.sendmail(msg['From'], to_emails, msg.as_string()) + logger.info(f"Email sent via SMTP to {to_emails}") + +def main(): + parser = argparse.ArgumentParser(description="Send weekly digest email via SMTP") + parser.add_argument('--recipients', nargs='+', required=True, help="Email addresses to send the digest to (space-separated)") + parser.add_argument('--days', type=int, default=None, help="Number of days to look back for signals (optional)") + parser.add_argument('--status', nargs='+', default=None, help="Signal statuses to filter by (e.g. Draft Approved). Optional.") + parser.add_argument('--limit', type=int, default=None, help="Maximum number of signals to include (optional)") + parser.add_argument('--smtp-server', type=str, default='smtp.office365.com', help="SMTP server address") + parser.add_argument('--smtp-port', type=int, default=587, help="SMTP server port") + parser.add_argument('--smtp-user', type=str, required=True, help="SMTP username (your email)") + parser.add_argument('--smtp-password', type=str, required=True, help="SMTP password (or app password)") + parser.add_argument('--test', action='store_true', help="Run in test mode (adds [TEST] to the subject line)") + args = parser.parse_args() + + subject = "UNDP Futures Weekly Digest" + if args.test: + subject = f"[TEST] {subject}" + + # Validate email addresses + for email in args.recipients: + if "@" not in email: + logger.error(f"Invalid email address: {email}") + sys.exit(1) + + # Map status strings to Status enum if provided + status_enum = None + if args.status: + from src.services.weekly_digest import Status + status_enum = [Status(s) for s in args.status] + + # Generate digest HTML + html_content = asyncio.run(generate_digest_html(days=args.days, status=status_enum, limit=args.limit)) + + # Send email via SMTP + send_email_smtp( + smtp_server=args.smtp_server, + smtp_port=args.smtp_port, + username=args.smtp_user, + password=args.smtp_password, + to_emails=args.recipients, + subject=subject, + html_content=html_content + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_digest_fetching.py b/scripts/test_digest_fetching.py new file mode 100644 index 0000000..cfbcd56 --- /dev/null +++ b/scripts/test_digest_fetching.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +""" +Script to test the weekly digest signal fetching and HTML generation without sending emails. +This helps verify that the core digest functionality is working properly. +""" + +import os +import sys +import asyncio +import argparse +import logging +import json +from typing import List +from datetime import datetime + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +try: + from dotenv import load_dotenv +except ImportError: + # Define a simple fallback if python-dotenv is not installed + def load_dotenv(path): + print(f"Warning: python-dotenv package not installed, loading environment manually") + if not os.path.exists(path): + return False + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith('#') or '=' not in line: + continue + key, value = line.split('=', 1) + os.environ[key.strip()] = value.strip().strip('"').strip("'") + return True + +# Load environment variables from .env +env_file = os.path.join(parent_dir, '.env') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# Import the weekly digest service +from src.services.weekly_digest import WeeklyDigestService + +async def test_digest_fetching(days: int = 7, save_to_file: bool = True, output_path: str = None) -> None: + """ + Test fetching recent signals and generating a digest without sending an email. + + Parameters + ---------- + days : int, optional + Number of days to look back for signals, defaults to 7. + save_to_file : bool, optional + Whether to save the generated HTML to a file, defaults to True. + output_path : str, optional + Path to save the output HTML, defaults to 'digest_output.html' in the current directory. + """ + print("\n=====================================================") + print(f"🔍 TESTING WEEKLY DIGEST SIGNAL FETCHING") + print(f"Looking back {days} days for signals...") + print("=====================================================\n") + + try: + # Create the digest service + digest_service = WeeklyDigestService() + + # Get recent signals + signals_list = await digest_service.get_recent_signals(days) + + # Print signal count and basic info + if signals_list: + print(f"\n✅ Successfully retrieved {len(signals_list)} signals from the last {days} days.") + print("\nSignals Summary:") + print("-" * 60) + + for i, signal in enumerate(signals_list, 1): + print(f"{i}. {signal.headline}") + print(f" Created: {signal.created_at}") + print(f" Location: {signal.location or 'Global'}") + if hasattr(signal, 'keywords') and signal.keywords: + print(f" Keywords: {', '.join(signal.keywords)}") + print(f" Status: {signal.status}") + print("-" * 60) + + # Generate HTML content + print("\nGenerating HTML digest content...") + html_content = digest_service.generate_email_html(signals_list) + + # Save HTML to file if requested + if save_to_file: + if output_path is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = os.path.join(os.path.dirname(__file__), f"digest_output_{timestamp}.html") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(html_content) + + print(f"\n✅ HTML digest content saved to: {output_path}") + print(" You can open this file in a browser to preview the digest email.") + + # Save signals to JSON for debugging + json_path = os.path.join(os.path.dirname(__file__), "signals_data.json") + + signals_data = [] + for signal in signals_list: + # Convert to dict and handle datetime objects for JSON serialization + signal_dict = signal.model_dump() + + # Convert datetime objects to strings + for key, value in signal_dict.items(): + if isinstance(value, datetime): + signal_dict[key] = value.isoformat() + + signals_data.append(signal_dict) + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(signals_data, f, indent=2) + + print(f"📊 Signal data saved to: {json_path}") + + else: + print(f"\n⚠️ No signals found in the last {days} days.") + print(" This could be because:") + print(" - There are no approved signals in the database") + print(" - The signals were created before the specified time period") + print(" - There might be an issue with the database connection") + + except Exception as e: + import traceback + logger.error(f"Error while testing digest fetching: {str(e)}") + traceback.print_exc() + print(f"\n❌ Error testing digest functionality: {str(e)}") + +def main() -> None: + """Parse command line arguments and run the digest test.""" + parser = argparse.ArgumentParser(description="Test weekly digest signal fetching and HTML generation") + + parser.add_argument( + "--days", + type=int, + default=7, + help="Number of days to look back for signals (default: 7)" + ) + + parser.add_argument( + "--no-save", + action="store_true", + help="Don't save the generated HTML to a file" + ) + + parser.add_argument( + "--output", + type=str, + help="Path to save the output HTML (default: digest_output_TIMESTAMP.html in current directory)" + ) + + args = parser.parse_args() + + # Run the async function + asyncio.run(test_digest_fetching( + days=args.days, + save_to_file=not args.no_save, + output_path=args.output + )) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_direct_email.py b/scripts/test_direct_email.py new file mode 100644 index 0000000..a8578a5 --- /dev/null +++ b/scripts/test_direct_email.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +""" +Direct test script for sending emails using Graph API. +This bypasses the normal email service for testing purposes. +""" + +import os +import sys +import asyncio +import logging +from datetime import datetime +from dotenv import load_dotenv + +# Add the parent directory to sys.path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +# Load environment variables +env_file = os.path.join(parent_dir, '.env.local') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Import our direct authentication module +from src.services.graph_direct_auth import GraphDirectAuth + +# Set up logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +async def test_direct_email(to_email: str) -> None: + """Send a test email using direct Graph API authentication""" + try: + print(f"\nSending test email to {to_email}...") + + # Get sender email from environment + from_email = os.getenv('MS_FROM_EMAIL', 'exo.futures.curators@undp.org') + + # Create the GraphDirectAuth client + graph_auth = GraphDirectAuth() + + # Create HTML content with current timestamp + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + html_content = f""" + + + + + + UNDP Futures - Direct Test Email + + + +
+

UNDP Futures - Direct Test Email

+
+ +
+

Direct Graph API Email Test

+

This is a test email sent using direct Graph API authentication.

+

If you're receiving this, it means the email configuration is working!

+

Sent at: {timestamp}

+

Configuration:

+
    +
  • From Email: {from_email}
  • +
  • To Email: {to_email}
  • +
  • Tenant ID: {os.getenv('AZURE_TENANT_ID')}
  • +
+
+ + + + + """ + + # Send the email + success = await graph_auth.send_email( + from_email=from_email, + to_emails=[to_email], + subject=f"[TEST] UNDP Futures - Direct Email Test ({timestamp})", + content=html_content, + content_type="HTML" + ) + + if success: + print("\n=====================================================") + print(f"✅ Test email successfully sent to {to_email}!") + print("=====================================================\n") + else: + print("\n=====================================================") + print(f"❌ Failed to send test email to {to_email}") + print("=====================================================\n") + + except Exception as e: + logger.error(f"Error in test_direct_email: {str(e)}", exc_info=True) + print("\n=====================================================") + print(f"❌ Error sending test email: {str(e)}") + print("=====================================================\n") + +def main(): + """Main entry point""" + if len(sys.argv) < 2: + print("Usage: python test_direct_email.py ") + sys.exit(1) + + recipient_email = sys.argv[1] + asyncio.run(test_direct_email(recipient_email)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_draft_digest.py b/scripts/test_draft_digest.py new file mode 100644 index 0000000..c98bd35 --- /dev/null +++ b/scripts/test_draft_digest.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +""" +Script to test fetching draft signals and generating a digest. +""" + +import os +import sys +import asyncio +import argparse +import logging +import json +from typing import List +from datetime import datetime +import time + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +try: + from dotenv import load_dotenv +except ImportError: + # Define a simple fallback if python-dotenv is not installed + def load_dotenv(path): + print(f"Warning: python-dotenv package not installed, loading environment manually") + if not os.path.exists(path): + return False + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith('#') or '=' not in line: + continue + key, value = line.split('=', 1) + os.environ[key.strip()] = value.strip().strip('"').strip("'") + return True + +# Load environment variables from .env +env_file = os.path.join(parent_dir, '.env') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# Import the draft digest service +from src.services.draft_digest import DraftDigestService +from src.services.email_factory import create_email_service +from src.entities import Status + +async def test_draft_digest(days: int = 7, save_to_file: bool = True, output_path: str = None, send_email: bool = False, recipient_email: str = "andrew.maguire@undp.org") -> None: + """ + Test fetching DRAFT signals and generating a digest. + + Parameters + ---------- + days : int, optional + Number of days to look back for signals, defaults to 7. + save_to_file : bool, optional + Whether to save the generated HTML to a file, defaults to True. + output_path : str, optional + Path to save the output HTML, defaults to 'draft_digest_output.html' in the current directory. + send_email : bool, optional + Whether to send the digest via email, defaults to False. + recipient_email : str, optional + Email address to send the digest to, defaults to "andrew.maguire@undp.org". + """ + print("\n=====================================================") + print(f"🔍 TESTING DRAFT SIGNAL DIGEST") + print(f"Looking back {days} days for draft signals...") + print("=====================================================\n") + + try: + # Create the digest service + digest_service = DraftDigestService() + + # Set title + title = "Draft Signals Digest" + + # Get draft signals + signals_list = await digest_service.get_recent_draft_signals(days) + + # Print signal count and basic info + if signals_list: + print(f"\n✅ Successfully retrieved {len(signals_list)} DRAFT signals from the last {days} days.") + print("\nSignals Summary:") + print("-" * 60) + + for i, signal in enumerate(signals_list, 1): + print(f"{i}. {signal.headline}") + print(f" Created: {signal.created_at}") + print(f" Status: {signal.status}") + print(f" Created by: {getattr(signal, 'created_by', 'Unknown')}") + print(f" Location: {signal.location or 'Global'}") + if hasattr(signal, 'keywords') and signal.keywords: + print(f" Keywords: {', '.join(signal.keywords)}") + print("-" * 60) + + # Generate HTML content + print(f"\nGenerating HTML digest content for draft signals...") + html_content = digest_service.generate_digest_html( + signals_list, + title=title, + intro_text=f"

Here's a digest of draft signals from the last {days} days:

" + ) + + # Save HTML to file if requested + if save_to_file: + if output_path is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = os.path.join(os.path.dirname(__file__), f"draft_digest_{timestamp}.html") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(html_content) + + print(f"\n✅ HTML digest content saved to: {output_path}") + print(" You can open this file in a browser to preview the digest.") + + # Save signals to JSON for debugging + json_path = os.path.join(os.path.dirname(__file__), f"draft_signals_data.json") + + signals_data = [] + for signal in signals_list: + # Convert to dict and handle datetime objects for JSON serialization + signal_dict = signal.model_dump() + + # Convert datetime objects to strings + for key, value in signal_dict.items(): + if isinstance(value, datetime): + signal_dict[key] = value.isoformat() + + signals_data.append(signal_dict) + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(signals_data, f, indent=2) + + print(f"📊 Signal data saved to: {json_path}") + + # Send email if requested + if send_email and html_content: + print(f"\n📧 Sending draft digest email to {recipient_email}...") + try: + # Create email service + email_service = create_email_service() + + # Generate subject with timestamp + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + subject = f"[TEST] UNDP Future Trends - Draft Signals Digest ({timestamp})" + + # Send the email + success = await email_service.send_email( + to_emails=[recipient_email], + subject=subject, + content=html_content, + content_type="text/html" + ) + + if success: + print(f"✅ Draft digest email successfully sent to {recipient_email}") + else: + print(f"❌ Failed to send draft digest email to {recipient_email}") + + except Exception as e: + logger.error(f"Error sending email: {str(e)}") + print(f"❌ Error sending email: {str(e)}") + + else: + print(f"\n⚠️ No draft signals found in the last {days} days.") + print(" This could be because:") + print(f" - There are no draft signals in the database") + print(" - The signals were created before the specified time period") + print(" - There might be an issue with the database connection") + + if send_email: + print("\n📧 Not sending email because no signals were found.") + + except Exception as e: + import traceback + logger.error(f"Error while testing draft digest: {str(e)}") + traceback.print_exc() + print(f"\n❌ Error testing draft digest functionality: {str(e)}") + +def main() -> None: + """Parse command line arguments and run the draft digest test.""" + parser = argparse.ArgumentParser(description="Test fetching draft signals and generating a digest") + + parser.add_argument( + "--days", + type=int, + default=7, + help="Number of days to look back for signals (default: 7)" + ) + + parser.add_argument( + "--no-save", + action="store_true", + help="Don't save the generated HTML to a file" + ) + + parser.add_argument( + "--output", + type=str, + help="Path to save the output HTML (default: draft_digest_TIMESTAMP.html in current directory)" + ) + + parser.add_argument( + "--email", + action="store_true", + help="Send the digest via email to the specified recipient" + ) + + parser.add_argument( + "--recipient", + type=str, + default="andrew.maguire@undp.org", + help="Email address to send the digest to (default: andrew.maguire@undp.org)" + ) + + args = parser.parse_args() + + # Run the async function + asyncio.run(test_draft_digest( + days=args.days, + save_to_file=not args.no_save, + output_path=args.output, + send_email=args.email, + recipient_email=args.recipient + )) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_email.sh b/scripts/test_email.sh new file mode 100755 index 0000000..230649a --- /dev/null +++ b/scripts/test_email.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Test script to run the email digest in the correct virtual environment + +# Change to the project directory +cd "$(dirname "$0")/.." + +# Activate the virtual environment +source venv/bin/activate + +# Run the digest script with test parameters +python scripts/send_digest.py --recipients andrew.maguire@undp.org --days 14 --test + +# Deactivate the virtual environment +deactivate \ No newline at end of file diff --git a/scripts/test_email_direct.py b/scripts/test_email_direct.py new file mode 100755 index 0000000..29e1357 --- /dev/null +++ b/scripts/test_email_direct.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +""" +Script to test email sending directly without database interactions. +This is useful for isolating email configuration issues. +""" + +import os +import sys +import asyncio +import logging +import traceback +from typing import List +try: + from dotenv import load_dotenv +except ImportError: + # Define a simple fallback if python-dotenv is not installed + def load_dotenv(path): + print(f"Warning: python-dotenv package not installed, loading environment manually") + if not os.path.exists(path): + return False + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith('#') or '=' not in line: + continue + key, value = line.split('=', 1) + os.environ[key.strip()] = value.strip().strip('"').strip("'") + return True + +# Load environment variables from .env.local if it exists +env_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env.local') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Check if required environment variables are set +required_vars = ["MS_FROM_EMAIL", "EMAIL_SERVICE_TYPE", "TENANT_ID", "CLIENT_ID"] +missing_vars = [var for var in required_vars if not os.getenv(var)] +if missing_vars: + print(f"ERROR: The following required environment variables are not set: {', '.join(missing_vars)}") + print("Please check your .env.local file or set them manually.") + sys.exit(1) + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +# Set up logging +logging.basicConfig( + level=logging.DEBUG, # Using DEBUG level to see more detailed info + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +async def test_email_sending(to_email: str) -> None: + """ + Send a test email directly using the configured email service. + + Parameters + ---------- + to_email : str + The email address to send the test email to. + """ + logger.info(f"Starting direct email test to {to_email}") + logger.info(f"Using email service type: {os.getenv('EMAIL_SERVICE_TYPE')}") + logger.info(f"From email: {os.getenv('MS_FROM_EMAIL')}") + + # Create the email service + try: + from src.services.email_factory import create_email_service + email_service = create_email_service() + logger.info(f"Email service created: {type(email_service).__name__}") + except Exception as e: + logger.error(f"Failed to create email service: {e}") + traceback.print_exc() + return + + # Create a simple HTML email with timestamp + from datetime import datetime + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + html_content = f""" + + + + + + UNDP Futures - Email Test + + + +
+

UNDP Futures - Test Email

+
+ +
+

Email Functionality Test

+

This is a test email to verify that the email sending functionality is working correctly.

+

If you're receiving this, it means the email configuration is properly set up!

+

Sent at: {timestamp}

+

Configuration:

+
    +
  • Email Service: {os.getenv('EMAIL_SERVICE_TYPE')}
  • +
  • From Email: {os.getenv('MS_FROM_EMAIL')}
  • +
  • To Email: {to_email}
  • +
+
+ + + + + """ + + # Send the email + try: + logger.info("Attempting to send email...") + success = await email_service.send_email( + to_emails=[to_email], + subject=f"[TEST] UNDP Futures - Email Configuration Test ({os.getenv('EMAIL_SERVICE_TYPE')})", + content=html_content, + content_type="text/html" + ) + + if success: + logger.info("✅ Test email sent successfully!") + print("\n=====================================================") + print(f"✅ Test email sent successfully to {to_email}!") + print("=====================================================\n") + else: + logger.error("❌ Failed to send test email") + print("\n=====================================================") + print(f"❌ Failed to send test email to {to_email}") + print("=====================================================\n") + except Exception as e: + logger.error(f"Error sending test email: {e}") + traceback.print_exc() + print("\n=====================================================") + print(f"❌ Error sending test email to {to_email}: {e}") + print("=====================================================\n") + +def main() -> None: + """Parse command line arguments and run the email test.""" + if len(sys.argv) < 2: + print("Usage: python test_email_direct.py ") + sys.exit(1) + + recipient_email = sys.argv[1] + + # Validate email address (basic check) + if "@" not in recipient_email: + logger.error(f"Invalid email address: {recipient_email}") + sys.exit(1) + + # Run the async function + asyncio.run(test_email_sending(recipient_email)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_enterprise_email.py b/scripts/test_enterprise_email.py new file mode 100755 index 0000000..654cd6e --- /dev/null +++ b/scripts/test_enterprise_email.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +""" +Script to test email sending using enterprise application authentication. +This is useful for verifying that the enterprise app credentials are properly set up. +""" + +import os +import sys +import asyncio +import logging +import traceback +from typing import List +from datetime import datetime +try: + from dotenv import load_dotenv +except ImportError: + # Define a simple fallback if python-dotenv is not installed + def load_dotenv(path): + print(f"Warning: python-dotenv package not installed, loading environment manually") + if not os.path.exists(path): + return False + with open(path) as f: + for line in f: + line = line.strip() + if not line or line.startswith('#') or '=' not in line: + continue + key, value = line.split('=', 1) + os.environ[key.strip()] = value.strip().strip('"').strip("'") + return True + +# Load environment variables from .env.local if it exists +env_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env.local') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Check if required environment variables are set for enterprise application authentication +required_vars = ["MS_FROM_EMAIL", "TENANT_ID", "CLIENT_ID", "CLIENT_SECRET"] +missing_vars = [var for var in required_vars if not os.getenv(var)] +if missing_vars: + print(f"ERROR: The following required environment variables are not set: {', '.join(missing_vars)}") + print("These are required for enterprise application authentication.") + print("Please check your .env.local file and ensure you've generated a CLIENT_SECRET in the Azure Portal.") + sys.exit(1) + +# Set EMAIL_SERVICE_TYPE to ms_graph to force using the enterprise app authentication +os.environ["EMAIL_SERVICE_TYPE"] = "ms_graph" +print("Force setting EMAIL_SERVICE_TYPE to 'ms_graph' for enterprise application authentication") + +# Add the parent directory to sys.path to allow importing the app modules +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +# Set up logging +logging.basicConfig( + level=logging.DEBUG, # Using DEBUG level to see more detailed info + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +async def test_enterprise_email(to_email: str) -> None: + """ + Send a test email using the enterprise application authentication. + + Parameters + ---------- + to_email : str + The email address to send the test email to. + """ + logger.info("="*80) + logger.info("ENTERPRISE APPLICATION EMAIL TEST") + logger.info("="*80) + logger.info(f"Starting enterprise authentication email test to {to_email}") + logger.info(f"From email: {os.getenv('MS_FROM_EMAIL')}") + logger.info(f"Tenant ID: {os.getenv('TENANT_ID')}") + logger.info(f"Client ID: {os.getenv('CLIENT_ID')}") + logger.info(f"Client Secret: {'*' * 8} (hidden for security)") + + # Create the email service + try: + from src.services.email_factory import create_email_service + from src.services.msgraph_service import MSGraphEmailService + + email_service = create_email_service() + + # Verify that we got an MSGraphEmailService instance + if not isinstance(email_service, MSGraphEmailService): + logger.error(f"Expected MSGraphEmailService, but got {type(email_service).__name__}") + print("\n=====================================================") + print("❌ Wrong email service type created. Check EMAIL_SERVICE_TYPE setting.") + print("=====================================================\n") + return + + logger.info(f"Email service created: {type(email_service).__name__}") + except Exception as e: + logger.error(f"Failed to create email service: {e}") + traceback.print_exc() + return + + # Create a simple HTML email with timestamp + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + html_content = f""" + + + + + + UNDP Futures - Enterprise App Email Test + + + +
+

UNDP Futures - Enterprise App Test

+
+ +
+

Enterprise Application Email Test

+

This is a test email to verify that email sending via enterprise application authentication is working correctly.

+

If you're receiving this, it means the enterprise application credentials are properly set up!

+

Sent at: {timestamp}

+

Configuration:

+
    +
  • Authentication: Enterprise Application
  • +
  • App Name: UNDP Future Trends and Signals System
  • +
  • App ID: {os.getenv('CLIENT_ID')}
  • +
  • From Email: {os.getenv('MS_FROM_EMAIL')}
  • +
  • To Email: {to_email}
  • +
+
+ + + + + """ + + # Send the email + try: + logger.info("Attempting to send email using enterprise application authentication...") + success = await email_service.send_email( + to_emails=[to_email], + subject=f"[TEST] UNDP Futures - Enterprise Application Email Test", + content=html_content, + content_type="text/html" + ) + + if success: + logger.info("✅ Test email sent successfully using enterprise application authentication!") + print("\n=====================================================") + print(f"✅ Test email sent successfully to {to_email}!") + print("The enterprise application authentication is working correctly.") + print("=====================================================\n") + else: + logger.error("❌ Failed to send test email using enterprise application authentication") + print("\n=====================================================") + print(f"❌ Failed to send test email to {to_email}") + print("Check the logs for more details.") + print("=====================================================\n") + except Exception as e: + logger.error(f"Error sending test email: {e}") + traceback.print_exc() + print("\n=====================================================") + print(f"❌ Error sending test email to {to_email}: {e}") + print("=====================================================\n") + +def main() -> None: + """Parse command line arguments and run the enterprise email test.""" + if len(sys.argv) < 2: + print("Usage: python test_enterprise_email.py ") + sys.exit(1) + + recipient_email = sys.argv[1] + + # Validate email address (basic check) + if "@" not in recipient_email: + logger.error(f"Invalid email address: {recipient_email}") + sys.exit(1) + + # Run the async function + asyncio.run(test_enterprise_email(recipient_email)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_user_auth_email.py b/scripts/test_user_auth_email.py new file mode 100755 index 0000000..6c6bf45 --- /dev/null +++ b/scripts/test_user_auth_email.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +""" +Test script for sending emails using user authentication with Azure CLI. +""" + +import os +import sys +import asyncio +import logging +from datetime import datetime +from dotenv import load_dotenv + +# Add the parent directory to sys.path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, parent_dir) + +# Load environment variables +env_file = os.path.join(parent_dir, '.env.local') +if os.path.exists(env_file): + load_dotenv(env_file) + print(f"Loaded environment from {env_file}") +else: + print(f"Warning: {env_file} not found") + +# Set up logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +async def test_email(to_email: str) -> None: + """Send a test email using the email service factory""" + try: + from src.services.email_factory import create_email_service + + print(f"\nSending test email to {to_email}...") + + # Create the email service + email_service = create_email_service() + logger.info(f"Email service created: {type(email_service).__name__}") + + # Create HTML content with current timestamp + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + html_content = f""" + + + + + + UNDP Futures - User Auth Test Email + + + +
+

UNDP Futures - User Auth Test Email

+
+ +
+

User Authentication Email Test

+

This is a test email sent using Azure CLI user authentication.

+

If you're receiving this, it means the email configuration is working!

+

Sent at: {timestamp}

+

Configuration:

+
    +
  • From Email: {os.getenv('MS_FROM_EMAIL')}
  • +
  • User Email: {os.getenv('USER_EMAIL')}
  • +
  • To Email: {to_email}
  • +
  • Tenant ID: {os.getenv('AZURE_TENANT_ID')}
  • +
+
+ + + + + """ + + # Send the email + success = await email_service.send_email( + to_emails=[to_email], + subject=f"[TEST] UNDP Futures - User Auth Email Test ({timestamp})", + content=html_content, + content_type="text/html" + ) + + if success: + print("\n=====================================================") + print(f"✅ Test email successfully sent to {to_email}!") + print("=====================================================\n") + else: + print("\n=====================================================") + print(f"❌ Failed to send test email to {to_email}") + print("=====================================================\n") + + except Exception as e: + logger.error(f"Error in test_email: {str(e)}", exc_info=True) + print("\n=====================================================") + print(f"❌ Error sending test email: {str(e)}") + print("=====================================================\n") + +def main(): + """Main entry point""" + if len(sys.argv) < 2: + print("Usage: python test_user_auth_email.py ") + sys.exit(1) + + recipient_email = sys.argv[1] + asyncio.run(test_email(recipient_email)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..eb48e1b --- /dev/null +++ b/src/main.py @@ -0,0 +1,11 @@ +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler('app.log') + ] +) \ No newline at end of file diff --git a/src/routers/__init__.py b/src/routers/__init__.py index 78d60dc..47ceee4 100644 --- a/src/routers/__init__.py +++ b/src/routers/__init__.py @@ -8,6 +8,7 @@ from .trends import router as trend_router from .users import router as user_router from .user_groups import router as user_group_router +from .email import router as email_router ALL = [ choice_router, @@ -16,4 +17,5 @@ trend_router, user_router, user_group_router, + email_router, ] diff --git a/src/routers/email.py b/src/routers/email.py new file mode 100644 index 0000000..bff1273 --- /dev/null +++ b/src/routers/email.py @@ -0,0 +1,114 @@ +""" +Router for email-related endpoints. +""" + +from typing import List + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, EmailStr + +from ..dependencies import require_admin +from ..entities import User +from ..services.email_factory import create_email_service +from ..authentication import authenticate_user + +router = APIRouter(prefix="/email", tags=["email"]) + +# Request models +class EmailRequest(BaseModel): + to_emails: List[EmailStr] + subject: str + content: str + content_type: str = "text/plain" + +class NotificationRequest(BaseModel): + to_email: EmailStr + subject: str + template_id: str + dynamic_data: dict + +class DigestRequest(BaseModel): + days: int | None = None + status: List[str] | None = None + limit: int | None = None + test: bool = False + +# Initialize email service +email_service = create_email_service() + +@router.post("/send", dependencies=[Depends(require_admin)]) +async def send_email(request: EmailRequest): + """ + Send an email to multiple recipients. + Only accessible by admin users. + """ + success = await email_service.send_email( + to_emails=request.to_emails, + subject=request.subject, + content=request.content, + content_type=request.content_type + ) + + if not success: + raise HTTPException(status_code=500, detail="Failed to send email") + + return {"message": "Email sent successfully"} + +@router.post("/notify", dependencies=[Depends(require_admin)]) +async def send_notification(request: NotificationRequest): + """ + Send a templated notification email. + Only accessible by admin users. + """ + success = await email_service.send_notification_email( + to_email=request.to_email, + subject=request.subject, + template_id=request.template_id, + dynamic_data=request.dynamic_data + ) + + if not success: + raise HTTPException(status_code=500, detail="Failed to send notification") + + return {"message": "Notification sent successfully"} + +@router.post("/digest") +async def trigger_digest(request: DigestRequest, user: User = Depends(authenticate_user)): + """ + Trigger the email digest process as the authenticated user (delegated permissions). + Only sends to the hardcoded cdo.curators@undp.org address. + """ + from src.services.weekly_digest import WeeklyDigestService, Status + import logging + import asyncio + + logger = logging.getLogger(__name__) + curator_email = "cdo.curators@undp.org" + logger.info(f"User {user.email} is triggering a digest email to {curator_email}") + + # Map status strings to Status enum if provided + status_enum = None + if request.status: + status_enum = [Status(s) for s in request.status] + + digest_service = WeeklyDigestService() + subject = "UNDP Futures Weekly Digest" + if request.test: + subject = f"[TEST] {subject}" + + # Generate signals and HTML + signals_list = await digest_service.get_recent_signals(days=request.days, status=status_enum, limit=request.limit) + html_content = digest_service.generate_email_html(signals_list) + + # Use user access token for this endpoint + email_service = create_email_service(useUserAccessToken=True) + success = await email_service.send_email( + to_emails=[curator_email], + subject=subject, + content=html_content, + content_type="text/html", + useUserAccessToken=True + ) + if not success: + raise HTTPException(status_code=500, detail="Failed to send digest email") + return {"message": "Digest email sent successfully"} \ No newline at end of file diff --git a/src/services/draft_digest.py b/src/services/draft_digest.py new file mode 100644 index 0000000..e3109ca --- /dev/null +++ b/src/services/draft_digest.py @@ -0,0 +1,389 @@ +""" +Service for generating digests of draft signals. +This is a specialized version of the weekly digest service that focuses on draft signals. +""" + +import logging +import datetime +from typing import List, Dict, Any, Optional +from datetime import timedelta + +from ..entities import Signal, SignalFilters, Status +from ..database import signals, connection + +logger = logging.getLogger(__name__) + +class DraftDigestService: + """Service class for generating digests of draft signals.""" + + def __init__(self): + """Initialize the draft digest service.""" + pass + + async def get_recent_draft_signals(self, days: int = 7) -> List[Signal]: + """ + Get draft signals created in the last specified number of days. + + Parameters + ---------- + days : int, optional + The number of days to look back, defaults to 7 days. + + Returns + ------- + List[Signal] + A list of draft signals created in the specified period. + """ + logger.info(f"Getting draft signals from the last {days} days") + + # Calculate date range + end_date = datetime.datetime.now() + start_date = end_date - timedelta(days=days) + start_date_str = start_date.strftime("%Y-%m-%d") + + logger.info(f"Date range: {start_date_str} to {end_date.strftime('%Y-%m-%d')}") + + # Create signal filters - specifically for DRAFT status + filters = SignalFilters( + statuses=[Status.DRAFT], # Only draft signals + # We'll filter by created_at in SQL directly + limit=100 # Limit the number of signals + ) + + # Use a DB connection to fetch signals + async with await connection.get_connection() as conn: + async with conn.cursor() as cursor: + # Get signals created after start_date + query = f""" + SELECT + *, COUNT(*) OVER() AS total_count + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + LEFT OUTER JOIN ( + SELECT + name AS unit_name, + region AS unit_region + FROM + units + ) AS u + ON + s.created_unit = u.unit_name + LEFT OUTER JOIN ( + SELECT + name AS location, + region AS location_region, + bureau AS location_bureau + FROM + locations + ) AS l + ON + s.location = l.location + WHERE + status = ANY(%(statuses)s) + AND created_at >= %(start_date)s + ORDER BY + created_at DESC + LIMIT + %(limit)s + ; + """ + + # Add start_date parameter to the filters + filter_params = filters.model_dump() + filter_params['start_date'] = start_date_str + + await cursor.execute(query, filter_params) + rows = await cursor.fetchall() + + signals_list = [Signal(**row) for row in rows] + + logger.info(f"Found {len(signals_list)} draft signals from the last {days} days") + return signals_list + + async def get_signals_by_status(self, statuses: List[Status], days: int = 7) -> List[Signal]: + """ + Get signals with specified statuses created in the last specified number of days. + + Parameters + ---------- + statuses : List[Status] + List of statuses to filter by (e.g., [Status.DRAFT, Status.PENDING]) + days : int, optional + The number of days to look back, defaults to 7 days. + + Returns + ------- + List[Signal] + A list of signals with the specified statuses created in the specified period. + """ + logger.info(f"Getting signals with statuses {statuses} from the last {days} days") + + # Calculate date range + end_date = datetime.datetime.now() + start_date = end_date - timedelta(days=days) + start_date_str = start_date.strftime("%Y-%m-%d") + + logger.info(f"Date range: {start_date_str} to {end_date.strftime('%Y-%m-%d')}") + + # Create signal filters with the specified statuses + filters = SignalFilters( + statuses=statuses, + # We'll filter by created_at in SQL directly + limit=100 # Limit the number of signals + ) + + # Use a DB connection to fetch signals + async with await connection.get_connection() as conn: + async with conn.cursor() as cursor: + # Get signals created after start_date + query = f""" + SELECT + *, COUNT(*) OVER() AS total_count + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + LEFT OUTER JOIN ( + SELECT + name AS unit_name, + region AS unit_region + FROM + units + ) AS u + ON + s.created_unit = u.unit_name + LEFT OUTER JOIN ( + SELECT + name AS location, + region AS location_region, + bureau AS location_bureau + FROM + locations + ) AS l + ON + s.location = l.location + WHERE + status = ANY(%(statuses)s) + AND created_at >= %(start_date)s + ORDER BY + created_at DESC + LIMIT + %(limit)s + ; + """ + + # Add start_date parameter to the filters + filter_params = filters.model_dump() + filter_params['start_date'] = start_date_str + + await cursor.execute(query, filter_params) + rows = await cursor.fetchall() + + signals_list = [Signal(**row) for row in rows] + + status_names = [s.value for s in statuses] + logger.info(f"Found {len(signals_list)} signals with statuses {status_names} from the last {days} days") + return signals_list + + def generate_digest_html(self, signals_list: List[Signal], intro_text: Optional[str] = None, title: str = "Signal Digest") -> str: + """ + Generate HTML content for the digest. + + Parameters + ---------- + signals_list : List[Signal] + List of signals to include in the digest. + intro_text : Optional[str] + Optional custom introduction text. + title : str + Title for the digest page. + + Returns + ------- + str + HTML content for the digest. + """ + if not signals_list: + logger.warning("No signals to include in digest") + return "

No signals were found for this period.

" + + default_intro = """ +

Here's a digest of signals from the UNDP Futures platform. + Below are the latest signals:

+ """ + + intro = intro_text or default_intro + + html = f""" + + + + + + UNDP Futures - {title} + + + +
+

UNDP Futures - {title}

+

Signals from the UNDP Futures platform

+

Generated on {datetime.datetime.now().strftime("%Y-%m-%d %H:%M")}

+
+ + {intro} + +
+ """ + + # Add each signal to the HTML + for signal in signals_list: + keywords_html = "" + if signal.keywords: + keywords_html = " ".join([f'{k}' for k in signal.keywords]) + + location_text = signal.location or "Global" + + # Add status class + status_class = f"status-{signal.status.lower()}" if hasattr(signal, 'status') else "" + status_text = signal.status.capitalize() if hasattr(signal, 'status') else "Unknown" + + # Format created date + created_date = "" + if hasattr(signal, 'created_at') and signal.created_at: + if isinstance(signal.created_at, str): + created_date = signal.created_at + else: + try: + created_date = signal.created_at.strftime("%Y-%m-%d") + except: + created_date = str(signal.created_at) + + html += f""" +
+

{signal.headline} {status_text}

+
+ Location: {location_text} + {f'• Source: View Source' if hasattr(signal, 'url') and signal.url else ''} + {f'• Created: {created_date}' if created_date else ''} + {f'• Created by: {signal.created_by}' if hasattr(signal, 'created_by') and signal.created_by else ''} +
+

{signal.description}

+
+ {keywords_html} +
+
+ """ + + html += """ +
+ + + + + """ + + return html \ No newline at end of file diff --git a/src/services/email_factory.py b/src/services/email_factory.py new file mode 100644 index 0000000..252a76c --- /dev/null +++ b/src/services/email_factory.py @@ -0,0 +1,44 @@ +""" +Factory for creating email service instances. +""" + +import os +import logging +from typing import Optional + +from .email_service import EmailServiceBase +from .msgraph_service import MSGraphEmailService +from .sendgrid_service import SendGridEmailService +from .user_auth_service import UserAuthEmailService + +logger = logging.getLogger(__name__) + +# Email service types +MS_GRAPH = "ms_graph" +SENDGRID = "sendgrid" +USER_AUTH = "user_auth" + +# Default to USER_AUTH with Azure CLI authentication +DEFAULT_EMAIL_SERVICE = USER_AUTH + +def create_email_service(useUserAccessToken: bool = False) -> EmailServiceBase: + """ + Factory function to create an email service instance based on configuration. + Accepts useUserAccessToken to control delegated vs app auth. + + Returns: + EmailServiceBase: An instance of the configured email service. + """ + service_type = os.getenv("EMAIL_SERVICE_TYPE", DEFAULT_EMAIL_SERVICE).lower() + + logger.info(f"Creating email service of type: {service_type} (useUserAccessToken={useUserAccessToken})") + + if service_type == MS_GRAPH: + return MSGraphEmailService(useUserAccessToken=useUserAccessToken) + elif service_type == SENDGRID: + return SendGridEmailService() + elif service_type == USER_AUTH: + return UserAuthEmailService() + else: + logger.warning(f"Unknown email service type: {service_type}. Defaulting to {DEFAULT_EMAIL_SERVICE}") + return UserAuthEmailService() \ No newline at end of file diff --git a/src/services/email_service.py b/src/services/email_service.py new file mode 100644 index 0000000..9565091 --- /dev/null +++ b/src/services/email_service.py @@ -0,0 +1,36 @@ +""" +Base email service interface. +""" + +import abc +import logging +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + +class EmailServiceBase(abc.ABC): + """Abstract base class for email services""" + + @abc.abstractmethod + async def send_email( + self, + to_emails: List[str], + subject: str, + content: str, + content_type: str = "text/plain", + useUserAccessToken: bool = False + ) -> bool: + """Send an email to multiple recipients""" + pass + + @abc.abstractmethod + async def send_notification_email( + self, + to_email: str, + subject: str, + template_id: str, + dynamic_data: Dict[str, Any], + useUserAccessToken: bool = False + ) -> bool: + """Send a templated notification email""" + pass \ No newline at end of file diff --git a/src/services/graph_direct_auth.py b/src/services/graph_direct_auth.py new file mode 100644 index 0000000..33add3e --- /dev/null +++ b/src/services/graph_direct_auth.py @@ -0,0 +1,125 @@ +""" +Direct authentication approach for Graph API using user credentials. +This is a simplified version for testing purposes. +""" + +import os +import httpx +import logging +import asyncio +import json +from typing import Dict, List, Any, Optional + +logger = logging.getLogger(__name__) + +class GraphDirectAuth: + """Direct authentication for Graph API using user credentials""" + + def __init__(self): + self.token = None + self.token_expires = 0 + self.tenant_id = os.getenv('AZURE_TENANT_ID') + self.client_id = os.getenv('AZURE_CLIENT_ID') + self.client_secret = os.getenv('AZURE_CLIENT_SECRET') + + if not all([self.tenant_id, self.client_id, self.client_secret]): + raise ValueError("Missing required environment variables for Graph authentication") + + self.token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" + self.graph_url = "https://graph.microsoft.com/v1.0" + + async def ensure_token(self) -> str: + """Ensure we have a valid token, refreshing if necessary""" + current_time = asyncio.get_event_loop().time() + + # If token is expired or will expire in the next 5 minutes, refresh it + if not self.token or current_time > (self.token_expires - 300): + await self.refresh_token() + + return self.token + + async def refresh_token(self) -> None: + """Get a new access token using client credentials flow""" + try: + data = { + 'grant_type': 'client_credentials', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'scope': 'https://graph.microsoft.com/.default' + } + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + + async with httpx.AsyncClient() as client: + response = await client.post(self.token_url, data=data, headers=headers) + + if response.status_code != 200: + logger.error(f"Failed to get token: {response.status_code}, {response.text}") + raise Exception(f"Failed to get token: {response.status_code}") + + token_data = response.json() + self.token = token_data['access_token'] + + # Calculate token expiration time (convert expires_in from seconds to epoch time) + current_time = asyncio.get_event_loop().time() + self.token_expires = current_time + token_data['expires_in'] + + logger.info(f"Token refreshed, expires in {token_data['expires_in']} seconds") + + except Exception as e: + logger.error(f"Error refreshing token: {str(e)}") + raise + + async def send_email(self, from_email: str, to_emails: List[str], subject: str, + content: str, content_type: str = "HTML") -> bool: + """Send an email using Graph API""" + try: + token = await self.ensure_token() + + # Prepare the email message + message = { + "message": { + "subject": subject, + "body": { + "contentType": content_type, + "content": content + }, + "toRecipients": [ + { + "emailAddress": { + "address": email + } + } for email in to_emails + ], + "from": { + "emailAddress": { + "address": from_email + } + } + }, + "saveToSentItems": "true" + } + + headers = { + 'Authorization': f'Bearer {token}', + 'Content-Type': 'application/json' + } + + # Use /users/{from_email} instead of /me to send as that user + endpoint = f"{self.graph_url}/users/{from_email}/sendMail" + + async with httpx.AsyncClient() as client: + response = await client.post(endpoint, json=message, headers=headers) + + if response.status_code in [200, 201, 202, 204]: + logger.info(f"Email sent successfully: {response.status_code}") + return True + else: + logger.error(f"Failed to send email: {response.status_code}, {response.text}") + return False + + except Exception as e: + logger.error(f"Error sending email: {str(e)}") + return False \ No newline at end of file diff --git a/src/services/msgraph_service.py b/src/services/msgraph_service.py new file mode 100644 index 0000000..bbfccfa --- /dev/null +++ b/src/services/msgraph_service.py @@ -0,0 +1,236 @@ +""" +Microsoft Graph implementation of the email service using Enterprise Application credentials. +""" + +import logging +import os +import json +from typing import Any, Dict, List + +from azure.identity import ClientSecretCredential +import httpx + +from .email_service import EmailServiceBase +from .user_auth_service import UserAuthEmailService + +logger = logging.getLogger(__name__) + +# Define the Microsoft Graph scopes +GRAPH_SCOPE = "https://graph.microsoft.com/.default" +GRAPH_ENDPOINT = "https://graph.microsoft.com/v1.0" + +class MSGraphEmailService(EmailServiceBase): + """Service class for handling email operations using Microsoft Graph API""" + + def __init__(self, useUserAccessToken: bool = False): + self.useUserAccessToken = useUserAccessToken + if useUserAccessToken: + self.user_auth_service = UserAuthEmailService() + return + try: + # Get credentials from environment variables + tenant_id = os.getenv('TENANT_ID') + client_id = os.getenv('CLIENT_ID') + client_secret = os.getenv('CLIENT_SECRET') + service_type = os.getenv('EMAIL_SERVICE_TYPE') + self.from_email = os.getenv('MS_FROM_EMAIL') + logger.info(f"MSGraphEmailService config: TENANT_ID={tenant_id}, CLIENT_ID={client_id}, FROM_EMAIL={self.from_email}, EMAIL_SERVICE_TYPE={service_type}") + if not all([tenant_id, client_id, client_secret]): + logger.error("Missing required environment variables for authentication") + raise ValueError("TENANT_ID, CLIENT_ID, and CLIENT_SECRET must be set") + + # Use ClientSecretCredential for app authentication + self.credential = ClientSecretCredential( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret + ) + + if not self.from_email: + logger.error("MS_FROM_EMAIL environment variable is not set") + raise ValueError("Microsoft sender email is required") + + logger.info("MSGraphEmailService initialized successfully with enterprise application credentials") + + except Exception as e: + logger.error(f"Failed to initialize MSGraphEmailService: {str(e)}") + raise + + async def send_email( + self, + to_emails: List[str], + subject: str, + content: str, + content_type: str = "text/plain", + useUserAccessToken: bool = False + ) -> bool: + if getattr(self, 'useUserAccessToken', False): + return await self.user_auth_service.send_email( + to_emails=to_emails, + subject=subject, + content=content, + content_type=content_type, + useUserAccessToken=True + ) + """Send an email using Microsoft Graph API with Mail.Send permission""" + try: + logger.info(f"send_email config: TENANT_ID={os.getenv('TENANT_ID')}, CLIENT_ID={os.getenv('CLIENT_ID')}, FROM_EMAIL={self.from_email}, EMAIL_SERVICE_TYPE={os.getenv('EMAIL_SERVICE_TYPE')}, to_emails={to_emails}, subject={subject}") + logger.info(f"Preparing to send email to {len(to_emails)} recipients") + + # Prepare the email message + message = { + "message": { + "subject": subject, + "body": { + "contentType": "HTML" if content_type.lower() == "text/html" else "Text", + "content": content + }, + "toRecipients": [ + { + "emailAddress": { + "address": email + } + } for email in to_emails + ], + "from": { + "emailAddress": { + "address": self.from_email + } + } + }, + "saveToSentItems": "true" + } + + logger.debug(f"Email content prepared: subject='{subject}', type='{content_type}'") + + # For enterprise applications with app permissions, we send on behalf of a user + # using /users/{user_id}/sendMail instead of /me/sendMail + user_email = self.from_email + + logger.info("Acquiring Microsoft Graph token...") + try: + token = self.credential.get_token(GRAPH_SCOPE) + logger.info("Token acquired successfully.") + except Exception as token_exc: + logger.error(f"Failed to acquire token: {token_exc}", exc_info=True) + return False + + logger.info(f"Sending email via Graph API to /users/{user_email}/sendMail ...") + try: + response = await self._post(f"/users/{user_email}/sendMail", message, token=token) + except Exception as post_exc: + logger.error(f"Exception during HTTP POST to Graph API: {post_exc}", exc_info=True) + return False + + if response.status_code in [200, 201, 202, 204]: + logger.info(f"Email sent successfully: status_code={response.status_code}") + return True + else: + logger.error(f"Failed to send email: status_code={response.status_code}, response={response.text}") + return False + + except Exception as e: + logger.error(f"Error sending email: {str(e)}", exc_info=True) + return False + + async def send_notification_email( + self, + to_email: str, + subject: str, + template_id: str, + dynamic_data: Dict[str, Any], + useUserAccessToken: bool = False + ) -> bool: + if getattr(self, 'useUserAccessToken', False): + return await self.user_auth_service.send_notification_email( + to_email=to_email, + subject=subject, + template_id=template_id, + dynamic_data=dynamic_data, + useUserAccessToken=True + ) + """Send a templated notification email using Microsoft Graph API""" + try: + logger.info(f"Preparing to send notification email to {to_email}") + logger.debug(f"Using template_id: {template_id}") + logger.debug(f"Dynamic data: {dynamic_data}") + + # For Microsoft Graph, we'll need to handle templates differently + # This is a simplified version that just replaces variables in the template + template_content = await self._get_template_content(template_id) + if not template_content: + return False + + # Replace template variables with dynamic data + for key, value in dynamic_data.items(): + template_content = template_content.replace(f"{{{key}}}", str(value)) + + # Send the email using the processed template + message = { + "message": { + "subject": subject, + "body": { + "contentType": "HTML", + "content": template_content + }, + "toRecipients": [ + { + "emailAddress": { + "address": to_email + } + } + ], + "from": { + "emailAddress": { + "address": self.from_email + } + } + }, + "saveToSentItems": "true" + } + + # For enterprise applications with app permissions, we send on behalf of a user + user_email = self.from_email + + # Send the email + response = await self._post(f"/users/{user_email}/sendMail", message) + + if response.status_code in [200, 201, 202, 204]: + logger.info(f"Notification email sent successfully: status_code={response.status_code}") + return True + else: + logger.error(f"Failed to send notification email: status_code={response.status_code}, response={response.text}") + return False + + except Exception as e: + logger.error(f"Error sending notification email: {str(e)}", exc_info=True) + return False + + async def _post(self, endpoint: str, data: dict, token=None) -> httpx.Response: + """Helper method to make a POST request to the Graph API""" + try: + if token is None: + logger.info("Acquiring token inside _post (should be passed from send_email)...") + token = self.credential.get_token(GRAPH_SCOPE) + headers = { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json" + } + url = f"{GRAPH_ENDPOINT}{endpoint}" + logger.info(f"Making HTTP POST to {url}") + async with httpx.AsyncClient() as client: + response = await client.post(url, headers=headers, json=data) + logger.info(f"HTTP POST completed with status {response.status_code}") + return response + except Exception as e: + logger.error(f"Error in _post method: {str(e)}", exc_info=True) + raise + + async def _get_template_content(self, template_id: str) -> str: + """ + Get the template content from Azure storage or other source. + This is a placeholder - implement based on where templates are stored. + """ + # TODO: Implement template retrieval from Azure storage or other source + logger.warning("Template retrieval not implemented - using placeholder") + return f"

Template {template_id}

This is a placeholder template.

" \ No newline at end of file diff --git a/src/services/sendgrid_service.py b/src/services/sendgrid_service.py new file mode 100644 index 0000000..57fc398 --- /dev/null +++ b/src/services/sendgrid_service.py @@ -0,0 +1,107 @@ +""" +SendGrid implementation of the email service. +""" + +import logging +import os +from typing import Any, Dict, List + +from sendgrid import SendGridAPIClient +from sendgrid.helpers.mail import Content, Email, Mail, Subject, To + +from .email_service import EmailServiceBase + +logger = logging.getLogger(__name__) + +class SendGridEmailService(EmailServiceBase): + """Service class for handling email operations using SendGrid""" + + def __init__(self): + try: + api_key = os.getenv('SENDGRID_API_KEY') + if not api_key: + logger.error("SENDGRID_API_KEY environment variable is not set") + raise ValueError("SendGrid API key is required") + + from_email = os.getenv('SENDGRID_FROM_EMAIL') + if not from_email: + logger.error("SENDGRID_FROM_EMAIL environment variable is not set") + raise ValueError("SendGrid from email is required") + + self.sg_client = SendGridAPIClient(api_key=api_key) + self.from_email = Email(from_email) + logger.info("SendGridEmailService initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize SendGridEmailService: {str(e)}") + raise + + async def send_email( + self, + to_emails: List[str], + subject: str, + content: str, + content_type: str = "text/plain" + ) -> bool: + """Send an email using SendGrid""" + try: + logger.info(f"Preparing to send email to {len(to_emails)} recipients") + + message = Mail( + from_email=self.from_email, + to_emails=[To(email) for email in to_emails], + subject=Subject(subject), + ) + + message.content = [Content(content_type, content)] + + logger.debug(f"Email content prepared: subject='{subject}', type='{content_type}'") + + response = self.sg_client.send(message) + status_code = response.status_code + + if status_code in [200, 201, 202]: + logger.info(f"Email sent successfully: status_code={status_code}") + return True + else: + logger.error(f"Failed to send email: status_code={status_code}") + return False + + except Exception as e: + logger.error(f"Error sending email: {str(e)}", exc_info=True) + return False + + async def send_notification_email( + self, + to_email: str, + subject: str, + template_id: str, + dynamic_data: Dict[str, Any] + ) -> bool: + """Send a templated notification email""" + try: + logger.info(f"Preparing to send notification email to {to_email}") + logger.debug(f"Using template_id: {template_id}") + logger.debug(f"Dynamic data: {dynamic_data}") + + message = Mail( + from_email=self.from_email, + to_emails=[To(to_email)] + ) + + message.template_id = template_id + message.dynamic_template_data = dynamic_data + + response = self.sg_client.send(message) + status_code = response.status_code + + if status_code in [200, 201, 202]: + logger.info(f"Notification email sent successfully: status_code={status_code}") + return True + else: + logger.error(f"Failed to send notification email: status_code={status_code}") + return False + + except Exception as e: + logger.error(f"Error sending notification email: {str(e)}", exc_info=True) + return False \ No newline at end of file diff --git a/src/services/user_auth_service.py b/src/services/user_auth_service.py new file mode 100644 index 0000000..0646191 --- /dev/null +++ b/src/services/user_auth_service.py @@ -0,0 +1,186 @@ +""" +Microsoft Graph implementation using user authentication. +This leverages the existing Azure CLI authentication. +""" + +import logging +import os +import json +import asyncio +import subprocess +from typing import Any, Dict, List + +import httpx + +from .email_service import EmailServiceBase + +logger = logging.getLogger(__name__) + +class UserAuthEmailService(EmailServiceBase): + """Service class for handling email operations using Microsoft Graph API with user auth""" + + def __init__(self): + try: + self.from_email = os.getenv('MS_FROM_EMAIL') + if not self.from_email: + logger.error("MS_FROM_EMAIL environment variable is not set") + raise ValueError("Microsoft sender email is required") + + self.user_email = os.getenv('USER_EMAIL') + if not self.user_email: + logger.error("USER_EMAIL environment variable is not set") + raise ValueError("User email is required") + + # Token cache + self.token = None + self.token_expires = 0 + + logger.info("UserAuthEmailService initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize UserAuthEmailService: {str(e)}") + raise + + async def _get_token(self) -> str: + """Get an access token using az cli""" + current_time = asyncio.get_event_loop().time() + + # If we have a valid token that won't expire in the next 5 minutes, use it + if self.token and current_time < (self.token_expires - 300): + return self.token + + try: + # Get token using az cli command + cmd = [ + "az", "account", "get-access-token", + "--resource", "https://graph.microsoft.com" + ] + + # Run the command + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + token_info = json.loads(result.stdout) + + # Extract token and expiration + self.token = token_info["accessToken"] + + # Calculate token expiration time (timestamp from Azure CLI is already in seconds) + self.token_expires = token_info["expiresOn"] + + logger.info(f"Successfully got access token using Azure CLI. Expires: {self.token_expires}") + return self.token + + except subprocess.CalledProcessError as e: + logger.error(f"Error executing Azure CLI command: {e.stderr}") + raise + except Exception as e: + logger.error(f"Error getting token: {str(e)}") + raise + + async def send_email( + self, + to_emails: List[str], + subject: str, + content: str, + content_type: str = "text/plain", + useUserAccessToken: bool = False + ) -> bool: + # useUserAccessToken is ignored here, always uses user token + try: + logger.info(f"Preparing to send email to {len(to_emails)} recipients") + + # Prepare the email message + message = { + "message": { + "subject": subject, + "body": { + "contentType": "HTML" if content_type.lower() == "text/html" else "Text", + "content": content + }, + "toRecipients": [ + { + "emailAddress": { + "address": email + } + } for email in to_emails + ], + "from": { + "emailAddress": { + "address": self.from_email + } + } + }, + "saveToSentItems": "true" + } + + logger.debug(f"Email content prepared: subject='{subject}', type='{content_type}'") + + # Get token + token = await self._get_token() + + # Send the email using Microsoft Graph API + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + # Use the /me/sendMail endpoint to send from the authenticated user + url = "https://graph.microsoft.com/v1.0/me/sendMail" + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=message, headers=headers) + + if response.status_code in [200, 201, 202, 204]: + logger.info(f"Email sent successfully: status_code={response.status_code}") + return True + else: + logger.error(f"Failed to send email: status_code={response.status_code}, response={response.text}") + return False + + except Exception as e: + logger.error(f"Error sending email: {str(e)}", exc_info=True) + return False + + async def send_notification_email( + self, + to_email: str, + subject: str, + template_id: str, + dynamic_data: Dict[str, Any], + useUserAccessToken: bool = False + ) -> bool: + # useUserAccessToken is ignored here, always uses user token + try: + logger.info(f"Preparing to send notification email to {to_email}") + logger.debug(f"Using template_id: {template_id}") + logger.debug(f"Dynamic data: {dynamic_data}") + + # For Microsoft Graph, we'll need to handle templates differently + # This is a simplified version that just replaces variables in the template + template_content = await self._get_template_content(template_id) + if not template_content: + return False + + # Replace template variables with dynamic data + for key, value in dynamic_data.items(): + template_content = template_content.replace(f"{{{key}}}", str(value)) + + # Send the email using the processed template + return await self.send_email( + to_emails=[to_email], + subject=subject, + content=template_content, + content_type="text/html" + ) + + except Exception as e: + logger.error(f"Error sending notification email: {str(e)}", exc_info=True) + return False + + async def _get_template_content(self, template_id: str) -> str: + """ + Get the template content from Azure storage or other source. + This is a placeholder - implement based on where templates are stored. + """ + # TODO: Implement template retrieval from Azure storage or other source + logger.warning("Template retrieval not implemented - using placeholder") + return f"

Template {template_id}

This is a placeholder template.

" \ No newline at end of file diff --git a/src/services/weekly_digest.py b/src/services/weekly_digest.py new file mode 100644 index 0000000..b499173 --- /dev/null +++ b/src/services/weekly_digest.py @@ -0,0 +1,315 @@ +""" +Service for generating and sending weekly digests of signals. +""" + +import logging +import datetime +from typing import List, Dict, Any, Optional +from datetime import timedelta + +from ..entities import Signal, SignalFilters, Status +from ..database import signals, connection + +logger = logging.getLogger(__name__) + +class WeeklyDigestService: + """Service class for generating and sending weekly digests of signals.""" + + def __init__(self): + """Initialize the weekly digest service.""" + pass + + async def get_recent_signals(self, days: Optional[int] = None, status: Optional[List[Status]] = None, limit: Optional[int] = None) -> List[Signal]: + """ + Get signals filtered by optional days, status, and limit. + If no filters are provided, fetch the last 10 draft signals. + """ + logger.info(f"Getting signals with filters - days: {days}, status: {status}, limit: {limit}") + start_time = datetime.datetime.now() + # Set defaults if not provided + if status is None: + status = [Status.DRAFT] + if limit is None: + limit = 10 + # Calculate date range if days is provided + end_date = datetime.datetime.now() + start_date = end_date - timedelta(days=days) if days is not None else None + start_date_str = start_date.strftime("%Y-%m-%d") if start_date else None + logger.info(f"Date range: {start_date_str} to {end_date.strftime('%Y-%m-%d') if start_date else 'ALL'}") + filters = SignalFilters( + statuses=status, + per_page=limit + ) + logger.debug("Opening database connection for signal fetch...") + async with await connection.get_connection() as conn: + logger.debug("Database connection established.") + async with conn.cursor() as cursor: + logger.debug("Cursor opened. Preparing to execute signal fetch query...") + query = f""" + SELECT + *, COUNT(*) OVER() AS total_count + FROM + signals AS s + LEFT OUTER JOIN ( + SELECT + signal_id, array_agg(trend_id) AS connected_trends + FROM + connections + GROUP BY + signal_id + ) AS c + ON + s.id = c.signal_id + LEFT OUTER JOIN ( + SELECT + name AS unit_name, + region AS unit_region + FROM + units + ) AS u + ON + s.created_unit = u.unit_name + LEFT OUTER JOIN ( + SELECT + name AS location, + region AS location_region, + bureau AS location_bureau + FROM + locations + ) AS l + ON + s.location = l.location + WHERE + status = ANY(%(statuses)s) + {f'AND created_at >= %(start_date)s' if start_date_str else ''} + ORDER BY + created_at DESC + LIMIT + %(limit)s + ; + """ + filter_params = filters.model_dump() + filter_params['limit'] = limit + if start_date_str: + filter_params['start_date'] = start_date_str + logger.debug(f"Executing query with params: {filter_params}") + await cursor.execute(query, filter_params) + logger.debug("Query executed. Fetching rows...") + rows = await cursor.fetchall() + logger.debug(f"Fetched {len(rows)} rows from database.") + signals_list = [Signal(**dict(row)) for row in rows] + logger.info(f"Found {len(signals_list)} signals with filters - days: {days}, status: {status}, limit: {limit}") + elapsed = (datetime.datetime.now() - start_time).total_seconds() + logger.info(f"Signal fetch took {elapsed:.2f} seconds.") + return signals_list + + def generate_email_html(self, signals_list: List[Signal], intro_text: Optional[str] = None) -> str: + """ + Generate HTML content for the weekly digest email. + + Parameters + ---------- + signals_list : List[Signal] + List of signals to include in the digest. + intro_text : Optional[str] + Optional custom introduction text. + + Returns + ------- + str + HTML content for the email. + """ + if not signals_list: + logger.warning("No signals to include in digest") + return "

No new signals were found for this period.

" + + default_intro = """ +

Hello,

+

Here's your weekly digest of new signals from the UNDP Futures platform. + Below are the latest signals that might be of interest:

+ """ + + intro = intro_text or default_intro + + html = f""" + + + + + + UNDP Futures - Weekly Signal Digest + + + +
+

UNDP Futures - Weekly Signal Digest

+

Stay updated with the latest signals from around the world

+
+ + {intro} + +
+ """ + + # Add each signal to the HTML + for signal in signals_list: + keywords_html = "" + if signal.keywords: + keywords_html = " ".join([f'{k}' for k in signal.keywords]) + + location_text = signal.location or "Global" + + html += f""" +
+

{signal.headline}

+
+ Location: {location_text} + {f'• Source: View Source' if signal.url else ''} +
+

{signal.description}

+
+ {keywords_html} +
+
+ """ + + html += """ +
+ + + + + """ + + return html + + async def generate_and_send_digest(self, + recipients: List[str], + days: int = 7, + subject: Optional[str] = None, + custom_intro: Optional[str] = None, + status: Optional[List[Status]] = None, + limit: Optional[int] = None, + useUserAccessToken: bool = False) -> bool: + """ + Generate and send a weekly digest email to specified recipients. + + Parameters + ---------- + recipients : List[str] + List of email addresses to send the digest to. + days : int, optional + Number of days to look back for signals, defaults to 7. + subject : Optional[str], optional + Custom email subject, defaults to standard subject with date. + custom_intro : Optional[str], optional + Custom introduction text for the email. + status : Optional[List[Status]], optional + List of signal statuses to filter by. + limit : Optional[int], optional + Maximum number of signals to include. + useUserAccessToken : bool, optional + Whether to use user access token for email sending. + + Returns + ------- + bool + True if the email was sent successfully, False otherwise. + """ + if not recipients: + logger.error("No recipients specified for weekly digest") + return False + logger.info(f"Generating weekly digest email for {len(recipients)} recipients") + step_start = datetime.datetime.now() + logger.info("Fetching recent signals for digest...") + signals_list = await self.get_recent_signals(days=days, status=status, limit=limit) + logger.info(f"Fetched {len(signals_list)} signals for digest.") + logger.info(f"Signal fetch step took {(datetime.datetime.now() - step_start).total_seconds():.2f} seconds.") + if not signals_list: + logger.warning("No signals found for digest, skipping email send") + return False + logger.info("Generating HTML content for digest email...") + html_start = datetime.datetime.now() + html_content = self.generate_email_html(signals_list, custom_intro) + logger.info(f"HTML generation took {(datetime.datetime.now() - html_start).total_seconds():.2f} seconds.") + today = datetime.datetime.now().strftime("%Y-%m-%d") + email_subject = subject or f"UNDP Futures Weekly Digest - {today}" + from .email_factory import create_email_service + logger.info("Creating email service...") + email_service = create_email_service(useUserAccessToken=useUserAccessToken) + logger.info(f"Sending weekly digest email to {recipients} with subject {email_subject}") + send_start = datetime.datetime.now() + try: + success = await email_service.send_email( + to_emails=recipients, + subject=email_subject, + content=html_content, + content_type="text/html", + useUserAccessToken=useUserAccessToken + ) + logger.info(f"Email send step took {(datetime.datetime.now() - send_start).total_seconds():.2f} seconds.") + if success: + logger.info(f"Weekly digest email sent successfully to {len(recipients)} recipients") + else: + logger.error("Failed to send weekly digest email") + return success + except Exception as e: + logger.error(f"Error sending weekly digest email: {e}", exc_info=True) + return False \ No newline at end of file diff --git a/test_mail_send.py b/test_mail_send.py new file mode 100644 index 0000000..19c903a --- /dev/null +++ b/test_mail_send.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +""" +Test script to verify the email service with Mail.Send permission. +""" + +import asyncio +import os +import sys +import logging +import subprocess +import json + +# Add the project root to the Python path +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from src.services.email_factory import create_email_service + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +logger = logging.getLogger(__name__) + +def check_azure_cli_auth(): + """Check Azure CLI authentication and ensure correct scopes are set.""" + try: + # Check if Azure CLI is logged in + process = subprocess.run( + ["az", "account", "show"], + capture_output=True, + text=True, + check=False + ) + + if process.returncode != 0: + print("Azure CLI is not authenticated. Logging in...") + subprocess.run( + ["az", "login", "--scope", "https://graph.microsoft.com/.default"], + check=True + ) + else: + # Check if we need to set the correct scope + print("Azure CLI is authenticated. Ensuring correct scope is set...") + subprocess.run( + ["az", "account", "get-access-token", "--scope", "https://graph.microsoft.com/.default"], + check=True + ) + + print("✅ Azure CLI authentication complete with correct scope") + return True + except Exception as e: + print(f"❌ Failed to authenticate with Azure CLI: {str(e)}") + return False + +async def test_mail_send(): + """Test the email service with Mail.Send permission.""" + + # First check Azure CLI authentication + if not check_azure_cli_auth(): + return + + # Create the email service using the factory + email_service = create_email_service() + + # Define test parameters + recipient_email = "andrew.maguire@undp.org" + subject = "Test Email - UNDP Future Trends and Signals System" + content = """ + This is a test email sent from the UNDP Future Trends and Signals System. + It verifies that the Mail.Send permission is working correctly. + + If you received this email, the Mail.Send permission is properly configured. + """ + + # Send the test email + print(f"Sending test email to {recipient_email}...") + result = await email_service.send_email( + to_emails=[recipient_email], + subject=subject, + content=content, + content_type="text/plain" + ) + + # Check the result + if result: + print("✅ Test email sent successfully!") + print("The Mail.Send permission is working correctly.") + else: + print("❌ Failed to send test email.") + print("Check logs for more details.") + +if __name__ == "__main__": + asyncio.run(test_mail_send()) \ No newline at end of file diff --git a/tests/test_email.py b/tests/test_email.py new file mode 100644 index 0000000..8fc4a30 --- /dev/null +++ b/tests/test_email.py @@ -0,0 +1,137 @@ +""" +Tests for email services. +""" + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from main import app +from src.services.msgraph_service import MSGraphEmailService +from src.services.sendgrid_service import SendGridEmailService + +client = TestClient(app) + +@pytest.fixture +def mock_msgraph_client(): + with patch('src.services.msgraph_service.GraphClient') as mock: + mock_instance = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 202 + mock_instance.post = AsyncMock(return_value=mock_response) + mock.return_value = mock_instance + yield mock_instance + +@pytest.fixture +def mock_sendgrid_client(): + with patch('src.services.sendgrid_service.SendGridAPIClient') as mock: + mock_instance = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 202 + mock_instance.send = MagicMock(return_value=mock_response) + mock.return_value = mock_instance + yield mock_instance + +@pytest.mark.asyncio +async def test_msgraph_send_email(mock_msgraph_client): + """Test sending email via Microsoft Graph API""" + # Setup + service = MSGraphEmailService() + + # Test + result = await service.send_email( + to_emails=["test@example.com"], + subject="Test Subject", + content="Test Content" + ) + + # Assert + assert result is True + mock_msgraph_client.post.assert_called_once() + call_args = mock_msgraph_client.post.call_args[0] + assert call_args[0] == "/me/sendMail" + +@pytest.mark.asyncio +async def test_msgraph_send_notification(mock_msgraph_client): + """Test sending notification via Microsoft Graph API""" + # Setup + service = MSGraphEmailService() + + # Test + result = await service.send_notification_email( + to_email="test@example.com", + subject="Test Notification", + template_id="test-template", + dynamic_data={"name": "Test User"} + ) + + # Assert + assert result is True + mock_msgraph_client.post.assert_called_once() + call_args = mock_msgraph_client.post.call_args[0] + assert call_args[0] == "/me/sendMail" + +@pytest.mark.asyncio +async def test_sendgrid_send_email(mock_sendgrid_client): + """Test sending email via SendGrid""" + # Setup + service = SendGridEmailService() + + # Test + result = await service.send_email( + to_emails=["test@example.com"], + subject="Test Subject", + content="Test Content" + ) + + # Assert + assert result is True + mock_sendgrid_client.send.assert_called_once() + +@pytest.mark.asyncio +async def test_sendgrid_send_notification(mock_sendgrid_client): + """Test sending notification via SendGrid""" + # Setup + service = SendGridEmailService() + + # Test + result = await service.send_notification_email( + to_email="test@example.com", + subject="Test Notification", + template_id="test-template", + dynamic_data={"name": "Test User"} + ) + + # Assert + assert result is True + mock_sendgrid_client.send.assert_called_once() + +@pytest.mark.skip(reason="Requires database connection") +def test_email_endpoints(headers: dict): + """Test email endpoints with authentication""" + # Test send email endpoint + response = client.post( + "/email/send", + json={ + "to_emails": ["test@example.com"], + "subject": "Test Subject", + "content": "Test Content" + }, + headers=headers + ) + assert response.status_code in [200, 403] # 200 if admin, 403 if not admin + + # Test notification endpoint + response = client.post( + "/email/notify", + json={ + "to_email": "test@example.com", + "subject": "Test Notification", + "template_id": "test-template", + "dynamic_data": {"name": "Test User"} + }, + headers=headers + ) + assert response.status_code in [200, 403] # 200 if admin, 403 if not admin From 066834c6efed0bc207d8d34cd3a82d30ca40d886 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 12:46:29 +0300 Subject: [PATCH 21/31] Update authentication.py --- src/authentication.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/authentication.py b/src/authentication.py index 89ba2ea..9ff3a4f 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -11,6 +11,7 @@ from fastapi import Depends, Security from fastapi.security import APIKeyHeader from psycopg import AsyncCursor +import psycopg.errors from . import database as db from . import exceptions @@ -204,5 +205,14 @@ async def authenticate_user( if (user := await db.read_user_by_email(cursor, email_str)) is None: user = User(email=email_str, role=Role.USER, name=name_str) - await db.create_user(cursor, user) + try: + await db.create_user(cursor, user) + except psycopg.errors.UniqueViolation: + # User was created by another request in the meantime, fetch the existing user + logging.info(f"User {email_str} already exists, fetching existing user") + user = await db.read_user_by_email(cursor, email_str) + if user is None: + # This should not happen, but handle it gracefully + logging.error(f"Failed to fetch user {email_str} after UniqueViolation") + raise exceptions.not_authenticated return user From 217de9c323f22e0febfa4f7650baa48a689f632b Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 13:21:36 +0300 Subject: [PATCH 22/31] CORS updates MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Set allow_origins to ["*"] for production mode - Disable credentials when using wildcard origins (CORS requirement) - Update CORSHandlerMiddleware to handle all origins properly - Maintain localhost support with credentials for local development 🤖 Generated with [Claude Code](https://claude.ai/code) --- .gitignore | 1 + main.py | 42 +++++++++++++++++++++++++++++------------- src/main.py | 11 ----------- 3 files changed, 30 insertions(+), 24 deletions(-) delete mode 100644 src/main.py diff --git a/.gitignore b/.gitignore index 8e2eb1c..9533938 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,4 @@ Taskfile.yml /logs webapp_logs.zip /.schemas +app_logs.zip diff --git a/main.py b/main.py index fab43eb..d0576cd 100644 --- a/main.py +++ b/main.py @@ -102,23 +102,39 @@ async def global_exception_handler(request: Request, exc: Exception): "http://127.0.0.1:3000" ] +# Production origins for different environments +production_origins = [ + "https://signals.data.undp.org", + "https://thankful-forest-05a90a303-staging.westeurope.3.azurestaticapps.net", + "https://signals-staging.data.undp.org" +] + # Create a custom middleware class for handling CORS class CORSHandlerMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Handle OPTIONS preflight requests if request.method == "OPTIONS": - headers = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", - "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", - "Access-Control-Allow-Credentials": "true", - "Access-Control-Max-Age": "600", # Cache preflight for 10 minutes - } - - # Set specific origin if in local mode origin = request.headers.get("origin") - if os.environ.get("ENV_MODE") == "local" and origin: - headers["Access-Control-Allow-Origin"] = origin + + # Allow all origins but handle credentials properly + if os.environ.get("ENV_MODE") == "local" and origin in local_origins: + # Local mode: allow specific origins with credentials + headers = { + "Access-Control-Allow-Origin": origin, + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "600", + } + else: + # Production mode: allow all origins without credentials + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", + "Access-Control-Allow-Credentials": "false", + "Access-Control-Max-Age": "600", + } return JSONResponse(content={}, status_code=200, headers=headers) @@ -141,11 +157,11 @@ async def dispatch(self, request: Request, call_next): expose_headers=["*"], ) else: - # Production mode - use more restrictive CORS + # Production mode - allow all origins for client flexibility app.add_middleware( CORSMiddleware, allow_origins=["*"], - allow_credentials=True, + allow_credentials=False, # Must be False when allow_origins is ["*"] allow_methods=["*"], allow_headers=["*", "access_token", "Authorization", "Content-Type"], ) diff --git a/src/main.py b/src/main.py deleted file mode 100644 index eb48e1b..0000000 --- a/src/main.py +++ /dev/null @@ -1,11 +0,0 @@ -import logging - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.StreamHandler(), - logging.FileHandler('app.log') - ] -) \ No newline at end of file From d2ac4de6e46dbb494fae22a247de330bb8032e03 Mon Sep 17 00:00:00 2001 From: Andrew Maguire Date: Tue, 27 May 2025 21:00:37 +0300 Subject: [PATCH 23/31] remove legacy Sengrid code --- .env.example | 2 - docs/email_digest_delivery_methods.md | 8 -- docs/email_system.md | 12 +-- scripts/email_requirements.txt | 3 - src/services/email_factory.py | 4 - src/services/sendgrid_service.py | 107 -------------------------- tests/test_email.py | 43 ----------- 7 files changed, 2 insertions(+), 177 deletions(-) delete mode 100644 scripts/email_requirements.txt delete mode 100644 src/services/sendgrid_service.py diff --git a/.env.example b/.env.example index 0f3765f..33e898f 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,4 @@ # Email Configuration -SENDGRID_API_KEY= -SENDGRID_FROM_EMAIL= MS_FROM_EMAIL=futureofdevelopment@undp.org EMAIL_SERVICE_TYPE=ms_graph diff --git a/docs/email_digest_delivery_methods.md b/docs/email_digest_delivery_methods.md index 17f1255..96cb876 100644 --- a/docs/email_digest_delivery_methods.md +++ b/docs/email_digest_delivery_methods.md @@ -47,13 +47,6 @@ This document summarizes all the methods we have tried (and considered) for send --- -## 4. SendGrid or Third-Party SMTP Relay - -- **Approach:** Use a third-party SMTP service (e.g., SendGrid) to send as the service account. -- **Status:** **Not attempted** (would require IT approval and setup). -- **Blocker:** - - May not be allowed by organizational policy. - --- ## 5. Distribution List/Group Delivery @@ -75,7 +68,6 @@ This document summarizes all the methods we have tried (and considered) for send | MS Graph (App perms) | Yes | Blocked | Need admin to grant permissions | | MS Graph (Delegated) | No | Works (manual only) | Not suitable for automation | | SMTP (O365) | Yes | Blocked | SMTP AUTH disabled for tenant | -| SendGrid/3rd-party | Yes | Not attempted | Needs IT approval | | Distribution List | Yes | Ready | Blocked by above sending method | --- diff --git a/docs/email_system.md b/docs/email_system.md index 14e727f..2bf9cd7 100644 --- a/docs/email_system.md +++ b/docs/email_system.md @@ -12,7 +12,6 @@ The email system consists of the following components: - `EmailServiceBase`: Abstract base class defining the interface for all email services - `MSGraphEmailService`: Implementation using Microsoft Graph API with enterprise application authentication - `UserAuthEmailService`: Implementation using Azure CLI authentication - - `SendGridEmailService`: Alternative implementation using SendGrid - `EmailFactory`: Factory pattern for creating the appropriate service based on configuration 2. **Weekly Digest Feature** @@ -35,7 +34,7 @@ The email system consists of the following components: source venv/bin/activate # Install the required packages -pip install python-dotenv msgraph-core azure-identity httpx sendgrid +pip install python-dotenv msgraph-core azure-identity httpx ``` ### Environment Variables @@ -45,7 +44,7 @@ The following environment variables need to be set in your `.env.local` file: ``` # Email Configuration MS_FROM_EMAIL=exo.futures.curators@undp.org # Email that will appear as the sender -EMAIL_SERVICE_TYPE=ms_graph # Authentication type (ms_graph, user_auth, or sendgrid) +EMAIL_SERVICE_TYPE=ms_graph # Authentication type (ms_graph or user_auth) # Azure Authentication for UNDP Enterprise Application TENANT_ID=b3e5db5e-2944-4837-99f5-7488ace54319 # UNDP tenant ID @@ -78,13 +77,6 @@ Requirements: - Azure CLI installed and logged in with `az login` - User must have Mail.Send permissions in Microsoft Graph -#### 3. SendGrid Authentication - -Alternative email provider if Microsoft Graph is not available. - -Requirements: -- SendGrid API key -- SendGrid from email address ### Azure AD Enterprise Application Configuration diff --git a/scripts/email_requirements.txt b/scripts/email_requirements.txt deleted file mode 100644 index c5f6967..0000000 --- a/scripts/email_requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -azure-identity>=1.13.0 -msgraph-core>=0.2.2 -sendgrid>=6.10.0 \ No newline at end of file diff --git a/src/services/email_factory.py b/src/services/email_factory.py index 252a76c..eac87b7 100644 --- a/src/services/email_factory.py +++ b/src/services/email_factory.py @@ -8,14 +8,12 @@ from .email_service import EmailServiceBase from .msgraph_service import MSGraphEmailService -from .sendgrid_service import SendGridEmailService from .user_auth_service import UserAuthEmailService logger = logging.getLogger(__name__) # Email service types MS_GRAPH = "ms_graph" -SENDGRID = "sendgrid" USER_AUTH = "user_auth" # Default to USER_AUTH with Azure CLI authentication @@ -35,8 +33,6 @@ def create_email_service(useUserAccessToken: bool = False) -> EmailServiceBase: if service_type == MS_GRAPH: return MSGraphEmailService(useUserAccessToken=useUserAccessToken) - elif service_type == SENDGRID: - return SendGridEmailService() elif service_type == USER_AUTH: return UserAuthEmailService() else: diff --git a/src/services/sendgrid_service.py b/src/services/sendgrid_service.py deleted file mode 100644 index 57fc398..0000000 --- a/src/services/sendgrid_service.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -SendGrid implementation of the email service. -""" - -import logging -import os -from typing import Any, Dict, List - -from sendgrid import SendGridAPIClient -from sendgrid.helpers.mail import Content, Email, Mail, Subject, To - -from .email_service import EmailServiceBase - -logger = logging.getLogger(__name__) - -class SendGridEmailService(EmailServiceBase): - """Service class for handling email operations using SendGrid""" - - def __init__(self): - try: - api_key = os.getenv('SENDGRID_API_KEY') - if not api_key: - logger.error("SENDGRID_API_KEY environment variable is not set") - raise ValueError("SendGrid API key is required") - - from_email = os.getenv('SENDGRID_FROM_EMAIL') - if not from_email: - logger.error("SENDGRID_FROM_EMAIL environment variable is not set") - raise ValueError("SendGrid from email is required") - - self.sg_client = SendGridAPIClient(api_key=api_key) - self.from_email = Email(from_email) - logger.info("SendGridEmailService initialized successfully") - - except Exception as e: - logger.error(f"Failed to initialize SendGridEmailService: {str(e)}") - raise - - async def send_email( - self, - to_emails: List[str], - subject: str, - content: str, - content_type: str = "text/plain" - ) -> bool: - """Send an email using SendGrid""" - try: - logger.info(f"Preparing to send email to {len(to_emails)} recipients") - - message = Mail( - from_email=self.from_email, - to_emails=[To(email) for email in to_emails], - subject=Subject(subject), - ) - - message.content = [Content(content_type, content)] - - logger.debug(f"Email content prepared: subject='{subject}', type='{content_type}'") - - response = self.sg_client.send(message) - status_code = response.status_code - - if status_code in [200, 201, 202]: - logger.info(f"Email sent successfully: status_code={status_code}") - return True - else: - logger.error(f"Failed to send email: status_code={status_code}") - return False - - except Exception as e: - logger.error(f"Error sending email: {str(e)}", exc_info=True) - return False - - async def send_notification_email( - self, - to_email: str, - subject: str, - template_id: str, - dynamic_data: Dict[str, Any] - ) -> bool: - """Send a templated notification email""" - try: - logger.info(f"Preparing to send notification email to {to_email}") - logger.debug(f"Using template_id: {template_id}") - logger.debug(f"Dynamic data: {dynamic_data}") - - message = Mail( - from_email=self.from_email, - to_emails=[To(to_email)] - ) - - message.template_id = template_id - message.dynamic_template_data = dynamic_data - - response = self.sg_client.send(message) - status_code = response.status_code - - if status_code in [200, 201, 202]: - logger.info(f"Notification email sent successfully: status_code={status_code}") - return True - else: - logger.error(f"Failed to send notification email: status_code={status_code}") - return False - - except Exception as e: - logger.error(f"Error sending notification email: {str(e)}", exc_info=True) - return False \ No newline at end of file diff --git a/tests/test_email.py b/tests/test_email.py index 8fc4a30..6a68146 100644 --- a/tests/test_email.py +++ b/tests/test_email.py @@ -10,7 +10,6 @@ from main import app from src.services.msgraph_service import MSGraphEmailService -from src.services.sendgrid_service import SendGridEmailService client = TestClient(app) @@ -24,15 +23,6 @@ def mock_msgraph_client(): mock.return_value = mock_instance yield mock_instance -@pytest.fixture -def mock_sendgrid_client(): - with patch('src.services.sendgrid_service.SendGridAPIClient') as mock: - mock_instance = MagicMock() - mock_response = MagicMock() - mock_response.status_code = 202 - mock_instance.send = MagicMock(return_value=mock_response) - mock.return_value = mock_instance - yield mock_instance @pytest.mark.asyncio async def test_msgraph_send_email(mock_msgraph_client): @@ -73,40 +63,7 @@ async def test_msgraph_send_notification(mock_msgraph_client): call_args = mock_msgraph_client.post.call_args[0] assert call_args[0] == "/me/sendMail" -@pytest.mark.asyncio -async def test_sendgrid_send_email(mock_sendgrid_client): - """Test sending email via SendGrid""" - # Setup - service = SendGridEmailService() - - # Test - result = await service.send_email( - to_emails=["test@example.com"], - subject="Test Subject", - content="Test Content" - ) - - # Assert - assert result is True - mock_sendgrid_client.send.assert_called_once() -@pytest.mark.asyncio -async def test_sendgrid_send_notification(mock_sendgrid_client): - """Test sending notification via SendGrid""" - # Setup - service = SendGridEmailService() - - # Test - result = await service.send_notification_email( - to_email="test@example.com", - subject="Test Notification", - template_id="test-template", - dynamic_data={"name": "Test User"} - ) - - # Assert - assert result is True - mock_sendgrid_client.send.assert_called_once() @pytest.mark.skip(reason="Requires database connection") def test_email_endpoints(headers: dict): From 1121642efaec39a0de99e0898d2f95df0ea9cf19 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 21:40:10 +0300 Subject: [PATCH 24/31] revert CORS --- main.py | 42 +++++++++++++----------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/main.py b/main.py index d0576cd..fab43eb 100644 --- a/main.py +++ b/main.py @@ -102,39 +102,23 @@ async def global_exception_handler(request: Request, exc: Exception): "http://127.0.0.1:3000" ] -# Production origins for different environments -production_origins = [ - "https://signals.data.undp.org", - "https://thankful-forest-05a90a303-staging.westeurope.3.azurestaticapps.net", - "https://signals-staging.data.undp.org" -] - # Create a custom middleware class for handling CORS class CORSHandlerMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): # Handle OPTIONS preflight requests if request.method == "OPTIONS": - origin = request.headers.get("origin") + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "600", # Cache preflight for 10 minutes + } - # Allow all origins but handle credentials properly - if os.environ.get("ENV_MODE") == "local" and origin in local_origins: - # Local mode: allow specific origins with credentials - headers = { - "Access-Control-Allow-Origin": origin, - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", - "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", - "Access-Control-Allow-Credentials": "true", - "Access-Control-Max-Age": "600", - } - else: - # Production mode: allow all origins without credentials - headers = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", - "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", - "Access-Control-Allow-Credentials": "false", - "Access-Control-Max-Age": "600", - } + # Set specific origin if in local mode + origin = request.headers.get("origin") + if os.environ.get("ENV_MODE") == "local" and origin: + headers["Access-Control-Allow-Origin"] = origin return JSONResponse(content={}, status_code=200, headers=headers) @@ -157,11 +141,11 @@ async def dispatch(self, request: Request, call_next): expose_headers=["*"], ) else: - # Production mode - allow all origins for client flexibility + # Production mode - use more restrictive CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], - allow_credentials=False, # Must be False when allow_origins is ["*"] + allow_credentials=True, allow_methods=["*"], allow_headers=["*", "access_token", "Authorization", "Content-Type"], ) From 19d333ba137d30a899f71e008f636f85c86acd25 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 22:07:44 +0300 Subject: [PATCH 25/31] =?UTF-8?q?Force=20deployment=20-=20fix=20sendgrid?= =?UTF-8?q?=20import=20error=20=F0=9F=A4=96=20Generated=20with=20[Claude?= =?UTF-8?q?=20Code](https://claude.ai/code)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude From da995f01155208a1495821ea924155e7133e94fe Mon Sep 17 00:00:00 2001 From: happy-devs Date: Tue, 27 May 2025 22:14:13 +0300 Subject: [PATCH 26/31] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 9533938..1814f88 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,5 @@ Taskfile.yml webapp_logs.zip /.schemas app_logs.zip +/deployments +/LogFiles From b0af730d5e683513a3438bd5b4e21781bcee0a5d Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 28 May 2025 10:16:35 +0300 Subject: [PATCH 27/31] Force new deployment - trigger fresh build --- .cursorindexingignore | 3 +++ main.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 .cursorindexingignore diff --git a/.cursorindexingignore b/.cursorindexingignore new file mode 100644 index 0000000..953908e --- /dev/null +++ b/.cursorindexingignore @@ -0,0 +1,3 @@ + +# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references +.specstory/** diff --git a/main.py b/main.py index fab43eb..dfd3d79 100644 --- a/main.py +++ b/main.py @@ -22,7 +22,7 @@ setup_logging() # Get application version -app_version = os.environ.get("RELEASE_VERSION", "dev") +app_version = os.environ.get("RELEASE_VERSION", "dev-fixed") app_env = os.environ.get("ENVIRONMENT", "development") # Override environment setting if in local mode if os.environ.get("ENV_MODE") == "local": From c8697a4400a6bea4ead8360afbc04f9a9ac75b18 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Wed, 28 May 2025 10:49:09 +0300 Subject: [PATCH 28/31] Fix email service startup errors - lazy initialization and graceful fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move email service creation from module level to endpoint level - Handle missing credentials gracefully with warnings instead of errors - Add configuration checks in email service methods - Prevent startup failures due to missing CLIENT_SECRET env var 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- scripts/test_direct_email.py | 2 +- scripts/test_user_auth_email.py | 2 +- src/routers/email.py | 8 ++++++-- src/services/graph_direct_auth.py | 15 +++++++++++---- src/services/msgraph_service.py | 17 +++++++++++++++-- 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/scripts/test_direct_email.py b/scripts/test_direct_email.py index a8578a5..d4dc5af 100644 --- a/scripts/test_direct_email.py +++ b/scripts/test_direct_email.py @@ -102,7 +102,7 @@ async def test_direct_email(to_email: str) -> None:
  • From Email: {from_email}
  • To Email: {to_email}
  • -
  • Tenant ID: {os.getenv('AZURE_TENANT_ID')}
  • +
  • Tenant ID: {os.getenv('TENANT_ID')}
diff --git a/scripts/test_user_auth_email.py b/scripts/test_user_auth_email.py index 6c6bf45..cea77d0 100755 --- a/scripts/test_user_auth_email.py +++ b/scripts/test_user_auth_email.py @@ -99,7 +99,7 @@ async def test_email(to_email: str) -> None:
  • From Email: {os.getenv('MS_FROM_EMAIL')}
  • User Email: {os.getenv('USER_EMAIL')}
  • To Email: {to_email}
  • -
  • Tenant ID: {os.getenv('AZURE_TENANT_ID')}
  • +
  • Tenant ID: {os.getenv('TENANT_ID')}
  • diff --git a/src/routers/email.py b/src/routers/email.py index bff1273..1fc627d 100644 --- a/src/routers/email.py +++ b/src/routers/email.py @@ -33,8 +33,10 @@ class DigestRequest(BaseModel): limit: int | None = None test: bool = False -# Initialize email service -email_service = create_email_service() +# Lazy email service initialization +def get_email_service(): + """Get email service instance. Created on first use to avoid startup errors.""" + return create_email_service() @router.post("/send", dependencies=[Depends(require_admin)]) async def send_email(request: EmailRequest): @@ -42,6 +44,7 @@ async def send_email(request: EmailRequest): Send an email to multiple recipients. Only accessible by admin users. """ + email_service = get_email_service() success = await email_service.send_email( to_emails=request.to_emails, subject=request.subject, @@ -60,6 +63,7 @@ async def send_notification(request: NotificationRequest): Send a templated notification email. Only accessible by admin users. """ + email_service = get_email_service() success = await email_service.send_notification_email( to_email=request.to_email, subject=request.subject, diff --git a/src/services/graph_direct_auth.py b/src/services/graph_direct_auth.py index 33add3e..e8c3165 100644 --- a/src/services/graph_direct_auth.py +++ b/src/services/graph_direct_auth.py @@ -18,18 +18,25 @@ class GraphDirectAuth: def __init__(self): self.token = None self.token_expires = 0 - self.tenant_id = os.getenv('AZURE_TENANT_ID') - self.client_id = os.getenv('AZURE_CLIENT_ID') - self.client_secret = os.getenv('AZURE_CLIENT_SECRET') + self.tenant_id = os.getenv('TENANT_ID') + self.client_id = os.getenv('CLIENT_ID') + self.client_secret = os.getenv('CLIENT_SECRET') if not all([self.tenant_id, self.client_id, self.client_secret]): - raise ValueError("Missing required environment variables for Graph authentication") + logger.warning("Missing required environment variables for Graph authentication. Service will not be available.") + self.configured = False + return + + self.configured = True self.token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" self.graph_url = "https://graph.microsoft.com/v1.0" async def ensure_token(self) -> str: """Ensure we have a valid token, refreshing if necessary""" + if not getattr(self, 'configured', False): + raise ValueError("GraphDirectAuth not properly configured - missing environment variables") + current_time = asyncio.get_event_loop().time() # If token is expired or will expire in the next 5 minutes, refresh it diff --git a/src/services/msgraph_service.py b/src/services/msgraph_service.py index bbfccfa..4e442df 100644 --- a/src/services/msgraph_service.py +++ b/src/services/msgraph_service.py @@ -36,8 +36,9 @@ def __init__(self, useUserAccessToken: bool = False): self.from_email = os.getenv('MS_FROM_EMAIL') logger.info(f"MSGraphEmailService config: TENANT_ID={tenant_id}, CLIENT_ID={client_id}, FROM_EMAIL={self.from_email}, EMAIL_SERVICE_TYPE={service_type}") if not all([tenant_id, client_id, client_secret]): - logger.error("Missing required environment variables for authentication") - raise ValueError("TENANT_ID, CLIENT_ID, and CLIENT_SECRET must be set") + logger.warning("Missing required environment variables for MSGraph authentication. Service will not be available.") + self.credential = None + return # Use ClientSecretCredential for app authentication self.credential = ClientSecretCredential( @@ -72,6 +73,12 @@ async def send_email( content_type=content_type, useUserAccessToken=True ) + + # Check if service is properly configured + if not hasattr(self, 'credential') or self.credential is None: + logger.error("MSGraphEmailService not properly configured - missing credentials") + return False + """Send an email using Microsoft Graph API with Mail.Send permission""" try: logger.info(f"send_email config: TENANT_ID={os.getenv('TENANT_ID')}, CLIENT_ID={os.getenv('CLIENT_ID')}, FROM_EMAIL={self.from_email}, EMAIL_SERVICE_TYPE={os.getenv('EMAIL_SERVICE_TYPE')}, to_emails={to_emails}, subject={subject}") @@ -149,6 +156,12 @@ async def send_notification_email( dynamic_data=dynamic_data, useUserAccessToken=True ) + + # Check if service is properly configured + if not hasattr(self, 'credential') or self.credential is None: + logger.error("MSGraphEmailService not properly configured - missing credentials") + return False + """Send a templated notification email using Microsoft Graph API""" try: logger.info(f"Preparing to send notification email to {to_email}") From a1454fb46ae3df84c79527a421634f4496b4f652 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 29 May 2025 14:25:39 +0300 Subject: [PATCH 29/31] add llm data export --- scripts/export_for_llm.py | 115 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100755 scripts/export_for_llm.py diff --git a/scripts/export_for_llm.py b/scripts/export_for_llm.py new file mode 100755 index 0000000..b34f542 --- /dev/null +++ b/scripts/export_for_llm.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +""" +Export signals and trends data for LLM processing. +Exports all public signals and trends to CSV format, excluding certain fields. +""" + +import os +import sys +import asyncio +import csv +from datetime import datetime +import psycopg +from psycopg.rows import dict_row + +# Add parent directory to path to import src modules +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# No need to import get_connection_string - we'll use DB_CONNECTION directly + +# Fields to exclude from the export +EXCLUDE_FIELDS = [ + 'is_draft', 'private', 'favourite', + 'can_edit', 'modified_at', 'url', 'favorite' +] +file_path = ".exports" + +async def export_table_to_csv(conn, table_name, query, filename_prefix): + """Export data from a table to CSV, excluding certain fields.""" + print(f"Exporting {table_name}...") + + async with conn.cursor(row_factory=dict_row) as cursor: + await cursor.execute(query) + records = await cursor.fetchall() + + if not records: + print(f"No records found in {table_name}.") + return + + # Get all field names from the first record + all_fields = list(records[0].keys()) + # Filter out excluded fields + export_fields = [field for field in all_fields if field not in EXCLUDE_FIELDS] + # Add app_link as the last column + export_fields.append('app_link') + + # Compose filename + filename = f'{file_path}/{table_name}.csv' + + # Ensure export directory exists + os.makedirs(file_path, exist_ok=True) + + with open(filename, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=export_fields) + writer.writeheader() + for record in records: + row = {field: record[field] for field in export_fields if field != 'app_link'} + for field, value in row.items(): + if isinstance(value, list): + row[field] = ', '.join(str(v) for v in value) if value else '' + # Add app_link + if table_name == 'signals': + row['app_link'] = f'https://signals.data.undp.org/signals/{record["id"]}' + elif table_name == 'trends': + row['app_link'] = f'https://signals.data.undp.org/trends/{record["id"]}' + else: + row['app_link'] = '' + writer.writerow(row) + + print(f"Exported {len(records)} {table_name} to {filename}") + return filename + +async def main(): + """Main function to export signals and trends.""" + # Get database connection string from environment + connection_string = os.environ.get("DB_CONNECTION") + + if not connection_string: + print("Error: DB_CONNECTION environment variable not set") + sys.exit(1) + + try: + # Connect to the database + async with await psycopg.AsyncConnection.connect( + connection_string, + row_factory=dict_row + ) as conn: + print("Connected to database successfully") + + # Export signals + signals_query = """ + SELECT * FROM signals + WHERE private = FALSE OR private IS NULL + ORDER BY id + """ + signals_file = await export_table_to_csv(conn, "signals", signals_query, "signals") + + # Export trends + trends_query = """ + SELECT * FROM trends + ORDER BY id + """ + trends_file = await export_table_to_csv(conn, "trends", trends_query, "trends") + + print("\nExport completed successfully!") + if signals_file: + print(f"Signals: {signals_file}") + if trends_file: + print(f"Trends: {trends_file}") + + except Exception as e: + print(f"Error during export: {e}") + sys.exit(1) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file From 2eae47756b25841821f305b4f1dd9be2e16a2b06 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 29 May 2025 18:53:32 +0300 Subject: [PATCH 30/31] update signals search --- .gitignore | 3 +++ src/database/signals.py | 7 +------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 1814f88..12518e9 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,6 @@ webapp_logs.zip app_logs.zip /deployments /LogFiles +/.exports +schema.sql +schema.dbml diff --git a/src/database/signals.py b/src/database/signals.py index 5a64ba6..44256d9 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -92,12 +92,7 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP AND (%(score)s IS NULL OR score = %(score)s) AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) - AND (%(user_email)s IS NOT NULL AND ( - private = FALSE OR - created_by = %(user_email)s OR - %(is_admin)s = TRUE OR - %(is_staff)s = TRUE - )) + AND private = FALSE ORDER BY {filters.order_by} {filters.direction} OFFSET From 5581ee894f377cb6db33dc9a8df61dedad848cb3 Mon Sep 17 00:00:00 2001 From: happy-devs Date: Thu, 29 May 2025 18:54:36 +0300 Subject: [PATCH 31/31] Revert "add llm data export" This reverts commit a1454fb46ae3df84c79527a421634f4496b4f652. --- scripts/export_for_llm.py | 115 -------------------------------------- 1 file changed, 115 deletions(-) delete mode 100755 scripts/export_for_llm.py diff --git a/scripts/export_for_llm.py b/scripts/export_for_llm.py deleted file mode 100755 index b34f542..0000000 --- a/scripts/export_for_llm.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -""" -Export signals and trends data for LLM processing. -Exports all public signals and trends to CSV format, excluding certain fields. -""" - -import os -import sys -import asyncio -import csv -from datetime import datetime -import psycopg -from psycopg.rows import dict_row - -# Add parent directory to path to import src modules -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# No need to import get_connection_string - we'll use DB_CONNECTION directly - -# Fields to exclude from the export -EXCLUDE_FIELDS = [ - 'is_draft', 'private', 'favourite', - 'can_edit', 'modified_at', 'url', 'favorite' -] -file_path = ".exports" - -async def export_table_to_csv(conn, table_name, query, filename_prefix): - """Export data from a table to CSV, excluding certain fields.""" - print(f"Exporting {table_name}...") - - async with conn.cursor(row_factory=dict_row) as cursor: - await cursor.execute(query) - records = await cursor.fetchall() - - if not records: - print(f"No records found in {table_name}.") - return - - # Get all field names from the first record - all_fields = list(records[0].keys()) - # Filter out excluded fields - export_fields = [field for field in all_fields if field not in EXCLUDE_FIELDS] - # Add app_link as the last column - export_fields.append('app_link') - - # Compose filename - filename = f'{file_path}/{table_name}.csv' - - # Ensure export directory exists - os.makedirs(file_path, exist_ok=True) - - with open(filename, 'w', newline='', encoding='utf-8') as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=export_fields) - writer.writeheader() - for record in records: - row = {field: record[field] for field in export_fields if field != 'app_link'} - for field, value in row.items(): - if isinstance(value, list): - row[field] = ', '.join(str(v) for v in value) if value else '' - # Add app_link - if table_name == 'signals': - row['app_link'] = f'https://signals.data.undp.org/signals/{record["id"]}' - elif table_name == 'trends': - row['app_link'] = f'https://signals.data.undp.org/trends/{record["id"]}' - else: - row['app_link'] = '' - writer.writerow(row) - - print(f"Exported {len(records)} {table_name} to {filename}") - return filename - -async def main(): - """Main function to export signals and trends.""" - # Get database connection string from environment - connection_string = os.environ.get("DB_CONNECTION") - - if not connection_string: - print("Error: DB_CONNECTION environment variable not set") - sys.exit(1) - - try: - # Connect to the database - async with await psycopg.AsyncConnection.connect( - connection_string, - row_factory=dict_row - ) as conn: - print("Connected to database successfully") - - # Export signals - signals_query = """ - SELECT * FROM signals - WHERE private = FALSE OR private IS NULL - ORDER BY id - """ - signals_file = await export_table_to_csv(conn, "signals", signals_query, "signals") - - # Export trends - trends_query = """ - SELECT * FROM trends - ORDER BY id - """ - trends_file = await export_table_to_csv(conn, "trends", trends_query, "trends") - - print("\nExport completed successfully!") - if signals_file: - print(f"Signals: {signals_file}") - if trends_file: - print(f"Trends: {trends_file}") - - except Exception as e: - print(f"Error during export: {e}") - sys.exit(1) - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file