30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
from __future__ import annotations
|
|
|
|
import requests
|
|
from typing import List, Dict, Any
|
|
|
|
|
|
class OllamaClient:
|
|
def __init__(self, host: str) -> None:
|
|
host = host.strip()
|
|
if not host.startswith("http://") and not host.startswith("https://"):
|
|
host = "http://" + host
|
|
self.host = host.rstrip("/")
|
|
|
|
def embeddings(self, model: str, text: str) -> List[float]:
|
|
url = f"{self.host}/api/embeddings"
|
|
resp = requests.post(url, json={"model": model, "prompt": text}, timeout=120)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data["embedding"]
|
|
|
|
def chat(self, model: str, messages: List[Dict[str, str]], stream: bool = False, options: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
|
url = f"{self.host}/api/chat"
|
|
payload: Dict[str, Any] = {"model": model, "messages": messages, "stream": stream}
|
|
if options:
|
|
payload["options"] = options
|
|
resp = requests.post(url, json=payload, timeout=600)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|