Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/domain/portfolio.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ def __post_init__(self):

def as_dict(self) -> PortfolioDict:
return asdict(self, dict_factory=custom_dict_factory)


@dataclass
class PortfolioInfo:
user_id: int
total_portfolio_value: float
total_gain: float
roi: float
248 changes: 109 additions & 139 deletions src/tests/test_stock_usecase.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock, ANY, patch
from usecase.stock import StockUsecase
from domain.stock import CreateStock, Stock
from domain.portfolio import Portfolio, Holding
from domain.portfolio import Portfolio, Holding, PortfolioInfo
from domain.enum import ActionType, StockType


Expand All @@ -15,7 +14,7 @@ def stock_usecase():
return usecase, stock_repo, portfolio_repo


class TestStockUsecase:
class TestStockUsecaseCreate:
def test_create_transfer_new_portfolio(self, stock_usecase):
# Arrange
usecase, stock_repo, portfolio_repo = stock_usecase
Expand Down Expand Up @@ -375,6 +374,8 @@ def test_create_handles_repository_error_on_stock_create(self, stock_usecase):
portfolio_repo.update.assert_called_once_with(portfolio=ANY) # Portfolio may vary, so use ANY
stock_repo.create.assert_called_once_with(stock)


class TestStockUsecaseList:
def test_list(self, stock_usecase):
# Arrange
usecase, mock_repo, _ = stock_usecase
Expand Down Expand Up @@ -427,143 +428,8 @@ def test_list_handles_repository_error(self, stock_usecase):
usecase.list(user_id)
mock_repo.list.assert_called_once_with(user_id)

