dnd-srd-api/app/main.py
Cupcake f6858e6ea1 feat: D&D SRD 5.2 API — FastAPI app with flat JSON caching
- Data fetched from Open5e API: 3,215 items (339 spells, 330 creatures,
  24 classes, 2,319 magic items, 203 equipment)
- FastAPI app with API key auth (X-API-Key header or ?api_key= param)
- Sliding window rate limiting (60 req/min, 10K req/day)
- Dice rolling endpoint (e.g., /api/dice/roll?spec=2d20+5)
- Full-text search across all resource types
- Pagination, filtering (name, level, school, class, etc.)
- Admin CLI for API key management
- nginx + systemd service ready for deployment
2026-06-03 18:13:00 +00:00

537 lines
No EOL
21 KiB
Python

"""
D&D SRD 5.2 API — FastAPI application
Serves SRD data from flat JSON files with API key auth + rate limiting.
"""
import json
import os
import random
import re
import threading
import time
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
# ── Paths ───────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).resolve().parent.parent
DATA_DIR = BASE_DIR / "data"
API_KEYS_FILE = DATA_DIR / "api_keys.json"
RATE_LIMITS_FILE = DATA_DIR / "rate_limits.json"
# ── Data Loading ────────────────────────────────────────────────────────────
_data_cache = {}
_data_lock = threading.Lock()
_last_load = 0
def load_data():
"""Load all JSON data files into memory cache."""
global _last_load
cache = {}
files = {
"spells": DATA_DIR / "spells.json",
"creatures": DATA_DIR / "creatures.json",
"classes": DATA_DIR / "classes.json",
"magic_items": DATA_DIR / "magic-items.json",
"equipment": DATA_DIR / "equipment.json",
}
for key, path in files.items():
if path.exists():
with open(path) as f:
cache[key] = json.load(f)
print(f" 📖 Loaded {key}: {cache[key]['count']} items")
else:
print(f" ⚠ Data file not found: {path}")
cache[key] = {"count": 0, "results": []}
with _data_lock:
_data_cache.clear()
_data_cache.update(cache)
_last_load = time.time()
return cache
def get_data(collection: str) -> dict:
"""Get a data collection, reloading if cache is stale."""
with _data_lock:
if not _data_cache:
load_data()
if collection not in _data_cache:
load_data()
return _data_cache.get(collection, {"count": 0, "results": []})
# ── API Keys & Rate Limiting ────────────────────────────────────────────────
def load_api_keys() -> dict:
"""Load API keys from JSON file."""
if API_KEYS_FILE.exists():
with open(API_KEYS_FILE) as f:
return json.load(f)
return {"keys": {}}
def save_api_keys(keys: dict):
"""Save API keys to JSON file."""
API_KEYS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(API_KEYS_FILE, "w") as f:
json.dump(keys, f, indent=2)
# Rate limiter: sliding window per key
_rate_limits = {} # key -> list of timestamps
_rate_lock = threading.Lock()
RATE_LIMIT_CONFIG = {"requests_per_minute": 60, "requests_per_day": 10000}
def check_rate_limit(api_key: str) -> Optional[JSONResponse]:
"""Check if request is within rate limits. Returns error response if exceeded."""
now = time.time()
with _rate_lock:
if api_key not in _rate_limits:
_rate_limits[api_key] = []
# Clean old entries (older than 1 minute)
timestamps = _rate_limits[api_key]
cutoff = now - 60
timestamps[:] = [t for t in timestamps if t > cutoff]
# Check per-minute limit
if len(timestamps) >= RATE_LIMIT_CONFIG["requests_per_minute"]:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"message": f"Max {RATE_LIMIT_CONFIG['requests_per_minute']} requests per minute",
"retry_after": 60,
},
)
# Check per-day limit (count all timestamps in last 24h)
day_cutoff = now - 86400
day_count = sum(1 for t in timestamps if t > day_cutoff)
if day_count >= RATE_LIMIT_CONFIG["requests_per_day"]:
return JSONResponse(
status_code=429,
content={
"error": "Daily rate limit exceeded",
"message": f"Max {RATE_LIMIT_CONFIG['requests_per_day']} requests per day",
},
)
timestamps.append(now)
return None
# Page through all rate limits without archive
def _has_rate_limit_archive(coll: str) -> bool:
return False
# ── Auth Middleware ──────────────────────────────────────────────────────────
AUTH_EXEMPT_PATHS = {"/api/docs", "/api/openapi.json", "/api/redoc", "/health"}
async def auth_middleware(request: Request, call_next):
"""API key authentication middleware."""
path = request.url.path
# Skip auth for docs and health
if any(path.startswith(p) for p in AUTH_EXEMPT_PATHS) or path == "/api":
return await call_next(request)
# Get API key from header or query param
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
if not api_key:
return JSONResponse(
status_code=401,
content={"error": "Missing API key", "message": "Provide via X-API-Key header or ?api_key= parameter"},
)
keys = load_api_keys()
if api_key not in keys.get("keys", {}):
return JSONResponse(
status_code=403,
content={"error": "Invalid API key", "message": "The provided API key is not valid"},
)
# Check rate limit
rate_check = check_rate_limit(api_key)
if rate_check:
return rate_check
return await call_next(request)
# ── Filtering Helpers ────────────────────────────────────────────────────────
def filter_spells(spells: list, params: dict) -> list:
"""Apply filters to spells list."""
result = spells
for key, value in params.items():
if key == "name":
result = [s for s in result if value.lower() in s.get("name", "").lower()]
elif key == "level":
try:
level = int(value)
result = [s for s in result if s.get("level") == level]
except ValueError:
pass
elif key == "school":
result = [s for s in result if value.lower() in s.get("school", "").lower()]
elif key == "class":
result = [s for s in result if any(value.lower() in c.lower() for c in s.get("classes", []))]
elif key == "concentration":
val = value.lower() in ("true", "1", "yes")
result = [s for s in result if s.get("concentration") == val]
elif key == "ritual":
val = value.lower() in ("true", "1", "yes")
result = [s for s in result if s.get("ritual") == val]
elif key == "search":
val = value.lower()
result = [s for s in result if val in s.get("name", "").lower() or val in s.get("description", "").lower()]
return result
def filter_creatures(creatures: list, params: dict) -> list:
"""Apply filters to creatures list."""
result = creatures
for key, value in params.items():
if key == "name":
result = [c for c in result if value.lower() in c.get("name", "").lower()]
elif key == "type":
result = [c for c in result if value.lower() in c.get("type", "").lower()]
elif key == "cr" or key == "challenge_rating":
result = [c for c in result if str(c.get("challenge_rating", "")) == value]
elif key == "size":
result = [c for c in result if value.lower() in c.get("size", "").lower()]
elif key == "alignment":
result = [c for c in result if value.lower() in c.get("alignment", "").lower()]
elif key == "search":
val = value.lower()
result = [c for c in result if val in c.get("name", "").lower() or val in c.get("description", "").lower()]
return result
def filter_items(items: list, params: dict) -> list:
"""Apply filters to items list."""
result = items
for key, value in params.items():
if key == "name":
result = [i for i in result if value.lower() in i.get("name", "").lower()]
elif key == "type":
result = [i for i in result if value.lower() in i.get("type", "").lower()]
elif key == "rarity":
result = [i for i in result if value.lower() in i.get("rarity", "").lower()]
elif key == "search":
val = value.lower()
result = [i for i in result if val in i.get("name", "").lower() or val in i.get("description", "").lower()]
return result
def paginate(items: list, page: int = 1, page_size: int = 50) -> dict:
"""Paginate a list of items."""
total = len(items)
total_pages = max(1, (total + page_size - 1) // page_size)
page = max(1, min(page, total_pages))
start = (page - 1) * page_size
end = start + page_size
return {
"count": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages,
"next": f"?page={page + 1}&page_size={page_size}" if page < total_pages else None,
"previous": f"?page={page - 1}&page_size={page_size}" if page > 1 else None,
"results": items[start:end],
}
# ── App Initialization ──────────────────────────────────────────────────────
app = FastAPI(
title="D&D SRD 5.2 API",
description="An open API for the Dungeons & Dragons 5.2 System Reference Document (SRD). "
"Powered by Open5e data with flat JSON caching.\n\n"
"**Authentication:** Pass `X-API-Key` header or `?api_key=` query parameter.",
version="1.0.0",
docs_url="/api/docs",
openapi_url="/api/openapi.json",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.middleware("http")(auth_middleware)
# ── Startup ──────────────────────────────────────────────────────────────────
@app.on_event("startup")
async def startup():
print("🚀 D&D SRD 5.2 API starting up...")
load_data()
print(f"✅ Loaded {sum(v['count'] for v in _data_cache.values())} total items")
# ── Dice Rolling ────────────────────────────────────────────────────────────
DICE_PATTERN = re.compile(r"^(\d*)d(\d+)([+-]\d+)?$")
def roll_dice(spec: str) -> dict:
"""Parse and roll dice notation like '2d20+5' or 'd6'."""
m = DICE_PATTERN.match(spec.lower().replace(" ", ""))
if not m:
return {"error": f"Invalid dice notation: '{spec}'. Use format like 2d20+5, d6, 3d8-2"}
num = int(m.group(1)) if m.group(1) else 1
sides = int(m.group(2))
mod = int(m.group(3)) if m.group(3) else 0
if num < 1 or num > 100:
return {"error": "Number of dice must be between 1 and 100"}
if sides < 2 or sides > 1000:
return {"error": "Dice sides must be between 2 and 1000"}
rolls = [random.randint(1, sides) for _ in range(num)]
total = sum(rolls) + mod
return {
"spec": spec.strip(),
"num_dice": num,
"sides": sides,
"modifier": mod,
"rolls": rolls,
"total": max(total, 1), # Minimum 1
}
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/")
async def root_redirect():
return {"message": "D&D SRD 5.2 API", "docs": "/api/docs", "api_root": "/api"}
@app.get("/api")
async def api_root():
return {
"name": "D&D SRD 5.2 API",
"version": "1.0.0",
"documentation": "/api/docs",
"endpoints": {
"spells": "/api/spells",
"creatures": "/api/creatures",
"classes": "/api/classes",
"magic_items": "/api/magic-items",
"equipment": "/api/equipment",
"dice": "/api/dice/roll",
},
}
@app.get("/health")
async def health():
data_count = sum(v["count"] for v in _data_cache.values()) if _data_cache else 0
return {"status": "healthy", "data_items": data_count, "cache_age_s": int(time.time() - _last_load) if _last_load else 0}
# ── Spells ───────────────────────────────────────────────────────────────────
@app.get("/api/spells")
async def list_spells(
request: Request,
name: Optional[str] = Query(None, description="Filter by name (partial match)"),
level: Optional[int] = Query(None, description="Filter by spell level (0=cantrip, 1-9)"),
school: Optional[str] = Query(None, description="Filter by school of magic"),
class_name: Optional[str] = Query(None, alias="class", description="Filter by class name"),
concentration: Optional[bool] = Query(None),
ritual: Optional[bool] = Query(None),
search: Optional[str] = Query(None, description="Search name and description"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
):
data = get_data("spells")
params = {k: v for k, v in {
"name": name, "level": level, "school": school, "class": class_name,
"concentration": concentration, "ritual": ritual, "search": search,
}.items() if v is not None}
filtered = filter_spells(data["results"], params)
return paginate(filtered, page, page_size)
@app.get("/api/spells/{spell_key}")
async def get_spell(spell_key: str):
data = get_data("spells")
for spell in data["results"]:
if spell.get("key") == spell_key:
return spell
raise HTTPException(status_code=404, detail=f"Spell '{spell_key}' not found")
# ── Creatures ────────────────────────────────────────────────────────────────
@app.get("/api/creatures")
async def list_creatures(
name: Optional[str] = Query(None),
type: Optional[str] = Query(None),
cr: Optional[str] = Query(None, alias="challenge_rating"),
size: Optional[str] = Query(None),
alignment: Optional[str] = Query(None),
search: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
):
data = get_data("creatures")
params = {k: v for k, v in {
"name": name, "type": type, "cr": cr, "size": size,
"alignment": alignment, "search": search,
}.items() if v is not None}
filtered = filter_creatures(data["results"], params)
return paginate(filtered, page, page_size)
@app.get("/api/creatures/{creature_key}")
async def get_creature(creature_key: str):
data = get_data("creatures")
for creature in data["results"]:
if creature.get("key") == creature_key:
return creature
raise HTTPException(status_code=404, detail=f"Creature '{creature_key}' not found")
# ── Classes ──────────────────────────────────────────────────────────────────
@app.get("/api/classes")
async def list_classes(
name: Optional[str] = Query(None),
search: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
):
data = get_data("classes")
results = data["results"]
if name:
results = [c for c in results if name.lower() in c.get("name", "").lower()]
if search:
val = search.lower()
results = [c for c in results if val in c.get("name", "").lower() or val in c.get("description", "").lower()]
return paginate(results, page, page_size)
@app.get("/api/classes/{class_key}")
async def get_class(class_key: str):
data = get_data("classes")
for cls in data["results"]:
if cls.get("key") == class_key:
return cls
raise HTTPException(status_code=404, detail=f"Class '{class_key}' not found")
# ── Magic Items ──────────────────────────────────────────────────────────────
@app.get("/api/magic-items")
async def list_magic_items(
name: Optional[str] = Query(None),
type: Optional[str] = Query(None),
rarity: Optional[str] = Query(None),
search: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
):
data = get_data("magic_items")
params = {k: v for k, v in {"name": name, "type": type, "rarity": rarity, "search": search}.items() if v is not None}
filtered = filter_items(data["results"], params)
return paginate(filtered, page, page_size)
@app.get("/api/magic-items/{item_key}")
async def get_magic_item(item_key: str):
data = get_data("magic_items")
for item in data["results"]:
if item.get("key") == item_key:
return item
raise HTTPException(status_code=404, detail=f"Magic item '{item_key}' not found")
# ── Equipment ────────────────────────────────────────────────────────────────
@app.get("/api/equipment")
async def list_equipment(
name: Optional[str] = Query(None),
type: Optional[str] = Query(None),
search: Optional[str] = Query(None),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
):
data = get_data("equipment")
params = {k: v for k, v in {"name": name, "type": type, "search": search}.items() if v is not None}
filtered = filter_items(data["results"], params)
return paginate(filtered, page, page_size)
@app.get("/api/equipment/{item_key}")
async def get_equipment(item_key: str):
data = get_data("equipment")
for item in data["results"]:
if item.get("key") == item_key:
return item
raise HTTPException(status_code=404, detail=f"Equipment '{item_key}' not found")
# ── Dice Rolling ─────────────────────────────────────────────────────────────
@app.get("/api/dice/roll")
async def dice_roll(
spec: str = Query("d20", description="Dice notation (e.g., 2d20+5, d6, 3d8)"),
count: int = Query(1, ge=1, le=10, description="Number of times to roll"),
):
if count > 1:
rolls = [roll_dice(spec) for _ in range(count)]
grand_total = sum(r.get("total", 0) for r in rolls if "error" not in r)
errors = [r for r in rolls if "error" in r]
result = {
"spec": spec,
"rolls": rolls,
"grand_total": grand_total,
"count": count,
}
if errors:
result["errors"] = errors
return result
else:
return roll_dice(spec)
# ── Search ───────────────────────────────────────────────────────────────────
@app.get("/api/search")
async def search_all(
q: str = Query(..., description="Search query"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=50),
):
"""Search across all resource types."""
query = q.lower()
results = []
for collection_key, label in [
("spells", "spells"),
("creatures", "creatures"),
("classes", "classes"),
("magic_items", "magic-items"),
("equipment", "equipment"),
]:
data = get_data(collection_key)
for item in data["results"]:
name = item.get("name", "").lower()
desc = (item.get("description", "") or "").lower()
if query in name or query in desc:
item["_type"] = label
results.append(item)
return paginate(results, page, page_size)
# ── Run ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)