Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions server/activity_tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Callable, Dict
from typing import Callable, Dict, Optional

from server.dynamodb_helpers import TableContext

Expand All @@ -23,6 +23,7 @@ class PointsConfig:
POINTS_PER_MESSAGE = 0
POINTS_PER_SUCCESSFUL_INVITE = 0
DAILY_MESSAGE_LIMIT = 20
OPG_HOLDER_DAILY_MESSAGE_LIMIT = 100


class ActivityTracker:
Expand Down Expand Up @@ -63,7 +64,9 @@ def _cache_blocked_address(self, user_address: str):
"""
self._blocked_cache[user_address] = self._get_today_end()

async def increment_message_count(self, user_address: str) -> bool:
async def increment_message_count(
self, user_address: str, daily_limit: int
) -> bool:
"""
Increment the message count for a user.
Returns True if the message was counted, False if the daily limit was reached.
Expand All @@ -90,7 +93,7 @@ async def increment_message_count(self, user_address: str) -> bool:
daily_message_count = 0

# Check if daily limit reached
if daily_message_count >= PointsConfig.DAILY_MESSAGE_LIMIT:
if daily_message_count >= daily_limit:
# Cache this blocked address
self._cache_blocked_address(user_address)
return False
Expand Down Expand Up @@ -145,7 +148,9 @@ async def award_swap_points(self, user_address: str, points: int):
},
)

