feat: AI Gateway Phase 1 - FastAPI 코어 구현
GPU 서버 중앙 AI 라우팅 서비스 초기 구현: - OpenAI 호환 API (/v1/chat/completions, /v1/models, /v1/embeddings) - 모델 레지스트리 + 백엔드 헬스체크 (30초 루프) - Ollama SSE 프록시 (NDJSON → OpenAI SSE 변환) - JWT 인증 이중 경로 (httpOnly 쿠키 + Bearer 토큰) - owner/guest 역할 분리, 로그인 rate limiting - 백엔드별 rate limiting (NanoClaude 대비) - SQLite 스키마 사전 정의 (aiosqlite + WAL) - Docker Compose + Caddy 리버스 프록시 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
16
.env.example
Normal file
16
.env.example
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Auth
|
||||||
|
OWNER_PASSWORD=
|
||||||
|
GUEST_PASSWORD=
|
||||||
|
JWT_SECRET=
|
||||||
|
|
||||||
|
# CORS (dev)
|
||||||
|
CORS_ORIGINS=http://localhost:5173
|
||||||
|
|
||||||
|
# GPU
|
||||||
|
NVIDIA_SMI_PATH=/usr/bin/nvidia-smi
|
||||||
|
|
||||||
|
# Backends config path (in Docker)
|
||||||
|
BACKENDS_CONFIG=/app/config/backends.json
|
||||||
|
|
||||||
|
# DB path (in Docker)
|
||||||
|
DB_PATH=/app/data/gateway.db
|
||||||
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
.env
|
||||||
|
*.pyc
|
||||||
|
__pycache__/
|
||||||
|
*.db
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
node_modules/
|
||||||
|
dist/
|
||||||
|
.next/
|
||||||
|
hub-web/dist/
|
||||||
42
CLAUDE.md
Normal file
42
CLAUDE.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# AI Gateway
|
||||||
|
|
||||||
|
GPU 서버(RTX 4070 Ti Super)에서 운영하는 중앙 AI 라우팅 서비스.
|
||||||
|
모든 AI 요청을 하나의 OpenAI 호환 API로 통합.
|
||||||
|
|
||||||
|
## 서비스 구조
|
||||||
|
|
||||||
|
| 서비스 | 디렉토리 | 스택 | 포트 |
|
||||||
|
|--------|----------|------|------|
|
||||||
|
| Caddy | caddy/ | Caddy 2 | 80/443 |
|
||||||
|
| hub-api | hub-api/ | FastAPI + aiosqlite | 8000 |
|
||||||
|
| hub-web | hub-web/ | Vite + React + shadcn/ui | 3000 (Phase 2) |
|
||||||
|
|
||||||
|
## 외부 연결
|
||||||
|
|
||||||
|
- GPU Ollama: host.docker.internal:11434
|
||||||
|
- 맥미니 Ollama: 100.115.153.119:11434
|
||||||
|
- NanoClaude: 100.115.153.119:PORT (Phase 1.5)
|
||||||
|
|
||||||
|
## 개발
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd hub-api
|
||||||
|
pip install -r requirements.txt
|
||||||
|
uvicorn main:app --reload --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 배포
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d --build
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
OpenAI 호환: `/v1/chat/completions`, `/v1/models`, `/v1/embeddings`
|
||||||
|
인증: `/auth/login` → Cookie 또는 Bearer 토큰
|
||||||
|
모니터링: `/health`, `/gpu`
|
||||||
|
|
||||||
|
## 백엔드 설정
|
||||||
|
|
||||||
|
`backends.json`에서 백엔드 추가/제거. 서비스 재시작 필요.
|
||||||
13
backends.json
Normal file
13
backends.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "ollama-gpu",
|
||||||
|
"type": "ollama",
|
||||||
|
"url": "http://host.docker.internal:11434",
|
||||||
|
"models": [
|
||||||
|
{ "id": "qwen3.5:9b-q8_0", "capabilities": ["chat"], "priority": 1 },
|
||||||
|
{ "id": "qwen3-vl:8b", "capabilities": ["chat", "vision"], "priority": 1 }
|
||||||
|
],
|
||||||
|
"access": "all",
|
||||||
|
"rate_limit": null
|
||||||
|
}
|
||||||
|
]
|
||||||
21
caddy/Caddyfile
Normal file
21
caddy/Caddyfile
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
:80 {
|
||||||
|
handle /v1/* {
|
||||||
|
reverse_proxy hub-api:8000
|
||||||
|
}
|
||||||
|
handle /auth/* {
|
||||||
|
reverse_proxy hub-api:8000
|
||||||
|
}
|
||||||
|
handle /health {
|
||||||
|
reverse_proxy hub-api:8000
|
||||||
|
}
|
||||||
|
handle /health/* {
|
||||||
|
reverse_proxy hub-api:8000
|
||||||
|
}
|
||||||
|
handle /gpu {
|
||||||
|
reverse_proxy hub-api:8000
|
||||||
|
}
|
||||||
|
handle {
|
||||||
|
respond "AI Gateway - hub-web not yet deployed" 200
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# TLS: caddy-tailscale 플러그인 또는 tailscale cert 자동 갱신 사용
|
||||||
47
docker-compose.yml
Normal file
47
docker-compose.yml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
services:
|
||||||
|
caddy:
|
||||||
|
image: caddy:2-alpine
|
||||||
|
container_name: gpu-caddy
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
- "443:443"
|
||||||
|
volumes:
|
||||||
|
- ./caddy/Caddyfile:/etc/caddy/Caddyfile
|
||||||
|
- caddy_data:/data
|
||||||
|
depends_on:
|
||||||
|
- hub-api
|
||||||
|
networks:
|
||||||
|
- gateway-net
|
||||||
|
|
||||||
|
hub-api:
|
||||||
|
build: ./hub-api
|
||||||
|
container_name: gpu-hub-api
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- OWNER_PASSWORD=${OWNER_PASSWORD}
|
||||||
|
- GUEST_PASSWORD=${GUEST_PASSWORD}
|
||||||
|
- JWT_SECRET=${JWT_SECRET}
|
||||||
|
- BACKENDS_CONFIG=/app/config/backends.json
|
||||||
|
- CORS_ORIGINS=${CORS_ORIGINS:-http://localhost:5173}
|
||||||
|
- DB_PATH=/app/data/gateway.db
|
||||||
|
volumes:
|
||||||
|
- hub_data:/app/data
|
||||||
|
- ./backends.json:/app/config/backends.json:ro
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
interval: 15s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
networks:
|
||||||
|
- gateway-net
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
caddy_data:
|
||||||
|
hub_data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
gateway-net:
|
||||||
|
name: gpu-gateway-network
|
||||||
16
hub-api/Dockerfile
Normal file
16
hub-api/Dockerfile
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN mkdir -p /app/data
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
21
hub-api/config.py
Normal file
21
hub-api/config.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
owner_password: str = "changeme"
|
||||||
|
guest_password: str = "guest"
|
||||||
|
jwt_secret: str = "dev-secret-change-in-production"
|
||||||
|
jwt_algorithm: str = "HS256"
|
||||||
|
jwt_expire_hours: int = 24
|
||||||
|
|
||||||
|
backends_config: str = "/app/config/backends.json"
|
||||||
|
cors_origins: str = "http://localhost:5173"
|
||||||
|
|
||||||
|
nvidia_smi_path: str = "/usr/bin/nvidia-smi"
|
||||||
|
|
||||||
|
db_path: str = "/app/data/gateway.db"
|
||||||
|
|
||||||
|
model_config = {"env_file": ".env", "extra": "ignore"}
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
0
hub-api/db/__init__.py
Normal file
0
hub-api/db/__init__.py
Normal file
50
hub-api/db/database.py
Normal file
50
hub-api/db/database.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
title TEXT,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL DEFAULT 'guest',
|
||||||
|
created_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL REFERENCES chat_sessions(id),
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS usage_logs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
backend_id TEXT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
prompt_tokens INTEGER DEFAULT 0,
|
||||||
|
completion_tokens INTEGER DEFAULT 0,
|
||||||
|
latency_ms REAL DEFAULT 0,
|
||||||
|
user_role TEXT NOT NULL DEFAULT 'guest',
|
||||||
|
created_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_session ON chat_messages(session_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_usage_created ON usage_logs(created_at);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
"""Initialize SQLite database with WAL mode and schema."""
|
||||||
|
async with aiosqlite.connect(settings.db_path) as db:
|
||||||
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
await db.executescript(SCHEMA)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> aiosqlite.Connection:
|
||||||
|
"""Get a database connection."""
|
||||||
|
db = await aiosqlite.connect(settings.db_path)
|
||||||
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
|
return db
|
||||||
2
hub-api/db/models.py
Normal file
2
hub-api/db/models.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# DB model helpers — used in Phase 3 for logging
|
||||||
|
# Schema defined in database.py
|
||||||
46
hub-api/main.py
Normal file
46
hub-api/main.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from middleware.auth import AuthMiddleware
|
||||||
|
from routers import auth, chat, embeddings, gpu, health, models
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
await registry.load_backends(settings.backends_config)
|
||||||
|
registry.start_health_loop()
|
||||||
|
yield
|
||||||
|
registry.stop_health_loop()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AI Gateway",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins.split(","),
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
app.include_router(auth.router)
|
||||||
|
app.include_router(chat.router)
|
||||||
|
app.include_router(models.router)
|
||||||
|
app.include_router(embeddings.router)
|
||||||
|
app.include_router(health.router)
|
||||||
|
app.include_router(gpu.router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"service": "AI Gateway", "version": "0.1.0"}
|
||||||
0
hub-api/middleware/__init__.py
Normal file
0
hub-api/middleware/__init__.py
Normal file
96
hub-api/middleware/auth.py
Normal file
96
hub-api/middleware/auth.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
# Paths that don't require authentication
|
||||||
|
PUBLIC_PATHS = {"/", "/health", "/auth/login", "/docs", "/openapi.json"}
|
||||||
|
PUBLIC_PREFIXES = ("/health/",)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# Skip auth for public paths
|
||||||
|
if path in PUBLIC_PATHS or any(path.startswith(p) for p in PUBLIC_PREFIXES):
|
||||||
|
request.state.role = "anonymous"
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Skip auth for OPTIONS (CORS preflight)
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Try Bearer token first, then cookie
|
||||||
|
token = _extract_token(request)
|
||||||
|
if not token:
|
||||||
|
request.state.role = "anonymous"
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Verify JWT
|
||||||
|
payload = _verify_token(token)
|
||||||
|
if payload:
|
||||||
|
request.state.role = payload.get("role", "guest")
|
||||||
|
else:
|
||||||
|
request.state.role = "anonymous"
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(role: str) -> str:
|
||||||
|
payload = {
|
||||||
|
"role": role,
|
||||||
|
"exp": time.time() + settings.jwt_expire_hours * 3600,
|
||||||
|
"iat": time.time(),
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_token(request: Request) -> str | None:
|
||||||
|
# 1. Authorization: Bearer header
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
return auth_header[7:]
|
||||||
|
|
||||||
|
# 2. httpOnly cookie
|
||||||
|
return request.cookies.get("token")
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_token(token: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.jwt_secret, algorithms=[settings.jwt_algorithm]
|
||||||
|
)
|
||||||
|
if payload.get("exp", 0) < time.time():
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Login rate limiting (IP-based)
|
||||||
|
_login_attempts: dict[str, list[float]] = {}
|
||||||
|
MAX_ATTEMPTS = 5
|
||||||
|
LOCKOUT_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def check_login_rate_limit(ip: str) -> bool:
|
||||||
|
"""Returns True if login is allowed for this IP."""
|
||||||
|
now = time.time()
|
||||||
|
attempts = _login_attempts.get(ip, [])
|
||||||
|
# Clean old attempts
|
||||||
|
attempts = [t for t in attempts if now - t < LOCKOUT_SECONDS]
|
||||||
|
_login_attempts[ip] = attempts
|
||||||
|
return len(attempts) < MAX_ATTEMPTS
|
||||||
|
|
||||||
|
|
||||||
|
def record_login_attempt(ip: str):
|
||||||
|
now = time.time()
|
||||||
|
if ip not in _login_attempts:
|
||||||
|
_login_attempts[ip] = []
|
||||||
|
_login_attempts[ip].append(now)
|
||||||
18
hub-api/middleware/rate_limit.py
Normal file
18
hub-api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
def check_backend_rate_limit(backend_id: str):
|
||||||
|
"""Raise 429 if rate limit exceeded for this backend."""
|
||||||
|
if not registry.check_rate_limit(backend_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Rate limit exceeded for backend '{backend_id}'",
|
||||||
|
"type": "rate_limit_error",
|
||||||
|
"code": "rate_limit_exceeded",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
7
hub-api/requirements.txt
Normal file
7
hub-api/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
fastapi==0.115.0
|
||||||
|
uvicorn[standard]==0.30.0
|
||||||
|
httpx==0.27.0
|
||||||
|
pydantic-settings==2.5.0
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
|
python-multipart==0.0.9
|
||||||
|
aiosqlite==0.20.0
|
||||||
0
hub-api/routers/__init__.py
Normal file
0
hub-api/routers/__init__.py
Normal file
79
hub-api/routers/auth.py
Normal file
79
hub-api/routers/auth.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from fastapi import APIRouter, Request, Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from middleware.auth import (
|
||||||
|
check_login_rate_limit,
|
||||||
|
create_token,
|
||||||
|
record_login_attempt,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
role: str
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login(body: LoginRequest, request: Request, response: Response):
|
||||||
|
ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
if not check_login_rate_limit(ip):
|
||||||
|
return _error_response(429, "Too many login attempts. Try again in 1 minute.")
|
||||||
|
|
||||||
|
record_login_attempt(ip)
|
||||||
|
|
||||||
|
if body.password == settings.owner_password:
|
||||||
|
role = "owner"
|
||||||
|
elif body.password == settings.guest_password:
|
||||||
|
role = "guest"
|
||||||
|
else:
|
||||||
|
return _error_response(401, "Invalid password")
|
||||||
|
|
||||||
|
token = create_token(role)
|
||||||
|
|
||||||
|
# Set httpOnly cookie for web UI
|
||||||
|
response.set_cookie(
|
||||||
|
key="token",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=settings.jwt_expire_hours * 3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LoginResponse(role=role, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def me(request: Request):
|
||||||
|
role = getattr(request.state, "role", "anonymous")
|
||||||
|
if role == "anonymous":
|
||||||
|
return _error_response(401, "Not authenticated")
|
||||||
|
return {"role": role}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
async def logout(response: Response):
|
||||||
|
response.delete_cookie("token")
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def _error_response(status_code: int, message: str):
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": message,
|
||||||
|
"type": "auth_error",
|
||||||
|
"code": f"auth_{status_code}",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
92
hub-api/routers/chat.py
Normal file
92
hub-api/routers/chat.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from middleware.rate_limit import check_backend_rate_limit
|
||||||
|
from services import proxy_ollama
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
stream: bool = False
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/chat/completions")
|
||||||
|
async def chat_completions(body: ChatRequest, request: Request):
|
||||||
|
role = getattr(request.state, "role", "anonymous")
|
||||||
|
if role == "anonymous":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"error": {"message": "Authentication required", "type": "auth_error", "code": "unauthorized"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve model to backend
|
||||||
|
result = registry.resolve_model(body.model, role)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Model '{body.model}' not found or not available",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "model_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
backend, model_info = result
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
check_backend_rate_limit(backend.id)
|
||||||
|
|
||||||
|
# Record request for rate limiting
|
||||||
|
registry.record_request(backend.id)
|
||||||
|
|
||||||
|
messages = [{"role": m.role, "content": m.content} for m in body.messages]
|
||||||
|
kwargs = {}
|
||||||
|
if body.temperature is not None:
|
||||||
|
kwargs["temperature"] = body.temperature
|
||||||
|
|
||||||
|
# Route to appropriate proxy
|
||||||
|
if backend.type == "ollama":
|
||||||
|
if body.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
proxy_ollama.stream_chat(
|
||||||
|
backend.url, body.model, messages, **kwargs
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await proxy_ollama.complete_chat(
|
||||||
|
backend.url, body.model, messages, **kwargs
|
||||||
|
)
|
||||||
|
return JSONResponse(content=result)
|
||||||
|
|
||||||
|
# Placeholder for other backend types
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=501,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Backend type '{backend.type}' not yet implemented",
|
||||||
|
"type": "api_error",
|
||||||
|
"code": "not_implemented",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
67
hub-api/routers/embeddings.py
Normal file
67
hub-api/routers/embeddings.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from services import proxy_ollama
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1", tags=["embeddings"])
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
input: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embeddings")
|
||||||
|
async def create_embedding(body: EmbeddingRequest, request: Request):
|
||||||
|
role = getattr(request.state, "role", "anonymous")
|
||||||
|
if role == "anonymous":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"error": {"message": "Authentication required", "type": "auth_error", "code": "unauthorized"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = registry.resolve_model(body.model, role)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Model '{body.model}' not found or not available",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "model_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
backend, model_info = result
|
||||||
|
|
||||||
|
if "embed" not in model_info.capabilities:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Model '{body.model}' does not support embeddings",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "capability_mismatch",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend.type == "ollama":
|
||||||
|
return await proxy_ollama.generate_embedding(
|
||||||
|
backend.url, body.model, body.input
|
||||||
|
)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=501,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Embedding not supported for backend type '{backend.type}'",
|
||||||
|
"type": "api_error",
|
||||||
|
"code": "not_implemented",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
13
hub-api/routers/gpu.py
Normal file
13
hub-api/routers/gpu.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from services.gpu_monitor import get_gpu_info
|
||||||
|
|
||||||
|
router = APIRouter(tags=["gpu"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/gpu")
|
||||||
|
async def gpu_status():
|
||||||
|
info = await get_gpu_info()
|
||||||
|
if not info:
|
||||||
|
return {"error": {"message": "GPU info unavailable", "type": "api_error", "code": "gpu_unavailable"}}
|
||||||
|
return info
|
||||||
31
hub-api/routers/health.py
Normal file
31
hub-api/routers/health.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from services.gpu_monitor import get_gpu_info
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
router = APIRouter(tags=["health"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health():
|
||||||
|
gpu = await get_gpu_info()
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"backends": registry.get_health_summary(),
|
||||||
|
"gpu": gpu,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/{backend_id}")
|
||||||
|
async def backend_health(backend_id: str):
|
||||||
|
backend = registry.backends.get(backend_id)
|
||||||
|
if not backend:
|
||||||
|
return {"error": {"message": f"Backend '{backend_id}' not found"}}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": backend.id,
|
||||||
|
"type": backend.type,
|
||||||
|
"status": "healthy" if backend.healthy else "down",
|
||||||
|
"models": [m.id for m in backend.models],
|
||||||
|
"latency_ms": backend.latency_ms,
|
||||||
|
}
|
||||||
12
hub-api/routers/models.py
Normal file
12
hub-api/routers/models.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from fastapi import APIRouter, Request
|
||||||
|
|
||||||
|
from services.registry import registry
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def list_models(request: Request):
|
||||||
|
role = getattr(request.state, "role", "anonymous")
|
||||||
|
models = registry.list_models(role)
|
||||||
|
return {"object": "list", "data": models}
|
||||||
0
hub-api/services/__init__.py
Normal file
0
hub-api/services/__init__.py
Normal file
41
hub-api/services/gpu_monitor.py
Normal file
41
hub-api/services/gpu_monitor.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_gpu_info() -> dict | None:
|
||||||
|
"""Run nvidia-smi and parse GPU info."""
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
settings.nvidia_smi_path,
|
||||||
|
"--query-gpu=utilization.gpu,temperature.gpu,memory.used,memory.total,power.draw,name",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=5.0)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
logger.debug("nvidia-smi failed: %s", stderr.decode())
|
||||||
|
return None
|
||||||
|
|
||||||
|
line = stdout.decode().strip().split("\n")[0]
|
||||||
|
parts = [p.strip() for p in line.split(",")]
|
||||||
|
if len(parts) < 6:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"utilization": int(parts[0]),
|
||||||
|
"temperature": int(parts[1]),
|
||||||
|
"vram_used": int(parts[2]),
|
||||||
|
"vram_total": int(parts[3]),
|
||||||
|
"power_draw": float(parts[4]),
|
||||||
|
"name": parts[5],
|
||||||
|
}
|
||||||
|
except (FileNotFoundError, asyncio.TimeoutError):
|
||||||
|
return None
|
||||||
156
hub-api/services/proxy_ollama.py
Normal file
156
hub-api/services/proxy_ollama.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict],
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Proxy Ollama chat streaming, converting NDJSON to OpenAI SSE format."""
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": True,
|
||||||
|
**{k: v for k, v in kwargs.items() if v is not None},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{base_url}/api/chat",
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
if resp.status_code != 200:
|
||||||
|
body = await resp.aread()
|
||||||
|
error_msg = body.decode("utf-8", errors="replace")
|
||||||
|
yield _error_event(f"Ollama error: {error_msg}")
|
||||||
|
return
|
||||||
|
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if chunk.get("done"):
|
||||||
|
# Final chunk — send [DONE]
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
content = chunk.get("message", {}).get("content", "")
|
||||||
|
if content:
|
||||||
|
openai_chunk = {
|
||||||
|
"id": "chatcmpl-gateway",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": content},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_chat(
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict],
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""Non-streaming Ollama chat, returns OpenAI-compatible response."""
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
**{k: v for k, v in kwargs.items() if v is not None},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(f"{base_url}/api/chat", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": "chatcmpl-gateway",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": data.get("message", {}).get("content", ""),
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||||
|
"completion_tokens": data.get("eval_count", 0),
|
||||||
|
"total_tokens": data.get("prompt_eval_count", 0)
|
||||||
|
+ data.get("eval_count", 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_embedding(
|
||||||
|
base_url: str,
|
||||||
|
model: str,
|
||||||
|
input_text: str | list[str],
|
||||||
|
) -> dict:
|
||||||
|
"""Ollama embedding, returns OpenAI-compatible response."""
|
||||||
|
texts = [input_text] if isinstance(input_text, str) else input_text
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{base_url}/api/embed",
|
||||||
|
json={"model": model, "input": texts},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
embeddings_data = []
|
||||||
|
raw_embeddings = data.get("embeddings", [])
|
||||||
|
for i, emb in enumerate(raw_embeddings):
|
||||||
|
embeddings_data.append({
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": emb,
|
||||||
|
"index": i,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": embeddings_data,
|
||||||
|
"model": model,
|
||||||
|
"usage": {"prompt_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _error_event(message: str) -> str:
|
||||||
|
error = {
|
||||||
|
"id": "chatcmpl-gateway",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"model": "error",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": f"[Error] {message}"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(error)}\n\ndata: [DONE]\n\n"
|
||||||
225
hub-api/services/registry.py
Normal file
225
hub-api/services/registry.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
id: str
|
||||||
|
capabilities: list[str]
|
||||||
|
priority: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitConfig:
|
||||||
|
rpm: int = 0
|
||||||
|
rph: int = 0
|
||||||
|
scope: str = "global"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendInfo:
|
||||||
|
id: str
|
||||||
|
type: str # "ollama", "openai-compat", "anthropic"
|
||||||
|
url: str
|
||||||
|
models: list[ModelInfo]
|
||||||
|
access: str = "all" # "all" or "owner"
|
||||||
|
rate_limit: RateLimitConfig | None = None
|
||||||
|
|
||||||
|
# runtime state
|
||||||
|
healthy: bool = False
|
||||||
|
last_check: float = 0
|
||||||
|
latency_ms: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitState:
|
||||||
|
minute_timestamps: list[float] = field(default_factory=list)
|
||||||
|
hour_timestamps: list[float] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
def __init__(self):
|
||||||
|
self.backends: dict[str, BackendInfo] = {}
|
||||||
|
self._health_task: asyncio.Task | None = None
|
||||||
|
self._rate_limits: dict[str, RateLimitState] = {}
|
||||||
|
|
||||||
|
async def load_backends(self, config_path: str):
|
||||||
|
path = Path(config_path)
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning("Backends config not found: %s", config_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for entry in data:
|
||||||
|
models = [
|
||||||
|
ModelInfo(
|
||||||
|
id=m["id"],
|
||||||
|
capabilities=m.get("capabilities", ["chat"]),
|
||||||
|
priority=m.get("priority", 1),
|
||||||
|
)
|
||||||
|
for m in entry.get("models", [])
|
||||||
|
]
|
||||||
|
rl_data = entry.get("rate_limit")
|
||||||
|
rate_limit = (
|
||||||
|
RateLimitConfig(
|
||||||
|
rpm=rl_data.get("rpm", 0),
|
||||||
|
rph=rl_data.get("rph", 0),
|
||||||
|
scope=rl_data.get("scope", "global"),
|
||||||
|
)
|
||||||
|
if rl_data
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
backend = BackendInfo(
|
||||||
|
id=entry["id"],
|
||||||
|
type=entry["type"],
|
||||||
|
url=entry["url"].rstrip("/"),
|
||||||
|
models=models,
|
||||||
|
access=entry.get("access", "all"),
|
||||||
|
rate_limit=rate_limit,
|
||||||
|
)
|
||||||
|
self.backends[backend.id] = backend
|
||||||
|
if rate_limit:
|
||||||
|
self._rate_limits[backend.id] = RateLimitState()
|
||||||
|
|
||||||
|
logger.info("Loaded %d backends", len(self.backends))
|
||||||
|
|
||||||
|
def start_health_loop(self, interval: float = 30.0):
|
||||||
|
self._health_task = asyncio.create_task(self._health_loop(interval))
|
||||||
|
|
||||||
|
def stop_health_loop(self):
|
||||||
|
if self._health_task:
|
||||||
|
self._health_task.cancel()
|
||||||
|
|
||||||
|
async def _health_loop(self, interval: float):
|
||||||
|
while True:
|
||||||
|
await self._check_all_backends()
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
async def _check_all_backends(self):
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
tasks = [
|
||||||
|
self._check_backend(client, backend)
|
||||||
|
for backend in self.backends.values()
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def _check_backend(self, client: httpx.AsyncClient, backend: BackendInfo):
|
||||||
|
try:
|
||||||
|
start = time.monotonic()
|
||||||
|
if backend.type == "ollama":
|
||||||
|
resp = await client.get(f"{backend.url}/api/tags")
|
||||||
|
elif backend.type in ("openai-compat", "anthropic"):
|
||||||
|
resp = await client.get(f"{backend.url}/v1/models")
|
||||||
|
else:
|
||||||
|
resp = await client.get(f"{backend.url}/health")
|
||||||
|
elapsed = (time.monotonic() - start) * 1000
|
||||||
|
|
||||||
|
backend.healthy = resp.status_code < 500
|
||||||
|
backend.latency_ms = round(elapsed, 1)
|
||||||
|
backend.last_check = time.time()
|
||||||
|
except Exception:
|
||||||
|
backend.healthy = False
|
||||||
|
backend.latency_ms = 0
|
||||||
|
backend.last_check = time.time()
|
||||||
|
logger.debug("Health check failed for %s", backend.id)
|
||||||
|
|
||||||
|
def resolve_model(self, model_id: str, role: str) -> tuple[BackendInfo, ModelInfo] | None:
|
||||||
|
"""Find the best backend for a given model ID. Returns (backend, model) or None."""
|
||||||
|
candidates: list[tuple[BackendInfo, ModelInfo, int]] = []
|
||||||
|
|
||||||
|
for backend in self.backends.values():
|
||||||
|
if not backend.healthy:
|
||||||
|
continue
|
||||||
|
if backend.access == "owner" and role != "owner":
|
||||||
|
continue
|
||||||
|
for model in backend.models:
|
||||||
|
if model.id == model_id:
|
||||||
|
candidates.append((backend, model, model.priority))
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates.sort(key=lambda x: x[2])
|
||||||
|
return candidates[0][0], candidates[0][1]
|
||||||
|
|
||||||
|
def list_models(self, role: str) -> list[dict]:
|
||||||
|
"""List all available models for a given role."""
|
||||||
|
result = []
|
||||||
|
for backend in self.backends.values():
|
||||||
|
if not backend.healthy:
|
||||||
|
continue
|
||||||
|
if backend.access == "owner" and role != "owner":
|
||||||
|
continue
|
||||||
|
for model in backend.models:
|
||||||
|
result.append({
|
||||||
|
"id": model.id,
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": backend.id,
|
||||||
|
"capabilities": model.capabilities,
|
||||||
|
"backend_id": backend.id,
|
||||||
|
"backend_status": "healthy" if backend.healthy else "down",
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def check_rate_limit(self, backend_id: str) -> bool:
|
||||||
|
"""Check if a request to this backend is within rate limits. Returns True if allowed."""
|
||||||
|
backend = self.backends.get(backend_id)
|
||||||
|
if not backend or not backend.rate_limit:
|
||||||
|
return True
|
||||||
|
|
||||||
|
state = self._rate_limits.get(backend_id)
|
||||||
|
if not state:
|
||||||
|
return True
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
rl = backend.rate_limit
|
||||||
|
|
||||||
|
# Clean old timestamps
|
||||||
|
if rl.rpm > 0:
|
||||||
|
state.minute_timestamps = [t for t in state.minute_timestamps if now - t < 60]
|
||||||
|
if len(state.minute_timestamps) >= rl.rpm:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if rl.rph > 0:
|
||||||
|
state.hour_timestamps = [t for t in state.hour_timestamps if now - t < 3600]
|
||||||
|
if len(state.hour_timestamps) >= rl.rph:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def record_request(self, backend_id: str):
|
||||||
|
"""Record a request timestamp for rate limiting."""
|
||||||
|
state = self._rate_limits.get(backend_id)
|
||||||
|
if not state:
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
state.minute_timestamps.append(now)
|
||||||
|
state.hour_timestamps.append(now)
|
||||||
|
|
||||||
|
def get_health_summary(self) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": b.id,
|
||||||
|
"type": b.type,
|
||||||
|
"status": "healthy" if b.healthy else "down",
|
||||||
|
"models": [m.id for m in b.models],
|
||||||
|
"latency_ms": b.latency_ms,
|
||||||
|
"last_check": b.last_check,
|
||||||
|
}
|
||||||
|
for b in self.backends.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
registry = Registry()
|
||||||
Reference in New Issue
Block a user