commit 3794afff95a4fe384210803292d4d32cd6d1ad52 Author: Hyungi Ahn Date: Tue Mar 31 13:41:46 2026 +0900 feat: AI Gateway Phase 1 - FastAPI 코어 구현 GPU 서버 중앙 AI 라우팅 서비스 초기 구현: - OpenAI 호환 API (/v1/chat/completions, /v1/models, /v1/embeddings) - 모델 레지스트리 + 백엔드 헬스체크 (30초 루프) - Ollama SSE 프록시 (NDJSON → OpenAI SSE 변환) - JWT 인증 이중 경로 (httpOnly 쿠키 + Bearer 토큰) - owner/guest 역할 분리, 로그인 rate limiting - 백엔드별 rate limiting (NanoClaude 대비) - SQLite 스키마 사전 정의 (aiosqlite + WAL) - Docker Compose + Caddy 리버스 프록시 Co-Authored-By: Claude Opus 4.6 (1M context) diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1603e55 --- /dev/null +++ b/.env.example @@ -0,0 +1,16 @@ +# Auth +OWNER_PASSWORD= +GUEST_PASSWORD= +JWT_SECRET= + +# CORS (dev) +CORS_ORIGINS=http://localhost:5173 + +# GPU +NVIDIA_SMI_PATH=/usr/bin/nvidia-smi + +# Backends config path (in Docker) +BACKENDS_CONFIG=/app/config/backends.json + +# DB path (in Docker) +DB_PATH=/app/data/gateway.db diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..faa622d --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.env +*.pyc +__pycache__/ +*.db +.venv/ +venv/ +node_modules/ +dist/ +.next/ +hub-web/dist/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..42974e2 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,42 @@ +# AI Gateway + +GPU 서버(RTX 4070 Ti Super)에서 운영하는 중앙 AI 라우팅 서비스. +모든 AI 요청을 하나의 OpenAI 호환 API로 통합. + +## 서비스 구조 + +| 서비스 | 디렉토리 | 스택 | 포트 | +|--------|----------|------|------| +| Caddy | caddy/ | Caddy 2 | 80/443 | +| hub-api | hub-api/ | FastAPI + aiosqlite | 8000 | +| hub-web | hub-web/ | Vite + React + shadcn/ui | 3000 (Phase 2) | + +## 외부 연결 + +- GPU Ollama: host.docker.internal:11434 +- 맥미니 Ollama: 100.115.153.119:11434 +- NanoClaude: 100.115.153.119:PORT (Phase 1.5) + +## 개발 + +```bash +cd hub-api +pip install -r requirements.txt +uvicorn main:app --reload --port 8000 +``` + +## 배포 + +```bash +docker compose up -d --build +``` + +## API + +OpenAI 호환: `/v1/chat/completions`, `/v1/models`, `/v1/embeddings` +인증: `/auth/login` → Cookie 또는 Bearer 토큰 +모니터링: `/health`, `/gpu` + +## 백엔드 설정 + +`backends.json`에서 백엔드 추가/제거. 서비스 재시작 필요. diff --git a/backends.json b/backends.json new file mode 100644 index 0000000..c6cb860 --- /dev/null +++ b/backends.json @@ -0,0 +1,13 @@ +[ + { + "id": "ollama-gpu", + "type": "ollama", + "url": "http://host.docker.internal:11434", + "models": [ + { "id": "qwen3.5:9b-q8_0", "capabilities": ["chat"], "priority": 1 }, + { "id": "qwen3-vl:8b", "capabilities": ["chat", "vision"], "priority": 1 } + ], + "access": "all", + "rate_limit": null + } +] diff --git a/caddy/Caddyfile b/caddy/Caddyfile new file mode 100644 index 0000000..d1a0878 --- /dev/null +++ b/caddy/Caddyfile @@ -0,0 +1,21 @@ +:80 { + handle /v1/* { + reverse_proxy hub-api:8000 + } + handle /auth/* { + reverse_proxy hub-api:8000 + } + handle /health { + reverse_proxy hub-api:8000 + } + handle /health/* { + reverse_proxy hub-api:8000 + } + handle /gpu { + reverse_proxy hub-api:8000 + } + handle { + respond "AI Gateway - hub-web not yet deployed" 200 + } +} +# TLS: caddy-tailscale 플러그인 또는 tailscale cert 자동 갱신 사용 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..0d643cd --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,47 @@ +services: + caddy: + image: caddy:2-alpine + container_name: gpu-caddy + restart: unless-stopped + ports: + - "80:80" + - "443:443" + volumes: + - ./caddy/Caddyfile:/etc/caddy/Caddyfile + - caddy_data:/data + depends_on: + - hub-api + networks: + - gateway-net + + hub-api: + build: ./hub-api + container_name: gpu-hub-api + restart: unless-stopped + environment: + - OWNER_PASSWORD=${OWNER_PASSWORD} + - GUEST_PASSWORD=${GUEST_PASSWORD} + - JWT_SECRET=${JWT_SECRET} + - BACKENDS_CONFIG=/app/config/backends.json + - CORS_ORIGINS=${CORS_ORIGINS:-http://localhost:5173} + - DB_PATH=/app/data/gateway.db + volumes: + - hub_data:/app/data + - ./backends.json:/app/config/backends.json:ro + extra_hosts: + - "host.docker.internal:host-gateway" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 15s + timeout: 5s + retries: 3 + networks: + - gateway-net + +volumes: + caddy_data: + hub_data: + +networks: + gateway-net: + name: gpu-gateway-network diff --git a/hub-api/Dockerfile b/hub-api/Dockerfile new file mode 100644 index 0000000..9311978 --- /dev/null +++ b/hub-api/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +RUN mkdir -p /app/data + +EXPOSE 8000 + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/hub-api/config.py b/hub-api/config.py new file mode 100644 index 0000000..dcfefe7 --- /dev/null +++ b/hub-api/config.py @@ -0,0 +1,21 @@ +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + owner_password: str = "changeme" + guest_password: str = "guest" + jwt_secret: str = "dev-secret-change-in-production" + jwt_algorithm: str = "HS256" + jwt_expire_hours: int = 24 + + backends_config: str = "/app/config/backends.json" + cors_origins: str = "http://localhost:5173" + + nvidia_smi_path: str = "/usr/bin/nvidia-smi" + + db_path: str = "/app/data/gateway.db" + + model_config = {"env_file": ".env", "extra": "ignore"} + + +settings = Settings() diff --git a/hub-api/db/__init__.py b/hub-api/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hub-api/db/database.py b/hub-api/db/database.py new file mode 100644 index 0000000..00079e8 --- /dev/null +++ b/hub-api/db/database.py @@ -0,0 +1,50 @@ +import aiosqlite + +from config import settings + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS chat_sessions ( + id TEXT PRIMARY KEY, + title TEXT, + model TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'guest', + created_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS chat_messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES chat_sessions(id), + role TEXT NOT NULL, + content TEXT NOT NULL, + created_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS usage_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + backend_id TEXT NOT NULL, + model TEXT NOT NULL, + prompt_tokens INTEGER DEFAULT 0, + completion_tokens INTEGER DEFAULT 0, + latency_ms REAL DEFAULT 0, + user_role TEXT NOT NULL DEFAULT 'guest', + created_at REAL NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_messages_session ON chat_messages(session_id); +CREATE INDEX IF NOT EXISTS idx_usage_created ON usage_logs(created_at); +""" + + +async def init_db(): + """Initialize SQLite database with WAL mode and schema.""" + async with aiosqlite.connect(settings.db_path) as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.executescript(SCHEMA) + await db.commit() + + +async def get_db() -> aiosqlite.Connection: + """Get a database connection.""" + db = await aiosqlite.connect(settings.db_path) + await db.execute("PRAGMA journal_mode=WAL") + return db diff --git a/hub-api/db/models.py b/hub-api/db/models.py new file mode 100644 index 0000000..ced73eb --- /dev/null +++ b/hub-api/db/models.py @@ -0,0 +1,2 @@ +# DB model helpers — used in Phase 3 for logging +# Schema defined in database.py diff --git a/hub-api/main.py b/hub-api/main.py new file mode 100644 index 0000000..e446732 --- /dev/null +++ b/hub-api/main.py @@ -0,0 +1,46 @@ +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from config import settings +from middleware.auth import AuthMiddleware +from routers import auth, chat, embeddings, gpu, health, models +from services.registry import registry + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await registry.load_backends(settings.backends_config) + registry.start_health_loop() + yield + registry.stop_health_loop() + + +app = FastAPI( + title="AI Gateway", + version="0.1.0", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins.split(","), + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.add_middleware(AuthMiddleware) + +app.include_router(auth.router) +app.include_router(chat.router) +app.include_router(models.router) +app.include_router(embeddings.router) +app.include_router(health.router) +app.include_router(gpu.router) + + +@app.get("/") +async def root(): + return {"service": "AI Gateway", "version": "0.1.0"} diff --git a/hub-api/middleware/__init__.py b/hub-api/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hub-api/middleware/auth.py b/hub-api/middleware/auth.py new file mode 100644 index 0000000..04dffd4 --- /dev/null +++ b/hub-api/middleware/auth.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import time + +from jose import JWTError, jwt +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +from config import settings + +# Paths that don't require authentication +PUBLIC_PATHS = {"/", "/health", "/auth/login", "/docs", "/openapi.json"} +PUBLIC_PREFIXES = ("/health/",) + + +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + path = request.url.path + + # Skip auth for public paths + if path in PUBLIC_PATHS or any(path.startswith(p) for p in PUBLIC_PREFIXES): + request.state.role = "anonymous" + return await call_next(request) + + # Skip auth for OPTIONS (CORS preflight) + if request.method == "OPTIONS": + return await call_next(request) + + # Try Bearer token first, then cookie + token = _extract_token(request) + if not token: + request.state.role = "anonymous" + return await call_next(request) + + # Verify JWT + payload = _verify_token(token) + if payload: + request.state.role = payload.get("role", "guest") + else: + request.state.role = "anonymous" + + return await call_next(request) + + +def create_token(role: str) -> str: + payload = { + "role": role, + "exp": time.time() + settings.jwt_expire_hours * 3600, + "iat": time.time(), + } + return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) + + +def _extract_token(request: Request) -> str | None: + # 1. Authorization: Bearer header + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] + + # 2. httpOnly cookie + return request.cookies.get("token") + + +def _verify_token(token: str) -> dict | None: + try: + payload = jwt.decode( + token, settings.jwt_secret, algorithms=[settings.jwt_algorithm] + ) + if payload.get("exp", 0) < time.time(): + return None + return payload + except JWTError: + return None + + +# Login rate limiting (IP-based) +_login_attempts: dict[str, list[float]] = {} +MAX_ATTEMPTS = 5 +LOCKOUT_SECONDS = 60 + + +def check_login_rate_limit(ip: str) -> bool: + """Returns True if login is allowed for this IP.""" + now = time.time() + attempts = _login_attempts.get(ip, []) + # Clean old attempts + attempts = [t for t in attempts if now - t < LOCKOUT_SECONDS] + _login_attempts[ip] = attempts + return len(attempts) < MAX_ATTEMPTS + + +def record_login_attempt(ip: str): + now = time.time() + if ip not in _login_attempts: + _login_attempts[ip] = [] + _login_attempts[ip].append(now) diff --git a/hub-api/middleware/rate_limit.py b/hub-api/middleware/rate_limit.py new file mode 100644 index 0000000..86ada49 --- /dev/null +++ b/hub-api/middleware/rate_limit.py @@ -0,0 +1,18 @@ +from fastapi import HTTPException + +from services.registry import registry + + +def check_backend_rate_limit(backend_id: str): + """Raise 429 if rate limit exceeded for this backend.""" + if not registry.check_rate_limit(backend_id): + raise HTTPException( + status_code=429, + detail={ + "error": { + "message": f"Rate limit exceeded for backend '{backend_id}'", + "type": "rate_limit_error", + "code": "rate_limit_exceeded", + } + }, + ) diff --git a/hub-api/requirements.txt b/hub-api/requirements.txt new file mode 100644 index 0000000..e08adcb --- /dev/null +++ b/hub-api/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.115.0 +uvicorn[standard]==0.30.0 +httpx==0.27.0 +pydantic-settings==2.5.0 +python-jose[cryptography]==3.3.0 +python-multipart==0.0.9 +aiosqlite==0.20.0 diff --git a/hub-api/routers/__init__.py b/hub-api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hub-api/routers/auth.py b/hub-api/routers/auth.py new file mode 100644 index 0000000..869d908 --- /dev/null +++ b/hub-api/routers/auth.py @@ -0,0 +1,79 @@ +from fastapi import APIRouter, Request, Response +from pydantic import BaseModel + +from config import settings +from middleware.auth import ( + check_login_rate_limit, + create_token, + record_login_attempt, +) + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +class LoginRequest(BaseModel): + password: str + + +class LoginResponse(BaseModel): + role: str + token: str + + +@router.post("/login") +async def login(body: LoginRequest, request: Request, response: Response): + ip = request.client.host if request.client else "unknown" + + if not check_login_rate_limit(ip): + return _error_response(429, "Too many login attempts. Try again in 1 minute.") + + record_login_attempt(ip) + + if body.password == settings.owner_password: + role = "owner" + elif body.password == settings.guest_password: + role = "guest" + else: + return _error_response(401, "Invalid password") + + token = create_token(role) + + # Set httpOnly cookie for web UI + response.set_cookie( + key="token", + value=token, + httponly=True, + samesite="lax", + max_age=settings.jwt_expire_hours * 3600, + ) + + return LoginResponse(role=role, token=token) + + +@router.get("/me") +async def me(request: Request): + role = getattr(request.state, "role", "anonymous") + if role == "anonymous": + return _error_response(401, "Not authenticated") + return {"role": role} + + +@router.post("/logout") +async def logout(response: Response): + response.delete_cookie("token") + return {"ok": True} + + +def _error_response(status_code: int, message: str): + from fastapi.responses import JSONResponse + + return JSONResponse( + status_code=status_code, + content={ + "error": { + "message": message, + "type": "auth_error", + "code": f"auth_{status_code}", + } + }, + ) diff --git a/hub-api/routers/chat.py b/hub-api/routers/chat.py new file mode 100644 index 0000000..f5400cc --- /dev/null +++ b/hub-api/routers/chat.py @@ -0,0 +1,92 @@ +from typing import List, Optional + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from middleware.rate_limit import check_backend_rate_limit +from services import proxy_ollama +from services.registry import registry + +router = APIRouter(prefix="/v1", tags=["chat"]) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatRequest(BaseModel): + model: str + messages: List[ChatMessage] + stream: bool = False + temperature: Optional[float] = None + max_tokens: Optional[int] = None + + +@router.post("/chat/completions") +async def chat_completions(body: ChatRequest, request: Request): + role = getattr(request.state, "role", "anonymous") + if role == "anonymous": + raise HTTPException( + status_code=401, + detail={"error": {"message": "Authentication required", "type": "auth_error", "code": "unauthorized"}}, + ) + + # Resolve model to backend + result = registry.resolve_model(body.model, role) + if not result: + raise HTTPException( + status_code=404, + detail={ + "error": { + "message": f"Model '{body.model}' not found or not available", + "type": "invalid_request_error", + "code": "model_not_found", + } + }, + ) + + backend, model_info = result + + # Check rate limit + check_backend_rate_limit(backend.id) + + # Record request for rate limiting + registry.record_request(backend.id) + + messages = [{"role": m.role, "content": m.content} for m in body.messages] + kwargs = {} + if body.temperature is not None: + kwargs["temperature"] = body.temperature + + # Route to appropriate proxy + if backend.type == "ollama": + if body.stream: + return StreamingResponse( + proxy_ollama.stream_chat( + backend.url, body.model, messages, **kwargs + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + else: + result = await proxy_ollama.complete_chat( + backend.url, body.model, messages, **kwargs + ) + return JSONResponse(content=result) + + # Placeholder for other backend types + raise HTTPException( + status_code=501, + detail={ + "error": { + "message": f"Backend type '{backend.type}' not yet implemented", + "type": "api_error", + "code": "not_implemented", + } + }, + ) diff --git a/hub-api/routers/embeddings.py b/hub-api/routers/embeddings.py new file mode 100644 index 0000000..748b808 --- /dev/null +++ b/hub-api/routers/embeddings.py @@ -0,0 +1,67 @@ +from typing import List, Union + +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel + +from services import proxy_ollama +from services.registry import registry + +router = APIRouter(prefix="/v1", tags=["embeddings"]) + + +class EmbeddingRequest(BaseModel): + model: str + input: Union[str, List[str]] + + +@router.post("/embeddings") +async def create_embedding(body: EmbeddingRequest, request: Request): + role = getattr(request.state, "role", "anonymous") + if role == "anonymous": + raise HTTPException( + status_code=401, + detail={"error": {"message": "Authentication required", "type": "auth_error", "code": "unauthorized"}}, + ) + + result = registry.resolve_model(body.model, role) + if not result: + raise HTTPException( + status_code=404, + detail={ + "error": { + "message": f"Model '{body.model}' not found or not available", + "type": "invalid_request_error", + "code": "model_not_found", + } + }, + ) + + backend, model_info = result + + if "embed" not in model_info.capabilities: + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": f"Model '{body.model}' does not support embeddings", + "type": "invalid_request_error", + "code": "capability_mismatch", + } + }, + ) + + if backend.type == "ollama": + return await proxy_ollama.generate_embedding( + backend.url, body.model, body.input + ) + + raise HTTPException( + status_code=501, + detail={ + "error": { + "message": f"Embedding not supported for backend type '{backend.type}'", + "type": "api_error", + "code": "not_implemented", + } + }, + ) diff --git a/hub-api/routers/gpu.py b/hub-api/routers/gpu.py new file mode 100644 index 0000000..077b3fd --- /dev/null +++ b/hub-api/routers/gpu.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter + +from services.gpu_monitor import get_gpu_info + +router = APIRouter(tags=["gpu"]) + + +@router.get("/gpu") +async def gpu_status(): + info = await get_gpu_info() + if not info: + return {"error": {"message": "GPU info unavailable", "type": "api_error", "code": "gpu_unavailable"}} + return info diff --git a/hub-api/routers/health.py b/hub-api/routers/health.py new file mode 100644 index 0000000..a499544 --- /dev/null +++ b/hub-api/routers/health.py @@ -0,0 +1,31 @@ +from fastapi import APIRouter + +from services.gpu_monitor import get_gpu_info +from services.registry import registry + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +async def health(): + gpu = await get_gpu_info() + return { + "status": "ok", + "backends": registry.get_health_summary(), + "gpu": gpu, + } + + +@router.get("/health/{backend_id}") +async def backend_health(backend_id: str): + backend = registry.backends.get(backend_id) + if not backend: + return {"error": {"message": f"Backend '{backend_id}' not found"}} + + return { + "id": backend.id, + "type": backend.type, + "status": "healthy" if backend.healthy else "down", + "models": [m.id for m in backend.models], + "latency_ms": backend.latency_ms, + } diff --git a/hub-api/routers/models.py b/hub-api/routers/models.py new file mode 100644 index 0000000..6c37142 --- /dev/null +++ b/hub-api/routers/models.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter, Request + +from services.registry import registry + +router = APIRouter(prefix="/v1", tags=["models"]) + + +@router.get("/models") +async def list_models(request: Request): + role = getattr(request.state, "role", "anonymous") + models = registry.list_models(role) + return {"object": "list", "data": models} diff --git a/hub-api/services/__init__.py b/hub-api/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hub-api/services/gpu_monitor.py b/hub-api/services/gpu_monitor.py new file mode 100644 index 0000000..c775d0e --- /dev/null +++ b/hub-api/services/gpu_monitor.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import asyncio +import logging + +from config import settings + +logger = logging.getLogger(__name__) + + +async def get_gpu_info() -> dict | None: + """Run nvidia-smi and parse GPU info.""" + try: + proc = await asyncio.create_subprocess_exec( + settings.nvidia_smi_path, + "--query-gpu=utilization.gpu,temperature.gpu,memory.used,memory.total,power.draw,name", + "--format=csv,noheader,nounits", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=5.0) + + if proc.returncode != 0: + logger.debug("nvidia-smi failed: %s", stderr.decode()) + return None + + line = stdout.decode().strip().split("\n")[0] + parts = [p.strip() for p in line.split(",")] + if len(parts) < 6: + return None + + return { + "utilization": int(parts[0]), + "temperature": int(parts[1]), + "vram_used": int(parts[2]), + "vram_total": int(parts[3]), + "power_draw": float(parts[4]), + "name": parts[5], + } + except (FileNotFoundError, asyncio.TimeoutError): + return None diff --git a/hub-api/services/proxy_ollama.py b/hub-api/services/proxy_ollama.py new file mode 100644 index 0000000..0507660 --- /dev/null +++ b/hub-api/services/proxy_ollama.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import json +import logging +from collections.abc import AsyncGenerator + +import httpx + +logger = logging.getLogger(__name__) + + +async def stream_chat( + base_url: str, + model: str, + messages: list[dict], + **kwargs, +) -> AsyncGenerator[str, None]: + """Proxy Ollama chat streaming, converting NDJSON to OpenAI SSE format.""" + payload = { + "model": model, + "messages": messages, + "stream": True, + **{k: v for k, v in kwargs.items() if v is not None}, + } + + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream( + "POST", + f"{base_url}/api/chat", + json=payload, + ) as resp: + if resp.status_code != 200: + body = await resp.aread() + error_msg = body.decode("utf-8", errors="replace") + yield _error_event(f"Ollama error: {error_msg}") + return + + async for line in resp.aiter_lines(): + if not line.strip(): + continue + try: + chunk = json.loads(line) + except json.JSONDecodeError: + continue + + if chunk.get("done"): + # Final chunk — send [DONE] + yield "data: [DONE]\n\n" + return + + content = chunk.get("message", {}).get("content", "") + if content: + openai_chunk = { + "id": "chatcmpl-gateway", + "object": "chat.completion.chunk", + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": content}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(openai_chunk)}\n\n" + + +async def complete_chat( + base_url: str, + model: str, + messages: list[dict], + **kwargs, +) -> dict: + """Non-streaming Ollama chat, returns OpenAI-compatible response.""" + payload = { + "model": model, + "messages": messages, + "stream": False, + **{k: v for k, v in kwargs.items() if v is not None}, + } + + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{base_url}/api/chat", json=payload) + resp.raise_for_status() + data = resp.json() + + return { + "id": "chatcmpl-gateway", + "object": "chat.completion", + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": data.get("message", {}).get("content", ""), + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": data.get("prompt_eval_count", 0), + "completion_tokens": data.get("eval_count", 0), + "total_tokens": data.get("prompt_eval_count", 0) + + data.get("eval_count", 0), + }, + } + + +async def generate_embedding( + base_url: str, + model: str, + input_text: str | list[str], +) -> dict: + """Ollama embedding, returns OpenAI-compatible response.""" + texts = [input_text] if isinstance(input_text, str) else input_text + + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{base_url}/api/embed", + json={"model": model, "input": texts}, + ) + resp.raise_for_status() + data = resp.json() + + embeddings_data = [] + raw_embeddings = data.get("embeddings", []) + for i, emb in enumerate(raw_embeddings): + embeddings_data.append({ + "object": "embedding", + "embedding": emb, + "index": i, + }) + + return { + "object": "list", + "data": embeddings_data, + "model": model, + "usage": {"prompt_tokens": 0, "total_tokens": 0}, + } + + +def _error_event(message: str) -> str: + error = { + "id": "chatcmpl-gateway", + "object": "chat.completion.chunk", + "model": "error", + "choices": [ + { + "index": 0, + "delta": {"content": f"[Error] {message}"}, + "finish_reason": "stop", + } + ], + } + return f"data: {json.dumps(error)}\n\ndata: [DONE]\n\n" diff --git a/hub-api/services/registry.py b/hub-api/services/registry.py new file mode 100644 index 0000000..a689a43 --- /dev/null +++ b/hub-api/services/registry.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelInfo: + id: str + capabilities: list[str] + priority: int = 1 + + +@dataclass +class RateLimitConfig: + rpm: int = 0 + rph: int = 0 + scope: str = "global" + + +@dataclass +class BackendInfo: + id: str + type: str # "ollama", "openai-compat", "anthropic" + url: str + models: list[ModelInfo] + access: str = "all" # "all" or "owner" + rate_limit: RateLimitConfig | None = None + + # runtime state + healthy: bool = False + last_check: float = 0 + latency_ms: float = 0 + + +@dataclass +class RateLimitState: + minute_timestamps: list[float] = field(default_factory=list) + hour_timestamps: list[float] = field(default_factory=list) + + +class Registry: + def __init__(self): + self.backends: dict[str, BackendInfo] = {} + self._health_task: asyncio.Task | None = None + self._rate_limits: dict[str, RateLimitState] = {} + + async def load_backends(self, config_path: str): + path = Path(config_path) + if not path.exists(): + logger.warning("Backends config not found: %s", config_path) + return + + with open(path) as f: + data = json.load(f) + + for entry in data: + models = [ + ModelInfo( + id=m["id"], + capabilities=m.get("capabilities", ["chat"]), + priority=m.get("priority", 1), + ) + for m in entry.get("models", []) + ] + rl_data = entry.get("rate_limit") + rate_limit = ( + RateLimitConfig( + rpm=rl_data.get("rpm", 0), + rph=rl_data.get("rph", 0), + scope=rl_data.get("scope", "global"), + ) + if rl_data + else None + ) + backend = BackendInfo( + id=entry["id"], + type=entry["type"], + url=entry["url"].rstrip("/"), + models=models, + access=entry.get("access", "all"), + rate_limit=rate_limit, + ) + self.backends[backend.id] = backend + if rate_limit: + self._rate_limits[backend.id] = RateLimitState() + + logger.info("Loaded %d backends", len(self.backends)) + + def start_health_loop(self, interval: float = 30.0): + self._health_task = asyncio.create_task(self._health_loop(interval)) + + def stop_health_loop(self): + if self._health_task: + self._health_task.cancel() + + async def _health_loop(self, interval: float): + while True: + await self._check_all_backends() + await asyncio.sleep(interval) + + async def _check_all_backends(self): + async with httpx.AsyncClient(timeout=5.0) as client: + tasks = [ + self._check_backend(client, backend) + for backend in self.backends.values() + ] + await asyncio.gather(*tasks, return_exceptions=True) + + async def _check_backend(self, client: httpx.AsyncClient, backend: BackendInfo): + try: + start = time.monotonic() + if backend.type == "ollama": + resp = await client.get(f"{backend.url}/api/tags") + elif backend.type in ("openai-compat", "anthropic"): + resp = await client.get(f"{backend.url}/v1/models") + else: + resp = await client.get(f"{backend.url}/health") + elapsed = (time.monotonic() - start) * 1000 + + backend.healthy = resp.status_code < 500 + backend.latency_ms = round(elapsed, 1) + backend.last_check = time.time() + except Exception: + backend.healthy = False + backend.latency_ms = 0 + backend.last_check = time.time() + logger.debug("Health check failed for %s", backend.id) + + def resolve_model(self, model_id: str, role: str) -> tuple[BackendInfo, ModelInfo] | None: + """Find the best backend for a given model ID. Returns (backend, model) or None.""" + candidates: list[tuple[BackendInfo, ModelInfo, int]] = [] + + for backend in self.backends.values(): + if not backend.healthy: + continue + if backend.access == "owner" and role != "owner": + continue + for model in backend.models: + if model.id == model_id: + candidates.append((backend, model, model.priority)) + + if not candidates: + return None + + candidates.sort(key=lambda x: x[2]) + return candidates[0][0], candidates[0][1] + + def list_models(self, role: str) -> list[dict]: + """List all available models for a given role.""" + result = [] + for backend in self.backends.values(): + if not backend.healthy: + continue + if backend.access == "owner" and role != "owner": + continue + for model in backend.models: + result.append({ + "id": model.id, + "object": "model", + "owned_by": backend.id, + "capabilities": model.capabilities, + "backend_id": backend.id, + "backend_status": "healthy" if backend.healthy else "down", + }) + return result + + def check_rate_limit(self, backend_id: str) -> bool: + """Check if a request to this backend is within rate limits. Returns True if allowed.""" + backend = self.backends.get(backend_id) + if not backend or not backend.rate_limit: + return True + + state = self._rate_limits.get(backend_id) + if not state: + return True + + now = time.time() + rl = backend.rate_limit + + # Clean old timestamps + if rl.rpm > 0: + state.minute_timestamps = [t for t in state.minute_timestamps if now - t < 60] + if len(state.minute_timestamps) >= rl.rpm: + return False + + if rl.rph > 0: + state.hour_timestamps = [t for t in state.hour_timestamps if now - t < 3600] + if len(state.hour_timestamps) >= rl.rph: + return False + + return True + + def record_request(self, backend_id: str): + """Record a request timestamp for rate limiting.""" + state = self._rate_limits.get(backend_id) + if not state: + return + now = time.time() + state.minute_timestamps.append(now) + state.hour_timestamps.append(now) + + def get_health_summary(self) -> list[dict]: + return [ + { + "id": b.id, + "type": b.type, + "status": "healthy" if b.healthy else "down", + "models": [m.id for m in b.models], + "latency_ms": b.latency_ms, + "last_check": b.last_check, + } + for b in self.backends.values() + ] + + +registry = Registry()