diff --git a/scripts/pkm_utils.py b/scripts/pkm_utils.py index 1ea3c8a..f901701 100644 --- a/scripts/pkm_utils.py +++ b/scripts/pkm_utils.py @@ -106,17 +106,28 @@ def run_applescript_inline(script: str) -> str: def llm_generate(prompt: str, model: str = "mlx-community/Qwen3.5-35B-A3B-4bit", - host: str = "http://localhost:8800") -> str: + host: str = "http://localhost:8800", json_mode: bool = False) -> str: """MLX 서버 API 호출 (OpenAI 호환)""" import requests + messages = [] + if json_mode: + messages.append({"role": "system", "content": "You must respond ONLY with valid JSON. No thinking, no explanation, no markdown."}) + # Qwen3.5: /nothink 접미사로 thinking 출력 억제 + messages.append({"role": "user", "content": prompt + " /nothink"}) resp = requests.post(f"{host}/v1/chat/completions", json={ "model": model, - "messages": [{"role": "user", "content": prompt}], + "messages": messages, "temperature": 0.3, "max_tokens": 1024, }, timeout=120) resp.raise_for_status() - return resp.json()["choices"][0]["message"]["content"] + content = resp.json()["choices"][0]["message"]["content"] + # JSON 블록 추출 (```json ... ``` 감싸기 대응) + if "```json" in content: + content = content.split("```json")[1].split("```")[0].strip() + elif "```" in content: + content = content.split("```")[1].split("```")[0].strip() + return content # 하위호환 별칭 diff --git a/tests/test_classify.py b/tests/test_classify.py index df15367..13ffee6 100644 --- a/tests/test_classify.py +++ b/tests/test_classify.py @@ -10,7 +10,7 @@ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "scripts")) -from pkm_utils import ollama_generate, PROJECT_ROOT +from pkm_utils import llm_generate, PROJECT_ROOT PROMPT_TEMPLATE = (PROJECT_ROOT / "scripts" / "prompts" / "classify_document.txt").read_text() @@ -54,7 +54,7 @@ def run_classify_test(doc: dict) -> dict: prompt = PROMPT_TEMPLATE.replace("{document_text}", doc["text"]) try: - response = ollama_generate(prompt) + response = llm_generate(prompt, json_mode=True) result = json.loads(response) db_match = result.get("domain_db") == doc["expected_db"]