async def get_activity_stats(self, user_address: str) -> ActivityStats:
async def get_activity_stats(
self, user_address: str, daily_message_limit: int
) -> ActivityStats:
"""
Get the message count and successful invites count for a user.
Returns ActivityStats with 0 for both counts if the user doesn't exist.
Expand Down Expand Up @@ -175,7 +180,7 @@ async def get_activity_stats(self, user_address: str) -> ActivityStats:
successful_invites=successful_invites,
points=points,
daily_message_count=daily_message_count,
daily_message_limit=PointsConfig.DAILY_MESSAGE_LIMIT,
daily_message_limit=daily_message_limit,
rank=-1,
)

Expand All @@ -186,6 +191,6 @@ async def get_activity_stats(self, user_address: str) -> ActivityStats:
successful_invites=0,
points=0,
daily_message_count=0,
daily_message_limit=PointsConfig.DAILY_MESSAGE_LIMIT,
daily_message_limit=daily_message_limit,
rank=-1, # Return -1 for rank if there's an error
)
8 changes: 8 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@

# Use OG TEE flag for LLM inference
USE_TEE = os.getenv("USE_OG_TEE", "").lower() == "true"

# Base chain RPC for OPG token gating
BASE_RPC_URL: str = os.getenv(
"BASE_RPC_URL",
"https://responsive-attentive-panorama.base-mainnet.quiknode.pro/11a3fd4381ebfe3d6cef02189257575b0b4250cc/",
)
OPG_TOKEN_ADDRESS = "0xFbC2051AE2265686a469421b2C5A2D5462FbF5eB"
OPG_HOLDER_THRESHOLD = 1000 * 10**18 # raw units, 18 decimals
23 changes: 19 additions & 4 deletions server/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@
)
from langchain_openai import ChatOpenAI
from server.invitecode import InviteCodeManager
from server.activity_tracker import ActivityTracker
from server.activity_tracker import ActivityTracker, PointsConfig
from server.utils import extract_patterns, convert_to_agent_msg
from server.dynamodb_helpers import DatabaseManager
from server.middleware import DatadogMetricsMiddleware
from server.swap_tracker import SwapTracker
from server.jup_validator import JUPValidator
from server.cow_validator import COWValidator
from server.opg_token_gate import OPGTokenGate

from . import service
from .auth import FirebaseIDTokenData, get_current_user
Expand Down Expand Up @@ -120,6 +121,7 @@ def create_fastapi_app() -> FastAPI:
)
jup_validator = JUPValidator(token_metadata_repo)
cow_validator = COWValidator(token_metadata_repo)
opg_gate = OPGTokenGate()

# Store services in app state for access in routes
app.state.activity_tracker = activity_tracker
Expand Down Expand Up @@ -215,6 +217,11 @@ async def verify_captcha_token(captchaToken: str):
logging.error(f"Captcha verification failed: {result}")
return False

async def _get_daily_limit(address: str | None) -> int:
if address and await opg_gate.is_opg_holder(address):
return PointsConfig.OPG_HOLDER_DAILY_MESSAGE_LIMIT
return PointsConfig.DAILY_MESSAGE_LIMIT

# Routes
@app.post("/api/cloudflare/turnstile/v0/siteverify")
async def verify_cloudflare_turnstile_token(request: Request):
Expand Down Expand Up @@ -320,8 +327,9 @@ async def run_agent(
raise HTTPException(status_code=400, detail="Captcha token is required")

# Increment message count, return 429 if limit reached
daily_limit = await _get_daily_limit(agent_request.context.address)
if not await activity_tracker.increment_message_count(
agent_request.context.address
agent_request.context.address, daily_limit=daily_limit
):
statsd.increment("agent.message.daily_limit_reached")
raise HTTPException(status_code=429, detail="Daily message limit reached")
Expand Down Expand Up @@ -361,7 +369,10 @@ async def run_suggestions(
# raise HTTPException(status_code=429, detail="Invalid captcha token")

# Check if user has reached daily message limit (without incrementing)
stats = await activity_tracker.get_activity_stats(agent_request.context.address)
daily_limit = await _get_daily_limit(agent_request.context.address)
stats = await activity_tracker.get_activity_stats(
agent_request.context.address, daily_message_limit=daily_limit
)
if stats.daily_message_count >= stats.daily_message_limit:
statsd.increment("agent.suggestions.daily_limit_reached")
raise HTTPException(status_code=429, detail="Daily message limit reached")
Expand Down Expand Up @@ -413,6 +424,7 @@ async def use_invite_code(request: Request):
@app.get("/api/activity/stats")
async def get_activity_stats(
address: str,
evm_address: str = None,
user: FirebaseIDTokenData = Depends(get_current_user),
):
try:
Expand All @@ -421,7 +433,10 @@ async def get_activity_stats(
status_code=400, detail="Address parameter is required"
)

stats = await activity_tracker.get_activity_stats(address)
daily_limit = await _get_daily_limit(evm_address)
stats = await activity_tracker.get_activity_stats(
address, daily_message_limit=daily_limit
)
return stats
except Exception as e:
logging.error(
Expand Down
63 changes: 63 additions & 0 deletions server/opg_token_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
import asyncio

from web3 import AsyncWeb3
from web3.providers import AsyncHTTPProvider
from cachetools import TTLCache

from server.config import BASE_RPC_URL, OPG_TOKEN_ADDRESS, OPG_HOLDER_THRESHOLD

ERC20_BALANCE_OF_ABI = [
{
"name": "balanceOf",
"type": "function",
"stateMutability": "view",
"inputs": [{"name": "account", "type": "address"}],
"outputs": [{"name": "", "type": "uint256"}],
}
]


class OPGTokenGate:
"""Checks whether an EVM address holds enough $OPG tokens on Base."""

def __init__(self):
self.w3 = AsyncWeb3(AsyncHTTPProvider(BASE_RPC_URL))
self.contract = self.w3.eth.contract(
address=AsyncWeb3.to_checksum_address(OPG_TOKEN_ADDRESS),
abi=ERC20_BALANCE_OF_ABI,
)
self.threshold = OPG_HOLDER_THRESHOLD
self._cache: TTLCache = TTLCache(maxsize=4096, ttl=300)
self._lock = asyncio.Lock()

async def is_opg_holder(self, evm_address: str) -> bool:
"""Return True if evm_address holds >= threshold $OPG on Base.

Fail-closed: returns False on any error so users get default limits.
"""
try:
# Validate EVM address
checksum = AsyncWeb3.to_checksum_address(evm_address)
except Exception:
return False

# Check cache
cached = self._cache.get(checksum)
if cached is not None:
return cached

async with self._lock:
# Double-check after acquiring lock
cached = self._cache.get(checksum)
if cached is not None:
return cached

try:
balance = await self.contract.functions.balanceOf(checksum).call()
result = balance >= self.threshold
self._cache[checksum] = result
return result
except Exception as e:
logging.warning(f"OPG balance check failed for {checksum}: {e}")
return False