Compare commits

..

No commits in common. "b676ba742913e8e4a3f984bb2ef865089fa7579a" and "98b4954065dbad26b19b7066a16b6876e829b8e2" have entirely different histories.

3 changed files with 18 additions and 32 deletions

View file

@ -1,15 +0,0 @@
from typing import Optional
from sqlmodel import Field, SQLModel
import datetime
class BlacklistBase(SQLModel):
handle: str
class Blacklist(BlacklistBase, table=True):
id: int | None = Field(default=None, primary_key=True)
class BlacklistCreate(BlacklistBase):
pass

17
main.py
View file

@ -1,11 +1,10 @@
from fastapi import FastAPI, HTTPException, Response, Depends from fastapi import FastAPI, HTTPException, Response, Depends
from sqlmodel import Session, select from sqlmodel import Session, select
from typing import Annotated from typing import Annotated
import datetime, re import datetime
from database.feed import Feed, FeedCreate, FeedPublic, FeedUpdate from database.feed import Feed, FeedCreate, FeedPublic, FeedUpdate
from database.blacklist import Blacklist, BlacklistCreate from utils.feed_generator import generate_feed_of_user, USER_NOT_FOUND, CANNOT_ACCESS_INSTANCE, INVALID_HANDLE
from utils.feed_generator import generate_feed_of_user, USER_NOT_FOUND, CANNOT_ACCESS_INSTANCE
from utils.database import get_session, create_db_and_tables from utils.database import get_session, create_db_and_tables
app = FastAPI() app = FastAPI()
@ -19,16 +18,6 @@ def on_startup():
@app.get("/feed/{user_handle}") @app.get("/feed/{user_handle}")
def get_feed_of_user(user_handle: str, session: SessionDep): def get_feed_of_user(user_handle: str, session: SessionDep):
# validate user's handle
HANDLE_PATTERN = "@[a-zA-Z0-9_]+@[^\t\n\r\f\v]+"
if re.match(HANDLE_PATTERN, user_handle) is None:
return HTTPException(status_code=400, detail="The handle is invalid.")
# check if there is the user in blacklist
black_user = session.exec(select(Blacklist).where(Blacklist.handle == user_handle)).first
if black_user:
return HTTPException(status_code=401, detail="The user is in blacklist.")
# get feed on database # get feed on database
feed_db = session.exec(select(Feed).where(Feed.handle == user_handle)).first() feed_db = session.exec(select(Feed).where(Feed.handle == user_handle)).first()
@ -42,6 +31,8 @@ def get_feed_of_user(user_handle: str, session: SessionDep):
return HTTPException(status_code=404, detail="The user cannot be found.") return HTTPException(status_code=404, detail="The user cannot be found.")
if feed == CANNOT_ACCESS_INSTANCE: if feed == CANNOT_ACCESS_INSTANCE:
return HTTPException(status_code=400, detail="Cannot access the instance.") return HTTPException(status_code=400, detail="Cannot access the instance.")
if feed == INVALID_HANDLE:
return HTTPException(status_code=400, detail="The handle is invalid.")
# cache new feed # cache new feed
if feed_db: if feed_db:

View file

@ -6,11 +6,21 @@ CANNOT_ACCESS_INSTANCE = 2
INVALID_HANDLE = 3 INVALID_HANDLE = 3
def parse_handle(user_handle: str) -> list[str] | None:
# validate user's handle
HANDLE_PATTERN = "@[a-zA-Z0-9_]+@[^\t\n\r\f\v]+"
if re.match(HANDLE_PATTERN, user_handle) is None:
return None
# parse user's handle
return user_handle.split("@")[1:]
def get_statuses_of_user(user_handle: str) -> list[dict] | str: def get_statuses_of_user(user_handle: str) -> list[dict] | str:
# parsed_handle = parse_handle(user_handle) parsed_handle = parse_handle(user_handle)
# if parsed_handle is None: if parsed_handle is None:
# return INVALID_HANDLE return INVALID_HANDLE
[username, instance] = user_handle.split("@")[1:] [username, instance] = parsed_handle
try: try:
account_lookup = requests.get(f"https://{instance}/api/v1/accounts/lookup?acct={username}") account_lookup = requests.get(f"https://{instance}/api/v1/accounts/lookup?acct={username}")