263 lines
8.6 KiB
Python
263 lines
8.6 KiB
Python
from __future__ import annotations
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import List, Dict, Any
|
|
|
|
from .config import settings
|
|
from .ollama_client import OllamaClient
|
|
from .index_store import JsonlIndex
|
|
from .security import require_api_key
|
|
from .paperless_client import PaperlessClient
|
|
from .utils import chunk_text
|
|
|
|
|
|
app = FastAPI(title="Local AI Server", version="0.2.1")
|
|
|
|
# CORS
|
|
import os
|
|
cors_origins = os.getenv("CORS_ORIGINS", "*")
|
|
origins = [o.strip() for o in cors_origins.split(",") if o.strip()] or ["*"]
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
ollama = OllamaClient(settings.ollama_host)
|
|
index = JsonlIndex(settings.index_path)
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
model: str | None = None
|
|
messages: List[Dict[str, str]]
|
|
use_rag: bool = True
|
|
top_k: int = 5
|
|
force_boost: bool = False
|
|
options: Dict[str, Any] | None = None
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str
|
|
top_k: int = 5
|
|
|
|
class UpsertRow(BaseModel):
|
|
id: str
|
|
text: str
|
|
source: str | None = None
|
|
|
|
class UpsertRequest(BaseModel):
|
|
rows: List[UpsertRow]
|
|
embed: bool = True
|
|
model: str | None = None
|
|
batch: int = 16
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> Dict[str, Any]:
|
|
return {
|
|
"status": "ok",
|
|
"base_model": settings.base_model,
|
|
"boost_model": settings.boost_model,
|
|
"embedding_model": settings.embedding_model,
|
|
"index_loaded": len(index.rows) if index else 0,
|
|
}
|
|
|
|
|
|
@app.post("/search")
|
|
def search(req: SearchRequest) -> Dict[str, Any]:
|
|
if not index.rows:
|
|
return {"results": []}
|
|
qvec = ollama.embeddings(settings.embedding_model, req.query)
|
|
results = index.search(qvec, top_k=req.top_k)
|
|
return {
|
|
"results": [
|
|
{"id": r.id, "score": float(score), "text": r.text[:400], "source": r.source}
|
|
for r, score in results
|
|
]
|
|
}
|
|
|
|
|
|
@app.post("/chat")
|
|
def chat(req: ChatRequest) -> Dict[str, Any]:
|
|
model = req.model
|
|
if not model:
|
|
# 언어 감지(매우 단순): 영문 비율이 높으면 영어 모델, 아니면 기본/부스팅
|
|
user_text = "\n".join(m.get("content", "") for m in req.messages if m.get("role") == "user")
|
|
ascii_letters = sum(ch.isascii() and ch.isalpha() for ch in user_text)
|
|
non_ascii_letters = sum((not ch.isascii()) and ch.isalpha() for ch in user_text)
|
|
english_ratio = ascii_letters / max(ascii_letters + non_ascii_letters, 1)
|
|
total_chars = len(user_text)
|
|
if english_ratio > settings.english_ratio_threshold:
|
|
model = settings.english_model
|
|
else:
|
|
model = settings.boost_model if (req.force_boost or total_chars > 2000) else settings.base_model
|
|
|
|
context_docs: List[str] = []
|
|
if req.use_rag and index.rows:
|
|
q = "\n".join([m.get("content", "") for m in req.messages if m.get("role") == "user"]).strip()
|
|
if q:
|
|
qvec = ollama.embeddings(settings.embedding_model, q)
|
|
hits = index.search(qvec, top_k=req.top_k)
|
|
context_docs = [r.text for r, _ in hits]
|
|
|
|
sys_prompt = ""
|
|
if context_docs:
|
|
sys_prompt = (
|
|
"당신은 문서 기반 비서입니다. 제공된 컨텍스트만 신뢰하고, 모르면 모른다고 답하세요.\n\n"
|
|
+ "\n\n".join(f"[DOC {i+1}]\n{t}" for i, t in enumerate(context_docs))
|
|
)
|
|
|
|
messages: List[Dict[str, str]] = []
|
|
if sys_prompt:
|
|
messages.append({"role": "system", "content": sys_prompt})
|
|
messages.extend(req.messages)
|
|
|
|
try:
|
|
resp = ollama.chat(model, messages, stream=False, options=req.options)
|
|
return {"model": model, "response": resp}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/index/upsert")
|
|
def index_upsert(req: UpsertRequest) -> Dict[str, Any]:
|
|
try:
|
|
if not req.rows:
|
|
return {"added": 0}
|
|
model = req.model or settings.embedding_model
|
|
new_rows = []
|
|
for r in req.rows:
|
|
vec = ollama.embeddings(model, r.text) if req.embed else []
|
|
new_rows.append({
|
|
"id": r.id,
|
|
"text": r.text,
|
|
"vector": vec,
|
|
"source": r.source or "api",
|
|
})
|
|
# convert to IndexRow and append
|
|
from .index_store import IndexRow
|
|
to_append = [IndexRow(**nr) for nr in new_rows]
|
|
added = index.append(to_append)
|
|
return {"added": added}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"index_upsert_error: {e}")
|
|
|
|
|
|
@app.post("/index/reload")
|
|
def index_reload() -> Dict[str, Any]:
|
|
total = index.reload()
|
|
return {"total": total}
|
|
|
|
|
|
# Paperless webhook placeholder (to be wired with user-provided details)
|
|
class PaperlessHook(BaseModel):
|
|
document_id: int
|
|
title: str | None = None
|
|
tags: List[str] | None = None
|
|
|
|
|
|
@app.post("/paperless/hook")
|
|
def paperless_hook(hook: PaperlessHook, _: None = Depends(require_api_key)) -> Dict[str, Any]:
|
|
# Fetch text from Paperless and upsert into index
|
|
client = PaperlessClient(settings.paperless_base_url, settings.paperless_token)
|
|
text = client.get_document_text(hook.document_id)
|
|
parts = chunk_text(text)
|
|
model = settings.embedding_model
|
|
from .index_store import IndexRow
|
|
to_append = []
|
|
for i, t in enumerate(parts):
|
|
vec = ollama.embeddings(model, t)
|
|
to_append.append(IndexRow(id=f"paperless:{hook.document_id}:{i}", text=t, vector=vec, source="paperless"))
|
|
added = index.append(to_append)
|
|
return {"status": "indexed", "document_id": hook.document_id, "chunks": added}
|
|
|
|
|
|
class PaperlessSyncRequest(BaseModel):
|
|
page_size: int = 50
|
|
ordering: str = "-created"
|
|
tags: List[int] | None = None
|
|
query: str | None = None
|
|
limit: int = 200
|
|
|
|
|
|
@app.post("/paperless/sync")
|
|
def paperless_sync(req: PaperlessSyncRequest, _: None = Depends(require_api_key)) -> Dict[str, Any]:
|
|
client = PaperlessClient(settings.paperless_base_url, settings.paperless_token)
|
|
from .index_store import IndexRow
|
|
added_total = 0
|
|
skipped = 0
|
|
next_url: str | None = None
|
|
fetched = 0
|
|
|
|
while True:
|
|
if next_url:
|
|
import requests as _rq
|
|
resp = _rq.get(next_url, headers=client._headers(), timeout=60)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
else:
|
|
data = client.list_documents(page_size=req.page_size, ordering=req.ordering, tags=req.tags, query=req.query)
|
|
results = data.get("results", [])
|
|
to_append: List[IndexRow] = []
|
|
for doc in results:
|
|
doc_id = doc.get("id")
|
|
if not doc_id:
|
|
continue
|
|
try:
|
|
text = client.get_document_text(int(doc_id))
|
|
if not text:
|
|
skipped += 1
|
|
continue
|
|
parts = chunk_text(text)
|
|
for i, t in enumerate(parts):
|
|
vec = ollama.embeddings(settings.embedding_model, t)
|
|
to_append.append(IndexRow(id=f"paperless:{doc_id}:{i}", text=t, vector=vec, source="paperless"))
|
|
except Exception:
|
|
skipped += 1
|
|
continue
|
|
if to_append:
|
|
added_total += index.append(to_append)
|
|
fetched += len(results)
|
|
if fetched >= req.limit:
|
|
break
|
|
next_url = data.get("next")
|
|
if not next_url:
|
|
break
|
|
|
|
return {"status": "synced", "added": added_total, "skipped": skipped}
|
|
|
|
|
|
# OpenAI-compatible chat completions (minimal)
|
|
class ChatCompletionsRequest(BaseModel):
|
|
model: str | None = None
|
|
messages: List[Dict[str, str]]
|
|
temperature: float | None = None
|
|
max_tokens: int | None = None
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
def chat_completions(req: ChatCompletionsRequest, _: None = Depends(require_api_key)) -> Dict[str, Any]:
|
|
chosen = req.model or settings.base_model
|
|
opts: Dict[str, Any] = {}
|
|
if req.temperature is not None:
|
|
opts["temperature"] = req.temperature
|
|
# Note: Ollama ignores max_tokens field; left here for interface similarity
|
|
resp = ollama.chat(chosen, req.messages, stream=False, options=opts)
|
|
# Minimal OpenAI-like response shape
|
|
return {
|
|
"id": "chatcmpl-local",
|
|
"object": "chat.completion",
|
|
"model": chosen,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": resp.get("message", {"role": "assistant", "content": resp.get("response", "")}),
|
|
"finish_reason": resp.get("done_reason", "stop"),
|
|
}
|
|
],
|
|
}
|
|
|