Files
hyungi_document_server/tests/test_query_rewriter.py
T
hyungi ecd2350c15 feat(search): Phase 2Q Diagnose Phase 2 — multi-query retrieval fusion
phase-2q-query-rewrite-diagnose.md v6 plan §5.5 + §7 Phase 2.
Phase 1B 3e6866b (scaffold + dispatcher) 위 retrieval 합성 wire-up.

신규:
- search_pipeline._rrf_fuse_variants() — N variant ranked list RRF 합성.
  fusion_service.RRFOnly 알고리즘 동일 (k=60), 첫 등장 variant representative 보존.
- search_pipeline.search_with_rewrite() — variant N 별 retrieval+fusion 후
  unified RRF (cap 60) → reranker 1회 (query=원본 q) → diversity+freshness+display.
  · per-variant K = 50//3 = 16 (PHASE2Q_PRODUCTION_TOPK//N, A1 채택)
  · variant 별 retrieval asyncio.gather 병렬
  · chunks_by_doc merge (variant 무관 unified reranker input)
  · production fusion_service.get_strategy() + rerank_chunks() 재사용
- 상수: PHASE2Q_PRODUCTION_TOPK=50, PHASE2Q_UNIFIED_CAP=60, PHASE2Q_RRF_K=60.

수정:
- search_pipeline.run_search() — rewrite_backend param 추가. hybrid + cand_<slug> 시
  search_with_rewrite() 위임. baseline/None 시 기존 single-query path 그대로 (invariant).
- app/api/search.py — Phase 1B scaffold discard call 제거. run_search 에 rewrite_backend
  전달. ValueError → 400 (unknown_rewrite_backend 우선 분기) / RuntimeError → 503
  (rewrite_llm_unavailable).
- tests/test_query_rewriter.py — Phase 2 test 9개 추가:
  · _rrf_fuse_variants 6 (single / overlap accumulation / representative / cap limit /
    empty / rank position)
  · search_pipeline import + run_search rewrite_backend default=None signature 1
  · PHASE2Q_* constants 1
  · DATABASE_URL dummy 주입 (api.search import → SQLAlchemy engine init 회피)

30/30 unit test PASS (Phase 1B 21 + Phase 2 9).

baseline 회귀 0 invariant:
- run_search(rewrite_backend=None) → 기존 path 100% 그대로 (분기 first line guard)
- run_search(rewrite_backend=baseline) → 동일
- mode != hybrid → multi-query path 비활성 (text-only/vector-only/trgm 영향 0)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-23 22:41:50 +00:00

326 lines
12 KiB
Python

