Skip to content

Commit 8bd9e24

Browse files
eyo-chenEyo Chen
andauthored
Refactor: support etf type (#18)
* feat: add etf type in stock proto * feat: add etf type in domain * feat: add stock proto in adapter * feat: add stock proto in usecase * feat: add stock proto in handler --------- Co-authored-by: Eyo Chen <eyo.chen@amazingtalker.com>
1 parent 41f2310 commit 8bd9e24

14 files changed

Lines changed: 403 additions & 143 deletions

File tree

src/adapters/portfolio.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from dataclasses import asdict
21
from datetime import datetime, timezone
32
from pymongo import MongoClient
43
from pymongo.database import Database
54
from .base import AbstractPortfolioRepository
65
from domain.portfolio import Portfolio, Holding
6+
from domain.enum import StockType
77

88

99
class PortfolioRepository(AbstractPortfolioRepository):
@@ -22,7 +22,12 @@ def get(self, user_id: int) -> Portfolio:
2222
cash_balance=result["cash_balance"],
2323
total_money_in=result["total_money_in"],
2424
holdings=[
25-
Holding(symbol=holding["symbol"], shares=holding["shares"], total_cost=holding["total_cost"])
25+
Holding(
26+
symbol=holding["symbol"],
27+
shares=holding["shares"],
28+
stock_type=StockType(holding["stock_type"]),
29+
total_cost=holding["total_cost"],
30+
)
2631
for holding in result["holdings"]
2732
],
2833
created_at=result["created_at"],
@@ -31,7 +36,7 @@ def get(self, user_id: int) -> Portfolio:
3136

3237
def update(self, portfolio: Portfolio) -> None:
3338
portfolio.updated_at = datetime.now(timezone.utc)
34-
self.collection.replace_one({"user_id": portfolio.user_id}, asdict(portfolio), upsert=True)
39+
self.collection.replace_one({"user_id": portfolio.user_id}, portfolio.as_dict(), upsert=True)
3540

3641
def __del__(self):
3742
self.client.close()

src/adapters/stock.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from typing import List
2-
32
from pymongo import MongoClient
43
from pymongo.database import Database
5-
64
from .base import AbstractStockRepository
7-
from domain.stock import CreateStock, Stock, ActionType
5+
from domain.stock import CreateStock, Stock
6+
from domain.enum import ActionType, StockType
87

98

109
class StockRepository(AbstractStockRepository):
@@ -14,17 +13,7 @@ def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"):
1413
self.collection = self.db["stocks"]
1514

1615
def create(self, stock: CreateStock) -> str:
17-
stock_dict = {
18-
"user_id": stock.user_id,
19-
"symbol": stock.symbol,
20-
"price": stock.price,
21-
"quantity": stock.quantity,
22-
"action_type": stock.action_type.value,
23-
"created_at": stock.created_at,
24-
"updated_at": stock.created_at,
25-
}
26-
27-
result = self.collection.insert_one(stock_dict)
16+
result = self.collection.insert_one(stock.as_dict())
2817
return str(result.inserted_id)
2918

3019
def list(self, user_id: int) -> List[Stock]:
@@ -38,6 +27,7 @@ def list(self, user_id: int) -> List[Stock]:
3827
price=doc["price"],
3928
quantity=doc["quantity"],
4029
action_type=ActionType(doc["action_type"]),
30+
stock_type=StockType(doc["stock_type"]),
4131
created_at=doc["created_at"],
4232
updated_at=doc["updated_at"],
4333
)

src/domain/enum.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from enum import Enum
2+
3+
4+
class ActionType(Enum):
5+
BUY = "BUY"
6+
SELL = "SELL"
7+
TRANSFER = "TRANSFER"
8+
9+
10+
ACTION_MAP = {
11+
1: ActionType.BUY,
12+
2: ActionType.SELL,
13+
3: ActionType.TRANSFER,
14+
}
15+
16+
17+
class StockType(Enum):
18+
STOCKS = "STOCKS"
19+
ETF = "ETF"
20+
21+
22+
STOCK_MAP = {
23+
1: StockType.STOCKS,
24+
2: StockType.ETF,
25+
}

