from __future__ import annotations import json import logging from collections.abc import AsyncGenerator import httpx logger = logging.getLogger(__name__) async def stream_chat( base_url: str, model: str, messages: list[dict], **kwargs, ) -> AsyncGenerator[str, None]: """Proxy Ollama chat streaming, converting NDJSON to OpenAI SSE format.""" payload = { "model": model, "messages": messages, "stream": True, **{k: v for k, v in kwargs.items() if v is not None}, } async with httpx.AsyncClient(timeout=120.0) as client: async with client.stream( "POST", f"{base_url}/api/chat", json=payload, ) as resp: if resp.status_code != 200: body = await resp.aread() error_msg = body.decode("utf-8", errors="replace") yield _error_event(f"Ollama error: {error_msg}") return async for line in resp.aiter_lines(): if not line.strip(): continue try: chunk = json.loads(line) except json.JSONDecodeError: continue if chunk.get("done"): # Final chunk — send [DONE] yield "data: [DONE]\n\n" return content = chunk.get("message", {}).get("content", "") if content: openai_chunk = { "id": "chatcmpl-gateway", "object": "chat.completion.chunk", "model": model, "choices": [ { "index": 0, "delta": {"content": content}, "finish_reason": None, } ], } yield f"data: {json.dumps(openai_chunk)}\n\n" async def complete_chat( base_url: str, model: str, messages: list[dict], **kwargs, ) -> dict: """Non-streaming Ollama chat, returns OpenAI-compatible response.""" payload = { "model": model, "messages": messages, "stream": False, **{k: v for k, v in kwargs.items() if v is not None}, } async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post(f"{base_url}/api/chat", json=payload) resp.raise_for_status() data = resp.json() return { "id": "chatcmpl-gateway", "object": "chat.completion", "model": model, "choices": [ { "index": 0, "message": { "role": "assistant", "content": data.get("message", {}).get("content", ""), }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": data.get("prompt_eval_count", 0), "completion_tokens": data.get("eval_count", 0), "total_tokens": data.get("prompt_eval_count", 0) + data.get("eval_count", 0), }, } async def generate_embedding( base_url: str, model: str, input_text: str | list[str], ) -> dict: """Ollama embedding, returns OpenAI-compatible response.""" texts = [input_text] if isinstance(input_text, str) else input_text async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post( f"{base_url}/api/embed", json={"model": model, "input": texts}, ) resp.raise_for_status() data = resp.json() embeddings_data = [] raw_embeddings = data.get("embeddings", []) for i, emb in enumerate(raw_embeddings): embeddings_data.append({ "object": "embedding", "embedding": emb, "index": i, }) return { "object": "list", "data": embeddings_data, "model": model, "usage": {"prompt_tokens": 0, "total_tokens": 0}, } def _error_event(message: str) -> str: error = { "id": "chatcmpl-gateway", "object": "chat.completion.chunk", "model": "error", "choices": [ { "index": 0, "delta": {"content": f"[Error] {message}"}, "finish_reason": "stop", } ], } return f"data: {json.dumps(error)}\n\ndata: [DONE]\n\n"