Compare commits

...

2 commits

Author SHA1 Message Date
sunwoo1524 b676ba7429 Feat: user blacklist 2024-11-03 18:10:56 +09:00
sunwoo1524 a7e1aae744 Refactor: user handle validation function 2024-11-03 17:44:18 +09:00
3 changed files with 32 additions and 18 deletions

15
database/blacklist.py Normal file
View file

@ -0,0 +1,15 @@
from typing import Optional
from sqlmodel import Field, SQLModel
import datetime
class BlacklistBase(SQLModel):
handle: str
class Blacklist(BlacklistBase, table=True):
id: int | None = Field(default=None, primary_key=True)
class BlacklistCreate(BlacklistBase):
pass

17
main.py
View file

@ -1,10 +1,11 @@
from fastapi import FastAPI, HTTPException, Response, Depends from fastapi import FastAPI, HTTPException, Response, Depends
from sqlmodel import Session, select from sqlmodel import Session, select
from typing import Annotated from typing import Annotated
import datetime import datetime, re
from database.feed import Feed, FeedCreate, FeedPublic, FeedUpdate from database.feed import Feed, FeedCreate, FeedPublic, FeedUpdate
from utils.feed_generator import generate_feed_of_user, USER_NOT_FOUND, CANNOT_ACCESS_INSTANCE, INVALID_HANDLE from database.blacklist import Blacklist, BlacklistCreate
from utils.feed_generator import generate_feed_of_user, USER_NOT_FOUND, CANNOT_ACCESS_INSTANCE
from utils.database import get_session, create_db_and_tables from utils.database import get_session, create_db_and_tables
app = FastAPI() app = FastAPI()
@ -18,6 +19,16 @@ def on_startup():
@app.get("/feed/{user_handle}") @app.get("/feed/{user_handle}")
def get_feed_of_user(user_handle: str, session: SessionDep): def get_feed_of_user(user_handle: str, session: SessionDep):
# validate user's handle
HANDLE_PATTERN = "@[a-zA-Z0-9_]+@[^\t\n\r\f\v]+"
if re.match(HANDLE_PATTERN, user_handle) is None:
return HTTPException(status_code=400, detail="The handle is invalid.")
# check if there is the user in blacklist
black_user = session.exec(select(Blacklist).where(Blacklist.handle == user_handle)).first
if black_user:
return HTTPException(status_code=401, detail="The user is in blacklist.")
# get feed on database # get feed on database
feed_db = session.exec(select(Feed).where(Feed.handle == user_handle)).first() feed_db = session.exec(select(Feed).where(Feed.handle == user_handle)).first()
@ -31,8 +42,6 @@ def get_feed_of_user(user_handle: str, session: SessionDep):
return HTTPException(status_code=404, detail="The user cannot be found.") return HTTPException(status_code=404, detail="The user cannot be found.")
if feed == CANNOT_ACCESS_INSTANCE: if feed == CANNOT_ACCESS_INSTANCE:
return HTTPException(status_code=400, detail="Cannot access the instance.") return HTTPException(status_code=400, detail="Cannot access the instance.")
if feed == INVALID_HANDLE:
return HTTPException(status_code=400, detail="The handle is invalid.")
# cache new feed # cache new feed
if feed_db: if feed_db:

View file

@ -6,21 +6,11 @@ CANNOT_ACCESS_INSTANCE = 2
INVALID_HANDLE = 3 INVALID_HANDLE = 3
def parse_handle(user_handle: str) -> list[str] | None:
# validate user's handle
HANDLE_PATTERN = "@[a-zA-Z0-9_]+@[^\t\n\r\f\v]+"
if re.match(HANDLE_PATTERN, user_handle) is None:
return None
# parse user's handle
return user_handle.split("@")[1:]
def get_statuses_of_user(user_handle: str) -> list[dict] | str: def get_statuses_of_user(user_handle: str) -> list[dict] | str:
parsed_handle = parse_handle(user_handle) # parsed_handle = parse_handle(user_handle)
if parsed_handle is None: # if parsed_handle is None:
return INVALID_HANDLE # return INVALID_HANDLE
[username, instance] = parsed_handle [username, instance] = user_handle.split("@")[1:]
try: try:
account_lookup = requests.get(f"https://{instance}/api/v1/accounts/lookup?acct={username}") account_lookup = requests.get(f"https://{instance}/api/v1/accounts/lookup?acct={username}")