src/domain/portfolio.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, asdict
22
from datetime import datetime
3-
from typing import List
3+
from typing import List, TypedDict
4+
from .enum import StockType
5+
from utils.utils import custom_dict_factory
6+
7+
8+
class HoldingDict(TypedDict):
9+
symbol: str
10+
shares: int
11+
stock_type: str
12+
total_cost: float
13+
14+
15+
class PortfolioDict(TypedDict):
16+
user_id: int
17+
cash_balance: float
18+
total_money_in: float
19+
holdings: List[HoldingDict]
20+
created_at: datetime
21+
updated_at: datetime
422

523

624
@dataclass
725
class Holding:
826
symbol: str
927
shares: int
28+
stock_type: StockType
1029
total_cost: float
1130

31+
def __post_init__(self):
32+
if not self.symbol:
33+
raise ValueError("symbol cannot be empty")
34+
if self.shares < 0:
35+
raise ValueError("shares cannot be negative")
36+
if self.total_cost < 0:
37+
raise ValueError("total_cost cannot be negative")
38+
39+
def as_dict(self) -> HoldingDict:
40+
return asdict(self, dict_factory=custom_dict_factory)
41+
1242

1343
@dataclass
1444
class Portfolio:
@@ -18,3 +48,12 @@ class Portfolio:
1848
holdings: List[Holding]
1949
created_at: datetime
2050
updated_at: datetime
51+
52+
def __post_init__(self):
53+
if self.cash_balance < 0:
54+
raise ValueError("cash_balance cannot be negative")
55+
if self.total_money_in < 0:
56+
raise ValueError("total_money_in cannot be negative")
57+
58+
def as_dict(self) -> PortfolioDict:
59+
return asdict(self, dict_factory=custom_dict_factory)

src/domain/stock.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
1-
from dataclasses import dataclass
1+
from dataclasses import asdict, dataclass
2+
from typing import TypedDict
23
from datetime import datetime
3-
from enum import Enum
4+
from utils.utils import custom_dict_factory
5+
from .enum import ActionType, StockType
46

57

6-
class ActionType(Enum):
7-
BUY = "BUY"
8-
SELL = "SELL"
9-
TRANSFER = "TRANSFER"
8+
class CreateStockDict(TypedDict):
9+
user_id: int
10+
symbol: str
11+
price: float
12+
quantity: int
13+
action_type: str
14+
stock_type: str
15+
created_at: datetime
1016

1117

12-
ACTION_MAP = {1: ActionType.BUY, 2: ActionType.SELL, 3: ActionType.TRANSFER}
18+
class StockDict(TypedDict):
19+
id: str
20+
user_id: int
21+
symbol: str
22+
price: float
23+
quantity: int
24+
action_type: str
25+
stock_type: str
26+
created_at: datetime
27+
updated_at: datetime
1328

1429

1530
@dataclass
@@ -19,8 +34,22 @@ class CreateStock:
1934
price: float
2035
quantity: int
2136
action_type: ActionType
37+
stock_type: StockType
2238
created_at: datetime
2339

40+
def __post_init__(self):
41+
if not self.user_id or self.user_id <= 0:
42+
raise ValueError("user_id must be non-empty and greater than 0")
43+
if not self.symbol or self.symbol.strip() == "":
44+
raise ValueError("symbol must be a non-empty string")
45+
if self.price <= 0:
46+
raise ValueError("price must be greater than 0")
47+
if self.quantity <= 0:
48+
raise ValueError("quantity must be greater than 0")
49+
50+
def as_dict(self) -> CreateStockDict:
51+
return asdict(self, dict_factory=custom_dict_factory)
52+
2453

2554
@dataclass
2655
class Stock:
@@ -30,5 +59,21 @@ class Stock:
3059
price: float
3160
quantity: int
3261
action_type: ActionType
62+
stock_type: StockType
3363
created_at: datetime
3464
updated_at: datetime
65+
66+
def __post_init__(self):
67+
if not self.id:
68+
raise ValueError("id cannot be empty")
69+
if self.user_id <= 0:
70+
raise ValueError("user_id must be positive")
71+
if not self.symbol:
72+
raise ValueError("symbol cannot be empty")
73+
if self.price < 0:
74+
raise ValueError("price cannot be negative")
75+
if self.quantity <= 0:
76+
raise ValueError("quantity must be positive")
77+
78+
def as_dict(self) -> StockDict:
79+
return asdict(self, dict_factory=custom_dict_factory)

