기존 weighted-sum merge를 Reciprocal Rank Fusion으로 교체. 정확 키워드 매치에서 RRF가 평탄화되는 문제는 boost로 보완. 신규 모듈 app/services/search_fusion.py: - FusionStrategy ABC - LegacyWeightedSum : 기존 _merge_results 동작 (A/B 비교용) - RRFOnly : 순수 RRF, k=60 - RRFWithBoost : RRF + title/tags/법령조문/high-text-score boost (default) - normalize_display_scores: SearchResult.score를 [0..1] 랭크 기반 정규화 (프론트엔드가 score*100을 % 표시하므로 RRF 원본 점수 노출 시 표시 깨짐) search.py: - ?fusion=legacy|rrf|rrf_boost 파라미터 (default rrf_boost) - _merge_results 제거 (LegacyWeightedSum에 흡수) - pre-fusion confidence: hybrid는 raw text/vector 신호로 계산 (fused score는 fusion 전략마다 스케일이 달라 일관 비교 불가) - timing에 fusion_ms 추가 - debug notes에 fusion 전략 표시 telemetry: - compute_confidence_hybrid(text_results, vector_results) 헬퍼 - record_search_event에 confidence override 파라미터 run_eval.py: - --fusion CLI 옵션, call_search 쿼리 파라미터에 전달 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
488 lines
17 KiB
Python
488 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""Document Server 검색 평가 스크립트 (Phase 0.2)
|
|
|
|
queries.yaml을 읽어 /api/search 엔드포인트에 호출하고
|
|
Recall@10, MRR@10, NDCG@10, Top3 hit-rate, Latency p50/p95를 계산한다.
|
|
|
|
A/B 비교 모드: --baseline-url, --candidate-url 를 각각 지정하면
|
|
두 엔드포인트에 동일 쿼리셋을 던지고 결과를 비교한다.
|
|
|
|
사용 예:
|
|
|
|
# 단일 평가
|
|
export DOCSRV_TOKEN="eyJ..."
|
|
python tests/search_eval/run_eval.py \
|
|
--base-url https://docs.hyungi.net \
|
|
--output reports/baseline_2026-04-07.csv
|
|
|
|
# A/B 비교 (같은 토큰)
|
|
python tests/search_eval/run_eval.py \
|
|
--baseline-url https://docs.hyungi.net \
|
|
--candidate-url http://localhost:8000 \
|
|
--output reports/phase1_vs_baseline.csv
|
|
|
|
토큰은 env DOCSRV_TOKEN 또는 --token 플래그로 전달.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import csv
|
|
import math
|
|
import os
|
|
import statistics
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import yaml
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# 데이터 구조
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
@dataclass
|
|
class Query:
|
|
id: str
|
|
query: str
|
|
category: str
|
|
intent: str
|
|
domain_hint: str
|
|
relevant_ids: list[int]
|
|
top3_ids: list[int] = field(default_factory=list)
|
|
notes: str = ""
|
|
|
|
|
|
@dataclass
|
|
class QueryResult:
|
|
query: Query
|
|
label: str # "baseline" or "candidate"
|
|
returned_ids: list[int]
|
|
latency_ms: float
|
|
recall_at_10: float
|
|
mrr_at_10: float
|
|
ndcg_at_10: float
|
|
top3_hit: bool
|
|
error: str | None = None
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# 평가 지표
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def recall_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
|
"""top-k 안에 들어간 정답 비율. 정답 0개면 1.0(빈 케이스는 별도 fail metric)."""
|
|
if not relevant:
|
|
return 1.0 if not returned else 0.0 # 비어야 정상인 케이스: 결과 있으면 fail
|
|
top_k = set(returned[:k])
|
|
hits = sum(1 for doc_id in relevant if doc_id in top_k)
|
|
return hits / len(relevant)
|
|
|
|
|
|
def mrr_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
|
"""top-k 안 첫 정답의 reciprocal rank. 정답 없으면 0."""
|
|
if not relevant:
|
|
return 0.0
|
|
relevant_set = set(relevant)
|
|
for rank, doc_id in enumerate(returned[:k], start=1):
|
|
if doc_id in relevant_set:
|
|
return 1.0 / rank
|
|
return 0.0
|
|
|
|
|
|
def ndcg_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
|
"""binary relevance 기반 NDCG@k. top3_ids 같은 가중치는 v0.1에선 무시."""
|
|
if not relevant:
|
|
return 0.0
|
|
relevant_set = set(relevant)
|
|
dcg = 0.0
|
|
for rank, doc_id in enumerate(returned[:k], start=1):
|
|
if doc_id in relevant_set:
|
|
# binary gain = 1, DCG = 1 / log2(rank+1)
|
|
dcg += 1.0 / math.log2(rank + 1)
|
|
# ideal DCG: 정답을 1..min(len(relevant), k) 위치에 모두 채운 경우
|
|
ideal_hits = min(len(relevant), k)
|
|
idcg = sum(1.0 / math.log2(r + 1) for r in range(1, ideal_hits + 1))
|
|
return dcg / idcg if idcg > 0 else 0.0
|
|
|
|
|
|
def top3_hit(returned: list[int], top3_ids: list[int]) -> bool:
|
|
"""top3_ids가 비어있으면 True (체크 안함). 있으면 그 중 하나라도 top-3에 들어와야 함."""
|
|
if not top3_ids:
|
|
return True
|
|
top3 = set(returned[:3])
|
|
return any(doc_id in top3 for doc_id in top3_ids)
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# API 호출
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
async def call_search(
|
|
client: httpx.AsyncClient,
|
|
base_url: str,
|
|
token: str,
|
|
query: str,
|
|
mode: str = "hybrid",
|
|
limit: int = 20,
|
|
fusion: str | None = None,
|
|
) -> tuple[list[int], float]:
|
|
"""검색 API 호출 → (doc_ids, latency_ms)."""
|
|
url = f"{base_url.rstrip('/')}/api/search/"
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
params: dict[str, str | int] = {"q": query, "mode": mode, "limit": limit}
|
|
if fusion:
|
|
params["fusion"] = fusion
|
|
|
|
import time
|
|
|
|
start = time.perf_counter()
|
|
response = await client.get(url, headers=headers, params=params, timeout=30.0)
|
|
latency_ms = (time.perf_counter() - start) * 1000
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
returned_ids = [r["id"] for r in data.get("results", [])]
|
|
return returned_ids, latency_ms
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# 평가 실행
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
async def evaluate(
|
|
queries: list[Query],
|
|
base_url: str,
|
|
token: str,
|
|
label: str,
|
|
mode: str = "hybrid",
|
|
fusion: str | None = None,
|
|
) -> list[QueryResult]:
|
|
"""전체 쿼리셋 평가."""
|
|
results: list[QueryResult] = []
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
for q in queries:
|
|
try:
|
|
returned_ids, latency_ms = await call_search(
|
|
client, base_url, token, q.query, mode=mode, fusion=fusion
|
|
)
|
|
results.append(
|
|
QueryResult(
|
|
query=q,
|
|
label=label,
|
|
returned_ids=returned_ids,
|
|
latency_ms=latency_ms,
|
|
recall_at_10=recall_at_k(returned_ids, q.relevant_ids, 10),
|
|
mrr_at_10=mrr_at_k(returned_ids, q.relevant_ids, 10),
|
|
ndcg_at_10=ndcg_at_k(returned_ids, q.relevant_ids, 10),
|
|
top3_hit=top3_hit(returned_ids, q.top3_ids),
|
|
)
|
|
)
|
|
except Exception as exc:
|
|
results.append(
|
|
QueryResult(
|
|
query=q,
|
|
label=label,
|
|
returned_ids=[],
|
|
latency_ms=0.0,
|
|
recall_at_10=0.0,
|
|
mrr_at_10=0.0,
|
|
ndcg_at_10=0.0,
|
|
top3_hit=False,
|
|
error=str(exc),
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# 결과 집계 / 출력
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def percentile(values: list[float], p: float) -> float:
|
|
if not values:
|
|
return 0.0
|
|
s = sorted(values)
|
|
k = (len(s) - 1) * p
|
|
f = int(k)
|
|
c = min(f + 1, len(s) - 1)
|
|
if f == c:
|
|
return s[f]
|
|
return s[f] + (s[c] - s[f]) * (k - f)
|
|
|
|
|
|
def print_summary(label: str, results: list[QueryResult]) -> dict[str, float]:
|
|
"""전체 + 카테고리별 요약 출력. 집계 dict 반환."""
|
|
n = len(results)
|
|
if n == 0:
|
|
return {}
|
|
|
|
# 실패 케이스(relevant_ids=[])는 평균 recall/mrr/ndcg에서 제외
|
|
scored = [r for r in results if r.query.relevant_ids]
|
|
failure_cases = [r for r in results if not r.query.relevant_ids]
|
|
|
|
avg_recall = statistics.mean([r.recall_at_10 for r in scored]) if scored else 0.0
|
|
avg_mrr = statistics.mean([r.mrr_at_10 for r in scored]) if scored else 0.0
|
|
avg_ndcg = statistics.mean([r.ndcg_at_10 for r in scored]) if scored else 0.0
|
|
top3_rate = sum(1 for r in scored if r.top3_hit) / len(scored) if scored else 0.0
|
|
|
|
latencies = [r.latency_ms for r in results if r.latency_ms > 0]
|
|
p50 = percentile(latencies, 0.50)
|
|
p95 = percentile(latencies, 0.95)
|
|
|
|
# 실패 케이스: 결과 0건이어야 정상
|
|
failure_correct = sum(1 for r in failure_cases if not r.returned_ids)
|
|
failure_precision = (
|
|
failure_correct / len(failure_cases) if failure_cases else 0.0
|
|
)
|
|
|
|
print(f"\n=== {label} (n={n}, scored={len(scored)}) ===")
|
|
print(f" Recall@10 : {avg_recall:.3f}")
|
|
print(f" MRR@10 : {avg_mrr:.3f}")
|
|
print(f" NDCG@10 : {avg_ndcg:.3f}")
|
|
print(f" Top-3 hit : {top3_rate:.3f}")
|
|
print(f" Latency p50: {p50:.0f} ms")
|
|
print(f" Latency p95: {p95:.0f} ms")
|
|
if failure_cases:
|
|
print(
|
|
f" Failure-case precision: {failure_correct}/{len(failure_cases)}"
|
|
f" ({failure_precision:.2f}) — empty result expected"
|
|
)
|
|
|
|
# 카테고리별
|
|
by_cat: dict[str, list[QueryResult]] = {}
|
|
for r in scored:
|
|
by_cat.setdefault(r.query.category, []).append(r)
|
|
print(" by category:")
|
|
for cat, items in sorted(by_cat.items()):
|
|
cat_recall = statistics.mean([r.recall_at_10 for r in items])
|
|
cat_ndcg = statistics.mean([r.ndcg_at_10 for r in items])
|
|
print(
|
|
f" {cat:<22} n={len(items):>2} recall={cat_recall:.2f} ndcg={cat_ndcg:.2f}"
|
|
)
|
|
|
|
# 에러 케이스
|
|
errors = [r for r in results if r.error]
|
|
if errors:
|
|
print(f" ERRORS ({len(errors)}):")
|
|
for r in errors:
|
|
print(f" [{r.query.id}] {r.error}")
|
|
|
|
return {
|
|
"n": n,
|
|
"recall_at_10": avg_recall,
|
|
"mrr_at_10": avg_mrr,
|
|
"ndcg_at_10": avg_ndcg,
|
|
"top3_hit_rate": top3_rate,
|
|
"latency_p50": p50,
|
|
"latency_p95": p95,
|
|
"failure_precision": failure_precision,
|
|
}
|
|
|
|
|
|
def write_csv(results: list[QueryResult], output_path: Path) -> None:
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("w", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(
|
|
[
|
|
"label",
|
|
"id",
|
|
"category",
|
|
"intent",
|
|
"domain_hint",
|
|
"query",
|
|
"relevant_ids",
|
|
"returned_ids_top10",
|
|
"latency_ms",
|
|
"recall_at_10",
|
|
"mrr_at_10",
|
|
"ndcg_at_10",
|
|
"top3_hit",
|
|
"error",
|
|
]
|
|
)
|
|
for r in results:
|
|
writer.writerow(
|
|
[
|
|
r.label,
|
|
r.query.id,
|
|
r.query.category,
|
|
r.query.intent,
|
|
r.query.domain_hint,
|
|
r.query.query,
|
|
";".join(map(str, r.query.relevant_ids)),
|
|
";".join(map(str, r.returned_ids[:10])),
|
|
f"{r.latency_ms:.1f}",
|
|
f"{r.recall_at_10:.3f}",
|
|
f"{r.mrr_at_10:.3f}",
|
|
f"{r.ndcg_at_10:.3f}",
|
|
"1" if r.top3_hit else "0",
|
|
r.error or "",
|
|
]
|
|
)
|
|
print(f"\nCSV written: {output_path}")
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# 로딩
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def load_queries(yaml_path: Path) -> list[Query]:
|
|
with yaml_path.open(encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
queries: list[Query] = []
|
|
for q in data["queries"]:
|
|
queries.append(
|
|
Query(
|
|
id=q["id"],
|
|
query=q["query"],
|
|
category=q["category"],
|
|
intent=q["intent"],
|
|
domain_hint=q["domain_hint"],
|
|
relevant_ids=q.get("relevant_ids", []) or [],
|
|
top3_ids=q.get("top3_ids", []) or [],
|
|
notes=q.get("notes", "") or "",
|
|
)
|
|
)
|
|
return queries
|
|
|
|
|
|
# ─────────────────────────────────────────────────────────
|
|
# CLI
|
|
# ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Document Server 검색 평가")
|
|
parser.add_argument(
|
|
"--queries",
|
|
type=Path,
|
|
default=Path(__file__).parent / "queries.yaml",
|
|
help="평가셋 YAML 경로",
|
|
)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
type=str,
|
|
default=None,
|
|
help="단일 평가용 URL (예: https://docs.hyungi.net)",
|
|
)
|
|
parser.add_argument(
|
|
"--baseline-url",
|
|
type=str,
|
|
default=None,
|
|
help="A/B 비교용 baseline URL",
|
|
)
|
|
parser.add_argument(
|
|
"--candidate-url",
|
|
type=str,
|
|
default=None,
|
|
help="A/B 비교용 candidate URL",
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
type=str,
|
|
default="hybrid",
|
|
choices=["fts", "trgm", "vector", "hybrid"],
|
|
help="검색 mode 파라미터",
|
|
)
|
|
parser.add_argument(
|
|
"--fusion",
|
|
type=str,
|
|
default=None,
|
|
choices=["legacy", "rrf", "rrf_boost"],
|
|
help="hybrid 모드 fusion 전략 (Phase 0.5+, 미지정 시 서버 기본값)",
|
|
)
|
|
parser.add_argument(
|
|
"--token",
|
|
type=str,
|
|
default=os.environ.get("DOCSRV_TOKEN"),
|
|
help="Bearer 토큰 (env DOCSRV_TOKEN)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=Path,
|
|
default=None,
|
|
help="CSV 출력 경로 (지정하면 raw 결과 저장)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not args.token:
|
|
print("ERROR: --token 또는 env DOCSRV_TOKEN 필요", file=sys.stderr)
|
|
return 2
|
|
|
|
if not args.base_url and not (args.baseline_url and args.candidate_url):
|
|
print(
|
|
"ERROR: --base-url 또는 (--baseline-url + --candidate-url) 둘 중 하나 필요",
|
|
file=sys.stderr,
|
|
)
|
|
return 2
|
|
|
|
queries = load_queries(args.queries)
|
|
print(f"Loaded {len(queries)} queries from {args.queries}")
|
|
print(f"Mode: {args.mode}", end="")
|
|
if args.fusion:
|
|
print(f" / fusion: {args.fusion}", end="")
|
|
print()
|
|
|
|
all_results: list[QueryResult] = []
|
|
|
|
if args.base_url:
|
|
print(f"\n>>> evaluating: {args.base_url}")
|
|
results = asyncio.run(
|
|
evaluate(queries, args.base_url, args.token, "single", mode=args.mode, fusion=args.fusion)
|
|
)
|
|
print_summary("single", results)
|
|
all_results.extend(results)
|
|
else:
|
|
print(f"\n>>> baseline: {args.baseline_url}")
|
|
baseline_results = asyncio.run(
|
|
evaluate(queries, args.baseline_url, args.token, "baseline", mode=args.mode, fusion=args.fusion)
|
|
)
|
|
baseline_summary = print_summary("baseline", baseline_results)
|
|
|
|
print(f"\n>>> candidate: {args.candidate_url}")
|
|
candidate_results = asyncio.run(
|
|
evaluate(
|
|
queries, args.candidate_url, args.token, "candidate", mode=args.mode, fusion=args.fusion
|
|
)
|
|
)
|
|
candidate_summary = print_summary("candidate", candidate_results)
|
|
|
|
# 델타
|
|
print("\n=== Δ (candidate - baseline) ===")
|
|
for k in (
|
|
"recall_at_10",
|
|
"mrr_at_10",
|
|
"ndcg_at_10",
|
|
"top3_hit_rate",
|
|
"latency_p50",
|
|
"latency_p95",
|
|
):
|
|
delta = candidate_summary[k] - baseline_summary[k]
|
|
sign = "+" if delta >= 0 else ""
|
|
print(f" {k:<16}: {sign}{delta:.3f}")
|
|
|
|
all_results.extend(baseline_results)
|
|
all_results.extend(candidate_results)
|
|
|
|
if args.output:
|
|
write_csv(all_results, args.output)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|