Files
ai-server/server/main.py

340 lines
11 KiB
Python

from __future__ import annotations
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any
from .config import settings
from .ollama_client import OllamaClient
from .index_store import JsonlIndex
from .security import require_api_key
from .paperless_client import PaperlessClient
from .utils import chunk_text
from .pipeline import DocumentPipeline
app = FastAPI(title="Local AI Server", version="0.2.1")
# CORS
import os
cors_origins = os.getenv("CORS_ORIGINS", "*")
origins = [o.strip() for o in cors_origins.split(",") if o.strip()] or ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
ollama = OllamaClient(settings.ollama_host)
index = JsonlIndex(settings.index_path)
pipeline = DocumentPipeline(ollama, settings.embedding_model, settings.boost_model)
class ChatRequest(BaseModel):
model: str | None = None
messages: List[Dict[str, str]]
use_rag: bool = True
top_k: int = 5
force_boost: bool = False
options: Dict[str, Any] | None = None
class SearchRequest(BaseModel):
query: str
top_k: int = 5
class UpsertRow(BaseModel):
id: str
text: str
source: str | None = None
class UpsertRequest(BaseModel):
rows: List[UpsertRow]
embed: bool = True
model: str | None = None
batch: int = 16
class PipelineIngestRequest(BaseModel):
doc_id: str
text: str
generate_html: bool = True
translate: bool = True
target_language: str = "ko"
summarize: bool = False
summary_sentences: int = 5
summary_language: str | None = None
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"status": "ok",
"base_model": settings.base_model,
"boost_model": settings.boost_model,
"embedding_model": settings.embedding_model,
"index_loaded": len(index.rows) if index else 0,
}
@app.post("/search")
def search(req: SearchRequest) -> Dict[str, Any]:
if not index.rows:
return {"results": []}
qvec = ollama.embeddings(settings.embedding_model, req.query)
results = index.search(qvec, top_k=req.top_k)
return {
"results": [
{"id": r.id, "score": float(score), "text": r.text[:400], "source": r.source}
for r, score in results
]
}
@app.post("/chat")
def chat(req: ChatRequest) -> Dict[str, Any]:
model = req.model
if not model:
# 언어 감지(매우 단순): 영문 비율이 높으면 영어 모델, 아니면 기본/부스팅
user_text = "\n".join(m.get("content", "") for m in req.messages if m.get("role") == "user")
ascii_letters = sum(ch.isascii() and ch.isalpha() for ch in user_text)
non_ascii_letters = sum((not ch.isascii()) and ch.isalpha() for ch in user_text)
english_ratio = ascii_letters / max(ascii_letters + non_ascii_letters, 1)
total_chars = len(user_text)
if english_ratio > settings.english_ratio_threshold:
model = settings.english_model
else:
model = settings.boost_model if (req.force_boost or total_chars > 2000) else settings.base_model
context_docs: List[str] = []
if req.use_rag and index.rows:
q = "\n".join([m.get("content", "") for m in req.messages if m.get("role") == "user"]).strip()
if q:
qvec = ollama.embeddings(settings.embedding_model, q)
hits = index.search(qvec, top_k=req.top_k)
context_docs = [r.text for r, _ in hits]
sys_prompt = ""
if context_docs:
sys_prompt = (
"당신은 문서 기반 비서입니다. 제공된 컨텍스트만 신뢰하고, 모르면 모른다고 답하세요.\n\n"
+ "\n\n".join(f"[DOC {i+1}]\n{t}" for i, t in enumerate(context_docs))
)
messages: List[Dict[str, str]] = []
if sys_prompt:
messages.append({"role": "system", "content": sys_prompt})
messages.extend(req.messages)
try:
resp = ollama.chat(model, messages, stream=False, options=req.options)
return {"model": model, "response": resp}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/index/upsert")
def index_upsert(req: UpsertRequest) -> Dict[str, Any]:
try:
if not req.rows:
return {"added": 0}
model = req.model or settings.embedding_model
new_rows = []
for r in req.rows:
vec = ollama.embeddings(model, r.text) if req.embed else []
new_rows.append({
"id": r.id,
"text": r.text,
"vector": vec,
"source": r.source or "api",
})
# convert to IndexRow and append
from .index_store import IndexRow
to_append = [IndexRow(**nr) for nr in new_rows]
added = index.append(to_append)
return {"added": added}
except Exception as e:
raise HTTPException(status_code=500, detail=f"index_upsert_error: {e}")
@app.post("/index/reload")
def index_reload() -> Dict[str, Any]:
total = index.reload()
return {"total": total}
@app.post("/pipeline/ingest")
def pipeline_ingest(req: PipelineIngestRequest, _: None = Depends(require_api_key)) -> Dict[str, Any]:
result = pipeline.process(
doc_id=req.doc_id,
text=req.text,
index=index,
generate_html=req.generate_html,
translate=req.translate,
target_language=req.target_language,
summarize=req.summarize,
summary_sentences=req.summary_sentences,
summary_language=req.summary_language,
)
return {"status": "ok", "doc_id": result.doc_id, "added": result.added_chunks, "chunks": result.chunks, "html_path": result.html_path}
@app.post("/pipeline/ingest_file")
async def pipeline_ingest_file(
_: None = Depends(require_api_key),
file: UploadFile = File(...),
doc_id: str = Form(...),
generate_html: bool = Form(True),
translate: bool = Form(True),
target_language: str = Form("ko"),
) -> Dict[str, Any]:
content_type = (file.content_type or "").lower()
raw = await file.read()
text = ""
if "text/plain" in content_type or file.filename.endswith(".txt"):
try:
text = raw.decode("utf-8")
except Exception:
text = raw.decode("latin-1", errors="ignore")
elif "pdf" in content_type or file.filename.endswith(".pdf"):
try:
from pypdf import PdfReader
from io import BytesIO
reader = PdfReader(BytesIO(raw))
parts: List[str] = []
for p in reader.pages:
try:
parts.append(p.extract_text() or "")
except Exception:
parts.append("")
text = "\n\n".join(parts)
except Exception as e:
raise HTTPException(status_code=400, detail=f"pdf_extract_error: {e}")
else:
raise HTTPException(status_code=400, detail="unsupported_file_type (only .txt/.pdf)")
if not text.strip():
raise HTTPException(status_code=400, detail="empty_text_after_extraction")
result = pipeline.process(
doc_id=doc_id,
text=text,
index=index,
generate_html=generate_html,
translate=translate,
target_language=target_language,
)
return {"status": "ok", "doc_id": result.doc_id, "added": result.added_chunks, "chunks": result.chunks, "html_path": result.html_path}
# Paperless webhook placeholder (to be wired with user-provided details)
class PaperlessHook(BaseModel):
document_id: int
title: str | None = None
tags: List[str] | None = None
@app.post("/paperless/hook")
def paperless_hook(hook: PaperlessHook, _: None = Depends(require_api_key)) -> Dict[str, Any]:
# Fetch text from Paperless and upsert into index
client = PaperlessClient(settings.paperless_base_url, settings.paperless_token)
text = client.get_document_text(hook.document_id)
parts = chunk_text(text)
model = settings.embedding_model
from .index_store import IndexRow
to_append = []
for i, t in enumerate(parts):
vec = ollama.embeddings(model, t)
to_append.append(IndexRow(id=f"paperless:{hook.document_id}:{i}", text=t, vector=vec, source="paperless"))
added = index.append(to_append)
return {"status": "indexed", "document_id": hook.document_id, "chunks": added}
class PaperlessSyncRequest(BaseModel):
page_size: int = 50
ordering: str = "-created"
tags: List[int] | None = None
query: str | None = None
limit: int = 200
@app.post("/paperless/sync")
def paperless_sync(req: PaperlessSyncRequest, _: None = Depends(require_api_key)) -> Dict[str, Any]:
client = PaperlessClient(settings.paperless_base_url, settings.paperless_token)
from .index_store import IndexRow
added_total = 0
skipped = 0
next_url: str | None = None
fetched = 0
while True:
if next_url:
import requests as _rq
resp = _rq.get(next_url, headers=client._headers(), timeout=60)
resp.raise_for_status()
data = resp.json()
else:
data = client.list_documents(page_size=req.page_size, ordering=req.ordering, tags=req.tags, query=req.query)
results = data.get("results", [])
to_append: List[IndexRow] = []
for doc in results:
doc_id = doc.get("id")
if not doc_id:
continue
try:
text = client.get_document_text(int(doc_id))
if not text:
skipped += 1
continue
parts = chunk_text(text)
for i, t in enumerate(parts):
vec = ollama.embeddings(settings.embedding_model, t)
to_append.append(IndexRow(id=f"paperless:{doc_id}:{i}", text=t, vector=vec, source="paperless"))
except Exception:
skipped += 1
continue
if to_append:
added_total += index.append(to_append)
fetched += len(results)
if fetched >= req.limit:
break
next_url = data.get("next")
if not next_url:
break
return {"status": "synced", "added": added_total, "skipped": skipped}
# OpenAI-compatible chat completions (minimal)
class ChatCompletionsRequest(BaseModel):
model: str | None = None
messages: List[Dict[str, str]]
temperature: float | None = None
max_tokens: int | None = None
@app.post("/v1/chat/completions")
def chat_completions(req: ChatCompletionsRequest, _: None = Depends(require_api_key)) -> Dict[str, Any]:
chosen = req.model or settings.base_model
opts: Dict[str, Any] = {}
if req.temperature is not None:
opts["temperature"] = req.temperature
# Note: Ollama ignores max_tokens field; left here for interface similarity
resp = ollama.chat(chosen, req.messages, stream=False, options=opts)
# Minimal OpenAI-like response shape
return {
"id": "chatcmpl-local",
"object": "chat.completion",
"model": chosen,
"choices": [
{
"index": 0,
"message": resp.get("message", {"role": "assistant", "content": resp.get("response", "")}),
"finish_reason": resp.get("done_reason", "stop"),
}
],
}