src/handler/stock.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import logging
22
from datetime import datetime, timezone
3-
43
import grpc
54
import proto.stock_pb2 as stock_pb2
65
import proto.stock_pb2_grpc as stock_pb2_grpc
7-
86
from usecase.base import AbstractStockUsecase
9-
from domain.stock import CreateStock, ActionType, ACTION_MAP
7+
from domain.stock import CreateStock
8+
from domain.enum import ActionType, ACTION_MAP, StockType, STOCK_MAP
109

1110

1211
class StockService(stock_pb2_grpc.StockService):
@@ -15,13 +14,13 @@ def __init__(self, stock_usecase: AbstractStockUsecase):
1514

1615
def Create(self, request, context):
1716
try:
18-
self._validate_create_request(request)
1917
stock = CreateStock(
2018
user_id=request.user_id,
2119
symbol=request.symbol,
2220
price=request.price,
2321
quantity=request.quantity,
2422
action_type=self._map_action_type(request.action),
23+
stock_type=self._map_stock_type(request.stock_type),
2524
created_at=datetime.now(timezone.utc),
2625
)
2726

@@ -64,15 +63,10 @@ def _map_action_type(self, action: int) -> ActionType:
6463
raise ValueError(f"Invalid action type: {action}. Must be 1 (BUY), 2 (SELL), or 3 (TRANSFER).")
6564
return ACTION_MAP[action]
6665

67-
def _validate_create_request(self, request):
68-
if not request.user_id or request.user_id <= 0:
69-
raise ValueError("user_id must be non-empty and greater than 0")
70-
if not request.symbol or request.symbol.strip() == "":
71-
raise ValueError("symbol must be a non-empty string")
72-
if request.price <= 0:
73-
raise ValueError("price must be greater than 0")
74-
if request.quantity <= 0:
75-
raise ValueError("quantity must be greater than 0")
66+
def _map_stock_type(self, stock_type: int) -> StockType:
67+
if stock_type not in STOCK_MAP:
68+
raise ValueError(f"Invalid stock type: {stock_type}. Must be 1 (STOCKS), 2 (ETF).")
69+
return STOCK_MAP[stock_type]
7670

7771
def _convert_to_proto_stock_list(self, stock_list):
7872
return [
@@ -83,6 +77,7 @@ def _convert_to_proto_stock_list(self, stock_list):
8377
price=stock.price,
8478
quantity=stock.quantity,
8579
action=stock.action_type.value,
80+
stock_type=stock.stock_type.value,
8681
created_at=stock.created_at,
8782
updated_at=stock.updated_at,
8883
)

src/proto/stock.proto

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,24 @@ message Action {
1313
}
1414
}
1515

16+
message StockType {
17+
enum Type {
18+
UNSPECIFIED = 0;
19+
STOCKS = 1;
20+
ETF = 2;
21+
}
22+
}
23+
1624
message Stock {
1725
string id = 1 [json_name = "id"];
1826
int32 user_id = 2 [json_name = "user_id"];
1927
string symbol = 3 [json_name = "symbol"];
2028
double price = 4 [json_name = "price"];
2129
int32 quantity = 5 [json_name = "quantity"];
2230
string action = 6 [json_name = "action"];
23-
google.protobuf.Timestamp created_at = 7 [json_name = "created_at"];
24-
google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"];
31+
string stock_type = 7 [json_name = "stock_type"];
32+
google.protobuf.Timestamp created_at = 8 [json_name = "created_at"];
33+
google.protobuf.Timestamp updated_at = 9 [json_name = "updated_at"];
2534
}
2635

2736
message CreateReq {
@@ -30,8 +39,9 @@ message CreateReq {
3039
double price = 3 [json_name = "price"];
3140
int32 quantity = 4 [json_name = "quantity"];
3241
Action.Type action = 5 [json_name = "action"]; // add validation rules
33-
google.protobuf.Timestamp created_at = 6 [json_name = "created_at"];
34-
google.protobuf.Timestamp updated_at = 7 [json_name = "updated_at"];
42+
StockType.Type stock_type = 6 [json_name = "stock_type"]; // add validation rules
43+
google.protobuf.Timestamp created_at = 7 [json_name = "created_at"];
44+
google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"];
3545
}
3646

3747
message CreateResp {

src/proto/stock_pb2.py

Lines changed: 17 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)