diff --git a/.gitignore b/.gitignore index db08835..8e2eb1c 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,6 @@ create_test_user.sql Taskfile.yml .env.local /.logs +/logs +webapp_logs.zip +/.schemas diff --git a/docs/user_groups_issue.md b/docs/user_groups_issue.md new file mode 100644 index 0000000..b71176c --- /dev/null +++ b/docs/user_groups_issue.md @@ -0,0 +1,189 @@ +# User Groups Discrepancy Issue + +## Summary + +We've identified a discrepancy between the database state and API responses when retrieving user groups. When a user requests their groups, they're not receiving all groups where they are members. Specifically, when user ID `1062` requests their groups, they should receive 3 groups (IDs 22, 25, and 28), but only receive 2 groups (IDs 25 and 28). + +## Investigation Details + +### Debug Logs + +The logs show only two groups being returned: + +``` +2025-05-21 21:38:17,604 - src.routers.user_groups - DEBUG - Fetching groups for user 1062... +2025-05-21 21:38:17,604 - src.database.user_groups - DEBUG - Getting user groups with users for user_id: 1062 +2025-05-21 21:38:17,604 - src.database.user_groups - DEBUG - Fetching groups where user 1062 is a member... +2025-05-21 21:38:18,656 - src.database.user_groups - DEBUG - Found 1 groups where user 1062 is a member: [28] +2025-05-21 21:38:18,656 - src.database.user_groups - DEBUG - Fetching groups where user 1062 is an admin... +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Found 1 groups where user 1062 is an admin: [25] +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Added 1 additional groups as admin (not already as member): [25] +2025-05-21 21:38:19,606 - src.database.user_groups - DEBUG - Total: Found 2 user groups with users for user_id: 1062, Group IDs: [25, 28] +``` + +### Database State + +Our database queries confirm that the user should be in 3 groups: + +```sql +SELECT id, name, user_ids, admin_ids +FROM user_groups +WHERE 1062 = ANY(user_ids) OR 1062 = ANY(admin_ids) +ORDER BY id; +``` + +Result: +``` + id | name | user_ids | admin_ids +----+-------------+-----------------+----------- + 22 | UNDP Studio | {1062,774,1067} | {} + 25 | GROUP 3 | {1062} | {1062} + 28 | test4 | {774,1062} | {774} +``` + +### Code Analysis + +The code in `user_groups.py` uses the correct SQL syntax to retrieve groups where a user is a member or admin: + +1. First query: `SELECT ... FROM user_groups WHERE %s = ANY(user_ids) ORDER BY created_at DESC;` +2. Second query: `SELECT ... FROM user_groups WHERE %s = ANY(admin_ids) ORDER BY created_at DESC;` + +These queries should correctly find all groups, but the first query is only returning group 28, not both 22 and 28 as expected. + +## Impact + +Users may not see all groups they belong to in the application, which could lead to: + +1. Reduced access to signals shared in "missing" groups +2. Confusion about group membership +3. Workflow disruptions if users expect to find signals in specific groups + +## Possible Causes + +1. SQL query execution issues +2. Application-level filtering not visible in the code +3. A caching or stale data issue +4. Transaction isolation level issues +5. Potential race condition if groups are being modified simultaneously + +## Fix Implemented + +We've implemented a comprehensive solution with multiple layers of improvements: + +### 1. Enhanced Primary Functions + +Modified the approach in the affected functions to use a single combined query with explicit array handling: + +```python +# Run a direct SQL query to ensure array type handling is consistent +query = """ +SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at +FROM + user_groups +WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) +ORDER BY + created_at DESC; +""" + +await cursor.execute(query, (user_id, user_id)) +``` + +We also added explicit type conversion when checking for user membership: + +```python +# Track membership rigorously by explicitly converting IDs to integers +is_member = False +if data['user_ids']: + is_member = user_id in [int(uid) for uid in data['user_ids']] + +is_admin = False +if data['admin_ids']: + is_admin = user_id in [int(aid) for aid in data['admin_ids']] +``` + +### 2. Debug Functions + +Added a `debug_user_groups` function that runs multiple direct queries to diagnose issues: + +```python +async def debug_user_groups(cursor: AsyncCursor, user_id: int) -> dict: + # Various direct SQL queries to check database state + # and array position functions + # ... +``` + +### 3. Fallback Implementation + +Created a completely separate direct SQL implementation in `user_groups_direct.py` as a fallback: + +```python +async def get_user_groups_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroup]: + """ + Get all groups that a user is a member of or an admin of using direct SQL. + """ + # Simple, direct SQL with minimal processing + # ... +``` + +### 4. Automatic Fallback in API + +Modified the user groups router to automatically detect and handle discrepancies: + +```python +# Check if there's a mismatch between direct query and regular function +if missing_ids: + logger.warning(f"MISMATCH! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation") + user_groups = await user_groups_direct.get_user_groups_with_users_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups)} groups") +``` + +These changes ensure: +- More reliable querying of user group memberships +- Better debug information if issues persist +- Automatic fallback to a simpler implementation if needed +- More detailed logging throughout the process + +After these changes, users should see all groups where they are members or admins consistently. + +## Additional Fix: Signal Entity can_edit Attribute + +During testing, we discovered that the Signal entity was missing the `can_edit` attribute that's dynamically added in user group contexts. This caused AttributeError exceptions when trying to access `signal.can_edit`. + +### Issue +```python +AttributeError: 'Signal' object has no attribute 'can_edit' +``` + +### Solution +Added the `can_edit` field to the Signal entity definition: + +```python +can_edit: bool = Field( + default=False, + description="Whether the current user can edit this signal (set dynamically based on group membership and collaboration).", +) +``` + +This ensures that: +- The Signal model accepts the `can_edit` attribute when created +- The attribute defaults to `False` for signals that don't have edit permissions +- Both the regular and direct SQL implementations can properly set this attribute +- Frontend code can safely access `signal.can_edit` without errors + +The fix has been applied to both endpoints: +1. `/user-groups/me` (user groups without signals) +2. `/user-groups/me/with-signals` (user groups with signals) + +Both endpoints now include the same debug checks and automatic fallback to direct SQL if discrepancies are detected. \ No newline at end of file diff --git a/main.py b/main.py index 04f4a80..fab43eb 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware from src import routers from src.authentication import authenticate_user @@ -23,6 +24,9 @@ # Get application version app_version = os.environ.get("RELEASE_VERSION", "dev") app_env = os.environ.get("ENVIRONMENT", "development") +# Override environment setting if in local mode +if os.environ.get("ENV_MODE") == "local": + app_env = "local" logging.info(f"Starting application - version: {app_version}, environment: {app_env}") # Configure Bugsnag for error tracking @@ -90,14 +94,61 @@ async def global_exception_handler(request: Request, exc: Exception): content={"detail": "Internal server error"}, ) -# allow cors -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# Configure CORS - simplified for local development +local_origins = [ + "http://localhost:5175", + "http://127.0.0.1:5175", + "http://localhost:3000", + "http://127.0.0.1:3000" +] + +# Create a custom middleware class for handling CORS +class CORSHandlerMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Handle OPTIONS preflight requests + if request.method == "OPTIONS": + headers = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "access_token, Authorization, Content-Type, Accept, X-API-Key", + "Access-Control-Allow-Credentials": "true", + "Access-Control-Max-Age": "600", # Cache preflight for 10 minutes + } + + # Set specific origin if in local mode + origin = request.headers.get("origin") + if os.environ.get("ENV_MODE") == "local" and origin: + headers["Access-Control-Allow-Origin"] = origin + + return JSONResponse(content={}, status_code=200, headers=headers) + + # Process all other requests normally + response = await call_next(request) + return response + +# Apply custom CORS middleware BEFORE the standard CORS middleware +app.add_middleware(CORSHandlerMiddleware) + +# Standard CORS middleware (as a backup) +if os.environ.get("ENV_MODE") == "local": + logging.info(f"Local mode: using specific CORS origins: {local_origins}") + app.add_middleware( + CORSMiddleware, + allow_origins=local_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*", "access_token", "Authorization", "Content-Type"], + expose_headers=["*"], + ) +else: + # Production mode - use more restrictive CORS + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*", "access_token", "Authorization", "Content-Type"], + ) # Add Bugsnag exception handling middleware # Important: Add middleware AFTER registering exception handlers @@ -137,5 +188,11 @@ async def test_error(): else: return {"status": "disabled", "message": "Bugsnag is not enabled"} +# Add special route for handling OPTIONS requests to /users/me +@app.options("/users/me", include_in_schema=False) +async def options_users_me(): + """Handle OPTIONS requests to /users/me specifically.""" + return {} + # Use the Bugsnag middleware wrapped app for ASGI app = bugsnag_app diff --git a/src/authentication.py b/src/authentication.py index da64752..89ba2ea 100644 --- a/src/authentication.py +++ b/src/authentication.py @@ -4,6 +4,7 @@ import logging import os +from typing import Dict, Any, Optional, cast import httpx import jwt @@ -18,7 +19,7 @@ api_key_header = APIKeyHeader( name="access_token", description="The API access token for the application.", - auto_error=True, + auto_error=False, ) @@ -67,10 +68,10 @@ async def get_jwk(token: str) -> jwt.PyJWK: jwks = await get_jwks() except httpx.HTTPError: jwks = {} - jwk = jwks.get(header["kid"]) - if jwk is None: + jwk_dict = jwks.get(header["kid"]) + if jwk_dict is None: raise ValueError("JWK could not be obtained or found") - jwk = jwt.PyJWK.from_dict(jwk, "RS256") + jwk = jwt.PyJWK.from_dict(jwk_dict, "RS256") return jwk @@ -108,7 +109,7 @@ async def decode_token(token: str) -> dict: async def authenticate_user( - token: str = Security(api_key_header), + token: Optional[str] = Security(api_key_header), cursor: AsyncCursor = Depends(db.yield_cursor), ) -> User: """ @@ -130,38 +131,78 @@ async def authenticate_user( user : User Pydantic model for a User object (if authentication succeeded). """ - logging.debug(f"Authenticating user with token") - if os.environ.get("TEST_USER_TOKEN"): - token = os.environ.get("TEST_USER_TOKEN") + logging.debug("Authenticating user with token") + + # For local development environment if os.environ.get("ENV_MODE") == "local": - # defaul user data - user_data = { - "email": "test.user@undp.org", - "name": "Test User", - "unit": "Data Futures Exchange (DFx)", - "acclab": False, - } + # Use test token if available + if os.environ.get("TEST_USER_TOKEN"): + test_token = os.environ.get("TEST_USER_TOKEN") + if test_token is not None: + token = test_token + + # Default user data for local development + local_email = os.environ.get("TEST_USER_EMAIL", "test.user@undp.org") + name = os.environ.get("TEST_USER_NAME", "Test User") + unit = os.environ.get("TEST_USER_UNIT", "Data Futures Exchange (DFx)") + acclab = os.environ.get("TEST_USER_ACCLAB", False) + + # Check for specific test tokens if token == "test-admin-token": - user_data["role"] = Role.ADMIN - return User(**user_data) + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.ADMIN + ) elif token == "test-user-token": - user_data["role"] = Role.USER - return User(**user_data) - - if token == os.environ.get("API_KEY"): + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.USER + ) + else: + # In local mode, if no valid token is provided, default to an admin user + logging.info("LOCAL MODE: No valid token provided, using default admin user") + return User( + email=local_email, + name=name, + unit=unit, + acclab=acclab, + role=Role.ADMIN + ) + + # Check for API key access + if token and token == os.environ.get("API_KEY"): if os.environ.get("ENV") == "dev": return User(email="name.surname@undp.org", role=Role.ADMIN) else: # dummy user object for anonymous access return User(email="name.surname@undp.org", role=Role.VISITOR) + + # If no token provided in non-local mode + if not token: + raise exceptions.not_authenticated + + # Try to decode and verify the token try: payload = await decode_token(token) except jwt.exceptions.PyJWTError as e: raise exceptions.not_authenticated from e - email, name = payload.get("unique_name"), payload.get("name") - if email is None or name is None: + + payload_email = payload.get("unique_name") + payload_name = payload.get("name") + + if payload_email is None or payload_name is None: raise exceptions.not_authenticated - if (user := await db.read_user_by_email(cursor, email)) is None: - user = User(email=email, role=Role.USER, name=name) + + email_str = str(payload_email) # Convert to string to satisfy type checker + name_str = str(payload_name) # Convert to string to satisfy type checker + + if (user := await db.read_user_by_email(cursor, email_str)) is None: + user = User(email=email_str, role=Role.USER, name=name_str) await db.create_user(cursor, user) return user diff --git a/src/bugsnag_config.py b/src/bugsnag_config.py index e7a32d6..7569c63 100644 --- a/src/bugsnag_config.py +++ b/src/bugsnag_config.py @@ -10,7 +10,7 @@ # Get API key from environment variable with fallback to the provided key BUGSNAG_API_KEY = os.environ.get("BUGSNAG_API_KEY") -ENVIRONMENT = os.environ.get("ENVIRONMENT", "development") +ENVIRONMENT = os.environ.get("ENVIRONMENT") RELEASE_VERSION = os.environ.get("RELEASE_VERSION", "dev") if not BUGSNAG_API_KEY: diff --git a/src/database/signals.py b/src/database/signals.py index efbab5f..5a64ba6 100644 --- a/src/database/signals.py +++ b/src/database/signals.py @@ -2,15 +2,20 @@ CRUD operations for signal entities. """ +import logging +from typing import List from psycopg import AsyncCursor, sql from .. import storage -from ..entities import Signal, SignalFilters, SignalPage, Status +from ..entities import Signal, SignalFilters, SignalPage, Status, SignalWithUserGroups, UserGroup + +logger = logging.getLogger(__name__) __all__ = [ "search_signals", "create_signal", "read_signal", + "read_signal_with_user_groups", "update_signal", "delete_signal", "read_user_signals", @@ -87,6 +92,12 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP AND (%(score)s IS NULL OR score = %(score)s) AND (%(unit)s IS NULL OR unit_region = %(unit)s OR unit_name = %(unit)s) AND (%(query)s IS NULL OR text_search_field @@ websearch_to_tsquery('english', %(query)s)) + AND (%(user_email)s IS NOT NULL AND ( + private = FALSE OR + created_by = %(user_email)s OR + %(is_admin)s = TRUE OR + %(is_staff)s = TRUE + )) ORDER BY {filters.order_by} {filters.direction} OFFSET @@ -102,10 +113,10 @@ async def search_signals(cursor: AsyncCursor, filters: SignalFilters) -> SignalP return page -async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: +async def create_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: List[int] = None) -> int: """ - Insert a signal into the database, connect it to trends and upload an attachment - to Azure Blob Storage if applicable. + Insert a signal into the database, connect it to trends, upload an attachment + to Azure Blob Storage if applicable, and add it to user groups if specified. Parameters ---------- @@ -114,77 +125,134 @@ async def create_signal(cursor: AsyncCursor, signal: Signal) -> int: signal : Signal A signal object to insert. The following fields are supported: - secondary_location: list[str] | None + user_group_ids : List[int], optional + List of user group IDs to add the signal to. Returns ------- signal_id : int An ID of the signal in the database. """ - query = """ - INSERT INTO signals ( - status, - created_by, - created_for, - modified_by, - headline, - description, - steep_primary, - steep_secondary, - signature_primary, - signature_secondary, - sdgs, - created_unit, - url, - relevance, - keywords, - location, - secondary_location, - score - ) - VALUES ( - %(status)s, - %(created_by)s, - %(created_for)s, - %(modified_by)s, - %(headline)s, - %(description)s, - %(steep_primary)s, - %(steep_secondary)s, - %(signature_primary)s, - %(signature_secondary)s, - %(sdgs)s, - %(created_unit)s, - %(url)s, - %(relevance)s, - %(keywords)s, - %(location)s, - %(secondary_location)s, - %(score)s - ) - RETURNING - id - ; - """ - await cursor.execute(query, signal.model_dump()) - row = await cursor.fetchone() - signal_id = row["id"] - - # add connected trends if any are present - for trend_id in signal.connected_trends or []: - query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" - await cursor.execute(query, (signal_id, trend_id, signal.created_by)) - - # upload an image + logger.info(f"Creating new signal with headline: '{signal.headline}', created by: {signal.created_by}") + logger.info(f"All Signal fields: {signal.model_dump()}") + if user_group_ids: + logger.info(f"Will add signal to user groups: {user_group_ids}") + + # Insert signal into database + try: + query = """ + INSERT INTO signals ( + status, + created_by, + created_for, + modified_by, + headline, + description, + steep_primary, + steep_secondary, + signature_primary, + signature_secondary, + sdgs, + created_unit, + url, + relevance, + keywords, + location, + secondary_location, + score, + private + ) + VALUES ( + %(status)s, + %(created_by)s, + %(created_for)s, + %(modified_by)s, + %(headline)s, + %(description)s, + %(steep_primary)s, + %(steep_secondary)s, + %(signature_primary)s, + %(signature_secondary)s, + %(sdgs)s, + %(created_unit)s, + %(url)s, + %(relevance)s, + %(keywords)s, + %(location)s, + %(secondary_location)s, + %(score)s, + %(private)s + ) + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + signal_id = row["id"] + logger.info(f"Signal created successfully with ID: {signal_id}") + except Exception as e: + logger.error(f"Failed to create signal: {e}") + raise + + # Add connected trends if any are present + try: + if signal.connected_trends: + logger.info(f"Adding connected trends for signal {signal_id}: {signal.connected_trends}") + for trend_id in signal.connected_trends: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + logger.info(f"Successfully added {len(signal.connected_trends)} trends to signal {signal_id}") + except Exception as e: + logger.error(f"Error adding connected trends to signal {signal_id}: {e}") + # Continue execution despite error with trends + + # Upload an image if provided if signal.attachment is not None: + logger.info(f"Uploading image attachment for signal {signal_id}") try: blob_url = await storage.upload_image( signal_id, "signals", signal.attachment ) - except Exception as e: - print(e) - else: query = "UPDATE signals SET attachment = %s WHERE id = %s;" await cursor.execute(query, (blob_url, signal_id)) + logger.info(f"Image attachment uploaded successfully for signal {signal_id}") + except Exception as e: + logger.error(f"Failed to upload image for signal {signal_id}: {e}") + # Continue execution despite attachment error + + # Add signal to user groups if specified + if user_group_ids: + logger.info(f"Processing user group assignments for signal {signal_id}") + from . import user_groups + groups_added = 0 + groups_failed = 0 + + for group_id in user_group_ids: + try: + logger.debug(f"Attempting to add signal {signal_id} to group {group_id}") + # Get the group + group = await user_groups.read_user_group(cursor, group_id) + if group is not None: + # Add signal to group's signal_ids + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + signal_ids.append(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_added += 1 + logger.info(f"Signal {signal_id} added to group {group_id} ({group.name})") + else: + logger.info(f"Signal {signal_id} already in group {group_id} ({group.name})") + else: + logger.warning(f"Group with ID {group_id} not found, skipping") + groups_failed += 1 + except Exception as e: + logger.error(f"Error adding signal {signal_id} to group {group_id}: {e}") + groups_failed += 1 + + logger.info(f"User group assignment complete for signal {signal_id}: {groups_added} successful, {groups_failed} failed") + return signal_id @@ -230,10 +298,85 @@ async def read_signal(cursor: AsyncCursor, uid: int) -> Signal | None: return Signal(**row) -async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: +async def read_signal_with_user_groups(cursor: AsyncCursor, uid: int) -> SignalWithUserGroups | None: + """ + Read a signal from the database with its associated user groups. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + uid : int + An ID of the signal to retrieve data for. + + Returns + ------- + SignalWithUserGroups | None + A signal with its user groups if it exists, otherwise None. + """ + logger.info(f"Fetching signal {uid} with its user groups") + + try: + # First, get the signal + signal = await read_signal(cursor, uid) + if signal is None: + logger.warning(f"Signal with ID {uid} not found") + return None + + logger.info(f"Found signal with ID {uid}: '{signal.headline}'") + + # Convert to SignalWithUserGroups + signal_with_groups = SignalWithUserGroups(**signal.model_dump()) + + # Get all groups that have this signal in their signal_ids + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map + FROM + user_groups + WHERE + %s = ANY(signal_ids) + ORDER BY + name; + """ + + await cursor.execute(query, (uid,)) + + # Add groups to the signal + from . import user_groups as ug_module + signal_with_groups.user_groups = [] + group_count = 0 + + async for row in cursor: + try: + # Convert row to dict + group_data = ug_module.handle_user_group_row(row) + # Create UserGroup from dict + group = UserGroup(**group_data) + signal_with_groups.user_groups.append(group) + group_count += 1 + logger.debug(f"Added group {group.id} ({group.name}) to signal {uid}") + except Exception as e: + logger.error(f"Error processing group for signal {uid}: {e}") + + logger.info(f"Signal {uid} is associated with {group_count} user groups") + return signal_with_groups + + except Exception as e: + logger.error(f"Error retrieving signal {uid} with user groups: {e}") + raise + + +async def update_signal(cursor: AsyncCursor, signal: Signal, user_group_ids: List[int] = None) -> int | None: """ Update a signal in the database, update its connected trends and update an attachment - in the Azure Blob Storage if applicable. + in the Azure Blob Storage if applicable. Optionally update the user groups the signal + belongs to. Parameters ---------- @@ -242,56 +385,154 @@ async def update_signal(cursor: AsyncCursor, signal: Signal) -> int | None: signal : Signal A signal object to update. The following fields are supported: - secondary_location: list[str] | None + user_group_ids : List[int], optional + List of user group IDs to add the signal to. Returns ------- int | None A signal ID if the update has been performed, otherwise None. """ - query = """ - UPDATE - signals - SET - status = COALESCE(%(status)s, status), - created_for = COALESCE(%(created_for)s, created_for), - modified_at = NOW(), - modified_by = %(modified_by)s, - headline = COALESCE(%(headline)s, headline), - description = COALESCE(%(description)s, description), - steep_primary = COALESCE(%(steep_primary)s, steep_primary), - steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), - signature_primary = COALESCE(%(signature_primary)s, signature_primary), - signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), - sdgs = COALESCE(%(sdgs)s, sdgs), - created_unit = COALESCE(%(created_unit)s, created_unit), - url = COALESCE(%(url)s, url), - relevance = COALESCE(%(relevance)s, relevance), - keywords = COALESCE(%(keywords)s, keywords), - location = COALESCE(%(location)s, location), - secondary_location = COALESCE(%(secondary_location)s, secondary_location), - score = COALESCE(%(score)s, score) - WHERE - id = %(id)s - RETURNING - id - ; - """ - await cursor.execute(query, signal.model_dump()) - row = await cursor.fetchone() - if row is None: - return None - signal_id = row["id"] + logger.info(f"Updating signal with ID: {signal.id}, modified by: {signal.modified_by}") + if user_group_ids is not None: + logger.info(f"Will update user groups for signal {signal.id}: {user_group_ids}") - # update connected trends if any are present - await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) - for trend_id in signal.connected_trends or []: - query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" - await cursor.execute(query, (signal_id, trend_id, signal.created_by)) - - # upload an image if it is not a URL to an existing image - blob_url = await storage.update_image(signal_id, "signals", signal.attachment) - query = "UPDATE signals SET attachment = %s WHERE id = %s;" - await cursor.execute(query, (blob_url, signal_id)) + # Update signal in database + try: + query = """ + UPDATE + signals + SET + status = COALESCE(%(status)s, status), + created_for = COALESCE(%(created_for)s, created_for), + modified_at = NOW(), + modified_by = %(modified_by)s, + headline = COALESCE(%(headline)s, headline), + description = COALESCE(%(description)s, description), + steep_primary = COALESCE(%(steep_primary)s, steep_primary), + steep_secondary = COALESCE(%(steep_secondary)s, steep_secondary), + signature_primary = COALESCE(%(signature_primary)s, signature_primary), + signature_secondary = COALESCE(%(signature_secondary)s, signature_secondary), + sdgs = COALESCE(%(sdgs)s, sdgs), + created_unit = COALESCE(%(created_unit)s, created_unit), + url = COALESCE(%(url)s, url), + relevance = COALESCE(%(relevance)s, relevance), + keywords = COALESCE(%(keywords)s, keywords), + location = COALESCE(%(location)s, location), + secondary_location = COALESCE(%(secondary_location)s, secondary_location), + score = COALESCE(%(score)s, score), + private = COALESCE(%(private)s, private) + WHERE + id = %(id)s + RETURNING + id + ; + """ + await cursor.execute(query, signal.model_dump()) + row = await cursor.fetchone() + if row is None: + logger.warning(f"Signal with ID {signal.id} not found for update") + return None + signal_id = row["id"] + logger.info(f"Signal {signal_id} updated successfully in database") + except Exception as e: + logger.error(f"Failed to update signal {signal.id}: {e}") + raise + + # Update connected trends + try: + logger.info(f"Updating connected trends for signal {signal_id}") + await cursor.execute("DELETE FROM connections WHERE signal_id = %s;", (signal_id,)) + logger.debug(f"Removed existing trend connections for signal {signal_id}") + + if signal.connected_trends: + trends_added = 0 + for trend_id in signal.connected_trends: + try: + query = "INSERT INTO connections (signal_id, trend_id, created_by) VALUES (%s, %s, %s);" + await cursor.execute(query, (signal_id, trend_id, signal.created_by)) + trends_added += 1 + except Exception as trend_e: + logger.warning(f"Failed to connect trend {trend_id} to signal {signal_id}: {trend_e}") + + logger.info(f"Added {trends_added} trend connections to signal {signal_id}") + except Exception as e: + logger.error(f"Error updating trend connections for signal {signal_id}: {e}") + # Continue execution despite trend connection errors + + # Update image attachment + try: + logger.info(f"Updating image attachment for signal {signal_id}") + blob_url = await storage.update_image(signal_id, "signals", signal.attachment) + if blob_url is not None: + query = "UPDATE signals SET attachment = %s WHERE id = %s;" + await cursor.execute(query, (blob_url, signal_id)) + logger.info(f"Image attachment updated successfully for signal {signal_id}") + else: + logger.info(f"No image attachment update needed for signal {signal_id}") + except Exception as e: + logger.error(f"Failed to update image for signal {signal_id}: {e}") + # Continue execution despite attachment error + + # Update signal's user groups if specified + if user_group_ids is not None: + logger.info(f"Processing user group updates for signal {signal_id}") + try: + from . import user_groups + + # Get all groups that currently have this signal + query = """ + SELECT id, name + FROM user_groups + WHERE %s = ANY(signal_ids); + """ + await cursor.execute(query, (signal_id,)) + current_groups = {} + async for row in cursor: + current_groups[row["id"]] = row["name"] + + logger.info(f"Signal {signal_id} is currently in groups: {list(current_groups.keys())}") + + # Remove signal from groups not in user_group_ids + groups_removed = 0 + groups_to_remove_from = [g for g in current_groups.keys() if g not in user_group_ids] + for group_id in groups_to_remove_from: + try: + logger.debug(f"Removing signal {signal_id} from group {group_id} ({current_groups[group_id]})") + group = await user_groups.read_user_group(cursor, group_id) + if group is not None and signal_id in group.signal_ids: + signal_ids = group.signal_ids.copy() + signal_ids.remove(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_removed += 1 + logger.info(f"Signal {signal_id} removed from group {group_id} ({group.name})") + except Exception as e: + logger.error(f"Failed to remove signal {signal_id} from group {group_id}: {e}") + + # Add signal to new groups + groups_added = 0 + for group_id in user_group_ids: + if group_id not in current_groups: + try: + logger.debug(f"Adding signal {signal_id} to group {group_id}") + group = await user_groups.read_user_group(cursor, group_id) + if group is not None: + signal_ids = group.signal_ids or [] + if signal_id not in signal_ids: + signal_ids.append(signal_id) + group.signal_ids = signal_ids + await user_groups.update_user_group(cursor, group) + groups_added += 1 + logger.info(f"Signal {signal_id} added to group {group_id} ({group.name})") + else: + logger.warning(f"Group with ID {group_id} not found, skipping") + except Exception as e: + logger.error(f"Failed to add signal {signal_id} to group {group_id}: {e}") + + logger.info(f"User group assignments updated for signal {signal_id}: {groups_added} added, {groups_removed} removed") + except Exception as e: + logger.error(f"Error processing user group updates for signal {signal_id}: {e}") return signal_id @@ -328,6 +569,8 @@ async def read_user_signals( cursor: AsyncCursor, user_email: str, status: Status, + is_admin: bool = False, + is_staff: bool = False, ) -> list[Signal]: """ Read signals from the database using a user email and status filter. @@ -340,6 +583,10 @@ async def read_user_signals( An email of the user whose signals to read. status : Status A status of signals to filter by. + is_admin : bool, optional + Whether the user is an admin, by default False + is_staff : bool, optional + Whether the user is staff, by default False Returns ------- @@ -365,6 +612,9 @@ async def read_user_signals( created_by = %s AND status = %s ; """ + # Since this function is explicitly for reading a user's OWN signals, + # we don't need additional private/public filtering here. + # The user is the creator, so they should see all their signals regardless of privacy setting. await cursor.execute(query, (user_email, status)) rows = await cursor.fetchall() return [Signal(**row) for row in rows] diff --git a/src/database/user_groups.py b/src/database/user_groups.py index 21d7ecd..a729e17 100644 --- a/src/database/user_groups.py +++ b/src/database/user_groups.py @@ -3,8 +3,11 @@ """ from psycopg import AsyncCursor +import json +import logging +from typing import List, Union -from ..entities import UserGroup, Signal +from ..entities import UserGroup, Signal, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete __all__ = [ "create_user_group", @@ -18,8 +21,132 @@ "get_group_users", "get_user_groups_with_signals", "get_signal_group_collaborators", + "get_user_group_with_users", + "list_user_groups_with_users", + "get_user_groups_with_signals_and_users", + "get_user_groups_with_users_by_user_id", + "debug_user_groups", ] +logger = logging.getLogger(__name__) + +# SQL Query Constants +SQL_SELECT_USER_GROUP = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups +""" + +SQL_SELECT_USERS = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name +""" + +SQL_SELECT_SIGNALS = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id +""" + +def handle_user_group_row(row) -> dict: + """ + Helper function to safely extract user group data from a database row. + + Parameters + ---------- + row : dict or tuple + A database result row + + Returns + ------- + dict + A dictionary of user group data ready for creating a UserGroup + """ + data = {} + if isinstance(row, dict): + data['id'] = row["id"] + data['name'] = row["name"] + data['signal_ids'] = row["signal_ids"] or [] + data['user_ids'] = row["user_ids"] or [] + data['admin_ids'] = row["admin_ids"] or [] + collab_map = row["collaborator_map"] + data['created_at'] = row.get("created_at") + else: + data['id'] = row[0] + data['name'] = row[1] + data['signal_ids'] = row[2] or [] + data['user_ids'] = row[3] or [] + data['admin_ids'] = row[4] or [] + collab_map = row[5] + # Created at would be at index 6 if present + data['created_at'] = row[6] if len(row) > 6 else None + + # Handle collaborator_map field + data['collaborator_map'] = {} + if collab_map: + if isinstance(collab_map, str): + try: + data['collaborator_map'] = json.loads(collab_map) + except json.JSONDecodeError: + data['collaborator_map'] = {} + else: + data['collaborator_map'] = collab_map + + return data + +async def get_users_for_group(cursor: AsyncCursor, user_ids: List[int]) -> List[User]: + """ + Helper function to fetch user details for a group. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_ids : List[int] + List of user IDs to fetch. + + Returns + ------- + List[User] + List of User objects. + """ + users = [] + if user_ids: + await cursor.execute(SQL_SELECT_USERS, (user_ids,)) + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + return users + async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: """ @@ -37,32 +164,46 @@ async def create_user_group(cursor: AsyncCursor, group: UserGroup) -> int: int The ID of the created user group. """ + # Convert model to dict and ensure collaborator_map is a JSON string + name = group.name + signal_ids = group.signal_ids + user_ids = group.user_ids + admin_ids = group.admin_ids + collaborator_map = json.dumps(group.collaborator_map) + query = """ INSERT INTO user_groups ( name, signal_ids, user_ids, + admin_ids, collaborator_map ) VALUES ( - %(name)s, - %(signal_ids)s, - %(user_ids)s, - %(collaborator_map)s + %s, + %s, + %s, + %s, + %s ) RETURNING id ; """ - await cursor.execute(query, group.model_dump(exclude={"id"})) + await cursor.execute(query, (name, signal_ids, user_ids, admin_ids, collaborator_map)) row = await cursor.fetchone() if row is None: raise ValueError("Failed to create user group") - group_id = row[0] + + # Access the ID safely + if isinstance(row, dict): + group_id = row["id"] + else: + group_id = row[0] return group_id -async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | None: +async def read_user_group(cursor: AsyncCursor, group_id: int, fetch_details: bool = False) -> UserGroup | UserGroupWithUsers | None: """ Read a user group from the database. @@ -72,33 +213,28 @@ async def read_user_group(cursor: AsyncCursor, group_id: int) -> UserGroup | Non An async database cursor. group_id : int The ID of the user group to read. + fetch_details : bool, optional + If True, fetches detailed user information for each group member, + by default False Returns ------- - UserGroup | None - The user group if found, otherwise None. - """ - query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - WHERE - id = %s - ; + UserGroup | UserGroupWithUsers | None + The user group if found (with optional detailed user data), otherwise None. """ + query = f"{SQL_SELECT_USER_GROUP} WHERE id = %s;" await cursor.execute(query, (group_id,)) if (row := await cursor.fetchone()) is None: return None - - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) - - return UserGroup(**data) + + data = handle_user_group_row(row) + + if fetch_details: + # Fetch detailed user information if requested + users = await get_users_for_group(cursor, data['user_ids']) + return UserGroupWithUsers(**data, users=users) + else: + return UserGroup(**data) async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None: @@ -117,22 +253,31 @@ async def update_user_group(cursor: AsyncCursor, group: UserGroup) -> int | None int | None The ID of the updated user group if successful, otherwise None. """ + # Convert model to dict and ensure collaborator_map is a JSON string + group_data = group.model_dump() + group_data["collaborator_map"] = json.dumps(group_data["collaborator_map"]) + query = """ UPDATE user_groups SET name = %(name)s, signal_ids = %(signal_ids)s, user_ids = %(user_ids)s, + admin_ids = %(admin_ids)s, collaborator_map = %(collaborator_map)s WHERE id = %(id)s RETURNING id ; """ - await cursor.execute(query, group.model_dump()) + await cursor.execute(query, group_data) if (row := await cursor.fetchone()) is None: return None - return row[0] + # Access the ID safely + if isinstance(row, dict): + return row["id"] + else: + return row[0] async def delete_user_group(cursor: AsyncCursor, group_id: int) -> bool: @@ -171,25 +316,12 @@ async def list_user_groups(cursor: AsyncCursor) -> list[UserGroup]: list[UserGroup] A list of all user groups. """ - query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - ORDER BY - name - ; - """ + query = f"{SQL_SELECT_USER_GROUP} ORDER BY created_at DESC;" await cursor.execute(query) result = [] async for row in cursor: - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) + data = handle_user_group_row(row) result.append(UserGroup(**data)) return result @@ -260,44 +392,57 @@ async def remove_user_from_group(cursor: AsyncCursor, group_id: int, user_id: in bool True if the user was removed, False otherwise. """ - # Get current user_ids + # Check if the group exists and get its current state await cursor.execute("SELECT user_ids, collaborator_map FROM user_groups WHERE id = %s;", (group_id,)) row = await cursor.fetchone() if row is None: return False user_ids = row[0] if row[0] is not None else [] - collaborator_map = row[1] if row[1] is not None else {} + + # Parse collaborator_map from JSON if needed + if row[1] is not None and isinstance(row[1], str): + try: + collaborator_map = json.loads(row[1]) + except json.JSONDecodeError: + collaborator_map = {} + else: + collaborator_map = row[1] if row[1] is not None else {} + + if user_id not in user_ids: + return True # Already not in the group # Remove user from user_ids - if user_id in user_ids: - user_ids.remove(user_id) - - # Remove user from collaborator_map - for signal_id, users in list(collaborator_map.items()): - if user_id in users: - users.remove(user_id) - if not users: # If empty, remove signal from map - del collaborator_map[signal_id] - - query = """ - UPDATE user_groups - SET - user_ids = %s, - collaborator_map = %s - WHERE id = %s - RETURNING id - ; - """ - await cursor.execute(query, (user_ids, collaborator_map, group_id)) - return await cursor.fetchone() is not None + user_ids.remove(user_id) - return False + # Remove user from collaborator_map + for signal_id, users in list(collaborator_map.items()): + if user_id in users: + users.remove(user_id) + + if not users: + del collaborator_map[signal_id] + + # Update the group + query = """ + UPDATE user_groups + SET + user_ids = %s, + collaborator_map = %s + WHERE id = %s + RETURNING id + ; + """ + # Convert collaborator_map to JSON string + collaborator_map_json = json.dumps(collaborator_map) + await cursor.execute(query, (user_ids, collaborator_map_json, group_id)) + + return await cursor.fetchone() is not None async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: """ - Get all groups that a user is a member of. + Get all groups that a user is a member of or an admin of. Parameters ---------- @@ -311,29 +456,76 @@ async def get_user_groups(cursor: AsyncCursor, user_id: int) -> list[UserGroup]: list[UserGroup] A list of user groups. """ + logger.debug("Getting user groups for user_id: %s", user_id) + + # Run a direct SQL query to ensure array type handling is consistent + logger.debug("Fetching all groups where user %s is a member or admin...", user_id) query = """ - SELECT - id, - name, - signal_ids, - user_ids, - collaborator_map - FROM - user_groups - WHERE - %s = ANY(user_ids) - ORDER BY - name - ; + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY + created_at DESC; """ - await cursor.execute(query, (user_id,)) + + await cursor.execute(query, (user_id, user_id)) + + # Process results + group_ids_seen = set() result = [] + member_groups = [] + admin_groups = [] + + # Debug + row_count = 0 async for row in cursor: - # Convert row to dict - data = dict(zip(["id", "name", "signal_ids", "user_ids", "collaborator_map"], row)) - result.append(UserGroup(**data)) + row_count += 1 + data = handle_user_group_row(row) + group_id = data['id'] + + # Debug + logger.debug("Processing group ID: %s, Name: %s", group_id, data['name']) + logger.debug("Group user_ids: %s", data['user_ids']) + logger.debug("Group admin_ids: %s", data['admin_ids']) + + # Track membership rigorously by explicitly converting IDs to integers + is_member = False + if data['user_ids']: + is_member = user_id in [int(uid) for uid in data['user_ids']] + + is_admin = False + if data['admin_ids']: + is_admin = user_id in [int(aid) for aid in data['admin_ids']] + + if is_member: + member_groups.append(group_id) + logger.debug("User %s is a member of group %s", user_id, group_id) + if is_admin: + admin_groups.append(group_id) + logger.debug("User %s is an admin of group %s", user_id, group_id) + + # Only add each group once + if group_id not in group_ids_seen: + group_ids_seen.add(group_id) + result.append(UserGroup(**data)) + logger.debug("Raw query returned %s rows", row_count) + logger.debug("Found %s groups where user %s is a member: %s", + len(member_groups), user_id, member_groups) + logger.debug("Found %s groups where user %s is an admin: %s", + len(admin_groups), user_id, admin_groups) + logger.debug("Total: Found %s user groups for user_id: %s, Group IDs: %s", + len(result), user_id, list(group_ids_seen)) return result @@ -365,7 +557,7 @@ async def get_group_users(cursor: AsyncCursor, group_id: int) -> list[int]: return row[0] if row and row[0] else [] -async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> list[dict]: +async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int, fetch_users: bool = False) -> List[Union[UserGroupWithSignals, UserGroupComplete]]: """ Get all groups that a user is a member of, along with the associated signals data. @@ -375,57 +567,75 @@ async def get_user_groups_with_signals(cursor: AsyncCursor, user_id: int) -> lis An async database cursor. user_id : int The ID of the user. + fetch_users : bool, optional + If True, also fetches detailed user information for each group member, + by default False Returns ------- - list[dict] - A list of dictionaries containing user group and signal data. + list[Union[UserGroupWithSignals, UserGroupComplete]] + A list of user groups with associated signals and optional user details. """ + logger.debug("Getting user groups with signals for user_id: %s", user_id) + # First get the groups the user belongs to user_groups = await get_user_groups(cursor, user_id) result = [] - + # For each group, fetch the signals data for group in user_groups: group_data = group.model_dump() signals = [] - + # Get signals for this group if group.signal_ids: - signals_query = """ - SELECT - s.*, - array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends - FROM - signals s - LEFT JOIN - connections c ON s.id = c.signal_id - WHERE - s.id = ANY(%s) - GROUP BY - s.id - ORDER BY - s.id - ; - """ - await cursor.execute(signals_query, (group.signal_ids,)) - - async for row in cursor: - signal_dict = dict(row) - # Check if user is a collaborator for this signal - can_edit = False - signal_id_str = str(signal_dict["id"]) - - if group.collaborator_map and signal_id_str in group.collaborator_map: - if user_id in group.collaborator_map[signal_id_str]: - can_edit = True - - signal_dict["can_edit"] = can_edit - signals.append(Signal(**signal_dict)) - - group_data["signals"] = signals - result.append(group_data) - + logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) + + # Import read_signal function directly + from .signals import read_signal + + # Get each signal individually + for signal_id in group.signal_ids: + signal = await read_signal(cursor, signal_id) + if signal: + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_id) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + # Create a copy of the signal with can_edit flag + signal_dict = signal.model_dump() + signal_dict["can_edit"] = can_edit + signals.append(Signal(**signal_dict)) + + # Fetch full user details if requested + users_list: List[User] = [] + if fetch_users and group.user_ids: + # Ensure user_ids are integers for get_users_for_group + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users_list = await get_users_for_group(cursor, typed_user_ids) + + # Create a UserGroupWithSignals instance + if fetch_users: + # If we fetched user details, create a UserGroupComplete + group_with_data: Union[UserGroupWithSignals, UserGroupComplete] = UserGroupComplete( + **group_data, + signals=signals, + users=users_list + ) + else: + # Otherwise create a UserGroupWithSignals + group_with_data = UserGroupWithSignals( + **group_data, + signals=signals + ) + + result.append(group_with_data) + + logger.debug("Found %s user groups with signals for user_id: %s", len(result), user_id) return result @@ -460,8 +670,340 @@ async def get_signal_group_collaborators(cursor: AsyncCursor, signal_id: int) -> collaborators = set() async for row in cursor: - if row[0]: # Access first column using integer index - for user_id in row[0]: + # Safely access collaborators + if isinstance(row, dict): + collab_data = row['collaborators'] + else: + collab_data = row[0] + + if collab_data: + for user_id in collab_data: collaborators.add(user_id) return list(collaborators) + + +async def get_user_group_with_users(cursor: AsyncCursor, group_id: int) -> UserGroupWithUsers | None: + """ + Get a user group with detailed user information for each member. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + group_id : int + The ID of the user group. + + Returns + ------- + UserGroupWithUsers | None + A user group with detailed user information, or None if the group doesn't exist. + """ + logger.debug("Getting user group with users for group_id: %s", group_id) + + # First, get the user group + group = await read_user_group(cursor, group_id) + if group is None: + logger.warning("User group not found with id: %s", group_id) + return None + + # Convert to dict for modification + group_data = group.model_dump() + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) + + # Create a UserGroupWithUsers instance + return UserGroupWithUsers(**group_data, users=users) + + +async def list_user_groups_with_users(cursor: AsyncCursor) -> list[UserGroupWithUsers]: + """ + List all user groups with detailed user information. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + + Returns + ------- + list[UserGroupWithUsers] + A list of user groups with detailed user information. + """ + logger.debug("Listing all user groups with users") + + # Get all user groups + groups = await list_user_groups(cursor) + result = [] + + # For each group, get user details + for group in groups: + group_data = group.model_dump() + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) + + # Create a UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("Listed %s user groups with users", len(result)) + return result + + +async def get_user_groups_with_signals_and_users(cursor: AsyncCursor, user_id: int) -> list[UserGroupComplete]: + """ + Get all groups that a user is a member of, along with the associated signals and users data. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[UserGroupComplete] + A list of user groups with signals and users data. + """ + logger.debug("Getting user groups with signals and users for user_id: %s", user_id) + + # First get the groups the user belongs to + user_groups = await get_user_groups(cursor, user_id) + result = [] + + # For each group, fetch the signals data and user data + for group in user_groups: + group_data = group.model_dump() + signals = [] + + # Get signals for this group + if group.signal_ids: + logger.debug("Fetching signals for group_id: %s, signal_ids: %s", group.id, group.signal_ids) + await cursor.execute(SQL_SELECT_SIGNALS, (group.signal_ids,)) + + signal_count = 0 + async for row in cursor: + signal_dict = dict(row) + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_dict["id"]) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + signal_dict["can_edit"] = can_edit + + # Create Signal instance + signal = Signal(**signal_dict) + signals.append(signal) + signal_count += 1 + + logger.debug("Found %s signals for group_id: %s", signal_count, group.id) + + # Get users for this group + typed_user_ids = [] + if group.user_ids: + typed_user_ids = [int(uid) for uid in group.user_ids if isinstance(uid, (int, str)) and str(uid).isdigit()] + users = await get_users_for_group(cursor, typed_user_ids) + + # Create a UserGroupComplete instance + group_complete = UserGroupComplete(**group_data, signals=signals, users=users) + result.append(group_complete) + + logger.debug("Found %s user groups with signals and users for user_id: %s", len(result), user_id) + return result + + +async def get_user_groups_with_users_by_user_id(cursor: AsyncCursor, user_id: int) -> list[UserGroupWithUsers]: + """ + Get all groups that a user is a member of or an admin of, along with detailed user information + for each group member. + This is a more focused version that only fetches user data, not signals. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + list[UserGroupWithUsers] + A list of user groups with detailed user information. + """ + logger.debug("Getting user groups with users for user_id: %s", user_id) + + # Run a direct raw SQL query to improve reliability when dealing with array types + # This directly checks for user_id in the arrays without relying on array operators + logger.debug("Fetching all groups where user %s is a member or admin...", user_id) + + query = """ + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY + created_at DESC; + """ + + await cursor.execute(query, (user_id, user_id)) + + # Process results + group_ids_seen = set() + result = [] + member_groups = [] + admin_groups = [] + + # Debug + row_count = 0 + + async for row in cursor: + row_count += 1 + group_data = handle_user_group_row(row) + group_id = group_data['id'] + + # Debug + logger.debug("Processing group ID: %s, Name: %s", group_id, group_data['name']) + logger.debug("Group user_ids: %s", group_data['user_ids']) + logger.debug("Group admin_ids: %s", group_data['admin_ids']) + + # Track membership rigorously + is_member = False + if group_data['user_ids']: + is_member = user_id in [int(uid) for uid in group_data['user_ids']] + + is_admin = False + if group_data['admin_ids']: + is_admin = user_id in [int(aid) for aid in group_data['admin_ids']] + + if is_member: + member_groups.append(group_id) + logger.debug("User %s is a member of group %s", user_id, group_id) + if is_admin: + admin_groups.append(group_id) + logger.debug("User %s is an admin of group %s", user_id, group_id) + + # Only add each group once + if group_id not in group_ids_seen: + group_ids_seen.add(group_id) + users = await get_users_for_group(cursor, group_data['user_ids']) + + # Create a UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("Raw query returned %s rows", row_count) + logger.debug("Found %s groups where user %s is a member: %s", + len(member_groups), user_id, member_groups) + logger.debug("Found %s groups where user %s is an admin: %s", + len(admin_groups), user_id, admin_groups) + logger.debug("Total: Found %s user groups with users for user_id: %s, Group IDs: %s", + len(result), user_id, list(group_ids_seen)) + return result + +async def debug_user_groups(cursor: AsyncCursor, user_id: int) -> dict: + """ + Debug function to directly query all user group information for a specific user. + This bypasses all processing logic to get raw data from the database. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + dict + A dictionary of raw debug information about the user's groups. + """ + logger.debug("=== DEBUGGING USER GROUPS for user_id: %s ===", user_id) + + # Direct query with no array handling + query1 = """ + SELECT id, name, user_ids, admin_ids + FROM user_groups + WHERE %s = ANY(user_ids) OR %s = ANY(admin_ids) + ORDER BY id; + """ + + await cursor.execute(query1, (user_id, user_id)) + rows1 = [] + async for row in cursor: + rows1.append(dict(row)) + + # Member-only query + query2 = """ + SELECT id, name, user_ids + FROM user_groups + WHERE %s = ANY(user_ids) + ORDER BY id; + """ + + await cursor.execute(query2, (user_id,)) + rows2 = [] + async for row in cursor: + rows2.append(dict(row)) + + # Admin-only query + query3 = """ + SELECT id, name, admin_ids + FROM user_groups + WHERE %s = ANY(admin_ids) + ORDER BY id; + """ + + await cursor.execute(query3, (user_id,)) + rows3 = [] + async for row in cursor: + rows3.append(dict(row)) + + # Check if PostgreSQL sees the value in the arrays using array_position + query4 = """ + SELECT + id, + name, + array_position(user_ids, %s) as user_position, + array_position(admin_ids, %s) as admin_position + FROM + user_groups + WHERE + id IN (22, 25, 28); + """ + + await cursor.execute(query4, (user_id, user_id)) + rows4 = [] + async for row in cursor: + rows4.append(dict(row)) + + result = { + "combined_query": rows1, + "member_query": rows2, + "admin_query": rows3, + "array_position_check": rows4 + } + + logger.debug("=== DEBUG RESULTS ===") + logger.debug("Combined query found %s groups: %s", len(rows1), [r['id'] for r in rows1]) + logger.debug("Member query found %s groups: %s", len(rows2), [r['id'] for r in rows2]) + logger.debug("Admin query found %s groups: %s", len(rows3), [r['id'] for r in rows3]) + logger.debug("Array position check results: %s", rows4) + logger.debug("=== END DEBUG ===") + + return result diff --git a/src/database/user_groups_direct.py b/src/database/user_groups_direct.py new file mode 100644 index 0000000..4a99c23 --- /dev/null +++ b/src/database/user_groups_direct.py @@ -0,0 +1,265 @@ +""" +Alternative direct SQL implementation for user group functions. + +This module provides direct SQL implementations of key user group functions +that bypass the normal processing logic to ensure reliable results. + +These functions should only be used in case of persistent issues with +the standard implementations in user_groups.py. +""" + +import logging +from typing import List +from psycopg import AsyncCursor + +from ..entities import UserGroup, User, UserGroupWithUsers + +logger = logging.getLogger(__name__) + +async def get_user_groups_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroup]: + """ + Get all groups that a user is a member of or an admin of using direct SQL. + + This implementation uses the simplest possible SQL and minimal processing + to maximize reliability. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List[UserGroup] + A list of user groups. + """ + logger.debug("DIRECT SQL: Getting user groups for user_id: %s", user_id) + + # Direct SQL with minimal processing + query = """ + WITH user_groups_for_user AS ( + SELECT + id, + name, + signal_ids, + user_ids, + admin_ids, + collaborator_map, + created_at + FROM + user_groups + WHERE + %s = ANY(user_ids) OR %s = ANY(admin_ids) + ) + SELECT * FROM user_groups_for_user + ORDER BY created_at DESC; + """ + + await cursor.execute(query, (user_id, user_id)) + + result = [] + row_count = 0 + + async for row in cursor: + row_count += 1 + # Convert row to dictionary + data = dict(row) + # Convert empty arrays to empty lists + if data['user_ids'] is None: + data['user_ids'] = [] + if data['admin_ids'] is None: + data['admin_ids'] = [] + if data['signal_ids'] is None: + data['signal_ids'] = [] + + # Create UserGroup instance + group = UserGroup(**data) + result.append(group) + logger.debug("DIRECT SQL: Found group ID: %s, Name: %s", group.id, group.name) + + logger.debug("DIRECT SQL: Query returned %s groups", row_count) + return result + +async def get_users_by_ids(cursor: AsyncCursor, user_ids: List[int]) -> List[User]: + """ + Get user details for a list of user IDs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_ids : List[int] + List of user IDs. + + Returns + ------- + List[User] + List of User objects. + """ + if not user_ids: + return [] + + query = """ + SELECT + id, + email, + role, + name, + unit, + acclab, + created_at + FROM + users + WHERE + id = ANY(%s) + ORDER BY + name; + """ + + await cursor.execute(query, (user_ids,)) + users = [] + + async for row in cursor: + user_data = dict(row) + users.append(User(**user_data)) + + return users + +async def get_signals_by_ids(cursor: AsyncCursor, signal_ids: List[int]) -> List[dict]: + """ + Get signal details for a list of signal IDs. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + signal_ids : List[int] + List of signal IDs. + + Returns + ------- + List[dict] + List of signal dictionaries. + """ + if not signal_ids: + return [] + + query = """ + SELECT + s.*, + array_agg(c.trend_id) FILTER (WHERE c.trend_id IS NOT NULL) AS connected_trends + FROM + signals s + LEFT JOIN + connections c ON s.id = c.signal_id + WHERE + s.id = ANY(%s) + GROUP BY + s.id + ORDER BY + s.id; + """ + + await cursor.execute(query, (signal_ids,)) + signals = [] + + async for row in cursor: + signal_data = dict(row) + signals.append(signal_data) + + return signals + +async def get_user_groups_with_users_direct(cursor: AsyncCursor, user_id: int) -> List[UserGroupWithUsers]: + """ + Get all groups that a user is a member of or an admin of, with user details, using direct SQL. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List[UserGroupWithUsers] + A list of user groups with user details. + """ + logger.debug("DIRECT SQL: Getting user groups with users for user_id: %s", user_id) + + # First, get the groups + groups = await get_user_groups_direct(cursor, user_id) + result = [] + + # For each group, fetch the users + for group in groups: + group_data = group.model_dump() + users = await get_users_by_ids(cursor, group.user_ids) + + # Create UserGroupWithUsers instance + group_with_users = UserGroupWithUsers(**group_data, users=users) + result.append(group_with_users) + + logger.debug("DIRECT SQL: Returning %s groups with users", len(result)) + return result + +async def get_user_groups_with_signals_direct(cursor: AsyncCursor, user_id: int) -> List: + """ + Get all groups that a user is a member of or an admin of, with signals, using direct SQL. + + Parameters + ---------- + cursor : AsyncCursor + An async database cursor. + user_id : int + The ID of the user. + + Returns + ------- + List + A list of user groups with signals and users data. + """ + from ..entities import Signal, UserGroupComplete + + logger.debug("DIRECT SQL: Getting user groups with signals for user_id: %s", user_id) + + # First, get the groups + groups = await get_user_groups_direct(cursor, user_id) + result = [] + + # For each group, fetch the signals and users + for group in groups: + group_data = group.model_dump() + signals = [] + + # Get signals for this group if it has any + if group.signal_ids: + signal_data_list = await get_signals_by_ids(cursor, group.signal_ids) + + for signal_data in signal_data_list: + # Check if user is a collaborator for this signal + can_edit = False + signal_id_str = str(signal_data["id"]) + + if group.collaborator_map and signal_id_str in group.collaborator_map: + if user_id in group.collaborator_map[signal_id_str]: + can_edit = True + + # Add can_edit attribute to signal data + signal_data["can_edit"] = can_edit + + # Create Signal instance + signal = Signal(**signal_data) + signals.append(signal) + + # Get users for this group + users = await get_users_by_ids(cursor, group.user_ids) + + # Create UserGroupComplete instance + group_complete = UserGroupComplete(**group_data, signals=signals, users=users) + result.append(group_complete) + + logger.debug("DIRECT SQL: Returning %s groups with signals", len(result)) + return result \ No newline at end of file diff --git a/src/entities/parameters.py b/src/entities/parameters.py index 50ccc32..86451d5 100644 --- a/src/entities/parameters.py +++ b/src/entities/parameters.py @@ -69,6 +69,10 @@ class SignalFilters(BaseFilters): bureau: str | None = Field(default=None) score: Score | None = Field(default=None) unit: str | None = Field(default=None) + user_email: str | None = Field(default=None) + user_id: int | None = Field(default=None) + is_admin: bool = Field(default=False) + is_staff: bool = Field(default=False) class TrendFilters(BaseFilters): diff --git a/src/entities/signal.py b/src/entities/signal.py index 86360c3..3bf161c 100644 --- a/src/entities/signal.py +++ b/src/entities/signal.py @@ -2,13 +2,17 @@ Entity (model) definitions for signal objects. """ -from typing import List, Dict +from typing import List, Dict, TYPE_CHECKING, Any, Optional from pydantic import ConfigDict, Field, field_validator, model_validator from . import utils from .base import BaseEntity -__all__ = ["Signal"] +# Import only for type checking to avoid circular imports +if TYPE_CHECKING: + from .user_groups import UserGroup + +__all__ = ["Signal", "SignalWithUserGroups", "SignalCreate", "SignalUpdate"] class Signal(BaseEntity): @@ -39,9 +43,13 @@ class Signal(BaseEntity): description="Whether the current user has favorited this signal.", ) is_draft: bool = Field( - default=True, + default=False, description="Whether the signal is in draft state or published.", ) + private: bool = Field( + default=False, + description="Whether the signal is private. Private signals are only visible to their creator, collaborators, and admins.", + ) group_ids: List[int] | None = Field( default=None, description="List of user group IDs associated with this signal.", @@ -50,6 +58,10 @@ class Signal(BaseEntity): default=None, description="List of user IDs who can collaborate on this signal.", ) + can_edit: bool = Field( + default=False, + description="Whether the current user can edit this signal (set dynamically based on group membership and collaboration).", + ) @model_validator(mode='before') @classmethod @@ -71,11 +83,126 @@ def convert_secondary_location(cls, data): "location": "Global", "favorite": False, "is_draft": True, + "private": False, + "group_ids": [1, 2], + "collaborators": [1, 2, 3], + "secondary_location": ["Africa", "Asia"], + "score": None, + "connected_trends": [101, 102], + } + } + ) + + +class SignalWithUserGroups(Signal): + """ + Extended signal entity that includes the user groups it belongs to. + This model is used in API responses to provide a signal with its associated user groups. + """ + + user_groups: List[Any] = Field( + default_factory=list, + description="List of user groups this signal belongs to." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "created_unit": "HQ", + "url": "https://undp.medium.com/the-cost-of-corruption-a827306696fb", + "relevance": "Of the approximately US$13 trillion that governments spend on public spending, up to 25 percent is lost to corruption.", + "keywords": ["economy", "governance"], + "location": "Global", + "favorite": False, + "is_draft": True, + "private": False, "group_ids": [1, 2], "collaborators": [1, 2, 3], "secondary_location": ["Africa", "Asia"], "score": None, "connected_trends": [101, 102], + "user_groups": [ + { + "id": 1, + "name": "Research Team", + "signal_ids": [1, 2, 3], + "user_ids": [101, 102], + "admin_ids": [101], + "collaborator_map": {"1": [101, 102]} + }, + { + "id": 2, + "name": "Policy Team", + "signal_ids": [1, 4, 5], + "user_ids": [103, 104], + "admin_ids": [103], + "collaborator_map": {"1": [103]} + } + ] + } + } + ) + + +class SignalCreate(Signal): + """ + Model for signal creation request that includes user_group_ids. + This is used for the request body in the POST endpoint. + """ + + status: Optional[utils.Status] = Field( + default=utils.Status.NEW, + description="Current signal review status. Defaults to NEW if not provided.", + ) + user_group_ids: Optional[List[int]] = Field( + default=None, + description="IDs of user groups to add the signal to after creation" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "headline": "New Signal Example", + "description": "This is a new signal with user groups.", + "steep_primary": "T", + "steep_secondary": ["S", "P"], + "signature_primary": "Shift", + "signature_secondary": ["Risk"], + "keywords": ["example", "test"], + "location": "Global", + "private": False, + "user_group_ids": [1, 2] + } + } + ) + + +class SignalUpdate(Signal): + """ + Model for signal update request that includes user_group_ids. + This is used for the request body in the PUT endpoint. + """ + + user_group_ids: Optional[List[int]] = Field( + default=None, + description="IDs of user groups to replace the signal's current group associations" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "headline": "Updated Signal Example", + "description": "This is an updated signal with new user groups.", + "steep_primary": "T", + "steep_secondary": ["S", "P"], + "signature_primary": "Shift", + "signature_secondary": ["Risk"], + "keywords": ["updated", "test"], + "location": "Global", + "private": True, + "user_group_ids": [2, 3] } } ) diff --git a/src/entities/user_groups.py b/src/entities/user_groups.py index 4599326..95135a6 100644 --- a/src/entities/user_groups.py +++ b/src/entities/user_groups.py @@ -2,13 +2,17 @@ Entity (model) definitions for user group objects. """ -from typing import Dict, List +from typing import Dict, List, TYPE_CHECKING, Any from pydantic import ConfigDict, Field from .base import BaseEntity -from .signal import Signal +from .user import User -__all__ = ["UserGroup", "UserGroupWithSignals"] +# Import only for type checking to avoid circular imports +if TYPE_CHECKING: + from .signal import Signal + +__all__ = ["UserGroup", "UserGroupWithSignals", "UserGroupWithUsers", "UserGroupComplete"] class UserGroup(BaseEntity): @@ -22,9 +26,13 @@ class UserGroup(BaseEntity): default_factory=list, description="List of signal IDs associated with this group." ) - user_ids: List[int] = Field( + user_ids: List[str | int] = Field( + default_factory=list, + description="List of user IDs (integers) or emails (strings) who are members of this group." + ) + admin_ids: List[int] = Field( default_factory=list, - description="List of user IDs who are members of this group." + description="List of user IDs who have admin privileges for this group." ) collaborator_map: Dict[str, List[int]] = Field( default_factory=dict, @@ -38,6 +46,7 @@ class UserGroup(BaseEntity): "name": "CDO", "signal_ids": [1, 2, 3], "user_ids": [1, 2, 3], + "admin_ids": [1], "collaborator_map": { "1": [1, 2], "2": [1, 3], @@ -51,7 +60,7 @@ class UserGroup(BaseEntity): class UserGroupWithSignals(UserGroup): """User group with associated signals data.""" - signals: List[Signal] = Field( + signals: List[Any] = Field( default_factory=list, description="List of signals associated with this group." ) @@ -78,3 +87,81 @@ class UserGroupWithSignals(UserGroup): } } ) + + +class UserGroupWithUsers(UserGroup): + """User group with associated users data.""" + + users: List[User] = Field( + default_factory=list, + description="List of users who are members of this group." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + }, + "users": [ + { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe" + } + ] + } + } + ) + + +class UserGroupComplete(UserGroup): + """User group with both associated signals and users data.""" + + signals: List[Any] = Field( + default_factory=list, + description="List of signals associated with this group." + ) + + users: List[User] = Field( + default_factory=list, + description="List of users who are members of this group." + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": 1, + "name": "CDO", + "signal_ids": [1, 2, 3], + "user_ids": [1, 2, 3], + "collaborator_map": { + "1": [1, 2], + "2": [1, 3], + "3": [2, 3] + }, + "signals": [ + { + "id": 1, + "headline": "Signal 1", + "can_edit": True + } + ], + "users": [ + { + "id": 1, + "email": "john.doe@undp.org", + "role": "Curator", + "name": "John Doe" + } + ] + } + } + ) diff --git a/src/routers/signals.py b/src/routers/signals.py index 4bee7e5..24ec7a6 100644 --- a/src/routers/signals.py +++ b/src/routers/signals.py @@ -13,7 +13,10 @@ from .. import exceptions, genai, utils from ..authentication import authenticate_user from ..dependencies import require_admin, require_creator, require_curator, require_user -from ..entities import Role, Signal, SignalFilters, SignalPage, Status, User, UserGroup +from ..entities import ( + Role, Signal, SignalFilters, SignalPage, Status, User, UserGroup, + SignalWithUserGroups, SignalCreate, SignalUpdate +) logger = logging.getLogger(__name__) @@ -27,19 +30,32 @@ async def search_signals( cursor: AsyncCursor = Depends(db.yield_cursor), ): """Search signals in the database using pagination and filters.""" + # Add user info to filters for private signal handling + filters.user_email = user.email + filters.user_id = user.id + filters.is_admin = user.is_admin + filters.is_staff = user.is_staff + page = await db.search_signals(cursor, filters) return page.sanitise(user) -@router.get("/export", response_model=None, dependencies=[Depends(require_curator)]) +@router.get("/export", response_model=None) async def export_signals( filters: Annotated[SignalFilters, Query()], + user: User = Depends(require_curator), cursor: AsyncCursor = Depends(db.yield_cursor), ): """ Export signals that match the filters from the database. You can export up to 10k rows at once. """ + # Add user info to filters for private signal handling + filters.user_email = user.email + filters.user_id = user.id + filters.is_admin = user.is_admin + filters.is_staff = user.is_staff + page = await db.search_signals(cursor, filters) # prettify the data @@ -82,19 +98,60 @@ async def generate_signal( @router.post("", response_model=Signal, status_code=201) async def create_signal( - signal: Signal, + signal_data: SignalCreate, user: User = Depends(require_user), cursor: AsyncCursor = Depends(db.yield_cursor), ): """ Submit a signal to the database. If the signal has a base64 encoded image attachment, it will be uploaded to Azure Blob Storage. + + Optionally, the signal can be added to one or more user groups by specifying + user_group_ids in the request body. """ - signal.created_by = user.email - signal.modified_by = user.email - signal.created_unit = user.unit - signal_id = await db.create_signal(cursor, signal) - return await db.read_signal(cursor, signal_id) + logger.info(f"Creating new signal requested by user: {user.email}") + + # Extract standard Signal fields and user_group_ids + signal_dict = signal_data.model_dump(exclude={"user_group_ids"}) + + # Ensure status is "New" if None was provided + if signal_dict.get("status") is None: + signal_dict["status"] = Status.NEW + + signal = Signal(**signal_dict) + user_group_ids = signal_data.user_group_ids or [] + + if user_group_ids: + logger.info(f"With user_group_ids: {user_group_ids}") + + try: + # Prepare the signal object + signal.created_by = user.email + signal.modified_by = user.email + signal.created_unit = user.unit + + # Create the signal in the database with user groups if specified + signal_id = await db.create_signal(cursor, signal, user_group_ids) + logger.info(f"Signal created successfully with ID: {signal_id}") + + # Read back the created signal to return it + created_signal = await db.read_signal(cursor, signal_id) + if not created_signal: + logger.error(f"Failed to read newly created signal with ID: {signal_id}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Signal was created but could not be retrieved" + ) + + return created_signal + + except Exception as e: + logger.error(f"Error creating signal: {str(e)}") + # Raise HTTPException with appropriate status code + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create signal: {str(e)}" + ) @router.get("/me", response_model=list[Signal]) @@ -106,7 +163,7 @@ async def read_my_signals( """ Retrieve signal with a given status submitted by the current user. """ - return await db.read_user_signals(cursor, user.email, status) + return await db.read_user_signals(cursor, user.email, status, user.is_admin, user.is_staff) @router.get("/{uid}", response_model=Signal) @@ -127,6 +184,28 @@ async def read_signal( logger.info("Retrieved signal: %s", signal.model_dump()) + # Check for permission to view private signals + if signal.private and not (user.is_admin or user.is_staff or signal.created_by == user.email): + # Check if user is a collaborator + is_collaborator = False + + # Check direct collaborator + collaborators = await db.get_signal_collaborators(cursor, uid) + if user.email in collaborators: + is_collaborator = True + + # Check group collaborator + if not is_collaborator and await db.can_user_edit_signal(cursor, uid, user.id): + is_collaborator = True + + if not is_collaborator: + logger.warning( + "Permission denied - user %s trying to access private signal %s", + user.email, uid + ) + raise exceptions.permission_denied + + # Check for visitor permission with status if user.role == Role.VISITOR and signal.status != Status.APPROVED: logger.warning( "Permission denied - visitor trying to access non-approved signal. Status: %s", @@ -145,20 +224,143 @@ async def read_signal( return signal +@router.get("/{uid}/with-user-groups", response_model=SignalWithUserGroups) +async def read_signal_with_user_groups( + uid: Annotated[int, Path(description="The ID of the signal to retrieve")], + user: User = Depends(authenticate_user), + cursor: AsyncCursor = Depends(db.yield_cursor), +): + """ + Retrieve a signal from the database along with all user groups it belongs to. + This endpoint provides a comprehensive view of a signal including its group associations. + """ + logger.info("Reading signal with user groups for ID: %s, user: %s", uid, user.email) + + try: + # Fetch signal with user groups + signal = await db.read_signal_with_user_groups(cursor, uid) + + if signal is None: + logger.warning("Signal not found with ID: %s", uid) + raise exceptions.not_found + + # Check for permission to view private signals + if signal.private and not (user.is_admin or user.is_staff or signal.created_by == user.email): + # Check if user is a collaborator + is_collaborator = False + + # Check direct collaborator + collaborators = await db.get_signal_collaborators(cursor, uid) + if user.email in collaborators: + is_collaborator = True + + # Check group collaborator + if not is_collaborator and await db.can_user_edit_signal(cursor, uid, user.id): + is_collaborator = True + + if not is_collaborator: + logger.warning( + "Permission denied - user %s trying to access private signal %s", + user.email, uid + ) + raise exceptions.permission_denied + + # Check visitor permissions + if user.role == Role.VISITOR and signal.status != Status.APPROVED: + logger.warning( + "Permission denied - visitor trying to access non-approved signal. Status: %s", + signal.status + ) + raise exceptions.permission_denied + + # Check if the signal is favorited by the user + try: + is_favorite = await db.is_signal_favorited(cursor, user.email, uid) + signal.favorite = is_favorite + logger.debug(f"Favorite status for signal {uid}, user {user.email}: {is_favorite}") + except Exception as fav_e: + logger.error(f"Failed to check favorite status: {str(fav_e)}") + # Continue even if favorite check fails + signal.favorite = False + + logger.info(f"Successfully retrieved signal {uid} with {len(signal.user_groups)} user groups") + return signal + + except exceptions.not_found: + raise + except exceptions.permission_denied: + raise + except Exception as e: + logger.error(f"Error retrieving signal {uid} with user groups: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to retrieve signal with user groups: {str(e)}" + ) + + @router.put("/{uid}", response_model=Signal) async def update_signal( uid: Annotated[int, Path(description="The ID of the signal to be updated")], - signal: Signal, + signal_data: SignalUpdate, user: User = Depends(require_creator), cursor: AsyncCursor = Depends(db.yield_cursor), ): - """Update a signal in the database.""" - if uid != signal.id: - raise exceptions.id_mismatch - signal.modified_by = user.email - if (signal_id := await db.update_signal(cursor, signal)) is None: - raise exceptions.not_found - return await db.read_signal(cursor, signal_id) + """ + Update a signal in the database. + + Optionally, the signal's user group associations can be updated by specifying + user_group_ids in the request body. If provided, the signal will only belong + to the specified groups, replacing any previous group associations. + """ + logger.info(f"Updating signal {uid} requested by user: {user.email}") + + # Extract signal data and user_group_ids from the request body + signal = Signal(**signal_data.model_dump(exclude={"user_group_ids"})) + user_group_ids = signal_data.user_group_ids or [] + + if user_group_ids is not None: + logger.info(f"With user_group_ids: {user_group_ids}") + + try: + # Verify ID match + if uid != signal.id: + logger.warning(f"ID mismatch: URL ID {uid} doesn't match payload ID {signal.id}") + raise exceptions.id_mismatch + + # Update metadata + signal.modified_by = user.email + + # Update the signal in the database + logger.info(f"Updating signal {uid} in database") + signal_id = await db.update_signal(cursor, signal, user_group_ids) + + if signal_id is None: + logger.warning(f"Signal {uid} not found during update") + raise exceptions.not_found + + logger.info(f"Signal {uid} updated successfully") + + # Read back the updated signal + updated_signal = await db.read_signal(cursor, signal_id) + if not updated_signal: + logger.error(f"Failed to read updated signal with ID: {signal_id}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Signal was updated but could not be retrieved" + ) + + return updated_signal + + except exceptions.id_mismatch: + raise + except exceptions.not_found: + raise + except Exception as e: + logger.error(f"Error updating signal {uid}: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update signal: {str(e)}" + ) @router.delete("/{uid}", response_model=Signal, dependencies=[Depends(require_creator)]) @@ -226,7 +428,7 @@ async def add_signal_collaborator( ) # Add collaborator - if not await db.add_collaborator(cursor, uid, user_id): + if not await db.add_collaborator(cursor, uid, str(user_id)): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid collaborator or signal", @@ -260,7 +462,7 @@ async def remove_signal_collaborator( ) # Remove collaborator - if not await db.remove_collaborator(cursor, uid, user_id): + if not await db.remove_collaborator(cursor, uid, str(user_id)): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Collaborator not found or signal does not exist", diff --git a/src/routers/user_groups.py b/src/routers/user_groups.py index ab27d7f..b332910 100644 --- a/src/routers/user_groups.py +++ b/src/routers/user_groups.py @@ -4,7 +4,7 @@ import logging import bugsnag -from typing import Annotated, List, Optional, Union +from typing import Annotated, List, Optional, Union, Dict from fastapi import APIRouter, Depends, Path, Body, Query, HTTPException, Request from psycopg import AsyncCursor @@ -13,8 +13,10 @@ from .. import database as db from .. import exceptions from ..dependencies import require_admin, require_user -from ..entities import UserGroup, User, UserGroupWithSignals +from ..entities import UserGroup, User, UserGroupWithSignals, UserGroupWithUsers, UserGroupComplete, Signal from ..authentication import authenticate_user +from ..database.signals import read_signal +from ..database import user_groups_direct # Set up logger for this module logger = logging.getLogger(__name__) @@ -28,6 +30,27 @@ class UserGroupCreate(BaseModel): users: Optional[List[str]] = None +class UserGroupUpdate(BaseModel): + id: int + name: str + signal_ids: List[int] = [] + user_ids: List[Union[str, int]] = [] # Can be either user IDs or email addresses + collaborator_map: Dict[str, List[int]] = {} + created_at: Optional[str] = None + status: Optional[str] = None + created_by: Optional[str] = None + created_for: Optional[str] = None + modified_at: Optional[str] = None + modified_by: Optional[str] = None + headline: Optional[str] = None + attachment: Optional[str] = None + steep_primary: Optional[str] = None + steep_secondary: Optional[str] = None + signature_primary: Optional[str] = None + signature_secondary: Optional[str] = None + sdgs: Optional[List[str]] = None + + class UserEmailIdentifier(BaseModel): email: str @@ -76,12 +99,16 @@ async def list_user_groups( cursor: AsyncCursor = Depends(db.yield_cursor), ): """List all user groups.""" + logger.info(f"Endpoint called: list_user_groups - URL: {request.url} - Method: {request.method}") try: + logger.debug("Fetching all user groups from database...") groups = await db.list_user_groups(cursor) - logger.info(f"Listed {len(groups)} user groups") + logger.info(f"Successfully listed {len(groups)} user groups") + logger.debug(f"Returning group IDs: {[g.id for g in groups]}") return groups except Exception as e: logger.error(f"Error listing user groups: {str(e)}") + logger.exception("Detailed traceback for listing user groups:") bugsnag.notify( e, metadata={ @@ -101,21 +128,54 @@ async def get_my_user_groups( cursor: AsyncCursor = Depends(db.yield_cursor), ): """ - Get all user groups that the current user is a member of. + Get all user groups that the current user is a member of or an admin of. This endpoint is accessible to all authenticated users. """ + logger.info(f"Endpoint called: get_my_user_groups - URL: {request.url} - Method: {request.method}") + logger.debug(f"User requesting their groups: {user.id} ({user.email})") + try: if not user.id: logger.warning("User ID not found in get_my_user_groups") raise exceptions.not_found - # Get groups this user is a member of - user_groups = await db.get_user_groups(cursor, user.id) - logger.info(f"User {user.id} retrieved {len(user_groups)} groups") + # Run debug queries first + logger.debug(f"Running debug queries for user {user.id}...") + debug_info = await db.debug_user_groups(cursor, user.id) + + # Check if we can directly extract groups from debug info + direct_group_ids = [] + for row in debug_info["combined_query"]: + direct_group_ids.append(row["id"]) + + logger.debug(f"Direct query found group IDs: {direct_group_ids}") + + # Now get the full groups with user details + logger.debug(f"Fetching groups for user {user.id}...") + user_groups = await db.get_user_groups_with_users_by_user_id(cursor, user.id) + + # Check if there's a mismatch + fetched_ids = [g.id for g in user_groups] + missing_ids = [gid for gid in direct_group_ids if gid not in fetched_ids] + + if missing_ids: + logger.warning(f"MISMATCH! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation") + user_groups = await user_groups_direct.get_user_groups_with_users_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups)} groups") + + logger.info(f"User {user.id} ({user.email}) retrieved {len(user_groups)} groups (as member or admin)") + if user_groups: + logger.debug(f"Returning group IDs: {[g.id for g in user_groups]}") + logger.debug(f"Group names: {[g.name for g in user_groups]}") return user_groups except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error in get_my_user_groups: {str(e)}") + logger.exception(f"Detailed traceback for get_my_user_groups (user {user.id}):") bugsnag.notify( e, metadata={ @@ -132,37 +192,92 @@ async def get_my_user_groups( raise -@router.get("/me/with-signals", response_model=List[UserGroupWithSignals]) +@router.get("/me/with-signals", response_model=List[Union[UserGroupWithSignals, UserGroupComplete]]) async def get_my_user_groups_with_signals( request: Request, user: User = Depends(authenticate_user), cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(True, description="If true, includes detailed user information for each group member (defaults to true)") ): """ - Get all user groups that the current user is a member of along with their signals data. - + Get all user groups that the current user is a member of or an admin of along with their signals data. + This enhanced endpoint provides detailed information about each signal associated with the groups, including whether the current user has edit permissions for each signal. This is useful for: - + - Displaying a dashboard of all signals a user can access through their groups - Showing which signals the user can edit vs. view-only - Building collaborative workflows where users can see their assigned signals - + The response includes all signal details plus a `can_edit` flag for each signal indicating if the current user has edit permissions based on the group's collaborator_map. + + By default, detailed user information for each group member is included. Set the `include_users` + parameter to false to get only basic user ID references without detailed user information. """ + logger.info(f"Endpoint called: get_my_user_groups_with_signals - URL: {request.url} - Method: {request.method}") + logger.debug(f"User requesting groups with signals: {user.id} ({user.email})") + logger.debug(f"Query parameters: include_users={include_users}") + try: if not user.id: logger.warning("User ID not found in get_my_user_groups_with_signals") raise exceptions.not_found + + # Run debug queries first to check for discrepancies + logger.debug(f"Running debug queries for user {user.id}...") + debug_info = await db.debug_user_groups(cursor, user.id) + + # Check if we can directly extract groups from debug info + direct_group_ids = [] + for row in debug_info["combined_query"]: + direct_group_ids.append(row["id"]) + + logger.debug(f"Direct query found group IDs: {direct_group_ids}") + + logger.debug(f"Fetching groups with signals for user {user.id}...") + # Get groups with signals for this user, optionally including full user details + user_groups_with_signals = await db.get_user_groups_with_signals( + cursor, + user.id, + fetch_users=include_users + ) + + # Check if there's a mismatch and fall back to direct implementation if needed + fetched_ids = [g.id for g in user_groups_with_signals] + missing_ids = [gid for gid in direct_group_ids if gid not in fetched_ids] + + if missing_ids: + logger.warning(f"MISMATCH in signals endpoint! Direct query found groups {direct_group_ids} but function returned only {fetched_ids}") + logger.warning(f"Missing groups: {missing_ids}") + + # Fall back to direct SQL implementation + logger.warning("Falling back to direct SQL implementation for signals") + user_groups_with_signals = await user_groups_direct.get_user_groups_with_signals_direct(cursor, user.id) + logger.info(f"Direct SQL implementation returned {len(user_groups_with_signals)} groups with signals") + + logger.info(f"User {user.id} ({user.email}) retrieved {len(user_groups_with_signals)} groups with signals") + if user_groups_with_signals: + logger.debug(f"Returning group IDs: {[g.id for g in user_groups_with_signals]}") + + # Log total signals count across all groups + total_signals = sum(len(g.signals) for g in user_groups_with_signals) + logger.debug(f"Total signals across all groups: {total_signals}") + + # Log collaborator access details + editable_signals = 0 + for group in user_groups_with_signals: + for signal in group.signals: + if signal.can_edit: + editable_signals += 1 + + logger.debug(f"User can edit {editable_signals} out of {total_signals} signals") - # Get groups with signals for this user - user_groups_with_signals = await db.get_user_groups_with_signals(cursor, user.id) - logger.info(f"User {user.id} retrieved {len(user_groups_with_signals)} groups with signals") return user_groups_with_signals except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error in get_my_user_groups_with_signals: {str(e)}") + logger.exception(f"Detailed traceback for get_my_user_groups_with_signals (user {user.id}):") bugsnag.notify( e, metadata={ @@ -173,40 +288,119 @@ async def get_my_user_groups_with_signals( "user": { "id": user.id if user else None, "email": user.email if user else None + }, + "query_params": { + "include_users": include_users } } ) raise -@router.post("", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.post("", response_model=Union[UserGroup, UserGroupWithUsers], dependencies=[Depends(require_admin)]) async def create_user_group( request: Request, group_data: UserGroupCreate, + current_user: User = Depends(authenticate_user), # Get the current user cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(False, description="If true, includes detailed user information for each group member"), + admins: List[str] = Query(None, description="List of user emails to set as admins in the group") ): - """Create a new user group.""" + """ + Create a new user group. + + Optionally include detailed user information for each group member in the response + by setting the `include_users` query parameter to true. + + The current authenticated user is automatically added as both a member and an admin of the group. + Additional admin users can be specified using the `admins` query parameter. + """ + logger.info(f"Endpoint called: create_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Creating user group with name: '{group_data.name}'") + logger.debug(f"Query parameters: include_users={include_users}, admins={admins}") + logger.debug(f"Current user: ID={current_user.id}, email={current_user.email}") + try: # Create the base group group = UserGroup(name=group_data.name) + logger.debug(f"Created base group entity with name: '{group.name}'") + + # Initialize user_ids list with the current user's ID + user_ids = [] + admin_ids = [] + if current_user.id: + user_ids.append(current_user.id) + admin_ids.append(current_user.id) # Make current user an admin + logger.debug(f"Added current user (ID: {current_user.id}) as member and admin") + # Handle email addresses if provided + user_emails_added = [] if group_data.users: - user_ids = [] + logger.debug(f"Processing {len(group_data.users)} user emails from request body") for email in group_data.users: + logger.debug(f"Looking up user with email: {email}") user = await db.read_user_by_email(cursor, email) - if user and user.id: + if user and user.id and user.id not in user_ids: # Avoid duplicates user_ids.append(user.id) + user_emails_added.append(email) + logger.debug(f"Added user {user.id} ({email}) as member") + else: + if not user: + logger.warning(f"User with email {email} not found") + elif user.id in user_ids: + logger.debug(f"User {user.id} ({email}) already added to members list") - if user_ids: - group.user_ids = user_ids - + logger.debug(f"Added {len(user_emails_added)} users as members from request body: {user_emails_added}") + + # Handle admin emails if provided + admin_emails_added = [] + if admins: + logger.debug(f"Processing {len(admins)} admin emails from query parameters") + for email in admins: + logger.debug(f"Looking up admin with email: {email}") + user = await db.read_user_by_email(cursor, email) + if user and user.id: + if user.id not in user_ids: # If not already in user_ids, add them + user_ids.append(user.id) + logger.debug(f"Added user {user.id} ({email}) as member") + + if user.id not in admin_ids: # Avoid duplicates in admin_ids + admin_ids.append(user.id) + admin_emails_added.append(email) + logger.debug(f"Added user {user.id} ({email}) as admin") + else: + logger.debug(f"User {user.id} ({email}) already added to admins list") + else: + logger.warning(f"Admin with email {email} not found") + + logger.debug(f"Added {len(admin_emails_added)} users as admins from query params: {admin_emails_added}") + + if user_ids: + group.user_ids = user_ids + logger.debug(f"Group has {len(user_ids)} members: {user_ids}") + + if admin_ids: + group.admin_ids = admin_ids + logger.debug(f"Group has {len(admin_ids)} admins: {admin_ids}") + # Create the group + logger.debug("Creating group in database...") group_id = await db.create_user_group(cursor, group) - logger.info(f"Created user group {group_id} with name '{group.name}'") - return await db.read_user_group(cursor, group_id) + logger.info(f"Created user group {group_id} with name '{group.name}', {len(user_ids)} users, and {len(admin_ids)} admins") + + # Retrieve and return the created group + logger.debug(f"Retrieving created group {group_id} from database...") + created_group = await db.read_user_group(cursor, group_id, fetch_details=include_users) + if not created_group: + logger.error(f"Failed to retrieve newly created group with ID {group_id}") + raise exceptions.not_found + + logger.debug(f"Successfully retrieved created group {group_id}") + return created_group except Exception as e: logger.error(f"Error creating user group: {str(e)}") + logger.exception("Detailed traceback for creating user group:") bugsnag.notify( e, metadata={ @@ -216,29 +410,94 @@ async def create_user_group( }, "group_data": { "name": group_data.name, - "users_count": len(group_data.users) if group_data.users else 0 + "users_count": len(group_data.users) if group_data.users else 0, + "users": group_data.users if group_data.users else [], + "admins_count": len(admins) if admins else 0, + "admins": admins if admins else [], + "current_user_id": current_user.id if current_user else None, + "current_user_email": current_user.email if current_user else None } } ) raise -@router.get("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.get("/{group_id}", response_model=Union[UserGroup, UserGroupWithUsers, UserGroupComplete], dependencies=[Depends(require_admin)]) async def read_user_group( request: Request, group_id: Annotated[int, Path(description="The ID of the user group to retrieve")], cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(True, description="If true, includes detailed user and signal information (defaults to true)") ): - """Get a user group by ID.""" + """ + Get a user group by ID with detailed information. + + By default, includes detailed user and signal information. Set the `include_users` + parameter to false to get only the basic group data without user and signal details. + """ + logger.info(f"Endpoint called: read_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Reading user group ID: {group_id}") + logger.debug(f"Query parameters: include_users={include_users}") + try: - if (group := await db.read_user_group(cursor, group_id)) is None: + logger.debug(f"Fetching group {group_id} from database...") + # First, get the basic group + if (group := await db.read_user_group(cursor, group_id, fetch_details=include_users)) is None: logger.warning(f"User group {group_id} not found") raise exceptions.not_found - logger.info(f"Retrieved user group {group_id}") + + # Log basic info to avoid huge logs + logger.info(f"Retrieved user group {group_id} with name '{group.name}'") + + # Log detailed information about the group + logger.debug(f"Group details - ID: {group.id}, Name: '{group.name}'") + logger.debug(f"Group members: {len(group.user_ids) if group.user_ids else 0} users") + logger.debug(f"Group admins: {len(group.admin_ids) if group.admin_ids else 0} users") + logger.debug(f"Group signals: {len(group.signal_ids) if group.signal_ids else 0} signals") + + # Log collaborator map details + if group.collaborator_map: + logger.debug(f"Group has {len(group.collaborator_map)} signals with collaborators") + total_collaborators = sum(len(collaborators) for collaborators in group.collaborator_map.values()) + logger.debug(f"Total collaborator assignments: {total_collaborators}") + + # If include_users is true and the group has signals, fetch those signals too + if include_users and hasattr(group, 'user_ids') and group.user_ids and hasattr(group, 'signal_ids') and group.signal_ids: + logger.debug(f"Fetching detailed signals data for group {group_id}") + # Get signals for this group and prepare a complete response + signals = [] + + # Import the signals database function directly + from ..database.signals import read_signal + + # Fetch each signal individually + signal_count = 0 + for signal_id in group.signal_ids: + logger.debug(f"Fetching signal {signal_id}") + signal = await read_signal(cursor, signal_id) + if signal: + signals.append(signal) + signal_count += 1 + else: + logger.warning(f"Signal {signal_id} referenced by group {group_id} not found") + + logger.debug(f"Successfully fetched {signal_count} signals for group {group_id}") + + # Convert to a UserGroupComplete if we have both users and signals + if hasattr(group, 'users') and group.users: + logger.debug(f"Creating UserGroupComplete with {len(signals)} signals and {len(group.users)} users") + return UserGroupComplete( + **group.model_dump(), + signals=signals + ) + else: + logger.debug("Group lacks user details, returning without creating UserGroupComplete") + return group except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error reading user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for reading user group {group_id}:") bugsnag.notify( e, metadata={ @@ -246,32 +505,132 @@ async def read_user_group( "url": str(request.url), "method": request.method, }, - "group_id": group_id + "group_id": group_id, + "query_params": { + "include_users": include_users + } } ) raise -@router.put("/{group_id}", response_model=UserGroup, dependencies=[Depends(require_admin)]) +@router.put("/{group_id}", response_model=Union[UserGroup, UserGroupWithUsers, UserGroupComplete], dependencies=[Depends(require_admin)]) async def update_user_group( request: Request, group_id: Annotated[int, Path(description="The ID of the user group to update")], - group: UserGroup, + group_data: UserGroupUpdate, cursor: AsyncCursor = Depends(db.yield_cursor), + include_users: bool = Query(False, description="If true, includes detailed user information for each group member after update") ): - """Update a user group.""" + """ + Update a user group. + + Optionally include detailed user information for each group member in the response + by setting the `include_users` query parameter to true. + + This endpoint accepts both user IDs (integers) and email addresses (strings) in + the user_ids field. Email addresses will be automatically converted to user IDs. + """ + logger.info(f"Endpoint called: update_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Updating user group ID: {group_id}") + logger.debug(f"Update data: name='{group_data.name}', {len(group_data.user_ids)} users, {len(group_data.signal_ids)} signals") + logger.debug(f"Query parameters: include_users={include_users}") + try: - if group_id != group.id: - logger.warning(f"ID mismatch: path ID {group_id} != body ID {group.id}") + # Validate ID consistency + if group_id != group_data.id: + logger.warning(f"ID mismatch: path ID {group_id} != body ID {group_data.id}") raise exceptions.id_mismatch + + # Process user_ids field in case it contains emails instead of integer IDs + processed_user_ids = [] + email_conversions = [] + + logger.debug(f"Processing {len(group_data.user_ids)} user identifiers...") + for user_id in group_data.user_ids: + if isinstance(user_id, str): + # Check if it's an email (contains @ sign) + if '@' in user_id: + # This looks like an email address, try to find the user ID + logger.debug(f"Looking up user by email: {user_id}") + user = await db.read_user_by_email(cursor, user_id) + if user and user.id: + processed_user_ids.append(user.id) + email_conversions.append((user_id, user.id)) + logger.debug(f"Converted email {user_id} to user ID {user.id}") + else: + logger.warning(f"User with email {user_id} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_id} not found") + else: + # String but not an email, try to convert to int if it's a digit string + try: + numeric_id = int(user_id) + processed_user_ids.append(numeric_id) + logger.debug(f"Converted string '{user_id}' to integer {numeric_id}") + except (ValueError, TypeError): + logger.warning(f"Invalid user ID format: {user_id}") + raise HTTPException(status_code=400, detail=f"Invalid user ID format: {user_id}") + else: + # Already an int + processed_user_ids.append(user_id) + logger.debug(f"Using integer user ID {user_id} as is") + + logger.debug(f"Processed {len(processed_user_ids)} user IDs: {processed_user_ids}") + if email_conversions: + logger.debug(f"Converted {len(email_conversions)} emails to user IDs: {email_conversions}") + + # Log collaborator map details if present + if group_data.collaborator_map: + logger.debug(f"Group has {len(group_data.collaborator_map)} signals with collaborators") + total_collaborators = sum(len(collaborators) for collaborators in group_data.collaborator_map.values()) + logger.debug(f"Total collaborator assignments: {total_collaborators}") + + # Log detailed collaborator info + for signal_id, collaborators in group_data.collaborator_map.items(): + logger.debug(f"Signal {signal_id} has {len(collaborators)} collaborators: {collaborators}") + + # Convert UserGroupUpdate to UserGroup + logger.debug("Creating UserGroup entity from update data...") + group = UserGroup( + id=group_data.id, + name=group_data.name, + signal_ids=group_data.signal_ids, + user_ids=processed_user_ids, # Use the processed user IDs + collaborator_map=group_data.collaborator_map, + created_at=group_data.created_at, + status=group_data.status, + created_by=group_data.created_by, + created_for=group_data.created_for, + modified_at=group_data.modified_at, + modified_by=group_data.modified_by, + headline=group_data.headline, + attachment=group_data.attachment, + steep_primary=group_data.steep_primary, + steep_secondary=group_data.steep_secondary, + signature_primary=group_data.signature_primary, + signature_secondary=group_data.signature_secondary, + sdgs=group_data.sdgs + ) + logger.debug("UserGroup entity created successfully") + + # Update the group in the database + logger.debug(f"Updating group {group_id} in database...") if (updated_id := await db.update_user_group(cursor, group)) is None: logger.warning(f"User group {group_id} not found for update") raise exceptions.not_found - logger.info(f"Updated user group {updated_id}") - return await db.read_user_group(cursor, updated_id) + + logger.info(f"Successfully updated user group {updated_id}") + + # Fetch the updated group + logger.debug(f"Fetching updated group {updated_id} from database...") + updated_group = await db.read_user_group(cursor, updated_id, fetch_details=include_users) + logger.debug(f"Successfully retrieved updated group {updated_id}") + + return updated_group except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error updating user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for updating user group {group_id}:") bugsnag.notify( e, metadata={ @@ -281,10 +640,14 @@ async def update_user_group( }, "group_id": group_id, "group_data": { - "id": group.id, - "name": group.name, - "user_count": len(group.user_ids) if group.user_ids else 0, - "signal_count": len(group.signal_ids) if group.signal_ids else 0 + "id": group_data.id, + "name": group_data.name, + "user_count": len(group_data.user_ids) if group_data.user_ids else 0, + "signal_count": len(group_data.signal_ids) if group_data.signal_ids else 0, + "collaborator_map_size": len(group_data.collaborator_map) if group_data.collaborator_map else 0 + }, + "query_params": { + "include_users": include_users } } ) @@ -298,15 +661,29 @@ async def delete_user_group( cursor: AsyncCursor = Depends(db.yield_cursor), ): """Delete a user group.""" + logger.info(f"Endpoint called: delete_user_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Deleting user group ID: {group_id}") + try: + # First get the group to log what's being deleted + group = await db.read_user_group(cursor, group_id) + if group: + logger.debug(f"Found group to delete: {group.id} - '{group.name}'") + logger.debug(f"Group contains: {len(group.user_ids) if group.user_ids else 0} members, " + + f"{len(group.signal_ids) if group.signal_ids else 0} signals, " + + f"{len(group.collaborator_map) if group.collaborator_map else 0} signal collaborator maps") + + logger.debug(f"Deleting group {group_id} from database...") if not await db.delete_user_group(cursor, group_id): logger.warning(f"User group {group_id} not found for deletion") raise exceptions.not_found - logger.info(f"Deleted user group {group_id}") + + logger.info(f"Successfully deleted user group {group_id}") return True except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error deleting user group {group_id}: {str(e)}") + logger.exception(f"Detailed traceback for deleting user group {group_id}:") bugsnag.notify( e, metadata={ @@ -554,32 +931,52 @@ async def add_collaborator_to_signal_in_group( This endpoint accepts either a numeric user ID or an email address. If an email is provided, the system will look up the corresponding user ID. """ + logger.info(f"Endpoint called: add_collaborator_to_signal_in_group - URL: {request.url} - Method: {request.method}") + logger.debug(f"Adding collaborator to group {group_id} for signal {signal_id}") + logger.debug(f"User identifier provided: {user_id_or_email}") + try: # Try to parse as int for backward compatibility + user_id = None + user_email = None + try: user_id = int(user_id_or_email) + logger.debug(f"User identifier is numeric: {user_id}") except ValueError: # Not an integer, treat as email - user = await db.read_user_by_email(cursor, user_id_or_email) + user_email = user_id_or_email + logger.debug(f"User identifier is an email: {user_email}") + + user = await db.read_user_by_email(cursor, user_email) if not user or not user.id: - logger.warning(f"User with email {user_id_or_email} not found") - raise HTTPException(status_code=404, detail=f"User with email {user_id_or_email} not found") + logger.warning(f"User with email {user_email} not found") + raise HTTPException(status_code=404, detail=f"User with email {user_email} not found") + user_id = user.id + logger.debug(f"Resolved email {user_email} to user ID {user_id}") # Get the group + logger.debug(f"Fetching group {group_id}...") group = await db.read_user_group(cursor, group_id) if group is None: logger.warning(f"Group {group_id} not found") raise exceptions.not_found + logger.debug(f"Successfully retrieved group '{group.name}' (ID: {group.id})") + # Check if signal is in the group signal_ids = group.signal_ids or [] + logger.debug(f"Group has {len(signal_ids)} signals: {signal_ids}") + if signal_id not in signal_ids: logger.warning(f"Signal {signal_id} not in group {group_id}") raise exceptions.not_found # Check if user is in the group user_ids = group.user_ids or [] + logger.debug(f"Group has {len(user_ids)} members: {user_ids}") + if user_id not in user_ids: logger.warning(f"User {user_id} not in group {group_id}") raise exceptions.not_found @@ -587,25 +984,32 @@ async def add_collaborator_to_signal_in_group( # Add collaborator collaborator_map = group.collaborator_map or {} signal_key = str(signal_id) + + logger.debug(f"Current collaborator map: {collaborator_map}") + if signal_key not in collaborator_map: + logger.debug(f"Creating new collaborator entry for signal {signal_id}") collaborator_map[signal_key] = [] if user_id not in collaborator_map[signal_key]: + logger.debug(f"Adding user {user_id} as collaborator for signal {signal_id}") collaborator_map[signal_key].append(user_id) group.collaborator_map = collaborator_map + logger.debug(f"Updating group {group_id} with new collaborator map") if await db.update_user_group(cursor, group) is None: logger.error(f"Failed to update group {group_id}") raise exceptions.not_found - logger.info(f"Added user {user_id} as collaborator for signal {signal_id} in group {group_id}") + logger.info(f"Successfully added user {user_id} as collaborator for signal {signal_id} in group {group_id}") else: - logger.info(f"User {user_id} already a collaborator for signal {signal_id} in group {group_id}") + logger.info(f"User {user_id} is already a collaborator for signal {signal_id} in group {group_id}") return True except Exception as e: if not isinstance(e, HTTPException): # Don't log HTTPExceptions logger.error(f"Error adding collaborator: {str(e)}") + logger.exception(f"Detailed traceback for adding collaborator to group {group_id}, signal {signal_id}:") bugsnag.notify( e, metadata={ @@ -615,7 +1019,8 @@ async def add_collaborator_to_signal_in_group( }, "group_id": group_id, "signal_id": signal_id, - "user_id_or_email": user_id_or_email + "user_id_or_email": user_id_or_email, + "resolved_user_id": user_id if 'user_id' in locals() else None } ) raise