"""PostgreSQL 연결 — SQLAlchemy async engine + session factory""" import logging import re import time from pathlib import Path from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase from core.config import settings logger = logging.getLogger("migration") engine = create_async_engine( settings.database_url, echo=False, pool_size=10, max_overflow=20, ) async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) class Base(DeclarativeBase): pass # NOTE: 모든 pending migration은 단일 트랜잭션으로 실행됨. # DDL이 많거나 대량 데이터 변경이 포함된 migration은 장시간 lock을 유발할 수 있음. _MIGRATION_VERSION_RE = re.compile(r"^(\d+)_") _MIGRATION_LOCK_KEY = 938475 def _parse_migration_files(migrations_dir: Path) -> list[tuple[int, str, Path]]: """migration 파일 스캔 → (version, name, path) 리스트, 버전순 정렬""" files = [] for p in sorted(migrations_dir.glob("*.sql")): m = _MIGRATION_VERSION_RE.match(p.name) if not m: continue version = int(m.group(1)) files.append((version, p.name, p)) # 중복 버전 검사 seen: dict[int, str] = {} for version, name, _ in files: if version in seen: raise RuntimeError( f"migration 버전 중복: {seen[version]} vs {name} (version={version})" ) seen[version] = name files.sort(key=lambda x: x[0]) return files def _validate_sql_content(name: str, sql: str) -> None: """migration SQL에 BEGIN/COMMIT이 포함되어 있으면 에러 (외부 트랜잭션 깨짐 방지)""" # 주석(-- ...) 라인 제거 후 검사 lines = [ line for line in sql.splitlines() if not line.strip().startswith("--") ] stripped = "\n".join(lines).upper() for keyword in ("BEGIN", "COMMIT", "ROLLBACK"): # 단어 경계로 매칭 (예: BEGIN_SOMETHING은 제외) if re.search(rf"\b{keyword}\b", stripped): raise RuntimeError( f"migration {name}에 {keyword} 포함됨 — " f"migration SQL에는 트랜잭션 제어문을 넣지 마세요" ) async def _run_migrations(conn) -> None: """미적용 migration 실행 (호출자가 트랜잭션 관리)""" from sqlalchemy import text # schema_migrations 테이블 생성 await conn.execute(text(""" CREATE TABLE IF NOT EXISTS schema_migrations ( version INT PRIMARY KEY, name TEXT NOT NULL, applied_at TIMESTAMPTZ DEFAULT NOW() ) """)) # advisory lock 획득 (트랜잭션 끝나면 자동 해제) await conn.execute(text( f"SELECT pg_advisory_xact_lock({_MIGRATION_LOCK_KEY})" )) # 적용 이력 조회 result = await conn.execute(text("SELECT version FROM schema_migrations")) applied = {row[0] for row in result} # migration 파일 스캔 migrations_dir = Path(__file__).resolve().parent.parent.parent / "migrations" if not migrations_dir.is_dir(): logger.info("[migration] migrations/ 디렉토리 없음, 스킵") return files = _parse_migration_files(migrations_dir) pending = [(v, name, path) for v, name, path in files if v not in applied] if not pending: logger.info("[migration] 미적용 migration 없음") return start = time.monotonic() logger.info(f"[migration] {len(pending)}건 적용 시작") for version, name, path in pending: sql = path.read_text(encoding="utf-8") _validate_sql_content(name, sql) logger.info(f"[migration] {name} 실행 중...") await conn.execute(text(sql)) await conn.execute( text("INSERT INTO schema_migrations (version, name) VALUES (:v, :n)"), {"v": version, "n": name}, ) logger.info(f"[migration] {name} 완료") elapsed = time.monotonic() - start logger.info(f"[migration] 전체 {len(pending)}건 완료 ({elapsed:.1f}s)") async def init_db(): """DB 연결 확인 + pending migration 실행""" from sqlalchemy import text async with engine.begin() as conn: await conn.execute(text("SELECT 1")) try: await _run_migrations(conn) except Exception as e: logger.error(f"[migration] 실패: {e} — 전체 트랜잭션 롤백") raise async def get_session() -> AsyncSession: """FastAPI Depends용 세션 제공""" async with async_session() as session: yield session