from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass, field from pathlib import Path import httpx logger = logging.getLogger(__name__) @dataclass class ModelInfo: id: str capabilities: list[str] priority: int = 1 backend_model_id: str = "" # actual model ID sent to backend (if different from id) @dataclass class RateLimitConfig: rpm: int = 0 rph: int = 0 scope: str = "global" @dataclass class BackendInfo: id: str type: str # "ollama", "openai-compat", "anthropic" url: str models: list[ModelInfo] access: str = "all" # "all" or "owner" rate_limit: RateLimitConfig | None = None # runtime state healthy: bool = False last_check: float = 0 latency_ms: float = 0 @dataclass class RateLimitState: minute_timestamps: list[float] = field(default_factory=list) hour_timestamps: list[float] = field(default_factory=list) class Registry: def __init__(self): self.backends: dict[str, BackendInfo] = {} self._health_task: asyncio.Task | None = None self._rate_limits: dict[str, RateLimitState] = {} async def load_backends(self, config_path: str): path = Path(config_path) if not path.exists(): logger.warning("Backends config not found: %s", config_path) return with open(path) as f: data = json.load(f) for entry in data: models = [ ModelInfo( id=m["id"], capabilities=m.get("capabilities", ["chat"]), priority=m.get("priority", 1), backend_model_id=m.get("backend_model_id", ""), ) for m in entry.get("models", []) ] rl_data = entry.get("rate_limit") rate_limit = ( RateLimitConfig( rpm=rl_data.get("rpm", 0), rph=rl_data.get("rph", 0), scope=rl_data.get("scope", "global"), ) if rl_data else None ) backend = BackendInfo( id=entry["id"], type=entry["type"], url=entry["url"].rstrip("/"), models=models, access=entry.get("access", "all"), rate_limit=rate_limit, ) self.backends[backend.id] = backend if rate_limit: self._rate_limits[backend.id] = RateLimitState() logger.info("Loaded %d backends", len(self.backends)) def start_health_loop(self, interval: float = 30.0): self._health_task = asyncio.create_task(self._health_loop(interval)) def stop_health_loop(self): if self._health_task: self._health_task.cancel() async def _health_loop(self, interval: float): while True: await self._check_all_backends() await asyncio.sleep(interval) async def _check_all_backends(self): async with httpx.AsyncClient(timeout=5.0) as client: tasks = [ self._check_backend(client, backend) for backend in self.backends.values() ] await asyncio.gather(*tasks, return_exceptions=True) async def _check_backend(self, client: httpx.AsyncClient, backend: BackendInfo): try: start = time.monotonic() if backend.type == "ollama": resp = await client.get(f"{backend.url}/api/tags") elif backend.type in ("openai-compat", "anthropic"): resp = await client.get(f"{backend.url}/v1/models") else: resp = await client.get(f"{backend.url}/health") elapsed = (time.monotonic() - start) * 1000 backend.healthy = resp.status_code < 500 backend.latency_ms = round(elapsed, 1) backend.last_check = time.time() except Exception: backend.healthy = False backend.latency_ms = 0 backend.last_check = time.time() logger.debug("Health check failed for %s", backend.id) def resolve_model(self, model_id: str, role: str) -> tuple[BackendInfo, ModelInfo] | None: """Find the best backend for a given model ID. Returns (backend, model) or None.""" candidates: list[tuple[BackendInfo, ModelInfo, int]] = [] for backend in self.backends.values(): if not backend.healthy: continue if backend.access == "owner" and role != "owner": continue for model in backend.models: if model.id == model_id: candidates.append((backend, model, model.priority)) if not candidates: return None candidates.sort(key=lambda x: x[2]) return candidates[0][0], candidates[0][1] def list_models(self, role: str) -> list[dict]: """List all available models for a given role.""" result = [] for backend in self.backends.values(): if not backend.healthy: continue if backend.access == "owner" and role != "owner": continue for model in backend.models: result.append({ "id": model.id, "object": "model", "owned_by": backend.id, "capabilities": model.capabilities, "backend_id": backend.id, "backend_status": "healthy" if backend.healthy else "down", }) return result def check_rate_limit(self, backend_id: str) -> bool: """Check if a request to this backend is within rate limits. Returns True if allowed.""" backend = self.backends.get(backend_id) if not backend or not backend.rate_limit: return True state = self._rate_limits.get(backend_id) if not state: return True now = time.time() rl = backend.rate_limit # Clean old timestamps if rl.rpm > 0: state.minute_timestamps = [t for t in state.minute_timestamps if now - t < 60] if len(state.minute_timestamps) >= rl.rpm: return False if rl.rph > 0: state.hour_timestamps = [t for t in state.hour_timestamps if now - t < 3600] if len(state.hour_timestamps) >= rl.rph: return False return True def record_request(self, backend_id: str): """Record a request timestamp for rate limiting.""" state = self._rate_limits.get(backend_id) if not state: return now = time.time() state.minute_timestamps.append(now) state.hour_timestamps.append(now) def get_health_summary(self) -> list[dict]: return [ { "id": b.id, "type": b.type, "status": "healthy" if b.healthy else "down", "models": [m.id for m in b.models], "latency_ms": b.latency_ms, "last_check": b.last_check, } for b in self.backends.values() ] registry = Registry()