"""Phase 2Q Diagnose Phase 1B — query_rewriter scaffold + dispatcher 단위 테스트.
가드레일 (plan v6 §5 + §7 Phase 1):
1. `_resolve_rewrite_backend` — slug resolve, unknown ValueError, baseline → None
2. `_cache_key` — deterministic + NFKC normalize + backend slug 분리
3. `_extract_variants` — valid shape / wrong count / type mismatch / empty / non-list
4. cache set/get/TTL (LRU evict 시뮬레이션)
5. `allowed_slugs` — LLM_BACKEND_MAP keys 1:1
"""
from __future__ import annotations
import asyncio
import logging
import os
import sys
import time
import pytest
# logs/llm_gate.log 가 root 소유 (운영 fastapi daemon write) → pytest 가 hyungi user 로
# import 시 PermissionError. 본 test 한정 FileHandler safe-wrap (다른 test 영향 0).
_orig_file_handler = logging.FileHandler
def _safe_file_handler(filename, *args, **kwargs): # type: ignore
try:
return _orig_file_handler(filename, *args, **kwargs)
except PermissionError:
return logging.NullHandler()
logging.FileHandler = _safe_file_handler # type: ignore[assignment]
# Phase 2 test (search_pipeline import) 는 api.search → SQLAlchemy engine init 트리거.
# DATABASE_URL 미설정 시 ArgumentError 로 collection 실패. dummy URL 주입 (실제 connect X).
os.environ.setdefault("DATABASE_URL", "postgresql+asyncpg://test:test@localhost:5432/test")
# tests/ → 프로젝트 루트 → app/
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "app"))
from services.search import query_rewriter
from services.search.query_rewriter import (
EXPECTED_N_VARIANTS,
LLM_BACKEND_MAP,
PROMPT_VERSION,
_cache_key,
_extract_variants,
_resolve_rewrite_backend,
allowed_slugs,
)
# ─── 1. _resolve_rewrite_backend ──────────────────────────
def test_resolve_baseline_returns_none():
assert _resolve_rewrite_backend(None) is None
assert _resolve_rewrite_backend("baseline") is None
def test_resolve_known_slugs():
cfg = _resolve_rewrite_backend("cand_multi_query_macmini")
assert cfg is not None
assert "endpoint" in cfg and "model" in cfg and "sampling" in cfg
assert cfg["model"] == "gemma-4-26b-a4b-it-8bit"
cfg = _resolve_rewrite_backend("cand_multi_query_macbook")
assert cfg is not None
assert cfg["model"] == "mlx-community/Qwen3.6-27B-8bit"
# qwen sampling 에 response_format 없음 (Phase 0 inspect 9 박제)
assert "response_format" not in cfg["sampling"]
def test_resolve_unknown_slug_raises():
with pytest.raises(ValueError, match="unknown_rewrite_backend"):
_resolve_rewrite_backend("cand_bogus")
with pytest.raises(ValueError):
_resolve_rewrite_backend("cand_multi_query_other")
def test_allowed_slugs_matches_map():
assert allowed_slugs() == list(LLM_BACKEND_MAP.keys())
assert "baseline" in allowed_slugs()
assert "cand_multi_query_macmini" in allowed_slugs()
assert "cand_multi_query_macbook" in allowed_slugs()
# ─── 2. _cache_key ────────────────────────────────────────
def test_cache_key_deterministic():
k1 = _cache_key("산업안전보건법 제6장", "cand_multi_query_macmini")
k2 = _cache_key("산업안전보건법 제6장", "cand_multi_query_macmini")
assert k1 == k2
assert len(k1) == 32 # sha256[:32]
def test_cache_key_nfkc_normalize_and_strip_lower():
# whitespace + uppercase → 동일 key
base = _cache_key("ASME Section VIII", "cand_multi_query_macmini")
assert _cache_key(" asme section viii ", "cand_multi_query_macmini") == base
assert _cache_key("ASME SECTION VIII", "cand_multi_query_macmini") == base
def test_cache_key_differs_by_backend_slug():
k_a = _cache_key("query", "cand_multi_query_macmini")
k_b = _cache_key("query", "cand_multi_query_macbook")
assert k_a != k_b
def test_cache_key_includes_prompt_version():
# PROMPT_VERSION 변경 시 cache 분리 — 직접 test 어렵지만 raw 구성 확인
assert PROMPT_VERSION == "v1"
k = _cache_key("query", "cand_multi_query_macmini")
assert len(k) == 32
# ─── 3. _extract_variants ─────────────────────────────────
def test_extract_variants_valid_shape():
raw = '{"variants": ["원본", "한국어 변형", "english"]}'
out = _extract_variants(raw, expected_n=3)
assert out == ["원본", "한국어 변형", "english"]
def test_extract_variants_strips_whitespace():
raw = '{"variants": [" 원본 ", "한국어\\n", " english "]}'
out = _extract_variants(raw, expected_n=3)
assert out == ["원본", "한국어", "english"]
def test_extract_variants_wrong_count_returns_none():
raw = '{"variants": ["only_one"]}'
assert _extract_variants(raw, expected_n=3) is None
raw = '{"variants": ["a", "b", "c", "d"]}'
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_missing_key_returns_none():
raw = '{"queries": ["a", "b", "c"]}'
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_non_list_returns_none():
raw = '{"variants": "single string"}'
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_empty_string_returns_none():
raw = '{"variants": ["a", "", "c"]}'
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_non_string_element_returns_none():
raw = '{"variants": ["a", 123, "c"]}'
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_invalid_json_returns_none():
raw = "not json at all"
assert _extract_variants(raw, expected_n=3) is None
def test_extract_variants_markdown_fence_fallback():
# parse_json_response 가 ```json fenced 블록 내부 추출 — production parser 재사용 검증
raw = '```json\n{"variants": ["a", "b", "c"]}\n```'
out = _extract_variants(raw, expected_n=3)
assert out == ["a", "b", "c"]
# ─── 4. cache set / get ───────────────────────────────────
@pytest.mark.asyncio
async def test_cache_set_get_roundtrip():
# 격리: 전역 _CACHE 초기화 (다른 테스트와 격리)
query_rewriter._CACHE.clear()
key = _cache_key("__test_unique_key__", "cand_multi_query_macmini")
assert await query_rewriter._get_cached(key) is None
await query_rewriter._set_cached(key, ["v0", "v1", "v2"])
out = await query_rewriter._get_cached(key)
assert out == ["v0", "v1", "v2"]
@pytest.mark.asyncio
async def test_cache_ttl_expiry():
query_rewriter._CACHE.clear()
key = "ttl_test_key"
# manual entry with past expire_at
query_rewriter._CACHE[key] = (time.time() - 1.0, ["a", "b", "c"])
assert await query_rewriter._get_cached(key) is None
# lazy delete verify
assert key not in query_rewriter._CACHE
@pytest.mark.asyncio
async def test_cache_returns_copy_not_reference():
"""_get_cached 반환 list 를 외부에서 수정해도 internal cache 안전."""
query_rewriter._CACHE.clear()
key = "copy_test_key"
await query_rewriter._set_cached(key, ["a", "b", "c"])
out = await query_rewriter._get_cached(key)
out.append("mutated")
out2 = await query_rewriter._get_cached(key)
assert out2 == ["a", "b", "c"]
# ─── 5. constants ─────────────────────────────────────────
def test_constants_match_plan_v6():
assert PROMPT_VERSION == "v1"
assert EXPECTED_N_VARIANTS == 3
assert query_rewriter.LLM_REWRITE_TIMEOUT_MS == 15000
assert query_rewriter.CACHE_TTL == 86400
assert query_rewriter.CACHE_MAXSIZE == 1000
# ─── 6. Phase 2 — _rrf_fuse_variants 합성 알고리즘 ────────
def _mk_search_result(doc_id: int, score: float = 1.0, match_reason: str = "test"):
"""SearchResult 인스턴스 (api.search 의 BaseModel). file_format 은 str 필수."""
from api.search import SearchResult
return SearchResult(
id=doc_id, title=f"doc-{doc_id}", ai_domain=None,
ai_summary=None, file_format="pdf",
score=score, snippet=None, match_reason=match_reason,
)
def test_rrf_fuse_variants_single_variant_preserves_order():
from services.search.search_pipeline import _rrf_fuse_variants
docs = [_mk_search_result(i) for i in (10, 20, 30)]
out = _rrf_fuse_variants([docs], k=60, limit=10)
assert [r.id for r in out] == [10, 20, 30]
# RRF score = 1/(60+1) > 1/(60+2) > 1/(60+3)
assert out[0].score > out[1].score > out[2].score
assert "multi_query_rrf" in out[0].match_reason
def test_rrf_fuse_variants_accumulates_overlapping_doc_ids():
"""같은 doc_id 가 여러 variant 에서 top rank 면 점수 누적 → 상위."""
from services.search.search_pipeline import _rrf_fuse_variants
v1 = [_mk_search_result(i) for i in (10, 20, 30)]
v2 = [_mk_search_result(i) for i in (40, 10, 50)] # 10 이 두 variant 모두 등장
out = _rrf_fuse_variants([v1, v2], k=60, limit=10)
# 10 = 1/61 + 1/62 (rank 1 + rank 2). 다른 doc 은 1 variant 만 → 단일 RRF score.
ids = [r.id for r in out]
assert ids[0] == 10 # 누적 점수 최상위
# 40 (1/61) vs 20 (1/62) — variant 1 에서 rank 1 인 40 이 단일 등장 doc 중 최상위
assert ids[1] == 40
assert set(ids) == {10, 20, 30, 40, 50}
def test_rrf_fuse_variants_first_variant_representative():
"""같은 doc_id 가 여러 variant 에 있으면 첫 등장 variant 의 SearchResult 보존."""
from services.search.search_pipeline import _rrf_fuse_variants
v1 = [_mk_search_result(10, match_reason="from_v1")]
v2 = [_mk_search_result(10, match_reason="from_v2")]
out = _rrf_fuse_variants([v1, v2], k=60, limit=10)
assert len(out) == 1
assert out[0].id == 10
assert "from_v1" in out[0].match_reason # 첫 등장 보존
assert "multi_query_rrf" in out[0].match_reason
def test_rrf_fuse_variants_respects_limit_cap():
from services.search.search_pipeline import _rrf_fuse_variants
v1 = [_mk_search_result(i) for i in range(100, 130)] # 30 docs
v2 = [_mk_search_result(i) for i in range(200, 230)] # 30 docs, 모두 unique
out = _rrf_fuse_variants([v1, v2], k=60, limit=5)
assert len(out) == 5
def test_rrf_fuse_variants_empty_lists_returns_empty():
from services.search.search_pipeline import _rrf_fuse_variants
assert _rrf_fuse_variants([], k=60, limit=10) == []
assert _rrf_fuse_variants([[], [], []], k=60, limit=10) == []
def test_rrf_fuse_variants_rank_position_matters():
"""variant 가 길어져도 RRF 공식이 rank 만 사용."""
from services.search.search_pipeline import _rrf_fuse_variants
v1 = [_mk_search_result(10)] # rank 1
v2 = [_mk_search_result(99), _mk_search_result(10)] # 10 이 rank 2
out = _rrf_fuse_variants([v1, v2], k=60, limit=10)
# 10 = 1/61 + 1/62, 99 = 1/61. 둘 다 등장 doc 중 10 점수 높음.
assert out[0].id == 10
assert out[1].id == 99
# ─── 7. Phase 2 — search_pipeline import + run_search signature ───
def test_search_pipeline_imports_query_rewriter():
"""search_pipeline 이 query_rewriter 를 import 하는지 (dispatch 분기 활성)."""
from services.search import search_pipeline
assert hasattr(search_pipeline, "query_rewriter")
assert hasattr(search_pipeline, "search_with_rewrite")
assert hasattr(search_pipeline, "_rrf_fuse_variants")
def test_run_search_has_rewrite_backend_param():
"""run_search signature 에 rewrite_backend 가 추가됐는지."""
import inspect
from services.search.search_pipeline import run_search
sig = inspect.signature(run_search)
assert "rewrite_backend" in sig.parameters
# default = None (baseline 회귀 0 invariant)
assert sig.parameters["rewrite_backend"].default is None
def test_phase2q_constants():
"""plan v6 §5.5 박제값."""
from services.search.search_pipeline import (
PHASE2Q_PRODUCTION_TOPK,
PHASE2Q_RRF_K,
PHASE2Q_UNIFIED_CAP,
)
assert PHASE2Q_PRODUCTION_TOPK == 50
assert PHASE2Q_RRF_K == 60
assert PHASE2Q_UNIFIED_CAP == 60
# per-variant K = 50 // 3 = 16 (A1 채택)
assert PHASE2Q_PRODUCTION_TOPK // EXPECTED_N_VARIANTS == 16