def test_calculate_total_roi_no_portfolio(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio_repo.get.return_value = None

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0

def test_calculate_total_roi_no_total_money_in(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=0.0,
total_money_in=0.0,
holdings=[],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0

def test_calculate_total_roi_no_holdings(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=1000.0,
holdings=[],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
assert result == 0.0 # ROI = ((1000 - 1000) / 1000) * 100 = 0.0

@patch.object(StockUsecase, "_get_stock_price")
def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=5000.0,
holdings=[
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
Holding(symbol="GOOGL", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0),
Holding(
symbol="TSLA", shares=0, stock_type=StockType.STOCKS, total_cost=0.0
), # Zero shares, should be ignored
],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
mock_get_stock_price.return_value = {
"AAPL": 200.0, # 10 shares * 200 = 2000
"GOOGL": 3000.0, # 5 shares * 3000 = 15000
}

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_called_once_with(
stock_info=[("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)]
)
# Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000
# ROI = ((18000 - 5000) / 5000) * 100 = 260.0
assert result == 260.0

@patch.object(StockUsecase, "_get_stock_price")
def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=5000.0,
holdings=[
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
Holding(symbol="INVALID", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0),
],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
mock_get_stock_price.return_value = {
"AAPL": 200.0, # 10 shares * 200 = 2000
"INVALID": 0.0, # No price available
}

# Act
result = usecase.calculate_total_roi(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_called_once_with(
stock_info=[("AAPL", StockType.STOCKS), ("INVALID", StockType.STOCKS)]
)
# Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000
# ROI = ((3000 - 5000) / 5000) * 100 = -40.0
assert result == -40.0

def test_calculate_total_roi_handles_repository_error(self, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio_repo.get.side_effect = Exception("Portfolio repository error")

# Act/Assert
with pytest.raises(Exception, match="Portfolio repository error"):
usecase.calculate_total_roi(user_id)
portfolio_repo.get.assert_called_once_with(user_id=user_id)

class TestStockUsecaseGetStockPrice:
@patch("usecase.stock.yf.Tickers")
def test_get_stock_price_success_stocks(self, mock_yf_tickers, stock_usecase):
# Arrange
Expand Down Expand Up @@ -690,3 +556,107 @@ def test_get_stock_price_missing_price_field(self, mock_yf_tickers, stock_usecas
# Assert
mock_yf_tickers.assert_called_once_with(["AAPL", "SPY"])
assert result == {"AAPL": 0.0, "SPY": 0.0}


class TestStockUsecaseGetPortfolioInfo:
@patch.object(StockUsecase, "_get_stock_price")
def test_get_portfolio_info_no_portfolio(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio_repo.get.return_value = None
expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)

# Act
result = usecase.get_portfolio_info(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_not_called()
assert result == expected_result

@patch.object(StockUsecase, "_get_stock_price")
def test_get_portfolio_info_empty_portfolio(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=0.0,
holdings=[],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
expected_result = PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)

# Act
result = usecase.get_portfolio_info(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_not_called()
assert result == expected_result

@patch.object(StockUsecase, "_get_stock_price")
def test_get_portfolio_info_no_valid_holdings(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=2000.0,
holdings=[Holding(symbol="AAPL", shares=0, stock_type=StockType.STOCKS, total_cost=0.0)],
created_at=ANY,
updated_at=ANY,
)
portfolio_repo.get.return_value = portfolio
expected_result = PortfolioInfo(
user_id=user_id,
total_portfolio_value=1000.0,
total_gain=-1000.0,
roi=-50.0,
)

# Act
result = usecase.get_portfolio_info(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_not_called()
assert result == expected_result

@patch.object(StockUsecase, "_get_stock_price")
def test_get_portfolio_info_with_valid_holdings(self, mock_get_stock_price, stock_usecase):
# Arrange
usecase, _, portfolio_repo = stock_usecase
user_id = 1
portfolio = Portfolio(
user_id=user_id,
cash_balance=1000.0,
total_money_in=2000.0,
holdings=[
Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0),
Holding(symbol="SPY", shares=5, stock_type=StockType.ETF, total_cost=1000.0),
],
created_at=ANY,
updated_at=ANY,
)
mock_get_stock_price.return_value = {"AAPL": 200.0, "SPY": 400.0}
portfolio_repo.get.return_value = portfolio
expected_result = PortfolioInfo(
user_id=user_id,
total_portfolio_value=5000.0,
total_gain=3000.0,
roi=150,
)

# Act
result = usecase.get_portfolio_info(user_id)

# Assert
portfolio_repo.get.assert_called_once_with(user_id=user_id)
mock_get_stock_price.assert_called_once_with(stock_info=[("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)])
assert result == expected_result
26 changes: 19 additions & 7 deletions src/usecase/stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yfinance as yf
from .base import AbstractStockUsecase
from adapters.base import AbstractStockRepository, AbstractPortfolioRepository
from domain.portfolio import Portfolio, Holding
from domain.portfolio import Portfolio, Holding, PortfolioInfo
from domain.stock import CreateStock, Stock
from domain.enum import ActionType, StockType

Expand Down Expand Up @@ -69,17 +69,23 @@ def create(self, stock: CreateStock) -> str:
def list(self, user_id: int) -> List[Stock]:
return self.stock_repo.list(user_id)

def calculate_total_roi(self, user_id: int) -> float:
def get_portfolio_info(self, user_id: int) -> PortfolioInfo:
portfolio = self.portfolio_repo.get(user_id=user_id)
if portfolio is None or portfolio.total_money_in == 0.0:
return 0.0
return PortfolioInfo(user_id=user_id, total_portfolio_value=0.0, total_gain=0.0, roi=0.0)

valid_holdings = [
(holding.symbol, holding.shares, holding.stock_type) for holding in portfolio.holdings if holding.shares > 0
]
if not valid_holdings:
# If no valid holdings, ROI depends only on cash balance
return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)
return PortfolioInfo(
user_id=user_id,
total_portfolio_value=portfolio.cash_balance,
total_gain=portfolio.cash_balance - portfolio.total_money_in,
roi=round(
((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2
), # If no valid holdings, ROI depends only on cash balance
)

# Fetch prices in batch
stock_info = [(symbol, stock_type) for symbol, _, stock_type in valid_holdings]
Expand All @@ -90,8 +96,14 @@ def calculate_total_roi(self, user_id: int) -> float:

# Compute ROI
total_value = total_stock_price + portfolio.cash_balance
roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100
return round(roi, 2)
roi = round(((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2)

return PortfolioInfo(
user_id=user_id,
total_portfolio_value=total_value,
total_gain=total_value - portfolio.total_money_in,
roi=roi,
)

def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, float]:
if not stock_info:
Expand Down