EXAONE 7.8B가 복잡한 JSON 분류를 안정적으로 못함. 키워드 매칭으로 일정/메일/문서/확인 요청을 사전 감지하여 분류기를 건너뛰고 바로 도구로 라우팅. 날짜 계산(오늘/내일/이번주)도 코드에서 처리. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
393 lines
18 KiB
Python
393 lines
18 KiB
Python
"""Worker — EXAONE 분류 → direct/route/clarify/tools 분기 (cancel-safe + fallback)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from time import time
|
|
|
|
from config import settings
|
|
from db.database import log_completion, log_request
|
|
from models.schemas import JobStatus
|
|
from services.backend_registry import backend_registry
|
|
from services.conversation import conversation_store
|
|
from services.job_manager import Job, job_manager
|
|
from services.state_stream import state_stream
|
|
from services.synology_sender import send_to_synology
|
|
from tools.registry import execute_tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
HEARTBEAT_INTERVAL = 4.0
|
|
CLASSIFY_HEARTBEAT = 2.0
|
|
MAX_PROMPT_LENGTH = 1000
|
|
SYNOLOGY_MAX_LEN = 4000
|
|
MAX_TOOL_PAYLOAD = 2000
|
|
TOOL_TIMEOUT = 10.0
|
|
|
|
|
|
async def _complete_with_heartbeat(adapter, message: str, job_id: str, *, messages=None, beat_msg="처리 중...") -> str:
|
|
"""complete_chat + heartbeat 병행."""
|
|
result_holder: dict[str, str] = {}
|
|
exc_holder: list[Exception] = []
|
|
|
|
async def call():
|
|
try:
|
|
result_holder["text"] = await adapter.complete_chat(message, messages=messages)
|
|
except Exception as e:
|
|
exc_holder.append(e)
|
|
|
|
task = asyncio.create_task(call())
|
|
while not task.done():
|
|
await asyncio.sleep(CLASSIFY_HEARTBEAT)
|
|
if not task.done():
|
|
await state_stream.push(job_id, "processing", {"message": beat_msg})
|
|
|
|
if exc_holder:
|
|
raise exc_holder[0]
|
|
return result_holder.get("text", "")
|
|
|
|
|
|
async def _stream_with_cancel(adapter, message: str, job: Job, collected: list[str], *, messages=None) -> bool:
|
|
"""스트리밍 + cancel 체크."""
|
|
last_heartbeat = asyncio.get_event_loop().time()
|
|
|
|
async for chunk in adapter.stream_chat(message, messages=messages):
|
|
if job.status == JobStatus.cancelled:
|
|
return False
|
|
collected.append(chunk)
|
|
await state_stream.push(job.id, "result", {"content": chunk})
|
|
|
|
now = asyncio.get_event_loop().time()
|
|
if now - last_heartbeat >= HEARTBEAT_INTERVAL:
|
|
await state_stream.push(job.id, "processing", {"message": "응답 생성 중..."})
|
|
last_heartbeat = now
|
|
|
|
return True
|
|
|
|
|
|
def _parse_classification(raw: str) -> dict:
|
|
"""EXAONE JSON 응답 파싱. 실패 시 direct fallback."""
|
|
raw = raw.strip()
|
|
# 어떤 형태든 첫 번째 { ~ 마지막 } 추출
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start >= 0 and end > start:
|
|
json_str = raw[start:end + 1]
|
|
try:
|
|
result = json.loads(json_str)
|
|
if "action" in result:
|
|
return result
|
|
except json.JSONDecodeError:
|
|
pass
|
|
# JSON 파싱 실패 → direct로 취급 (raw 텍스트가 직접 응답)
|
|
# 마크다운/코드블록 잔재 제거
|
|
cleaned = raw.replace("```json", "").replace("```", "").replace("`json", "").replace("`", "").strip()
|
|
return {"action": "direct", "response": cleaned, "prompt": ""}
|
|
|
|
|
|
async def _send_callback(job: Job, text: str) -> None:
|
|
"""Synology callback이면 전송."""
|
|
if job.callback == "synology":
|
|
if len(text) > SYNOLOGY_MAX_LEN:
|
|
text = text[:SYNOLOGY_MAX_LEN] + "\n\n...(생략됨)"
|
|
await send_to_synology(text)
|
|
|
|
|
|
def _pre_route(message: str) -> dict | None:
|
|
"""키워드 기반 사전 라우팅. EXAONE 7.8B 분류기 보완."""
|
|
from datetime import datetime, timedelta
|
|
msg = message.lower().strip()
|
|
now = datetime.now()
|
|
|
|
# 캘린더 키워드
|
|
cal_keywords = ["일정", "캘린더", "스케줄", "약속", "미팅", "회의"]
|
|
if any(k in msg for k in cal_keywords):
|
|
# 생성 요청
|
|
if any(k in msg for k in ["잡아", "만들", "등록", "추가", "넣어"]):
|
|
return None # EXAONE이 파라미터 추출해야 함
|
|
# 오늘
|
|
if "오늘" in msg:
|
|
return {"action": "tools", "tool": "calendar", "operation": "today", "params": {}}
|
|
# 내일
|
|
if "내일" in msg:
|
|
tmr = (now + timedelta(days=1)).strftime("%Y-%m-%d")
|
|
return {"action": "tools", "tool": "calendar", "operation": "search", "params": {"date_from": tmr, "date_to": tmr}}
|
|
# 이번주
|
|
if "이번" in msg and ("주" in msg or "week" in msg):
|
|
monday = now - timedelta(days=now.weekday())
|
|
sunday = monday + timedelta(days=6)
|
|
return {"action": "tools", "tool": "calendar", "operation": "search", "params": {"date_from": monday.strftime("%Y-%m-%d"), "date_to": sunday.strftime("%Y-%m-%d")}}
|
|
# 기본: 오늘
|
|
return {"action": "tools", "tool": "calendar", "operation": "today", "params": {}}
|
|
|
|
# 메일 키워드
|
|
if any(k in msg for k in ["메일", "이메일", "mail", "편지"]):
|
|
query = ""
|
|
days = 7
|
|
if "오늘" in msg:
|
|
days = 1
|
|
return {"action": "tools", "tool": "email", "operation": "search", "params": {"query": query, "days": days}}
|
|
|
|
# 문서 키워드
|
|
if any(k in msg for k in ["문서", "도큐먼트", "자료", "파일"]) and any(k in msg for k in ["찾아", "검색", "확인"]):
|
|
# 검색어 추출: 키워드 제거 후 남은 텍스트
|
|
query = msg
|
|
for rm in ["문서", "도큐먼트", "자료", "파일", "찾아줘", "찾아", "검색", "확인", "해줘", "줘", "좀"]:
|
|
query = query.replace(rm, "")
|
|
query = query.strip()
|
|
if query:
|
|
return {"action": "tools", "tool": "document", "operation": "search", "params": {"query": query}}
|
|
|
|
# pending_draft 확인 응답
|
|
if msg in ("확인", "예", "yes", "ㅇㅇ", "응", "네", "좋아", "ok"):
|
|
return {"action": "tools", "tool": "calendar", "operation": "create_confirmed", "params": {}}
|
|
|
|
if msg in ("취소", "아니", "no", "ㄴㄴ"):
|
|
return {"action": "direct", "response": "알겠어, 취소했어!", "prompt": ""}
|
|
|
|
return None
|
|
|
|
|
|
async def run(job: Job) -> None:
|
|
"""사전 라우팅 → EXAONE 분류 → direct/route/clarify/tools 분기."""
|
|
start_time = time()
|
|
user_id = job.callback_meta.get("user_id", "api")
|
|
classify_model = None
|
|
reasoning_model = None
|
|
rewritten_message = ""
|
|
|
|
try:
|
|
await log_request(job.id, job.message, "classify", job.created_at)
|
|
except Exception:
|
|
logger.warning("Failed to log request for job %s", job.id, exc_info=True)
|
|
|
|
try:
|
|
# --- ACK ---
|
|
await state_stream.push(job.id, "ack", {"message": "요청을 확인했습니다."})
|
|
job_manager.set_status(job.id, JobStatus.processing)
|
|
|
|
if job.status == JobStatus.cancelled:
|
|
return
|
|
|
|
classify_model = backend_registry.classifier.model
|
|
|
|
# --- 사전 라우팅 (키워드 기반, EXAONE 스킵) ---
|
|
pre = _pre_route(job.message)
|
|
classify_latency = 0
|
|
|
|
if pre:
|
|
classification = pre
|
|
logger.info("Job %s pre-routed: %s.%s", job.id, pre.get("tool", ""), pre.get("operation", pre.get("action", "")))
|
|
else:
|
|
# --- EXAONE 분류기 호출 ---
|
|
from datetime import datetime
|
|
now_str = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
|
history = conversation_store.format_for_prompt(user_id)
|
|
classify_input = f"[현재 시간]\n{now_str}\n\n"
|
|
if history:
|
|
classify_input += f"[대화 이력]\n{history}\n\n"
|
|
classify_input += f"[현재 메시지]\n{job.message}"
|
|
|
|
await state_stream.push(job.id, "processing", {"message": "메시지를 분석하고 있습니다..."})
|
|
classify_start = time()
|
|
|
|
try:
|
|
raw_result = await _complete_with_heartbeat(
|
|
backend_registry.classifier, classify_input, job.id,
|
|
beat_msg="메시지를 분석하고 있습니다..."
|
|
)
|
|
except Exception:
|
|
logger.warning("Classification failed for job %s, falling back to direct", job.id)
|
|
raw_result = ""
|
|
|
|
classify_latency = (time() - classify_start) * 1000
|
|
classification = _parse_classification(raw_result)
|
|
|
|
action = classification.get("action", "direct")
|
|
response_text = classification.get("response", "")
|
|
route_prompt = classification.get("prompt", "")
|
|
|
|
logger.info("Job %s classified as '%s'", job.id, action)
|
|
|
|
# 대화 기록: 사용자 메시지
|
|
conversation_store.add(user_id, "user", job.message)
|
|
|
|
collected: list[str] = []
|
|
|
|
if job.status == JobStatus.cancelled:
|
|
return
|
|
|
|
if action == "tools":
|
|
# === TOOLS: 도구 실행 ===
|
|
tool_name = classification.get("tool", "")
|
|
operation = classification.get("operation", "")
|
|
params = classification.get("params", {})
|
|
|
|
logger.info("Job %s tool call: %s.%s(%s)", job.id, tool_name, operation, params)
|
|
await state_stream.push(job.id, "processing", {"message": f"🔧 {tool_name} 도구를 사용하고 있습니다..."})
|
|
|
|
# create_confirmed → pending_draft에서 데이터 가져오기
|
|
if operation == "create_confirmed":
|
|
draft = conversation_store.get_pending_draft(user_id)
|
|
if not draft:
|
|
response = "확인할 일정이 없습니다. 다시 요청해주세요."
|
|
collected.append(response)
|
|
await state_stream.push(job.id, "result", {"content": response})
|
|
conversation_store.add(user_id, "assistant", response)
|
|
else:
|
|
try:
|
|
result = await asyncio.wait_for(execute_tool(tool_name, operation, draft), timeout=TOOL_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
result = {"ok": False, "tool": tool_name, "operation": operation, "data": [], "summary": "", "error": "⚠️ 서비스 응답 시간이 초과되었습니다."}
|
|
conversation_store.clear_pending_draft(user_id)
|
|
response = result.get("summary", "") if result["ok"] else result.get("error", "⚠️ 오류")
|
|
collected.append(response)
|
|
await state_stream.push(job.id, "result", {"content": response})
|
|
conversation_store.add(user_id, "assistant", response)
|
|
else:
|
|
# 일반 도구 실행
|
|
try:
|
|
result = await asyncio.wait_for(execute_tool(tool_name, operation, params), timeout=TOOL_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
result = {"ok": False, "tool": tool_name, "operation": operation, "data": [], "summary": "", "error": "⚠️ 서비스 응답 시간이 초과되었습니다."}
|
|
|
|
if not result["ok"]:
|
|
response = result.get("error", "⚠️ 서비스를 사용할 수 없습니다.")
|
|
collected.append(response)
|
|
await state_stream.push(job.id, "result", {"content": response})
|
|
else:
|
|
# create_draft → pending에 저장 + 확인 요청
|
|
if operation == "create_draft":
|
|
conversation_store.set_pending_draft(user_id, result["data"])
|
|
response = result["summary"] + "\n\n'확인' 또는 '취소'로 답해주세요."
|
|
collected.append(response)
|
|
await state_stream.push(job.id, "result", {"content": response})
|
|
else:
|
|
# 결과를 EXAONE에 전달하여 자연어로 정리
|
|
tool_json = json.dumps(result["data"], ensure_ascii=False)
|
|
if len(tool_json) > MAX_TOOL_PAYLOAD:
|
|
tool_json = tool_json[:MAX_TOOL_PAYLOAD] + "...(truncated)"
|
|
format_input = f"[도구 결과]\n{tool_json}\n\n위 데이터를 바탕으로 사용자에게 자연스럽고 간결하게 답해."
|
|
try:
|
|
response = await _complete_with_heartbeat(
|
|
backend_registry.classifier, format_input, job.id,
|
|
beat_msg="결과를 정리하고 있습니다..."
|
|
)
|
|
# 포맷팅 응답이 JSON으로 올 수도 있으니 파싱 시도
|
|
try:
|
|
parsed = json.loads(response)
|
|
response = parsed.get("response", response)
|
|
except (json.JSONDecodeError, AttributeError):
|
|
pass
|
|
except Exception:
|
|
response = result.get("summary", "결과를 조회했습니다.")
|
|
collected.append(response)
|
|
await state_stream.push(job.id, "result", {"content": response})
|
|
|
|
conversation_store.add(user_id, "assistant", "".join(collected))
|
|
|
|
elif action == "clarify":
|
|
# === CLARIFY: 추가 질문 ===
|
|
collected.append(response_text)
|
|
await state_stream.push(job.id, "result", {"content": response_text})
|
|
conversation_store.add(user_id, "assistant", response_text)
|
|
|
|
elif action == "route" and settings.pipeline_enabled and backend_registry.is_healthy("reasoner"):
|
|
# === ROUTE: Gemma reasoning ===
|
|
reasoning_model = backend_registry.reasoner.model
|
|
rewritten_message = (route_prompt or job.message)[:MAX_PROMPT_LENGTH]
|
|
job.rewritten_message = rewritten_message
|
|
|
|
if job.callback != "synology":
|
|
await state_stream.push(job.id, "rewrite", {"content": rewritten_message})
|
|
else:
|
|
await send_to_synology("📝 더 깊이 살펴볼게요...", raw=True)
|
|
|
|
if job.status == JobStatus.cancelled:
|
|
return
|
|
|
|
await state_stream.push(job.id, "processing", {"message": "Gemma 4가 응답을 생성하고 있습니다..."})
|
|
|
|
try:
|
|
ok = await _stream_with_cancel(backend_registry.reasoner, rewritten_message, job, collected)
|
|
if not ok:
|
|
return
|
|
except Exception:
|
|
logger.warning("Reasoner failed for job %s, falling back to EXAONE", job.id, exc_info=True)
|
|
if job.status == JobStatus.cancelled:
|
|
return
|
|
await state_stream.push(job.id, "processing", {"message": "모델 전환 중..."})
|
|
reasoning_model = classify_model
|
|
ok = await _stream_with_cancel(backend_registry.classifier, job.message, job, collected)
|
|
if not ok:
|
|
return
|
|
|
|
if collected:
|
|
conversation_store.add(user_id, "assistant", "".join(collected))
|
|
|
|
else:
|
|
# === DIRECT: EXAONE 직접 응답 ===
|
|
if response_text:
|
|
# 분류기가 이미 응답을 생성함
|
|
collected.append(response_text)
|
|
await state_stream.push(job.id, "result", {"content": response_text})
|
|
else:
|
|
# 분류 실패 → EXAONE 스트리밍으로 직접 응답
|
|
await state_stream.push(job.id, "processing", {"message": "응답을 생성하고 있습니다..."})
|
|
ok = await _stream_with_cancel(backend_registry.classifier, job.message, job, collected)
|
|
if not ok:
|
|
return
|
|
|
|
if collected:
|
|
conversation_store.add(user_id, "assistant", "".join(collected))
|
|
|
|
# --- Complete ---
|
|
if not collected:
|
|
job_manager.set_status(job.id, JobStatus.failed)
|
|
await state_stream.push(job.id, "error", {"message": "응답을 받지 못했습니다."})
|
|
status = "failed"
|
|
await _send_callback(job, "⚠️ 응답을 받지 못했습니다. 다시 시도해주세요.")
|
|
else:
|
|
job_manager.set_status(job.id, JobStatus.completed)
|
|
await state_stream.push(job.id, "done", {"message": "완료"})
|
|
status = "completed"
|
|
await _send_callback(job, "".join(collected))
|
|
|
|
# --- DB 로깅 ---
|
|
latency_ms = (time() - start_time) * 1000
|
|
try:
|
|
await log_completion(
|
|
job.id, status, len("".join(collected)), latency_ms, time(),
|
|
rewrite_model=classify_model,
|
|
reasoning_model=reasoning_model,
|
|
rewritten_message=rewritten_message,
|
|
rewrite_latency_ms=classify_latency,
|
|
)
|
|
except Exception:
|
|
logger.warning("Failed to log completion for job %s", job.id, exc_info=True)
|
|
|
|
except asyncio.CancelledError:
|
|
job_manager.set_status(job.id, JobStatus.cancelled)
|
|
await state_stream.push(job.id, "error", {"message": "작업이 취소되었습니다."})
|
|
try:
|
|
await log_completion(job.id, "cancelled", 0, (time() - start_time) * 1000, time())
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
logger.exception("Worker failed for job %s", job.id)
|
|
job_manager.set_status(job.id, JobStatus.failed)
|
|
await state_stream.push(job.id, "error", {"message": "내부 오류가 발생했습니다."})
|
|
if job.callback == "synology":
|
|
try:
|
|
await send_to_synology("⚠️ 처리 중 오류가 발생했습니다. 다시 시도해주세요.", raw=True)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
await log_completion(job.id, "failed", 0, (time() - start_time) * 1000, time())
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
await state_stream.push_done(job.id)
|