from fastapi import FastAPI, HTTPException, Response, Depends from sqlmodel import Session, select from typing import Annotated import datetime, re from database.feed import Feed, FeedCreate, FeedPublic, FeedUpdate from database.blacklist import Blacklist, BlacklistCreate from utils.feed_generator import generate_feed_of_user, USER_NOT_FOUND, CANNOT_ACCESS_INSTANCE from utils.database import get_session, create_db_and_tables app = FastAPI() SessionDep = Annotated[Session, Depends(get_session)] @app.on_event("startup") def on_startup(): create_db_and_tables() @app.get("/feed/{user_handle}") def get_feed_of_user(user_handle: str, session: SessionDep): # validate user's handle HANDLE_PATTERN = "@[a-zA-Z0-9_]+@[^\t\n\r\f\v]+" if re.match(HANDLE_PATTERN, user_handle) is None: return HTTPException(status_code=400, detail="The handle is invalid.") # check if there is the user in blacklist black_user = session.exec(select(Blacklist).where(Blacklist.handle == user_handle)).first() if black_user: return HTTPException(status_code=401, detail="The user is in blacklist.") # get feed on database feed_db = session.exec(select(Feed).where(Feed.handle == user_handle)).first() # return cached feed if it has been cached for less than 10 minutes if feed_db and ((feed_db.updated_at + datetime.timedelta(minutes=10)) - datetime.datetime.now()).total_seconds() > 0: return Response(content=feed_db.feed, media_type="application/xml") feed = generate_feed_of_user(user_handle) if feed == USER_NOT_FOUND: return HTTPException(status_code=404, detail="The user cannot be found.") if feed == CANNOT_ACCESS_INSTANCE: return HTTPException(status_code=400, detail="Cannot access the instance.") # cache new feed if feed_db: feed_data = FeedUpdate(feed=feed, updated_at=datetime.datetime.now()).model_dump(exclude_unset=True) feed_db.sqlmodel_update(feed_data) else: feed_db = Feed.model_validate(FeedCreate(handle=user_handle, feed=feed)) session.add(feed_db) session.commit() session.refresh(feed_db) return Response(content=feed, media_type="application/xml")