Skip to content

Commit 41f2310

Browse files
eyo-chenEyo Chen
andauthored
Feat: get total roi usecase (#17)
* feat: add yfinance * feat: add get total roi * test: add unit testing --------- Co-authored-by: Eyo Chen <[email protected]>
1 parent 38e2f29 commit 41f2310

File tree

4 files changed

+783
-3
lines changed

4 files changed

+783
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"urllib3==2.4.0",
2626
"validate-email==1.3",
2727
"websockets==14.2",
28+
"yfinance>=0.2.61",
2829
]
2930

3031
[dependency-groups]

src/tests/test_stock_usecase.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from datetime import datetime, timezone
3-
from unittest.mock import Mock, ANY
3+
from unittest.mock import Mock, ANY, patch
44
from usecase.stock import StockUsecase
55
from domain.stock import CreateStock, ActionType, Stock
66
from domain.portfolio import Portfolio, Holding
@@ -410,3 +410,182 @@ def test_list_handles_repository_error(self, stock_usecase):
410410
with pytest.raises(Exception, match="Repository error"):
411411
usecase.list(user_id)
412412
mock_repo.list.assert_called_once_with(user_id)
413+
414+
def test_calculate_total_roi_no_portfolio(self, stock_usecase):
415+
# Arrange
416+
usecase, _, portfolio_repo = stock_usecase
417+
user_id = 1
418+
portfolio_repo.get.return_value = None
419+
420+
# Act
421+
result = usecase.calculate_total_roi(user_id)
422+
423+
# Assert
424+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
425+
assert result == 0.0
426+
427+
def test_calculate_total_roi_no_total_money_in(self, stock_usecase):
428+
# Arrange
429+
usecase, _, portfolio_repo = stock_usecase
430+
user_id = 1
431+
portfolio = Portfolio(
432+
user_id=user_id,
433+
cash_balance=0.0,
434+
total_money_in=0.0,
435+
holdings=[],
436+
created_at=ANY,
437+
updated_at=ANY,
438+
)
439+
portfolio_repo.get.return_value = portfolio
440+
441+
# Act
442+
result = usecase.calculate_total_roi(user_id)
443+
444+
# Assert
445+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
446+
assert result == 0.0
447+
448+
def test_calculate_total_roi_no_holdings(self, stock_usecase):
449+
# Arrange
450+
usecase, _, portfolio_repo = stock_usecase
451+
user_id = 1
452+
portfolio = Portfolio(
453+
user_id=user_id,
454+
cash_balance=1000.0,
455+
total_money_in=1000.0,
456+
holdings=[],
457+
created_at=ANY,
458+
updated_at=ANY,
459+
)
460+
portfolio_repo.get.return_value = portfolio
461+
462+
# Act
463+
result = usecase.calculate_total_roi(user_id)
464+
465+
# Assert
466+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
467+
assert result == 0.0 # ROI = ((1000 - 1000) / 1000) * 100 = 0.0
468+
469+
@patch.object(StockUsecase, "_get_stock_price")
470+
def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_usecase):
471+
# Arrange
472+
usecase, _, portfolio_repo = stock_usecase
473+
user_id = 1
474+
portfolio = Portfolio(
475+
user_id=user_id,
476+
cash_balance=1000.0,
477+
total_money_in=5000.0,
478+
holdings=[
479+
Holding(symbol="AAPL", shares=10, total_cost=1500.0),
480+
Holding(symbol="GOOGL", shares=5, total_cost=2000.0),
481+
Holding(symbol="TSLA", shares=0, total_cost=0.0), # Zero shares, should be ignored
482+
],
483+
created_at=ANY,
484+
updated_at=ANY,
485+
)
486+
portfolio_repo.get.return_value = portfolio
487+
mock_get_stock_price.return_value = {
488+
"AAPL": 200.0, # 10 shares * 200 = 2000
489+
"GOOGL": 3000.0, # 5 shares * 3000 = 15000
490+
}
491+
492+
# Act
493+
result = usecase.calculate_total_roi(user_id)
494+
495+
# Assert
496+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
497+
mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "GOOGL"])
498+
# Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000
499+
# ROI = ((18000 - 5000) / 5000) * 100 = 260.0
500+
assert result == 260.0
501+
502+
@patch.object(StockUsecase, "_get_stock_price")
503+
def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, stock_usecase):
504+
# Arrange
505+
usecase, _, portfolio_repo = stock_usecase
506+
user_id = 1
507+
portfolio = Portfolio(
508+
user_id=user_id,
509+
cash_balance=1000.0,
510+
total_money_in=5000.0,
511+
holdings=[
512+
Holding(symbol="AAPL", shares=10, total_cost=1500.0),
513+
Holding(symbol="INVALID", shares=5, total_cost=2000.0),
514+
],
515+
created_at=ANY,
516+
updated_at=ANY,
517+
)
518+
portfolio_repo.get.return_value = portfolio
519+
mock_get_stock_price.return_value = {
520+
"AAPL": 200.0, # 10 shares * 200 = 2000
521+
"INVALID": 0.0, # No price available
522+
}
523+
524+
# Act
525+
result = usecase.calculate_total_roi(user_id)
526+
527+
# Assert
528+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
529+
mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "INVALID"])
530+
# Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000
531+
# ROI = ((3000 - 5000) / 5000) * 100 = -40.0
532+
assert result == -40.0
533+
534+
def test_calculate_total_roi_handles_repository_error(self, stock_usecase):
535+
# Arrange
536+
usecase, _, portfolio_repo = stock_usecase
537+
user_id = 1
538+
portfolio_repo.get.side_effect = Exception("Portfolio repository error")
539+
540+
# Act/Assert
541+
with pytest.raises(Exception, match="Portfolio repository error"):
542+
usecase.calculate_total_roi(user_id)
543+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
544+
545+
@patch("usecase.stock.yf.Tickers")
546+
def test_get_stock_price_success(self, mock_yf_tickers, stock_usecase):
547+
# Arrange
548+
usecase, _, _ = stock_usecase
549+
symbols = ["AAPL", "GOOGL"]
550+
mock_ticker_aapl = Mock()
551+
mock_ticker_aapl.info = {"currentPrice": 150.0}
552+
mock_ticker_googl = Mock()
553+
mock_ticker_googl.info = {"currentPrice": 2800.0}
554+
mock_yf_tickers.return_value.tickers = {
555+
"AAPL": mock_ticker_aapl,
556+
"GOOGL": mock_ticker_googl,
557+
}
558+
559+
# Act
560+
result = usecase._get_stock_price(symbols)
561+
562+
# Assert
563+
mock_yf_tickers.assert_called_once_with(symbols)
564+
assert result == {"AAPL": 150.0, "GOOGL": 2800.0}
565+
566+
@patch("usecase.stock.yf.Tickers")
567+
def test_get_stock_price_empty_symbols(self, mock_yf_tickers, stock_usecase):
568+
# Arrange
569+
usecase, _, _ = stock_usecase
570+
symbols = []
571+
572+
# Act
573+
result = usecase._get_stock_price(symbols)
574+
575+
# Assert
576+
mock_yf_tickers.assert_not_called()
577+
assert result == {}
578+
579+
@patch("usecase.stock.yf.Tickers")
580+
def test_get_stock_price_handles_api_error(self, mock_yf_tickers, stock_usecase):
581+
# Arrange
582+
usecase, _, _ = stock_usecase
583+
symbols = ["AAPL", "GOOGL"]
584+
mock_yf_tickers.side_effect = Exception("API error")
585+
586+
# Act
587+
result = usecase._get_stock_price(symbols)
588+
589+
# Assert
590+
mock_yf_tickers.assert_called_once_with(symbols)
591+
assert result == {"AAPL": 0.0, "GOOGL": 0.0}

