Skip to content

Commit ff05ee7

Browse files
eyo-chenEyo Chen
andauthored
Feat: add get portfolio info usecase (#19)
* feat: add get portfolio info * test: add unit testing --------- Co-authored-by: Eyo Chen <[email protected]>
1 parent 8bd9e24 commit ff05ee7

File tree

3 files changed

+136
-146
lines changed

3 files changed

+136
-146
lines changed

src/domain/portfolio.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,11 @@ def __post_init__(self):
5757

5858
def as_dict(self) -> PortfolioDict:
5959
return asdict(self, dict_factory=custom_dict_factory)
60+
61+
62+
@dataclass
63+
class PortfolioInfo:
64+
user_id: int
65+
total_portfolio_value: float
66+
total_gain: float
67+
roi: float

src/tests/test_stock_usecase.py

Lines changed: 109 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytest
2-
from datetime import datetime, timezone
32
from unittest.mock import Mock, ANY, patch
43
from usecase.stock import StockUsecase
54
from domain.stock import CreateStock, Stock
6-
from domain.portfolio import Portfolio, Holding
5+
from domain.portfolio import Portfolio, Holding, PortfolioInfo
76
from domain.enum import ActionType, StockType
87

98

@@ -15,7 +14,7 @@ def stock_usecase():
1514
return usecase, stock_repo, portfolio_repo
1615

1716

18-
class TestStockUsecase:
17+
class TestStockUsecaseCreate:
1918
def test_create_transfer_new_portfolio(self, stock_usecase):
2019
# Arrange
2120
usecase, stock_repo, portfolio_repo = stock_usecase
@@ -375,6 +374,8 @@ def test_create_handles_repository_error_on_stock_create(self, stock_usecase):
375374
portfolio_repo.update.assert_called_once_with(portfolio=ANY) # Portfolio may vary, so use ANY
376375
stock_repo.create.assert_called_once_with(stock)
377376

377+
378+
class TestStockUsecaseList:
378379
def test_list(self, stock_usecase):
379380
# Arrange
380381
usecase, mock_repo, _ = stock_usecase
@@ -427,143 +428,8 @@ def test_list_handles_repository_error(self, stock_usecase):
427428
usecase.list(user_id)
428429
mock_repo.list.assert_called_once_with(user_id)
429430

430-
def test_calculate_total_roi_no_portfolio(self, stock_usecase):
431-
# Arrange
432-
usecase, _, portfolio_repo = stock_usecase
433-
user_id = 1
434-
portfolio_repo.get.return_value = None
435-
436-
# Act
437-
result = usecase.calculate_total_roi(user_id)
438-
439-
# Assert
440-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
441-
assert result == 0.0
442-
443-
def test_calculate_total_roi_no_total_money_in(self, stock_usecase):
444-
# Arrange
445-
usecase, _, portfolio_repo = stock_usecase
446-
user_id = 1
447-
portfolio = Portfolio(
448-
user_id=user_id,
449-
cash_balance=0.0,
450-
total_money_in=0.0,
451-
holdings=[],
452-
created_at=ANY,
453-
updated_at=ANY,
454-
)
455-
portfolio_repo.get.return_value = portfolio
456-
457-
# Act
458-
result = usecase.calculate_total_roi(user_id)
459-
460-
# Assert
461-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
462-
assert result == 0.0
463-
464-
def test_calculate_total_roi_no_holdings(self, stock_usecase):
465-
# Arrange
466-
usecase, _, portfolio_repo = stock_usecase
467-
user_id = 1
468-
portfolio = Portfolio(
469-
user_id=user_id,
470-
cash_balance=1000.0,
471-
total_money_in=1000.0,
472-
holdings=[],
473-
created_at=ANY,
474-
updated_at=ANY,
475-
)
476-
portfolio_repo.get.return_value = portfolio
477-
478-
# Act
479-
result = usecase.calculate_total_roi(user_id)
480-
481-
# Assert
482-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
483-
assert result == 0.0 # ROI = ((1000 - 1000) / 1000) * 100 = 0.0
484-
485-
@patch.object(StockUsecase, "_get_stock_price")
486-
def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_usecase):
487-
# Arrange
488-
usecase, _, portfolio_repo = stock_usecase
489-
user_id = 1
490-
portfolio = Portfolio(
491-
user_id=user_id,
492-
cash_balance=1000.0,
493-
total_money_in=5000.0,
494-
holdings=[
495-
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
496-
Holding(symbol="GOOGL", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0),
497-
Holding(
498-
symbol="TSLA", shares=0, stock_type=StockType.STOCKS, total_cost=0.0
499-
), # Zero shares, should be ignored
500-
],
501-
created_at=ANY,
502-
updated_at=ANY,
503-
)
504-
portfolio_repo.get.return_value = portfolio
505-
mock_get_stock_price.return_value = {
506-
"AAPL": 200.0, # 10 shares * 200 = 2000
507-
"GOOGL": 3000.0, # 5 shares * 3000 = 15000
508-
}
509-
510-
# Act
511-
result = usecase.calculate_total_roi(user_id)
512-
513-
# Assert
514-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
515-
mock_get_stock_price.assert_called_once_with(
516-
stock_info=[("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)]
517-
)
518-
# Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000
519-
# ROI = ((18000 - 5000) / 5000) * 100 = 260.0
520-
assert result == 260.0
521-
522-
@patch.object(StockUsecase, "_get_stock_price")
523-
def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, stock_usecase):
524-
# Arrange
525-
usecase, _, portfolio_repo = stock_usecase
526-
user_id = 1
527-
portfolio = Portfolio(
528-
user_id=user_id,
529-
cash_balance=1000.0,
530-
total_money_in=5000.0,
531-
holdings=[
532-
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
533-
Holding(symbol="INVALID", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0),
534-
],
535-
created_at=ANY,
536-
updated_at=ANY,
537-
)
538-
portfolio_repo.get.return_value = portfolio
539-
mock_get_stock_price.return_value = {
540-
"AAPL": 200.0, # 10 shares * 200 = 2000
541-
"INVALID": 0.0, # No price available
542-
}
543-
544-
# Act
545-
result = usecase.calculate_total_roi(user_id)
546-
547-
# Assert
548-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
549-
mock_get_stock_price.assert_called_once_with(
550-
stock_info=[("AAPL", StockType.STOCKS), ("INVALID", StockType.STOCKS)]
551-
)
552-
# Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000
553-
# ROI = ((3000 - 5000) / 5000) * 100 = -40.0
554-
assert result == -40.0
555-
556-
def test_calculate_total_roi_handles_repository_error(self, stock_usecase):
557-
# Arrange
558-
usecase, _, portfolio_repo = stock_usecase
559-
user_id = 1
560-
portfolio_repo.get.side_effect = Exception("Portfolio repository error")
561-
562-
# Act/Assert
563-
with pytest.raises(Exception, match="Portfolio repository error"):
564-
usecase.calculate_total_roi(user_id)
565-
portfolio_repo.get.assert_called_once_with(user_id=user_id)
566431

432+
class TestStockUsecaseGetStockPrice:
567433
@patch("usecase.stock.yf.Tickers")
568434
def test_get_stock_price_success_stocks(self, mock_yf_tickers, stock_usecase):
569435
# Arrange
@@ -690,3 +556,107 @@ def test_get_stock_price_missing_price_field(self, mock_yf_tickers, stock_usecas
690556
# Assert
691557
mock_yf_tickers.assert_called_once_with(["AAPL", "SPY"])
692558
assert result == {"AAPL": 0.0, "SPY": 0.0}
559+
560+
561+
class TestStockUsecaseGetPortfolioInfo:
562+
@patch.object(StockUsecase, "_get_stock_price")
563+
def test_get_portfolio_info_no_portfolio(self, mock_get_stock_price, stock_usecase):
564+
# Arrange
565+
usecase, _, portfolio_repo = stock_usecase
566+
user_id = 1
567+
portfolio_repo.get.return_value = None
568+
expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)
569+
570+
# Act
571+
result = usecase.get_portfolio_info(user_id)
572+
573+
# Assert
574+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
575+
mock_get_stock_price.assert_not_called()
576+
assert result == expected_result
577+
578+
@patch.object(StockUsecase, "_get_stock_price")
579+
def test_get_portfolio_info_empty_portfolio(self, mock_get_stock_price, stock_usecase):
580+
# Arrange
581+
usecase, _, portfolio_repo = stock_usecase
582+
user_id = 1
583+
portfolio = Portfolio(
584+
user_id=user_id,
585+
cash_balance=1000.0,
586+
total_money_in=0.0,
587+
holdings=[],
588+
created_at=ANY,
589+
updated_at=ANY,
590+
)
591+
portfolio_repo.get.return_value = portfolio
592+
expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)
593+
594+
# Act
595+
result = usecase.get_portfolio_info(user_id)
596+
597+
# Assert
598+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
599+
mock_get_stock_price.assert_not_called()
600+
assert result == expected_result
601+
602+
@patch.object(StockUsecase, "_get_stock_price")
603+
def test_get_portfolio_info_no_valid_holdings(self, mock_get_stock_price, stock_usecase):
604+
# Arrange
605+
usecase, _, portfolio_repo = stock_usecase
606+
user_id = 1
607+
portfolio = Portfolio(
608+
user_id=user_id,
609+
cash_balance=1000.0,
610+
total_money_in=2000.0,
611+
holdings=[Holding(symbol="AAPL", shares=0, stock_type=StockType.STOCKS, total_cost=0.0)],
612+
created_at=ANY,
613+
updated_at=ANY,
614+
)
615+
portfolio_repo.get.return_value = portfolio
616+
expected_result = PortfolioInfo(
617+
user_id=user_id,
618+
total_portfolio_value=1000.0,
619+
total_gain=-1000.0,
620+
roi=-50.0,
621+
)
622+
623+
# Act
624+
result = usecase.get_portfolio_info(user_id)
625+
626+
# Assert
627+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
628+
mock_get_stock_price.assert_not_called()
629+
assert result == expected_result
630+
631+
@patch.object(StockUsecase, "_get_stock_price")
632+
def test_get_portfolio_info_with_valid_holdings(self, mock_get_stock_price, stock_usecase):
633+
# Arrange
634+
usecase, _, portfolio_repo = stock_usecase
635+
user_id = 1
636+
portfolio = Portfolio(
637+
user_id=user_id,
638+
cash_balance=1000.0,
639+
total_money_in=2000.0,
640+
holdings=[
641+
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
642+
Holding(symbol="SPY", shares=5, stock_type=StockType.ETF, total_cost=1000.0),
643+
],
644+
created_at=ANY,
645+
updated_at=ANY,
646+
)
647+
mock_get_stock_price.return_value = {"AAPL": 200.0, "SPY": 400.0}
648+
portfolio_repo.get.return_value = portfolio
649+
expected_result = PortfolioInfo(
650+
user_id=user_id,
651+
total_portfolio_value=5000.0,
652+
total_gain=3000.0,
653+
roi=150,
654+
)
655+
656+
# Act
657+
result = usecase.get_portfolio_info(user_id)
658+
659+
# Assert
660+
portfolio_repo.get.assert_called_once_with(user_id=user_id)
661+
mock_get_stock_price.assert_called_once_with(stock_info=[("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)])
662+
assert result == expected_result

src/usecase/stock.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import yfinance as yf
44
from .base import AbstractStockUsecase
55
from adapters.base import AbstractStockRepository, AbstractPortfolioRepository
6-
from domain.portfolio import Portfolio, Holding
6+
from domain.portfolio import Portfolio, Holding, PortfolioInfo
77
from domain.stock import CreateStock, Stock
88
from domain.enum import ActionType, StockType
99

@@ -69,17 +69,23 @@ def create(self, stock: CreateStock) -> str:
6969
def list(self, user_id: int) -> List[Stock]:
7070
return self.stock_repo.list(user_id)
7171

72-
def calculate_total_roi(self, user_id: int) -> float:
72+
def get_portfolio_info(self, user_id: int) -> PortfolioInfo:
7373
portfolio = self.portfolio_repo.get(user_id=user_id)
7474
if portfolio is None or portfolio.total_money_in == 0.0:
75-
return 0.0
75+
return PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)
7676

7777
valid_holdings = [
7878
(holding.symbol, holding.shares, holding.stock_type) for holding in portfolio.holdings if holding.shares > 0
7979
]
8080
if not valid_holdings:
81-
# If no valid holdings, ROI depends only on cash balance
82-
return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)
81+
return PortfolioInfo(
82+
user_id=user_id,
83+
total_portfolio_value=portfolio.cash_balance,
84+
total_gain=portfolio.cash_balance - portfolio.total_money_in,
85+
roi=round(
86+
((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2
87+
), # If no valid holdings, ROI depends only on cash balance
88+
)
8389

8490
# Fetch prices in batch
8591
stock_info = [(symbol, stock_type) for symbol, _, stock_type in valid_holdings]
@@ -90,8 +96,14 @@ def calculate_total_roi(self, user_id: int) -> float:
9096

9197
# Compute ROI
9298
total_value = total_stock_price + portfolio.cash_balance
93-
roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100
94-
return round(roi, 2)
99+
roi = round(((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)
100+
101+
return PortfolioInfo(
102+
user_id=user_id,
103+
total_portfolio_value=total_value,
104+
total_gain=total_value - portfolio.total_money_in,
105+
roi=roi,
106+
)
95107

96108
def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, float]:
97109
if not stock_info:

0 commit comments

Comments
 (0)