Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"urllib3==2.4.0",
"validate-email==1.3",
"websockets==14.2",
"yfinance>=0.2.61",
]

[dependency-groups]
Expand Down
181 changes: 180 additions & 1 deletion src/tests/test_stock_usecase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, ANY
from unittest.mock import Mock, ANY, patch
from usecase.stock import StockUsecase
from domain.stock import CreateStock, ActionType, Stock
from domain.portfolio import Portfolio, Holding
Expand Down Expand Up @@ -410,3 +410,182 @@ def test_list_handles_repository_error(self, stock_usecase):
with pytest.raises(Exception, match="Repository error"):
usecase.list(user_id)
mock_repo.list.assert_called_once_with(user_id)

def test_calculate_total_roi_no_portfolio(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio_repo.get.return_value = None

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0

def test_calculate_total_roi_no_total_money_in(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=0.0,
total_money_in=0.0,
holdings=[],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0

def test_calculate_total_roi_no_holdings(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=1000.0,
holdings=[],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0 # ROI = ((1000 - 1000) / 1000) * 100 = 0.0

@patch.object(StockUsecase, "_get_stock_price")
def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=5000.0,
holdings=[
Holding(symbol="AAPL", shares=10, total_cost=1500.0),
Holding(symbol="GOOGL", shares=5, total_cost=2000.0),
Holding(symbol="TSLA", shares=0, total_cost=0.0), # Zero shares, should be ignored
],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
mock_get_stock_price.return_value = {
"AAPL": 200.0, # 10 shares * 200 = 2000
"GOOGL": 3000.0, # 5 shares * 3000 = 15000
}

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "GOOGL"])
# Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000
# ROI = ((18000 - 5000) / 5000) * 100 = 260.0
assert result == 260.0

@patch.object(StockUsecase, "_get_stock_price")
def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=5000.0,
holdings=[
Holding(symbol="AAPL", shares=10, total_cost=1500.0),
Holding(symbol="INVALID", shares=5, total_cost=2000.0),
],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
mock_get_stock_price.return_value = {
"AAPL": 200.0, # 10 shares * 200 = 2000
"INVALID": 0.0, # No price available
}

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "INVALID"])
# Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000
# ROI = ((3000 - 5000) / 5000) * 100 = -40.0
assert result == -40.0

def test_calculate_total_roi_handles_repository_error(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio_repo.get.side_effect = Exception("Portfolio repository error")

# Act/Assert
with pytest.raises(Exception, match="Portfolio repository error"):
usecase.calculate_total_roi(user_id)
portfolio_repo.get.assert_called_once_with(user_id=user_id)

@patch("usecase.stock.yf.Tickers")
def test_get_stock_price_success(self, mock_yf_tickers, stock_usecase):
# Arrange
usecase, _, _ = stock_usecase
symbols = ["AAPL", "GOOGL"]
mock_ticker_aapl = Mock()
mock_ticker_aapl.info = {"currentPrice": 150.0}
mock_ticker_googl = Mock()
mock_ticker_googl.info = {"currentPrice": 2800.0}
mock_yf_tickers.return_value.tickers = {
"AAPL": mock_ticker_aapl,
"GOOGL": mock_ticker_googl,
}

# Act
result = usecase._get_stock_price(symbols)

# Assert
mock_yf_tickers.assert_called_once_with(symbols)
assert result == {"AAPL": 150.0, "GOOGL": 2800.0}

@patch("usecase.stock.yf.Tickers")
def test_get_stock_price_empty_symbols(self, mock_yf_tickers, stock_usecase):
# Arrange
usecase, _, _ = stock_usecase
symbols = []

# Act
result = usecase._get_stock_price(symbols)

# Assert
mock_yf_tickers.assert_not_called()
assert result == {}

@patch("usecase.stock.yf.Tickers")
def test_get_stock_price_handles_api_error(self, mock_yf_tickers, stock_usecase):
# Arrange
usecase, _, _ = stock_usecase
symbols = ["AAPL", "GOOGL"]
mock_yf_tickers.side_effect = Exception("API error")

# Act
result = usecase._get_stock_price(symbols)

# Assert
mock_yf_tickers.assert_called_once_with(symbols)
assert result == {"AAPL": 0.0, "GOOGL": 0.0}
43 changes: 42 additions & 1 deletion src/usecase/stock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from typing import List, Dict
from datetime import datetime, timezone
import yfinance as yf
from domain.stock import CreateStock, Stock, ActionType
from domain.portfolio import Portfolio, Holding
from adapters.base import AbstractStockRepository, AbstractPortfolioRepository
Expand Down Expand Up @@ -63,3 +64,43 @@ def create(self, stock: CreateStock) -> str:

def list(self, user_id: int) -> List[Stock]:
return self.stock_repo.list(user_id)

def calculate_total_roi(self, user_id: int) -> float:
portfolio = self.portfolio_repo.get(user_id=user_id)
if portfolio is None or portfolio.total_money_in == 0.0:
return 0.0

valid_holdings = [(holding.symbol, holding.shares) for holding in portfolio.holdings if holding.shares > 0]
if not valid_holdings:
# If no valid holdings, ROI depends only on cash balance
return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)

# Fetch prices in batch
stock_symbols = [symbol for symbol, _ in valid_holdings]
stock_price_by_symbol = self._get_stock_price(symbols=stock_symbols)

# Calculate total stock value
total_stock_price = sum(shares * stock_price_by_symbol.get(symbol, 0.0) for symbol, shares in valid_holdings)

# Compute ROI
total_value = total_stock_price + portfolio.cash_balance
roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100
return round(roi, 2)

def _get_stock_price(self, symbols: List[str]) -> Dict[str, float]:
if not symbols:
return {}

try:
tickers = yf.Tickers(symbols)
return {
symbol: (
ticker.info.get("currentPrice", 0.0)
if (ticker := tickers.tickers.get(symbol.upper())) is not None
else 0.0
)
for symbol in symbols
}
except Exception as e:
print(f"Error fetching prices for symbols {symbols}: {e}")
return {symbol.upper(): 0.0 for symbol in symbols}
Loading