src/usecase/stock.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List
1+
from typing import List, Dict
22
from datetime import datetime, timezone
3+
import yfinance as yf
34
from domain.stock import CreateStock, Stock, ActionType
45
from domain.portfolio import Portfolio, Holding
56
from adapters.base import AbstractStockRepository, AbstractPortfolioRepository
@@ -63,3 +64,43 @@ def create(self, stock: CreateStock) -> str:
6364

6465
def list(self, user_id: int) -> List[Stock]:
6566
return self.stock_repo.list(user_id)
67+
68+
def calculate_total_roi(self, user_id: int) -> float:
69+
portfolio = self.portfolio_repo.get(user_id=user_id)
70+
if portfolio is None or portfolio.total_money_in == 0.0:
71+
return 0.0
72+
73+
valid_holdings = [(holding.symbol, holding.shares) for holding in portfolio.holdings if holding.shares > 0]
74+
if not valid_holdings:
75+
# If no valid holdings, ROI depends only on cash balance
76+
return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)
77+
78+
# Fetch prices in batch
79+
stock_symbols = [symbol for symbol, _ in valid_holdings]
80+
stock_price_by_symbol = self._get_stock_price(symbols=stock_symbols)
81+
82+
# Calculate total stock value
83+
total_stock_price = sum(shares * stock_price_by_symbol.get(symbol, 0.0) for symbol, shares in valid_holdings)
84+
85+
# Compute ROI
86+
total_value = total_stock_price + portfolio.cash_balance
87+
roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100
88+
return round(roi, 2)
89+
90+
def _get_stock_price(self, symbols: List[str]) -> Dict[str, float]:
91+
if not symbols:
92+
return {}
93+
94+
try:
95+
tickers = yf.Tickers(symbols)
96+
return {
97+
symbol: (
98+
ticker.info.get("currentPrice", 0.0)
99+
if (ticker := tickers.tickers.get(symbol.upper())) is not None
100+
else 0.0
101+
)
102+
for symbol in symbols
103+
}
104+
except Exception as e:
105+
print(f"Error fetching prices for symbols {symbols}: {e}")
106+
return {symbol.upper(): 0.0 for symbol in symbols}

0 commit comments

Comments
 (0)