test(search): Phase 0.2 평가셋 + 평가 스크립트
22개 쿼리(6개 카테고리)와 Recall/MRR/NDCG@10 + latency p50/p95 측정 스크립트 추가. wiggly-weaving-puppy 플랜 Phase 0.2 산출물. - queries.yaml: 정확키워드/한국어자연어/crosslingual/뉴스/실패 케이스 실제 코퍼스(2026-04-07, 753 docs) 기반 정답 doc_id 매핑 - run_eval.py: 단일 평가 + A/B 비교 모드, CSV 저장 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
473
tests/search_eval/run_eval.py
Normal file
473
tests/search_eval/run_eval.py
Normal file
@@ -0,0 +1,473 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Document Server 검색 평가 스크립트 (Phase 0.2)
|
||||
|
||||
queries.yaml을 읽어 /api/search 엔드포인트에 호출하고
|
||||
Recall@10, MRR@10, NDCG@10, Top3 hit-rate, Latency p50/p95를 계산한다.
|
||||
|
||||
A/B 비교 모드: --baseline-url, --candidate-url 를 각각 지정하면
|
||||
두 엔드포인트에 동일 쿼리셋을 던지고 결과를 비교한다.
|
||||
|
||||
사용 예:
|
||||
|
||||
# 단일 평가
|
||||
export DOCSRV_TOKEN="eyJ..."
|
||||
python tests/search_eval/run_eval.py \
|
||||
--base-url https://docs.hyungi.net \
|
||||
--output reports/baseline_2026-04-07.csv
|
||||
|
||||
# A/B 비교 (같은 토큰)
|
||||
python tests/search_eval/run_eval.py \
|
||||
--baseline-url https://docs.hyungi.net \
|
||||
--candidate-url http://localhost:8000 \
|
||||
--output reports/phase1_vs_baseline.csv
|
||||
|
||||
토큰은 env DOCSRV_TOKEN 또는 --token 플래그로 전달.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import csv
|
||||
import math
|
||||
import os
|
||||
import statistics
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# 데이터 구조
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class Query:
|
||||
id: str
|
||||
query: str
|
||||
category: str
|
||||
intent: str
|
||||
domain_hint: str
|
||||
relevant_ids: list[int]
|
||||
top3_ids: list[int] = field(default_factory=list)
|
||||
notes: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
query: Query
|
||||
label: str # "baseline" or "candidate"
|
||||
returned_ids: list[int]
|
||||
latency_ms: float
|
||||
recall_at_10: float
|
||||
mrr_at_10: float
|
||||
ndcg_at_10: float
|
||||
top3_hit: bool
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# 평가 지표
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def recall_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
||||
"""top-k 안에 들어간 정답 비율. 정답 0개면 1.0(빈 케이스는 별도 fail metric)."""
|
||||
if not relevant:
|
||||
return 1.0 if not returned else 0.0 # 비어야 정상인 케이스: 결과 있으면 fail
|
||||
top_k = set(returned[:k])
|
||||
hits = sum(1 for doc_id in relevant if doc_id in top_k)
|
||||
return hits / len(relevant)
|
||||
|
||||
|
||||
def mrr_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
||||
"""top-k 안 첫 정답의 reciprocal rank. 정답 없으면 0."""
|
||||
if not relevant:
|
||||
return 0.0
|
||||
relevant_set = set(relevant)
|
||||
for rank, doc_id in enumerate(returned[:k], start=1):
|
||||
if doc_id in relevant_set:
|
||||
return 1.0 / rank
|
||||
return 0.0
|
||||
|
||||
|
||||
def ndcg_at_k(returned: list[int], relevant: list[int], k: int = 10) -> float:
|
||||
"""binary relevance 기반 NDCG@k. top3_ids 같은 가중치는 v0.1에선 무시."""
|
||||
if not relevant:
|
||||
return 0.0
|
||||
relevant_set = set(relevant)
|
||||
dcg = 0.0
|
||||
for rank, doc_id in enumerate(returned[:k], start=1):
|
||||
if doc_id in relevant_set:
|
||||
# binary gain = 1, DCG = 1 / log2(rank+1)
|
||||
dcg += 1.0 / math.log2(rank + 1)
|
||||
# ideal DCG: 정답을 1..min(len(relevant), k) 위치에 모두 채운 경우
|
||||
ideal_hits = min(len(relevant), k)
|
||||
idcg = sum(1.0 / math.log2(r + 1) for r in range(1, ideal_hits + 1))
|
||||
return dcg / idcg if idcg > 0 else 0.0
|
||||
|
||||
|
||||
def top3_hit(returned: list[int], top3_ids: list[int]) -> bool:
|
||||
"""top3_ids가 비어있으면 True (체크 안함). 있으면 그 중 하나라도 top-3에 들어와야 함."""
|
||||
if not top3_ids:
|
||||
return True
|
||||
top3 = set(returned[:3])
|
||||
return any(doc_id in top3 for doc_id in top3_ids)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# API 호출
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def call_search(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
token: str,
|
||||
query: str,
|
||||
mode: str = "hybrid",
|
||||
limit: int = 20,
|
||||
) -> tuple[list[int], float]:
|
||||
"""검색 API 호출 → (doc_ids, latency_ms)."""
|
||||
url = f"{base_url.rstrip('/')}/api/search/"
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
params = {"q": query, "mode": mode, "limit": limit}
|
||||
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
response = await client.get(url, headers=headers, params=params, timeout=30.0)
|
||||
latency_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
returned_ids = [r["id"] for r in data.get("results", [])]
|
||||
return returned_ids, latency_ms
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# 평가 실행
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def evaluate(
|
||||
queries: list[Query],
|
||||
base_url: str,
|
||||
token: str,
|
||||
label: str,
|
||||
mode: str = "hybrid",
|
||||
) -> list[QueryResult]:
|
||||
"""전체 쿼리셋 평가."""
|
||||
results: list[QueryResult] = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
for q in queries:
|
||||
try:
|
||||
returned_ids, latency_ms = await call_search(
|
||||
client, base_url, token, q.query, mode=mode
|
||||
)
|
||||
results.append(
|
||||
QueryResult(
|
||||
query=q,
|
||||
label=label,
|
||||
returned_ids=returned_ids,
|
||||
latency_ms=latency_ms,
|
||||
recall_at_10=recall_at_k(returned_ids, q.relevant_ids, 10),
|
||||
mrr_at_10=mrr_at_k(returned_ids, q.relevant_ids, 10),
|
||||
ndcg_at_10=ndcg_at_k(returned_ids, q.relevant_ids, 10),
|
||||
top3_hit=top3_hit(returned_ids, q.top3_ids),
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
results.append(
|
||||
QueryResult(
|
||||
query=q,
|
||||
label=label,
|
||||
returned_ids=[],
|
||||
latency_ms=0.0,
|
||||
recall_at_10=0.0,
|
||||
mrr_at_10=0.0,
|
||||
ndcg_at_10=0.0,
|
||||
top3_hit=False,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# 결과 집계 / 출력
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
k = (len(s) - 1) * p
|
||||
f = int(k)
|
||||
c = min(f + 1, len(s) - 1)
|
||||
if f == c:
|
||||
return s[f]
|
||||
return s[f] + (s[c] - s[f]) * (k - f)
|
||||
|
||||
|
||||
def print_summary(label: str, results: list[QueryResult]) -> dict[str, float]:
|
||||
"""전체 + 카테고리별 요약 출력. 집계 dict 반환."""
|
||||
n = len(results)
|
||||
if n == 0:
|
||||
return {}
|
||||
|
||||
# 실패 케이스(relevant_ids=[])는 평균 recall/mrr/ndcg에서 제외
|
||||
scored = [r for r in results if r.query.relevant_ids]
|
||||
failure_cases = [r for r in results if not r.query.relevant_ids]
|
||||
|
||||
avg_recall = statistics.mean([r.recall_at_10 for r in scored]) if scored else 0.0
|
||||
avg_mrr = statistics.mean([r.mrr_at_10 for r in scored]) if scored else 0.0
|
||||
avg_ndcg = statistics.mean([r.ndcg_at_10 for r in scored]) if scored else 0.0
|
||||
top3_rate = sum(1 for r in scored if r.top3_hit) / len(scored) if scored else 0.0
|
||||
|
||||
latencies = [r.latency_ms for r in results if r.latency_ms > 0]
|
||||
p50 = percentile(latencies, 0.50)
|
||||
p95 = percentile(latencies, 0.95)
|
||||
|
||||
# 실패 케이스: 결과 0건이어야 정상
|
||||
failure_correct = sum(1 for r in failure_cases if not r.returned_ids)
|
||||
failure_precision = (
|
||||
failure_correct / len(failure_cases) if failure_cases else 0.0
|
||||
)
|
||||
|
||||
print(f"\n=== {label} (n={n}, scored={len(scored)}) ===")
|
||||
print(f" Recall@10 : {avg_recall:.3f}")
|
||||
print(f" MRR@10 : {avg_mrr:.3f}")
|
||||
print(f" NDCG@10 : {avg_ndcg:.3f}")
|
||||
print(f" Top-3 hit : {top3_rate:.3f}")
|
||||
print(f" Latency p50: {p50:.0f} ms")
|
||||
print(f" Latency p95: {p95:.0f} ms")
|
||||
if failure_cases:
|
||||
print(
|
||||
f" Failure-case precision: {failure_correct}/{len(failure_cases)}"
|
||||
f" ({failure_precision:.2f}) — empty result expected"
|
||||
)
|
||||
|
||||
# 카테고리별
|
||||
by_cat: dict[str, list[QueryResult]] = {}
|
||||
for r in scored:
|
||||
by_cat.setdefault(r.query.category, []).append(r)
|
||||
print(" by category:")
|
||||
for cat, items in sorted(by_cat.items()):
|
||||
cat_recall = statistics.mean([r.recall_at_10 for r in items])
|
||||
cat_ndcg = statistics.mean([r.ndcg_at_10 for r in items])
|
||||
print(
|
||||
f" {cat:<22} n={len(items):>2} recall={cat_recall:.2f} ndcg={cat_ndcg:.2f}"
|
||||
)
|
||||
|
||||
# 에러 케이스
|
||||
errors = [r for r in results if r.error]
|
||||
if errors:
|
||||
print(f" ERRORS ({len(errors)}):")
|
||||
for r in errors:
|
||||
print(f" [{r.query.id}] {r.error}")
|
||||
|
||||
return {
|
||||
"n": n,
|
||||
"recall_at_10": avg_recall,
|
||||
"mrr_at_10": avg_mrr,
|
||||
"ndcg_at_10": avg_ndcg,
|
||||
"top3_hit_rate": top3_rate,
|
||||
"latency_p50": p50,
|
||||
"latency_p95": p95,
|
||||
"failure_precision": failure_precision,
|
||||
}
|
||||
|
||||
|
||||
def write_csv(results: list[QueryResult], output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(
|
||||
[
|
||||
"label",
|
||||
"id",
|
||||
"category",
|
||||
"intent",
|
||||
"domain_hint",
|
||||
"query",
|
||||
"relevant_ids",
|
||||
"returned_ids_top10",
|
||||
"latency_ms",
|
||||
"recall_at_10",
|
||||
"mrr_at_10",
|
||||
"ndcg_at_10",
|
||||
"top3_hit",
|
||||
"error",
|
||||
]
|
||||
)
|
||||
for r in results:
|
||||
writer.writerow(
|
||||
[
|
||||
r.label,
|
||||
r.query.id,
|
||||
r.query.category,
|
||||
r.query.intent,
|
||||
r.query.domain_hint,
|
||||
r.query.query,
|
||||
";".join(map(str, r.query.relevant_ids)),
|
||||
";".join(map(str, r.returned_ids[:10])),
|
||||
f"{r.latency_ms:.1f}",
|
||||
f"{r.recall_at_10:.3f}",
|
||||
f"{r.mrr_at_10:.3f}",
|
||||
f"{r.ndcg_at_10:.3f}",
|
||||
"1" if r.top3_hit else "0",
|
||||
r.error or "",
|
||||
]
|
||||
)
|
||||
print(f"\nCSV written: {output_path}")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# 로딩
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_queries(yaml_path: Path) -> list[Query]:
|
||||
with yaml_path.open(encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
queries: list[Query] = []
|
||||
for q in data["queries"]:
|
||||
queries.append(
|
||||
Query(
|
||||
id=q["id"],
|
||||
query=q["query"],
|
||||
category=q["category"],
|
||||
intent=q["intent"],
|
||||
domain_hint=q["domain_hint"],
|
||||
relevant_ids=q.get("relevant_ids", []) or [],
|
||||
top3_ids=q.get("top3_ids", []) or [],
|
||||
notes=q.get("notes", "") or "",
|
||||
)
|
||||
)
|
||||
return queries
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# CLI
|
||||
# ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Document Server 검색 평가")
|
||||
parser.add_argument(
|
||||
"--queries",
|
||||
type=Path,
|
||||
default=Path(__file__).parent / "queries.yaml",
|
||||
help="평가셋 YAML 경로",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="단일 평가용 URL (예: https://docs.hyungi.net)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baseline-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A/B 비교용 baseline URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--candidate-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A/B 비교용 candidate URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["fts", "trgm", "vector", "hybrid"],
|
||||
help="검색 mode 파라미터",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
default=os.environ.get("DOCSRV_TOKEN"),
|
||||
help="Bearer 토큰 (env DOCSRV_TOKEN)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="CSV 출력 경로 (지정하면 raw 결과 저장)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.token:
|
||||
print("ERROR: --token 또는 env DOCSRV_TOKEN 필요", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
if not args.base_url and not (args.baseline_url and args.candidate_url):
|
||||
print(
|
||||
"ERROR: --base-url 또는 (--baseline-url + --candidate-url) 둘 중 하나 필요",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
queries = load_queries(args.queries)
|
||||
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||
print(f"Mode: {args.mode}")
|
||||
|
||||
all_results: list[QueryResult] = []
|
||||
|
||||
if args.base_url:
|
||||
print(f"\n>>> evaluating: {args.base_url}")
|
||||
results = asyncio.run(
|
||||
evaluate(queries, args.base_url, args.token, "single", mode=args.mode)
|
||||
)
|
||||
print_summary("single", results)
|
||||
all_results.extend(results)
|
||||
else:
|
||||
print(f"\n>>> baseline: {args.baseline_url}")
|
||||
baseline_results = asyncio.run(
|
||||
evaluate(queries, args.baseline_url, args.token, "baseline", mode=args.mode)
|
||||
)
|
||||
baseline_summary = print_summary("baseline", baseline_results)
|
||||
|
||||
print(f"\n>>> candidate: {args.candidate_url}")
|
||||
candidate_results = asyncio.run(
|
||||
evaluate(
|
||||
queries, args.candidate_url, args.token, "candidate", mode=args.mode
|
||||
)
|
||||
)
|
||||
candidate_summary = print_summary("candidate", candidate_results)
|
||||
|
||||
# 델타
|
||||
print("\n=== Δ (candidate - baseline) ===")
|
||||
for k in (
|
||||
"recall_at_10",
|
||||
"mrr_at_10",
|
||||
"ndcg_at_10",
|
||||
"top3_hit_rate",
|
||||
"latency_p50",
|
||||
"latency_p95",
|
||||
):
|
||||
delta = candidate_summary[k] - baseline_summary[k]
|
||||
sign = "+" if delta >= 0 else ""
|
||||
print(f" {k:<16}: {sign}{delta:.3f}")
|
||||
|
||||
all_results.extend(baseline_results)
|
||||
all_results.extend(candidate_results)
|
||||
|
||||
if args.output:
|
||||
write_csv(all_results, args.output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user