diff --git a/src/domain/portfolio.py b/src/domain/portfolio.py index dae1006..a5d835e 100644 --- a/src/domain/portfolio.py +++ b/src/domain/portfolio.py @@ -57,3 +57,11 @@ def __post_init__(self): def as_dict(self) -> PortfolioDict: return asdict(self, dict_factory=custom_dict_factory) + + +@dataclass +class PortfolioInfo: + user_id: int + total_portfolio_value: float + total_gain: float + roi: float diff --git a/src/tests/test_stock_usecase.py b/src/tests/test_stock_usecase.py index 4cf71ab..521f662 100644 --- a/src/tests/test_stock_usecase.py +++ b/src/tests/test_stock_usecase.py @@ -1,9 +1,8 @@ import pytest -from datetime import datetime, timezone from unittest.mock import Mock, ANY, patch from usecase.stock import StockUsecase from domain.stock import CreateStock, Stock -from domain.portfolio import Portfolio, Holding +from domain.portfolio import Portfolio, Holding, PortfolioInfo from domain.enum import ActionType, StockType @@ -15,7 +14,7 @@ def stock_usecase(): return usecase, stock_repo, portfolio_repo -class TestStockUsecase: +class TestStockUsecaseCreate: def test_create_transfer_new_portfolio(self, stock_usecase): # Arrange usecase, stock_repo, portfolio_repo = stock_usecase @@ -375,6 +374,8 @@ def test_create_handles_repository_error_on_stock_create(self, stock_usecase): portfolio_repo.update.assert_called_once_with(portfolio=ANY) # Portfolio may vary, so use ANY stock_repo.create.assert_called_once_with(stock) + +class TestStockUsecaseList: def test_list(self, stock_usecase): # Arrange usecase, mock_repo, _ = stock_usecase @@ -427,143 +428,8 @@ def test_list_handles_repository_error(self, stock_usecase): usecase.list(user_id) mock_repo.list.assert_called_once_with(user_id) - def test_calculate_total_roi_no_portfolio(self, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio_repo.get.return_value = None - - # Act - result = usecase.calculate_total_roi(user_id) - - # Assert - portfolio_repo.get.assert_called_once_with(user_id=user_id) - assert result == 0.0 - - def test_calculate_total_roi_no_total_money_in(self, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio = Portfolio( - user_id=user_id, - cash_balance=0.0, - total_money_in=0.0, - holdings=[], - created_at=ANY, - updated_at=ANY, - ) - portfolio_repo.get.return_value = portfolio - - # Act - result = usecase.calculate_total_roi(user_id) - - # Assert - portfolio_repo.get.assert_called_once_with(user_id=user_id) - assert result == 0.0 - - def test_calculate_total_roi_no_holdings(self, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio = Portfolio( - user_id=user_id, - cash_balance=1000.0, - total_money_in=1000.0, - holdings=[], - created_at=ANY, - updated_at=ANY, - ) - portfolio_repo.get.return_value = portfolio - - # Act - result = usecase.calculate_total_roi(user_id) - - # Assert - portfolio_repo.get.assert_called_once_with(user_id=user_id) - assert result == 0.0 # ROI = ((1000 - 1000) / 1000) * 100 = 0.0 - - @patch.object(StockUsecase, "_get_stock_price") - def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio = Portfolio( - user_id=user_id, - cash_balance=1000.0, - total_money_in=5000.0, - holdings=[ - Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), - Holding(symbol="GOOGL", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0), - Holding( - symbol="TSLA", shares=0, stock_type=StockType.STOCKS, total_cost=0.0 - ), # Zero shares, should be ignored - ], - created_at=ANY, - updated_at=ANY, - ) - portfolio_repo.get.return_value = portfolio - mock_get_stock_price.return_value = { - "AAPL": 200.0, # 10 shares * 200 = 2000 - "GOOGL": 3000.0, # 5 shares * 3000 = 15000 - } - - # Act - result = usecase.calculate_total_roi(user_id) - - # Assert - portfolio_repo.get.assert_called_once_with(user_id=user_id) - mock_get_stock_price.assert_called_once_with( - stock_info=[("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)] - ) - # Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000 - # ROI = ((18000 - 5000) / 5000) * 100 = 260.0 - assert result == 260.0 - - @patch.object(StockUsecase, "_get_stock_price") - def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio = Portfolio( - user_id=user_id, - cash_balance=1000.0, - total_money_in=5000.0, - holdings=[ - Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), - Holding(symbol="INVALID", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0), - ], - created_at=ANY, - updated_at=ANY, - ) - portfolio_repo.get.return_value = portfolio - mock_get_stock_price.return_value = { - "AAPL": 200.0, # 10 shares * 200 = 2000 - "INVALID": 0.0, # No price available - } - - # Act - result = usecase.calculate_total_roi(user_id) - - # Assert - portfolio_repo.get.assert_called_once_with(user_id=user_id) - mock_get_stock_price.assert_called_once_with( - stock_info=[("AAPL", StockType.STOCKS), ("INVALID", StockType.STOCKS)] - ) - # Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000 - # ROI = ((3000 - 5000) / 5000) * 100 = -40.0 - assert result == -40.0 - - def test_calculate_total_roi_handles_repository_error(self, stock_usecase): - # Arrange - usecase, _, portfolio_repo = stock_usecase - user_id = 1 - portfolio_repo.get.side_effect = Exception("Portfolio repository error") - - # Act/Assert - with pytest.raises(Exception, match="Portfolio repository error"): - usecase.calculate_total_roi(user_id) - portfolio_repo.get.assert_called_once_with(user_id=user_id) +class TestStockUsecaseGetStockPrice: @patch("usecase.stock.yf.Tickers") def test_get_stock_price_success_stocks(self, mock_yf_tickers, stock_usecase): # Arrange @@ -690,3 +556,107 @@ def test_get_stock_price_missing_price_field(self, mock_yf_tickers, stock_usecas # Assert mock_yf_tickers.assert_called_once_with(["AAPL", "SPY"]) assert result == {"AAPL": 0.0, "SPY": 0.0} + + +class TestStockUsecaseGetPortfolioInfo: + @patch.object(StockUsecase, "_get_stock_price") + def test_get_portfolio_info_no_portfolio(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio_repo.get.return_value = None + expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0) + + # Act + result = usecase.get_portfolio_info(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_portfolio_info_empty_portfolio(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=0.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = portfolio + expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0) + + # Act + result = usecase.get_portfolio_info(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_portfolio_info_no_valid_holdings(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=2000.0, + holdings=[Holding(symbol="AAPL", shares=0, stock_type=StockType.STOCKS, total_cost=0.0)], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = portfolio + expected_result = PortfolioInfo( + user_id=user_id, + total_portfolio_value=1000.0, + total_gain=-1000.0, + roi=-50.0, + ) + + # Act + result = usecase.get_portfolio_info(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_not_called() + assert result == expected_result + + @patch.object(StockUsecase, "_get_stock_price") + def test_get_portfolio_info_with_valid_holdings(self, mock_get_stock_price, stock_usecase): + # Arrange + usecase, _, portfolio_repo = stock_usecase + user_id = 1 + portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=2000.0, + holdings=[ + Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), + Holding(symbol="SPY", shares=5, stock_type=StockType.ETF, total_cost=1000.0), + ], + created_at=ANY, + updated_at=ANY, + ) + mock_get_stock_price.return_value = {"AAPL": 200.0, "SPY": 400.0} + portfolio_repo.get.return_value = portfolio + expected_result = PortfolioInfo( + user_id=user_id, + total_portfolio_value=5000.0, + total_gain=3000.0, + roi=150, + ) + + # Act + result = usecase.get_portfolio_info(user_id) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id=user_id) + mock_get_stock_price.assert_called_once_with(stock_info=[("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)]) + assert result == expected_result diff --git a/src/usecase/stock.py b/src/usecase/stock.py index 83b0c89..c8a4964 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -3,7 +3,7 @@ import yfinance as yf from .base import AbstractStockUsecase from adapters.base import AbstractStockRepository, AbstractPortfolioRepository -from domain.portfolio import Portfolio, Holding +from domain.portfolio import Portfolio, Holding, PortfolioInfo from domain.stock import CreateStock, Stock from domain.enum import ActionType, StockType @@ -69,17 +69,23 @@ def create(self, stock: CreateStock) -> str: def list(self, user_id: int) -> List[Stock]: return self.stock_repo.list(user_id) - def calculate_total_roi(self, user_id: int) -> float: + def get_portfolio_info(self, user_id: int) -> PortfolioInfo: portfolio = self.portfolio_repo.get(user_id=user_id) if portfolio is None or portfolio.total_money_in == 0.0: - return 0.0 + return PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0) valid_holdings = [ (holding.symbol, holding.shares, holding.stock_type) for holding in portfolio.holdings if holding.shares > 0 ] if not valid_holdings: - # If no valid holdings, ROI depends only on cash balance - return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2) + return PortfolioInfo( + user_id=user_id, + total_portfolio_value=portfolio.cash_balance, + total_gain=portfolio.cash_balance - portfolio.total_money_in, + roi=round( + ((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2 + ), # If no valid holdings, ROI depends only on cash balance + ) # Fetch prices in batch stock_info = [(symbol, stock_type) for symbol, _, stock_type in valid_holdings] @@ -90,8 +96,14 @@ def calculate_total_roi(self, user_id: int) -> float: # Compute ROI total_value = total_stock_price + portfolio.cash_balance - roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100 - return round(roi, 2) + roi = round(((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2) + + return PortfolioInfo( + user_id=user_id, + total_portfolio_value=total_value, + total_gain=total_value - portfolio.total_money_in, + roi=roi, + ) def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, float]: if not stock_info: