import asyncio import httpx from config import settings class OllamaClient: def __init__(self): self.base_url = settings.OLLAMA_BASE_URL self.timeout = httpx.Timeout(float(settings.OLLAMA_TIMEOUT), connect=10.0) self._client: httpx.AsyncClient | None = None async def _get_client(self) -> httpx.AsyncClient: if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=self.timeout) return self._client async def close(self): if self._client and not self._client.is_closed: await self._client.aclose() self._client = None async def generate_embedding(self, text: str) -> list[float]: client = await self._get_client() response = await client.post( f"{self.base_url}/api/embeddings", json={"model": settings.OLLAMA_EMBED_MODEL, "prompt": text}, ) response.raise_for_status() return response.json()["embedding"] async def batch_embeddings(self, texts: list[str], concurrency: int = 5) -> list[list[float]]: semaphore = asyncio.Semaphore(concurrency) async def _embed(text: str) -> list[float]: async with semaphore: return await self.generate_embedding(text) return await asyncio.gather(*[_embed(t) for t in texts]) async def generate_text(self, prompt: str, system: str = None) -> str: messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) client = await self._get_client() # 조립컴 Ollama 메인, MLX fallback try: response = await client.post( f"{self.base_url}/api/chat", json={ "model": settings.OLLAMA_TEXT_MODEL, "messages": messages, "stream": False, "think": False, "options": {"temperature": 0.3, "num_predict": 2048}, }, ) response.raise_for_status() return response.json()["message"]["content"] except Exception: response = await client.post( f"{settings.MLX_BASE_URL}/chat/completions", json={ "model": settings.MLX_TEXT_MODEL, "messages": messages, "max_tokens": 2048, "temperature": 0.3, }, ) response.raise_for_status() return response.json()["choices"][0]["message"]["content"] async def check_health(self) -> dict: result = {} short_timeout = httpx.Timeout(5.0, connect=3.0) try: async with httpx.AsyncClient(timeout=short_timeout) as c: response = await c.get(f"{self.base_url}/api/tags") models = response.json().get("models", []) result["ollama"] = {"status": "connected", "models": [m["name"] for m in models]} except Exception: result["ollama"] = {"status": "disconnected"} try: async with httpx.AsyncClient(timeout=short_timeout) as c: response = await c.get(f"{settings.MLX_BASE_URL}/health") result["mlx"] = {"status": "connected", "model": settings.MLX_TEXT_MODEL} except Exception: result["mlx"] = {"status": "disconnected"} return result ollama_client = OllamaClient()