From eee9bb90b05b15c88ae103b840114f406818d7a4 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 14 Jun 2025 11:33:18 +0800 Subject: [PATCH 1/5] feat: add etf type in stock proto --- src/proto/stock.proto | 18 ++++++++++++++---- src/proto/stock_pb2.py | 30 +++++++++++++++++------------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/proto/stock.proto b/src/proto/stock.proto index 00e5d5b..04b8437 100644 --- a/src/proto/stock.proto +++ b/src/proto/stock.proto @@ -13,6 +13,14 @@ message Action { } } +message StockType { + enum Type { + UNSPECIFIED = 0; + STOCKS = 1; + ETF = 2; + } +} + message Stock { string id = 1 [json_name = "id"]; int32 user_id = 2 [json_name = "user_id"]; @@ -20,8 +28,9 @@ message Stock { double price = 4 [json_name = "price"]; int32 quantity = 5 [json_name = "quantity"]; string action = 6 [json_name = "action"]; - google.protobuf.Timestamp created_at = 7 [json_name = "created_at"]; - google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"]; + string stock_type = 7 [json_name = "stock_type"]; + google.protobuf.Timestamp created_at = 8 [json_name = "created_at"]; + google.protobuf.Timestamp updated_at = 9 [json_name = "updated_at"]; } message CreateReq { @@ -30,8 +39,9 @@ message CreateReq { double price = 3 [json_name = "price"]; int32 quantity = 4 [json_name = "quantity"]; Action.Type action = 5 [json_name = "action"]; // add validation rules - google.protobuf.Timestamp created_at = 6 [json_name = "created_at"]; - google.protobuf.Timestamp updated_at = 7 [json_name = "updated_at"]; + StockType.Type stock_type = 6 [json_name = "stock_type"]; // add validation rules + google.protobuf.Timestamp created_at = 7 [json_name = "created_at"]; + google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"]; } message CreateResp { diff --git a/src/proto/stock_pb2.py b/src/proto/stock_pb2.py index 665f00f..e3ec87f 100644 --- a/src/proto/stock_pb2.py +++ b/src/proto/stock_pb2.py @@ -25,7 +25,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"\x8b\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x93\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12:\n\ncreated_at\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list2j\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"9\n\tStockType\",\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06STOCKS\x10\x01\x12\x07\n\x03\x45TF\x10\x02\"\xab\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12\x1e\n\nstock_type\x18\x07 \x01(\tR\nstock_type\x12:\n\ncreated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\t \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\xca\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12\x35\n\nstock_type\x18\x06 \x01(\x0e\x32\x15.stock.StockType.TypeR\nstock_type\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list2j\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -36,16 +36,20 @@ _globals['_ACTION']._serialized_end=127 _globals['_ACTION_TYPE']._serialized_start=71 _globals['_ACTION_TYPE']._serialized_end=127 - _globals['_STOCK']._serialized_start=130 - _globals['_STOCK']._serialized_end=397 - _globals['_CREATEREQ']._serialized_start=400 - _globals['_CREATEREQ']._serialized_end=675 - _globals['_CREATERESP']._serialized_start=677 - _globals['_CREATERESP']._serialized_end=705 - _globals['_LISTREQ']._serialized_start=707 - _globals['_LISTREQ']._serialized_end=742 - _globals['_LISTRESP']._serialized_start=744 - _globals['_LISTRESP']._serialized_end=800 - _globals['_STOCKSERVICE']._serialized_start=802 - _globals['_STOCKSERVICE']._serialized_end=908 + _globals['_STOCKTYPE']._serialized_start=129 + _globals['_STOCKTYPE']._serialized_end=186 + _globals['_STOCKTYPE_TYPE']._serialized_start=142 + _globals['_STOCKTYPE_TYPE']._serialized_end=186 + _globals['_STOCK']._serialized_start=189 + _globals['_STOCK']._serialized_end=488 + _globals['_CREATEREQ']._serialized_start=491 + _globals['_CREATEREQ']._serialized_end=821 + _globals['_CREATERESP']._serialized_start=823 + _globals['_CREATERESP']._serialized_end=851 + _globals['_LISTREQ']._serialized_start=853 + _globals['_LISTREQ']._serialized_end=888 + _globals['_LISTRESP']._serialized_start=890 + _globals['_LISTRESP']._serialized_end=946 + _globals['_STOCKSERVICE']._serialized_start=948 + _globals['_STOCKSERVICE']._serialized_end=1054 # @@protoc_insertion_point(module_scope) From 45da3593367953b3e4476869fb1dc296da909a8a Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 14 Jun 2025 11:37:53 +0800 Subject: [PATCH 2/5] feat: add etf type in domain --- src/domain/enum.py | 25 +++++++++++++++++ src/domain/portfolio.py | 43 ++++++++++++++++++++++++++++-- src/domain/stock.py | 59 ++++++++++++++++++++++++++++++++++++----- src/utils/utils.py | 14 ++++++++++ 4 files changed, 132 insertions(+), 9 deletions(-) create mode 100644 src/domain/enum.py create mode 100644 src/utils/utils.py diff --git a/src/domain/enum.py b/src/domain/enum.py new file mode 100644 index 0000000..6cfbdd1 --- /dev/null +++ b/src/domain/enum.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class ActionType(Enum): + BUY = "BUY" + SELL = "SELL" + TRANSFER = "TRANSFER" + + +ACTION_MAP = { + 1: ActionType.BUY, + 2: ActionType.SELL, + 3: ActionType.TRANSFER, +} + + +class StockType(Enum): + STOCKS = "STOCKS" + ETF = "ETF" + + +STOCK_MAP = { + 1: StockType.STOCKS, + 2: StockType.ETF, +} diff --git a/src/domain/portfolio.py b/src/domain/portfolio.py index b1d6610..dae1006 100644 --- a/src/domain/portfolio.py +++ b/src/domain/portfolio.py @@ -1,14 +1,44 @@ -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime -from typing import List +from typing import List, TypedDict +from .enum import StockType +from utils.utils import custom_dict_factory + + +class HoldingDict(TypedDict): + symbol: str + shares: int + stock_type: str + total_cost: float + + +class PortfolioDict(TypedDict): + user_id: int + cash_balance: float + total_money_in: float + holdings: List[HoldingDict] + created_at: datetime + updated_at: datetime @dataclass class Holding: symbol: str shares: int + stock_type: StockType total_cost: float + def __post_init__(self): + if not self.symbol: + raise ValueError("symbol cannot be empty") + if self.shares < 0: + raise ValueError("shares cannot be negative") + if self.total_cost < 0: + raise ValueError("total_cost cannot be negative") + + def as_dict(self) -> HoldingDict: + return asdict(self, dict_factory=custom_dict_factory) + @dataclass class Portfolio: @@ -18,3 +48,12 @@ class Portfolio: holdings: List[Holding] created_at: datetime updated_at: datetime + + def __post_init__(self): + if self.cash_balance < 0: + raise ValueError("cash_balance cannot be negative") + if self.total_money_in < 0: + raise ValueError("total_money_in cannot be negative") + + def as_dict(self) -> PortfolioDict: + return asdict(self, dict_factory=custom_dict_factory) diff --git a/src/domain/stock.py b/src/domain/stock.py index 4cf1a65..4d92a3e 100644 --- a/src/domain/stock.py +++ b/src/domain/stock.py @@ -1,15 +1,30 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass +from typing import TypedDict from datetime import datetime -from enum import Enum +from utils.utils import custom_dict_factory +from .enum import ActionType, StockType -class ActionType(Enum): - BUY = "BUY" - SELL = "SELL" - TRANSFER = "TRANSFER" +class CreateStockDict(TypedDict): + user_id: int + symbol: str + price: float + quantity: int + action_type: str + stock_type: str + created_at: datetime -ACTION_MAP = {1: ActionType.BUY, 2: ActionType.SELL, 3: ActionType.TRANSFER} +class StockDict(TypedDict): + id: str + user_id: int + symbol: str + price: float + quantity: int + action_type: str + stock_type: str + created_at: datetime + updated_at: datetime @dataclass @@ -19,8 +34,22 @@ class CreateStock: price: float quantity: int action_type: ActionType + stock_type: StockType created_at: datetime + def __post_init__(self): + if not self.user_id or self.user_id <= 0: + raise ValueError("user_id must be non-empty and greater than 0") + if not self.symbol or self.symbol.strip() == "": + raise ValueError("symbol must be a non-empty string") + if self.price <= 0: + raise ValueError("price must be greater than 0") + if self.quantity <= 0: + raise ValueError("quantity must be greater than 0") + + def as_dict(self) -> CreateStockDict: + return asdict(self, dict_factory=custom_dict_factory) + @dataclass class Stock: @@ -30,5 +59,21 @@ class Stock: price: float quantity: int action_type: ActionType + stock_type: StockType created_at: datetime updated_at: datetime + + def __post_init__(self): + if not self.id: + raise ValueError("id cannot be empty") + if self.user_id <= 0: + raise ValueError("user_id must be positive") + if not self.symbol: + raise ValueError("symbol cannot be empty") + if self.price < 0: + raise ValueError("price cannot be negative") + if self.quantity <= 0: + raise ValueError("quantity must be positive") + + def as_dict(self) -> StockDict: + return asdict(self, dict_factory=custom_dict_factory) diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000..ed9b7d3 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,14 @@ +from typing import List, Any +from enum import Enum + + +def _convert_value(value: Any) -> Any: + if isinstance(value, Enum): + return value.value + if isinstance(value, list): + return [_convert_value(item) for item in value] + return value + + +def custom_dict_factory(data: List[tuple[str, Any]]) -> dict: + return {key: _convert_value(value) for key, value in data} From b5e9d933d37ed3f6dd15a6e3b93193a9da081a5f Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 14 Jun 2025 11:45:45 +0800 Subject: [PATCH 3/5] feat: add stock proto in adapter --- src/adapters/portfolio.py | 11 ++++++--- src/adapters/stock.py | 18 ++++----------- src/tests/test_portfolio_adapters.py | 24 +++++++++++--------- src/tests/test_stock_adapters.py | 34 ++++++++++++++++++---------- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/src/adapters/portfolio.py b/src/adapters/portfolio.py index ec4de53..69ce8ca 100644 --- a/src/adapters/portfolio.py +++ b/src/adapters/portfolio.py @@ -1,9 +1,9 @@ -from dataclasses import asdict from datetime import datetime, timezone from pymongo import MongoClient from pymongo.database import Database from .base import AbstractPortfolioRepository from domain.portfolio import Portfolio, Holding +from domain.enum import StockType class PortfolioRepository(AbstractPortfolioRepository): @@ -22,7 +22,12 @@ def get(self, user_id: int) -> Portfolio: cash_balance=result["cash_balance"], total_money_in=result["total_money_in"], holdings=[ - Holding(symbol=holding["symbol"], shares=holding["shares"], total_cost=holding["total_cost"]) + Holding( + symbol=holding["symbol"], + shares=holding["shares"], + stock_type=StockType(holding["stock_type"]), + total_cost=holding["total_cost"], + ) for holding in result["holdings"] ], created_at=result["created_at"], @@ -31,7 +36,7 @@ def get(self, user_id: int) -> Portfolio: def update(self, portfolio: Portfolio) -> None: portfolio.updated_at = datetime.now(timezone.utc) - self.collection.replace_one({"user_id": portfolio.user_id}, asdict(portfolio), upsert=True) + self.collection.replace_one({"user_id": portfolio.user_id}, portfolio.as_dict(), upsert=True) def __del__(self): self.client.close() diff --git a/src/adapters/stock.py b/src/adapters/stock.py index bbd479b..63f3178 100644 --- a/src/adapters/stock.py +++ b/src/adapters/stock.py @@ -1,10 +1,9 @@ from typing import List - from pymongo import MongoClient from pymongo.database import Database - from .base import AbstractStockRepository -from domain.stock import CreateStock, Stock, ActionType +from domain.stock import CreateStock, Stock +from domain.enum import ActionType, StockType class StockRepository(AbstractStockRepository): @@ -14,17 +13,7 @@ def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"): self.collection = self.db["stocks"] def create(self, stock: CreateStock) -> str: - stock_dict = { - "user_id": stock.user_id, - "symbol": stock.symbol, - "price": stock.price, - "quantity": stock.quantity, - "action_type": stock.action_type.value, - "created_at": stock.created_at, - "updated_at": stock.created_at, - } - - result = self.collection.insert_one(stock_dict) + result = self.collection.insert_one(stock.as_dict()) return str(result.inserted_id) def list(self, user_id: int) -> List[Stock]: @@ -38,6 +27,7 @@ def list(self, user_id: int) -> List[Stock]: price=doc["price"], quantity=doc["quantity"], action_type=ActionType(doc["action_type"]), + stock_type=StockType(doc["stock_type"]), created_at=doc["created_at"], updated_at=doc["updated_at"], ) diff --git a/src/tests/test_portfolio_adapters.py b/src/tests/test_portfolio_adapters.py index 3b49b80..00bb6b0 100644 --- a/src/tests/test_portfolio_adapters.py +++ b/src/tests/test_portfolio_adapters.py @@ -1,10 +1,10 @@ import pytest from datetime import datetime, timezone from unittest.mock import ANY -from dataclasses import asdict from pymongo import MongoClient -from domain.portfolio import Portfolio, Holding from adapters.portfolio import PortfolioRepository +from domain.portfolio import Portfolio, Holding +from domain.enum import StockType @pytest.fixture(scope="module") @@ -33,7 +33,7 @@ def test_update_new_portfolio(self, portfolio_repository): user_id=1, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + holdings=[Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0)], created_at=created_at, updated_at=created_at, ) @@ -49,6 +49,7 @@ def test_update_new_portfolio(self, portfolio_repository): assert len(result["holdings"]) == 1 assert result["holdings"][0]["symbol"] == "AAPL" assert result["holdings"][0]["shares"] == 10 + assert result["holdings"][0]["stock_type"] == StockType.STOCKS.value assert result["holdings"][0]["total_cost"] == 1500.0 def test_update_existing_portfolio(self, portfolio_repository): @@ -57,17 +58,17 @@ def test_update_existing_portfolio(self, portfolio_repository): user_id=1, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + holdings=[Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0)], created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) - portfolio_repository.collection.insert_one(asdict(initial_portfolio)) + result = portfolio_repository.collection.insert_one(initial_portfolio.as_dict()) updated_portfolio = Portfolio( user_id=1, cash_balance=2000.0, total_money_in=2000.0, - holdings=[Holding(symbol="AAPL", shares=20, total_cost=3000.0)], + holdings=[Holding(symbol="AAPL", shares=20, stock_type=StockType.STOCKS, total_cost=3000.0)], created_at=initial_portfolio.created_at, updated_at=datetime.now(timezone.utc), ) @@ -84,6 +85,7 @@ def test_update_existing_portfolio(self, portfolio_repository): assert len(result["holdings"]) == 1 assert result["holdings"][0]["symbol"] == "AAPL" assert result["holdings"][0]["shares"] == 20 + assert result["holdings"][0]["stock_type"] == StockType.STOCKS.value assert result["holdings"][0]["total_cost"] == 3000.0 def test_get_existing_portfolio(self, portfolio_repository): @@ -93,16 +95,16 @@ def test_get_existing_portfolio(self, portfolio_repository): user_id=1, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + holdings=[Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0)], created_at=created_at, updated_at=created_at, ) - portfolio_repository.collection.insert_one(asdict(portfolio)) + portfolio_repository.collection.insert_one(portfolio.as_dict()) expected_result = Portfolio( user_id=1, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + holdings=[Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0)], created_at=ANY, updated_at=ANY, ) @@ -120,11 +122,11 @@ def test_get_non_existent_portfolio(self, portfolio_repository): user_id=1, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + holdings=[Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0)], created_at=created_at, updated_at=created_at, ) - portfolio_repository.collection.insert_one(asdict(portfolio)) + portfolio_repository.collection.insert_one(portfolio.as_dict()) # Action result = portfolio_repository.get(user_id=999) diff --git a/src/tests/test_stock_adapters.py b/src/tests/test_stock_adapters.py index f458143..eca3756 100644 --- a/src/tests/test_stock_adapters.py +++ b/src/tests/test_stock_adapters.py @@ -1,11 +1,11 @@ import pytest from datetime import datetime, timezone - from pymongo import MongoClient +from unittest.mock import ANY from bson.objectid import ObjectId - -from domain.stock import CreateStock, ActionType, Stock from adapters.stock import StockRepository +from domain.stock import CreateStock, Stock +from domain.enum import ActionType, StockType @pytest.fixture(scope="module") @@ -34,7 +34,8 @@ def test_create_stock(self, stock_repository): price=150.25, quantity=100, action_type=ActionType.BUY, - created_at=datetime.utcnow(), + stock_type=StockType.STOCKS, + created_at=datetime.now(timezone.utc), ) expected_result = { @@ -43,8 +44,7 @@ def test_create_stock(self, stock_repository): "price": mock_stock.price, "quantity": mock_stock.quantity, "action_type": mock_stock.action_type.value, - "created_at": mock_stock.created_at, - "updated_at": mock_stock.created_at, + "stock_type": mock_stock.stock_type.value, } stock_id = stock_repository.create(mock_stock) @@ -59,13 +59,15 @@ def test_create_stock(self, stock_repository): def test_create_multiple_stocks(self, stock_repository): # Create mock stocks + created_at = datetime.now(timezone.utc) mock_stock1 = CreateStock( user_id=1, symbol="TSLA", price=110.25, quantity=50, action_type=ActionType.BUY, - created_at=datetime.utcnow(), + stock_type=StockType.STOCKS, + created_at=created_at, ) mock_stock2 = CreateStock( user_id=1, @@ -73,7 +75,8 @@ def test_create_multiple_stocks(self, stock_repository): price=2500.50, quantity=10, action_type=ActionType.SELL, - created_at=datetime.utcnow(), + stock_type=StockType.STOCKS, + created_at=created_at, ) # Define expected results @@ -83,8 +86,7 @@ def test_create_multiple_stocks(self, stock_repository): "price": mock_stock1.price, "quantity": mock_stock1.quantity, "action_type": mock_stock1.action_type.value, - "created_at": mock_stock1.created_at, - "updated_at": mock_stock1.created_at, + "stock_type": mock_stock1.stock_type.value, } expected_result2 = { "user_id": mock_stock2.user_id, @@ -92,8 +94,7 @@ def test_create_multiple_stocks(self, stock_repository): "price": mock_stock2.price, "quantity": mock_stock2.quantity, "action_type": mock_stock2.action_type.value, - "created_at": mock_stock2.created_at, - "updated_at": mock_stock2.created_at, + "stock_type": mock_stock2.stock_type.value, } expected_data = [expected_result1, expected_result2] @@ -127,6 +128,7 @@ def test_list_stocks(self, stock_repository): "price": 150.25, "quantity": 100, "action_type": ActionType.BUY.value, + "stock_type": StockType.STOCKS.value, "created_at": created_at, "updated_at": created_at, } @@ -136,6 +138,7 @@ def test_list_stocks(self, stock_repository): "price": 110.25, "quantity": 50, "action_type": ActionType.SELL.value, + "stock_type": StockType.STOCKS.value, "created_at": created_at, "updated_at": created_at, } @@ -145,6 +148,7 @@ def test_list_stocks(self, stock_repository): "price": 2500.50, "quantity": 10, "action_type": ActionType.BUY.value, + "stock_type": StockType.STOCKS.value, "created_at": created_at, "updated_at": created_at, } @@ -167,6 +171,7 @@ def test_list_stocks(self, stock_repository): price=mock_stock1["price"], quantity=mock_stock1["quantity"], action_type=ActionType(mock_stock1["action_type"]), + stock_type=StockType(mock_stock1["stock_type"]), created_at=mock_stock1["created_at"], updated_at=mock_stock1["updated_at"], ), @@ -177,6 +182,7 @@ def test_list_stocks(self, stock_repository): price=mock_stock2["price"], quantity=mock_stock2["quantity"], action_type=ActionType(mock_stock2["action_type"]), + stock_type=StockType(mock_stock2["stock_type"]), created_at=mock_stock2["created_at"], updated_at=mock_stock2["updated_at"], ), @@ -198,6 +204,7 @@ def test_list_stocks(self, stock_repository): stock.price, stock.quantity, stock.action_type, + stock.stock_type, ) for stock in result } @@ -209,6 +216,7 @@ def test_list_stocks(self, stock_repository): stock.price, stock.quantity, stock.action_type, + stock.stock_type, ) for stock in expected_stocks } @@ -224,6 +232,7 @@ def test_list_stocks_no_data(self, stock_repository): "price": 150.25, "quantity": 100, "action_type": ActionType.BUY.value, + "stock_type": StockType.STOCKS.value, "created_at": created_at, "updated_at": created_at, } @@ -233,6 +242,7 @@ def test_list_stocks_no_data(self, stock_repository): "price": 110.25, "quantity": 50, "action_type": ActionType.SELL.value, + "stock_type": StockType.STOCKS.value, "created_at": created_at, "updated_at": created_at, } From b30e18c86a4f7f146ef683b6d4a83255bdd1db23 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 14 Jun 2025 11:47:50 +0800 Subject: [PATCH 4/5] feat: add stock proto in usecase --- src/tests/test_stock_usecase.py | 179 +++++++++++++++++++++++++------- src/usecase/stock.py | 55 ++++++---- 2 files changed, 173 insertions(+), 61 deletions(-) diff --git a/src/tests/test_stock_usecase.py b/src/tests/test_stock_usecase.py index 0cf7996..4cf71ab 100644 --- a/src/tests/test_stock_usecase.py +++ b/src/tests/test_stock_usecase.py @@ -2,8 +2,9 @@ from datetime import datetime, timezone from unittest.mock import Mock, ANY, patch from usecase.stock import StockUsecase -from domain.stock import CreateStock, ActionType, Stock +from domain.stock import CreateStock, Stock from domain.portfolio import Portfolio, Holding +from domain.enum import ActionType, StockType @pytest.fixture @@ -21,10 +22,11 @@ def test_create_transfer_new_portfolio(self, stock_usecase): user_id, stock_id = 1, "123" stock = CreateStock( user_id=user_id, - symbol="", + symbol="TRANSFER", price=3000.0, quantity=1, action_type=ActionType.TRANSFER, + stock_type=StockType.STOCKS, created_at=ANY, ) portfolio_repo.get.return_value = None @@ -53,10 +55,11 @@ def test_create_transfer_existing_portfolio(self, stock_usecase): user_id, stock_id = 1, "123" stock = CreateStock( user_id=user_id, - symbol="", + symbol="TRANSFER", price=3000.0, quantity=1, action_type=ActionType.TRANSFER, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( @@ -96,7 +99,8 @@ def test_create_buy_new_stock(self, stock_usecase): symbol="TSLA", price=2000.0, quantity=2, - action_type=ActionType.BUY.value, + action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( @@ -113,7 +117,7 @@ def test_create_buy_new_stock(self, stock_usecase): user_id=existing_portfolio.user_id, cash_balance=1000.0, total_money_in=5000.0, - holdings=[Holding(symbol="TSLA", shares=2, total_cost=4000.0)], + holdings=[Holding(symbol="TSLA", shares=2, stock_type=StockType.STOCKS, total_cost=4000.0)], created_at=ANY, updated_at=ANY, ) @@ -136,14 +140,15 @@ def test_create_buy_existing_holding(self, stock_usecase): symbol="AAPL", price=150.0, quantity=3, - action_type=ActionType.BUY.value, + action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( user_id=user_id, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + holdings=[Holding(symbol="AAPL", shares=5, stock_type=StockType.STOCKS, total_cost=750.0)], created_at=ANY, updated_at=ANY, ) @@ -153,7 +158,9 @@ def test_create_buy_existing_holding(self, stock_usecase): user_id=user_id, cash_balance=550.0, # 1000 - (150 * 3) total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=8, total_cost=1200.0)], # 750 + (150 * 3) + holdings=[ + Holding(symbol="AAPL", shares=8, stock_type=StockType.STOCKS, total_cost=1200.0) + ], # 750 + (150 * 3) created_at=ANY, updated_at=ANY, ) @@ -176,14 +183,15 @@ def test_create_sell_existing_holding_partial(self, stock_usecase): symbol="AAPL", price=200.0, quantity=2, - action_type=ActionType.SELL.value, + action_type=ActionType.SELL, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( user_id=user_id, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + holdings=[Holding(symbol="AAPL", shares=5, stock_type=StockType.STOCKS, total_cost=750.0)], created_at=ANY, updated_at=ANY, ) @@ -193,7 +201,9 @@ def test_create_sell_existing_holding_partial(self, stock_usecase): user_id=user_id, cash_balance=1400.0, # 1000 + (200 * 2) total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=3, total_cost=450.0)], # 750 - (150 * 2) + holdings=[ + Holding(symbol="AAPL", shares=3, stock_type=StockType.STOCKS, total_cost=450.0) + ], # 750 - (150 * 2) created_at=ANY, updated_at=ANY, ) @@ -216,14 +226,15 @@ def test_create_sell_existing_holding_all_shares(self, stock_usecase): symbol="AAPL", price=300.0, quantity=5, - action_type=ActionType.SELL.value, + action_type=ActionType.SELL, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( user_id=user_id, cash_balance=1000.0, total_money_in=1000.0, - holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + holdings=[Holding(symbol="AAPL", shares=5, stock_type=StockType.STOCKS, total_cost=750.0)], created_at=ANY, updated_at=ANY, ) @@ -256,7 +267,8 @@ def test_create_sell_non_existent_holding(self, stock_usecase): symbol="AAPL", price=150.0, quantity=5, - action_type=ActionType.SELL.value, + action_type=ActionType.SELL, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( @@ -286,7 +298,8 @@ def test_create_handles_repository_error_on_get(self, stock_usecase): symbol="AAPL", price=150.0, quantity=10, - action_type=ActionType.BUY.value, + action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=ANY, ) portfolio_repo.get.side_effect = Exception("Portfolio repository error") @@ -308,7 +321,8 @@ def test_create_handles_repository_error_on_update(self, stock_usecase): symbol="AAPL", price=150.0, quantity=10, - action_type=ActionType.BUY.value, + action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( @@ -339,7 +353,8 @@ def test_create_handles_repository_error_on_stock_create(self, stock_usecase): symbol="AAPL", price=150.0, quantity=10, - action_type=ActionType.BUY.value, + action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=ANY, ) existing_portfolio = Portfolio( @@ -364,7 +379,6 @@ def test_list(self, stock_usecase): # Arrange usecase, mock_repo, _ = stock_usecase user_id = 1 - created_at = datetime.now(timezone.utc) mock_stocks = [ Stock( id="stock_123", @@ -373,8 +387,9 @@ def test_list(self, stock_usecase): price=150.25, quantity=100, action_type=ActionType.BUY, - created_at=created_at, - updated_at=created_at, + stock_type=StockType.STOCKS, + created_at=ANY, + updated_at=ANY, ), Stock( id="stock_124", @@ -383,8 +398,9 @@ def test_list(self, stock_usecase): price=2800.50, quantity=10, action_type=ActionType.BUY, - created_at=created_at, - updated_at=created_at, + stock_type=StockType.STOCKS, + created_at=ANY, + updated_at=ANY, ), ] mock_repo.list.return_value = mock_stocks @@ -476,9 +492,11 @@ def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_use cash_balance=1000.0, total_money_in=5000.0, holdings=[ - Holding(symbol="AAPL", shares=10, total_cost=1500.0), - Holding(symbol="GOOGL", shares=5, total_cost=2000.0), - Holding(symbol="TSLA", shares=0, total_cost=0.0), # Zero shares, should be ignored + Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), + Holding(symbol="GOOGL", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0), + Holding( + symbol="TSLA", shares=0, stock_type=StockType.STOCKS, total_cost=0.0 + ), # Zero shares, should be ignored ], created_at=ANY, updated_at=ANY, @@ -494,7 +512,9 @@ def test_calculate_total_roi_with_holdings(self, mock_get_stock_price, stock_use # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) - mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "GOOGL"]) + mock_get_stock_price.assert_called_once_with( + stock_info=[("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)] + ) # Total value = 2000 (AAPL) + 15000 (GOOGL) + 1000 (cash) = 18000 # ROI = ((18000 - 5000) / 5000) * 100 = 260.0 assert result == 260.0 @@ -509,8 +529,8 @@ def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, sto cash_balance=1000.0, total_money_in=5000.0, holdings=[ - Holding(symbol="AAPL", shares=10, total_cost=1500.0), - Holding(symbol="INVALID", shares=5, total_cost=2000.0), + Holding(symbol="AAPL", shares=10, stock_type=StockType.STOCKS, total_cost=1500.0), + Holding(symbol="INVALID", shares=5, stock_type=StockType.STOCKS, total_cost=2000.0), ], created_at=ANY, updated_at=ANY, @@ -526,7 +546,9 @@ def test_calculate_total_roi_with_missing_prices(self, mock_get_stock_price, sto # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) - mock_get_stock_price.assert_called_once_with(symbols=["AAPL", "INVALID"]) + mock_get_stock_price.assert_called_once_with( + stock_info=[("AAPL", StockType.STOCKS), ("INVALID", StockType.STOCKS)] + ) # Total value = 2000 (AAPL) + 0 (INVALID) + 1000 (cash) = 3000 # ROI = ((3000 - 5000) / 5000) * 100 = -40.0 assert result == -40.0 @@ -543,10 +565,10 @@ def test_calculate_total_roi_handles_repository_error(self, stock_usecase): portfolio_repo.get.assert_called_once_with(user_id=user_id) @patch("usecase.stock.yf.Tickers") - def test_get_stock_price_success(self, mock_yf_tickers, stock_usecase): + def test_get_stock_price_success_stocks(self, mock_yf_tickers, stock_usecase): # Arrange usecase, _, _ = stock_usecase - symbols = ["AAPL", "GOOGL"] + stock_info = [("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)] mock_ticker_aapl = Mock() mock_ticker_aapl.info = {"currentPrice": 150.0} mock_ticker_googl = Mock() @@ -557,20 +579,62 @@ def test_get_stock_price_success(self, mock_yf_tickers, stock_usecase): } # Act - result = usecase._get_stock_price(symbols) + result = usecase._get_stock_price(stock_info=stock_info) # Assert - mock_yf_tickers.assert_called_once_with(symbols) + mock_yf_tickers.assert_called_once_with(["AAPL", "GOOGL"]) assert result == {"AAPL": 150.0, "GOOGL": 2800.0} @patch("usecase.stock.yf.Tickers") - def test_get_stock_price_empty_symbols(self, mock_yf_tickers, stock_usecase): + def test_get_stock_price_success_etf(self, mock_yf_tickers, stock_usecase): # Arrange usecase, _, _ = stock_usecase - symbols = [] + stock_info = [("SPY", StockType.ETF), ("VTI", StockType.ETF)] + mock_ticker_spy = Mock() + mock_ticker_spy.info = {"navPrice": 400.0} + mock_ticker_vti = Mock() + mock_ticker_vti.info = {"navPrice": 200.0} + mock_yf_tickers.return_value.tickers = { + "SPY": mock_ticker_spy, + "VTI": mock_ticker_vti, + } # Act - result = usecase._get_stock_price(symbols) + result = usecase._get_stock_price(stock_info=stock_info) + + # Assert + mock_yf_tickers.assert_called_once_with(["SPY", "VTI"]) + assert result == {"SPY": 400.0, "VTI": 200.0} + + @patch("usecase.stock.yf.Tickers") + def test_get_stock_price_mixed_types(self, mock_yf_tickers, stock_usecase): + # Arrange + usecase, _, _ = stock_usecase + stock_info = [("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)] + mock_ticker_aapl = Mock() + mock_ticker_aapl.info = {"currentPrice": 150.0} + mock_ticker_spy = Mock() + mock_ticker_spy.info = {"navPrice": 400.0} + mock_yf_tickers.return_value.tickers = { + "AAPL": mock_ticker_aapl, + "SPY": mock_ticker_spy, + } + + # Act + result = usecase._get_stock_price(stock_info=stock_info) + + # Assert + mock_yf_tickers.assert_called_once_with(["AAPL", "SPY"]) + assert result == {"AAPL": 150.0, "SPY": 400.0} + + @patch("usecase.stock.yf.Tickers") + def test_get_stock_price_empty_stock_info(self, mock_yf_tickers, stock_usecase): + # Arrange + usecase, _, _ = stock_usecase + stock_info = [] + + # Act + result = usecase._get_stock_price(stock_info=stock_info) # Assert mock_yf_tickers.assert_not_called() @@ -580,12 +644,49 @@ def test_get_stock_price_empty_symbols(self, mock_yf_tickers, stock_usecase): def test_get_stock_price_handles_api_error(self, mock_yf_tickers, stock_usecase): # Arrange usecase, _, _ = stock_usecase - symbols = ["AAPL", "GOOGL"] + stock_info = [("AAPL", StockType.STOCKS), ("GOOGL", StockType.STOCKS)] mock_yf_tickers.side_effect = Exception("API error") # Act - result = usecase._get_stock_price(symbols) + result = usecase._get_stock_price(stock_info=stock_info) # Assert - mock_yf_tickers.assert_called_once_with(symbols) + mock_yf_tickers.assert_called_once_with(["AAPL", "GOOGL"]) assert result == {"AAPL": 0.0, "GOOGL": 0.0} + + @patch("usecase.stock.yf.Tickers") + def test_get_stock_price_missing_ticker(self, mock_yf_tickers, stock_usecase): + # Arrange + usecase, _, _ = stock_usecase + stock_info = [("AAPL", StockType.STOCKS), ("INVALID", StockType.STOCKS)] + mock_ticker_aapl = Mock() + mock_ticker_aapl.info = {"currentPrice": 150.0} + mock_yf_tickers.return_value.tickers = {"AAPL": mock_ticker_aapl} + + # Act + result = usecase._get_stock_price(stock_info=stock_info) + + # Assert + mock_yf_tickers.assert_called_once_with(["AAPL", "INVALID"]) + assert result == {"AAPL": 150.0, "INVALID": 0.0} + + @patch("usecase.stock.yf.Tickers") + def test_get_stock_price_missing_price_field(self, mock_yf_tickers, stock_usecase): + # Arrange + usecase, _, _ = stock_usecase + stock_info = [("AAPL", StockType.STOCKS), ("SPY", StockType.ETF)] + mock_ticker_aapl = Mock() + mock_ticker_aapl.info = {} # No currentPrice + mock_ticker_spy = Mock() + mock_ticker_spy.info = {} # No navPrice + mock_yf_tickers.return_value.tickers = { + "AAPL": mock_ticker_aapl, + "SPY": mock_ticker_spy, + } + + # Act + result = usecase._get_stock_price(stock_info=stock_info) + + # Assert + mock_yf_tickers.assert_called_once_with(["AAPL", "SPY"]) + assert result == {"AAPL": 0.0, "SPY": 0.0} diff --git a/src/usecase/stock.py b/src/usecase/stock.py index 4584fe3..996ff17 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -1,10 +1,14 @@ -from typing import List, Dict +from typing import List, Dict, Tuple from datetime import datetime, timezone import yfinance as yf -from domain.stock import CreateStock, Stock, ActionType -from domain.portfolio import Portfolio, Holding -from adapters.base import AbstractStockRepository, AbstractPortfolioRepository from .base import AbstractStockUsecase +from adapters.base import AbstractStockRepository, AbstractPortfolioRepository +from domain.portfolio import Portfolio, Holding +from domain.stock import CreateStock, Stock +from domain.enum import ActionType, StockType + +ETF_KEY = "navPrice" +STOCK_KEY = "currentPrice" class StockUsecase(AbstractStockUsecase): @@ -28,7 +32,7 @@ def create(self, stock: CreateStock) -> str: symbol = stock.symbol price = stock.price quantity = stock.quantity - action_type = ActionType(stock.action_type) + action_type = stock.action_type if action_type == ActionType.TRANSFER: portfolio.cash_balance += price * quantity @@ -38,7 +42,7 @@ def create(self, stock: CreateStock) -> str: holding = next((h for h in portfolio.holdings if h.symbol == symbol), None) # Find the first holding if not holding: - holding = Holding(symbol=symbol, shares=0, total_cost=0.0) + holding = Holding(symbol=symbol, shares=0, stock_type=stock.stock_type, total_cost=0.0) portfolio.holdings.append(holding) holding.shares += quantity @@ -70,37 +74,44 @@ def calculate_total_roi(self, user_id: int) -> float: if portfolio is None or portfolio.total_money_in == 0.0: return 0.0 - valid_holdings = [(holding.symbol, holding.shares) for holding in portfolio.holdings if holding.shares > 0] + valid_holdings = [ + (holding.symbol, holding.shares, holding.stock_type) for holding in portfolio.holdings if holding.shares > 0 + ] if not valid_holdings: # If no valid holdings, ROI depends only on cash balance return round(((portfolio.cash_balance - portfolio.total_money_in) / portfolio.total_money_in) * 100, 2) # Fetch prices in batch - stock_symbols = [symbol for symbol, _ in valid_holdings] - stock_price_by_symbol = self._get_stock_price(symbols=stock_symbols) + stock_info = [(symbol, stock_type) for symbol, _, stock_type in valid_holdings] + stock_price_by_symbol = self._get_stock_price(stock_info=stock_info) # Calculate total stock value - total_stock_price = sum(shares * stock_price_by_symbol.get(symbol, 0.0) for symbol, shares in valid_holdings) + total_stock_price = sum(shares * stock_price_by_symbol.get(symbol, 0.0) for symbol, shares, _ in valid_holdings) # Compute ROI total_value = total_stock_price + portfolio.cash_balance roi = ((total_value - portfolio.total_money_in) / portfolio.total_money_in) * 100 return round(roi, 2) - def _get_stock_price(self, symbols: List[str]) -> Dict[str, float]: - if not symbols: + def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, float]: + if not stock_info: return {} try: + symbols = [symbol for symbol, _ in stock_info] tickers = yf.Tickers(symbols) - return { - symbol: ( - ticker.info.get("currentPrice", 0.0) - if (ticker := tickers.tickers.get(symbol.upper())) is not None - else 0.0 - ) - for symbol in symbols - } + stock_price_by_symbol = {} + + for symbol, stock_type in stock_info: + ticker = tickers.tickers.get(symbol.upper()) + if ticker is None: + stock_price_by_symbol[symbol] = 0.0 + continue + + price_field = STOCK_KEY if stock_type == StockType.STOCKS else ETF_KEY + stock_price_by_symbol[symbol] = ticker.info.get(price_field, 0.0) + + return stock_price_by_symbol except Exception as e: - print(f"Error fetching prices for symbols {symbols}: {e}") - return {symbol.upper(): 0.0 for symbol in symbols} + print(f"Error fetching prices for symbols {[symbol for symbol, _ in stock_info]}: {e}") + return {symbol: 0.0 for symbol, _ in stock_info}s \ No newline at end of file From d8f2893798ce556c43cdf055b130eadfff2ffc62 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 14 Jun 2025 11:49:41 +0800 Subject: [PATCH 5/5] feat: add stock proto in handler --- src/handler/stock.py | 21 ++++++++------------- src/tests/test_stock_handler.py | 15 ++++++++++++--- src/usecase/stock.py | 2 +- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/handler/stock.py b/src/handler/stock.py index ca4090e..135995f 100644 --- a/src/handler/stock.py +++ b/src/handler/stock.py @@ -1,12 +1,11 @@ import logging from datetime import datetime, timezone - import grpc import proto.stock_pb2 as stock_pb2 import proto.stock_pb2_grpc as stock_pb2_grpc - from usecase.base import AbstractStockUsecase -from domain.stock import CreateStock, ActionType, ACTION_MAP +from domain.stock import CreateStock +from domain.enum import ActionType, ACTION_MAP, StockType, STOCK_MAP class StockService(stock_pb2_grpc.StockService): @@ -15,13 +14,13 @@ def __init__(self, stock_usecase: AbstractStockUsecase): def Create(self, request, context): try: - self._validate_create_request(request) stock = CreateStock( user_id=request.user_id, symbol=request.symbol, price=request.price, quantity=request.quantity, action_type=self._map_action_type(request.action), + stock_type=self._map_stock_type(request.stock_type), created_at=datetime.now(timezone.utc), ) @@ -64,15 +63,10 @@ def _map_action_type(self, action: int) -> ActionType: raise ValueError(f"Invalid action type: {action}. Must be 1 (BUY), 2 (SELL), or 3 (TRANSFER).") return ACTION_MAP[action] - def _validate_create_request(self, request): - if not request.user_id or request.user_id <= 0: - raise ValueError("user_id must be non-empty and greater than 0") - if not request.symbol or request.symbol.strip() == "": - raise ValueError("symbol must be a non-empty string") - if request.price <= 0: - raise ValueError("price must be greater than 0") - if request.quantity <= 0: - raise ValueError("quantity must be greater than 0") + def _map_stock_type(self, stock_type: int) -> StockType: + if stock_type not in STOCK_MAP: + raise ValueError(f"Invalid stock type: {stock_type}. Must be 1 (STOCKS), 2 (ETF).") + return STOCK_MAP[stock_type] def _convert_to_proto_stock_list(self, stock_list): return [ @@ -83,6 +77,7 @@ def _convert_to_proto_stock_list(self, stock_list): price=stock.price, quantity=stock.quantity, action=stock.action_type.value, + stock_type=stock.stock_type.value, created_at=stock.created_at, updated_at=stock.updated_at, ) diff --git a/src/tests/test_stock_handler.py b/src/tests/test_stock_handler.py index 53fccf5..a81ac61 100644 --- a/src/tests/test_stock_handler.py +++ b/src/tests/test_stock_handler.py @@ -1,13 +1,12 @@ import pytest import grpc import proto.stock_pb2 as stock_pb2 - from datetime import datetime, timezone from unittest.mock import Mock - -from domain.stock import CreateStock, ActionType, Stock from handler.stock import StockService from usecase.base import AbstractStockUsecase +from domain.stock import CreateStock, Stock +from domain.enum import ActionType, StockType class TestStockServiceCreate: @@ -35,6 +34,7 @@ def valid_request(self): request.price = 100.0 request.quantity = 10 request.action = 1 # Maps to ActionType.BUY + request.stock_type = 1 # Maps to StockType.STOCKS return request def test_success(self, mock_stock_usecase, mock_context, valid_request): @@ -68,6 +68,7 @@ def test_invalid_user_id(self, mock_stock_usecase, mock_context): request.price = 100.0 request.quantity = 10 request.action = 1 + request.stock_type = 1 # Act/Assertion with pytest.raises(grpc.RpcError) as exc_info: @@ -86,6 +87,7 @@ def test_invalid_symbol(self, mock_stock_usecase, mock_context): request.price = 100.0 request.quantity = 10 request.action = 1 + request.stock_type = 1 # Act/Assertion with pytest.raises(grpc.RpcError) as exc_info: @@ -104,6 +106,7 @@ def test_invalid_price(self, mock_stock_usecase, mock_context): request.price = 0.0 # Invalid request.quantity = 10 request.action = 1 + request.stock_type = 1 # Act/Assertion with pytest.raises(grpc.RpcError) as exc_info: @@ -122,6 +125,7 @@ def test_invalid_quantity(self, mock_stock_usecase, mock_context): request.price = 100.0 request.quantity = 0 # Invalid request.action = 1 + request.stock_type = 1 # Act/Assertion with pytest.raises(grpc.RpcError) as exc_info: @@ -140,6 +144,7 @@ def test_invalid_action(self, mock_stock_usecase, mock_context): request.price = 100.0 request.quantity = 10 request.action = 4 # Invalid + request.stock_type = 1 # Act/Assertion with pytest.raises(grpc.RpcError) as exc_info: @@ -180,6 +185,7 @@ def mock_stock_usecase(self): price=100.0, quantity=10, action_type=ActionType.BUY, + stock_type=StockType.STOCKS, created_at=datetime(2023, 1, 1, tzinfo=timezone.utc), updated_at=datetime(2023, 1, 1, tzinfo=timezone.utc), # Include if required ), @@ -190,6 +196,7 @@ def mock_stock_usecase(self): price=1500.0, quantity=5, action_type=ActionType.SELL, + stock_type=StockType.STOCKS, created_at=datetime(2023, 1, 2, tzinfo=timezone.utc), updated_at=datetime(2023, 1, 2, tzinfo=timezone.utc), # Include if required ), @@ -227,12 +234,14 @@ def test_success(self, mock_stock_usecase, mock_context, valid_request): assert response.stock_list[0].price == 100.0 assert response.stock_list[0].quantity == 10 assert response.stock_list[0].action == ActionType.BUY.value + assert response.stock_list[0].stock_type == StockType.STOCKS.value assert response.stock_list[1].id == "stock_124" assert response.stock_list[1].user_id == 1 assert response.stock_list[1].symbol == "GOOGL" assert response.stock_list[1].price == 1500.0 assert response.stock_list[1].quantity == 5 assert response.stock_list[1].action == ActionType.SELL.value + assert response.stock_list[1].stock_type == StockType.STOCKS.value mock_stock_usecase.list.assert_called_once_with(1) mock_context.set_code.assert_not_called() mock_context.set_details.assert_not_called() diff --git a/src/usecase/stock.py b/src/usecase/stock.py index 996ff17..83b0c89 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -114,4 +114,4 @@ def _get_stock_price(self, stock_info: List[Tuple[str, StockType]]) -> Dict[str, return stock_price_by_symbol except Exception as e: print(f"Error fetching prices for symbols {[symbol for symbol, _ in stock_info]}: {e}") - return {symbol: 0.0 for symbol, _ in stock_info}s \ No newline at end of file + return {symbol: 0.0 for symbol, _ in stock_info}