""" 트랜잭션 관리 유틸리티 데이터 일관성을 위한 트랜잭션 관리 및 데코레이터 """ import functools from typing import Any, Callable, Optional, TypeVar, Generic from contextlib import contextmanager from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError import asyncio from .logger import get_logger logger = get_logger(__name__) T = TypeVar('T') class TransactionManager: """트랜잭션 관리 클래스""" def __init__(self, db: Session): self.db = db @contextmanager def transaction(self, rollback_on_exception: bool = True): """ 트랜잭션 컨텍스트 매니저 Args: rollback_on_exception: 예외 발생 시 롤백 여부 """ try: logger.debug("트랜잭션 시작") yield self.db self.db.commit() logger.debug("트랜잭션 커밋 완료") except Exception as e: if rollback_on_exception: self.db.rollback() logger.warning(f"트랜잭션 롤백 - 에러: {str(e)}") else: logger.error(f"트랜잭션 에러 (롤백 안함) - 에러: {str(e)}") raise @contextmanager def savepoint(self, name: Optional[str] = None): """ 세이브포인트 컨텍스트 매니저 Args: name: 세이브포인트 이름 """ savepoint_name = name or f"sp_{id(self)}" try: # 세이브포인트 생성 savepoint = self.db.begin_nested() logger.debug(f"세이브포인트 생성: {savepoint_name}") yield self.db # 세이브포인트 커밋 savepoint.commit() logger.debug(f"세이브포인트 커밋: {savepoint_name}") except Exception as e: # 세이브포인트 롤백 savepoint.rollback() logger.warning(f"세이브포인트 롤백: {savepoint_name} - 에러: {str(e)}") raise def execute_in_transaction(self, func: Callable[..., T], *args, **kwargs) -> T: """ 함수를 트랜잭션 내에서 실행 Args: func: 실행할 함수 *args: 함수 인자 **kwargs: 함수 키워드 인자 Returns: 함수 실행 결과 """ with self.transaction(): return func(*args, **kwargs) async def execute_in_transaction_async(self, func: Callable[..., T], *args, **kwargs) -> T: """ 비동기 함수를 트랜잭션 내에서 실행 Args: func: 실행할 비동기 함수 *args: 함수 인자 **kwargs: 함수 키워드 인자 Returns: 함수 실행 결과 """ with self.transaction(): if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) else: return func(*args, **kwargs) def transactional(rollback_on_exception: bool = True): """ 트랜잭션 데코레이터 Args: rollback_on_exception: 예외 발생 시 롤백 여부 """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): # 첫 번째 인자가 Session인지 확인 if args and isinstance(args[0], Session): db = args[0] transaction_manager = TransactionManager(db) try: with transaction_manager.transaction(rollback_on_exception): return func(*args, **kwargs) except Exception as e: logger.error(f"트랜잭션 함수 실행 실패: {func.__name__} - {str(e)}") raise else: # Session이 없으면 일반 함수로 실행 return func(*args, **kwargs) return wrapper return decorator def async_transactional(rollback_on_exception: bool = True): """ 비동기 트랜잭션 데코레이터 Args: rollback_on_exception: 예외 발생 시 롤백 여부 """ def decorator(func: Callable) -> Callable: @functools.wraps(func) async def wrapper(*args, **kwargs): # 첫 번째 인자가 Session인지 확인 if args and isinstance(args[0], Session): db = args[0] transaction_manager = TransactionManager(db) try: with transaction_manager.transaction(rollback_on_exception): if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) else: return func(*args, **kwargs) except Exception as e: logger.error(f"비동기 트랜잭션 함수 실행 실패: {func.__name__} - {str(e)}") raise else: # Session이 없으면 일반 함수로 실행 if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper return decorator class BatchProcessor: """배치 처리를 위한 트랜잭션 관리""" def __init__(self, db: Session, batch_size: int = 1000): self.db = db self.batch_size = batch_size self.transaction_manager = TransactionManager(db) def process_in_batches( self, items: list, process_func: Callable, commit_per_batch: bool = True ): """ 아이템들을 배치 단위로 처리 Args: items: 처리할 아이템 리스트 process_func: 각 아이템을 처리할 함수 commit_per_batch: 배치마다 커밋 여부 """ total_items = len(items) processed_count = 0 failed_count = 0 logger.info(f"배치 처리 시작 - 총 {total_items}개 아이템, 배치 크기: {self.batch_size}") for i in range(0, total_items, self.batch_size): batch = items[i:i + self.batch_size] batch_num = (i // self.batch_size) + 1 try: if commit_per_batch: with self.transaction_manager.transaction(): self._process_batch(batch, process_func) else: self._process_batch(batch, process_func) processed_count += len(batch) logger.debug(f"배치 {batch_num} 처리 완료 - {len(batch)}개 아이템") except Exception as e: failed_count += len(batch) logger.error(f"배치 {batch_num} 처리 실패 - {str(e)}") # 개별 아이템 처리 시도 if commit_per_batch: self._process_batch_individually(batch, process_func) # 전체 커밋 (배치마다 커밋하지 않은 경우) if not commit_per_batch: try: self.db.commit() logger.info("전체 배치 처리 커밋 완료") except Exception as e: self.db.rollback() logger.error(f"전체 배치 처리 커밋 실패: {str(e)}") raise logger.info(f"배치 처리 완료 - 성공: {processed_count}, 실패: {failed_count}") return { "total_items": total_items, "processed_count": processed_count, "failed_count": failed_count, "success_rate": (processed_count / total_items) * 100 if total_items > 0 else 0 } def _process_batch(self, batch: list, process_func: Callable): """배치 처리""" for item in batch: process_func(item) def _process_batch_individually(self, batch: list, process_func: Callable): """배치 내 아이템을 개별적으로 처리 (에러 복구용)""" for item in batch: try: with self.transaction_manager.savepoint(): process_func(item) except Exception as e: logger.warning(f"개별 아이템 처리 실패: {str(e)}") class DatabaseLock: """데이터베이스 레벨 락 관리""" def __init__(self, db: Session): self.db = db @contextmanager def advisory_lock(self, lock_id: int): """ PostgreSQL Advisory Lock Args: lock_id: 락 ID """ try: # Advisory Lock 획득 result = self.db.execute(f"SELECT pg_advisory_lock({lock_id})") logger.debug(f"Advisory Lock 획득: {lock_id}") yield finally: # Advisory Lock 해제 self.db.execute(f"SELECT pg_advisory_unlock({lock_id})") logger.debug(f"Advisory Lock 해제: {lock_id}") @contextmanager def table_lock(self, table_name: str, lock_mode: str = "ACCESS EXCLUSIVE"): """ 테이블 레벨 락 Args: table_name: 테이블명 lock_mode: 락 모드 """ try: # 테이블 락 획득 self.db.execute(f"LOCK TABLE {table_name} IN {lock_mode} MODE") logger.debug(f"테이블 락 획득: {table_name} ({lock_mode})") yield except Exception as e: logger.error(f"테이블 락 실패: {table_name} - {str(e)}") raise class TransactionStats: """트랜잭션 통계 수집""" def __init__(self): self.stats = { "total_transactions": 0, "successful_transactions": 0, "failed_transactions": 0, "rollback_count": 0, "savepoint_count": 0 } def record_transaction_start(self): """트랜잭션 시작 기록""" self.stats["total_transactions"] += 1 def record_transaction_success(self): """트랜잭션 성공 기록""" self.stats["successful_transactions"] += 1 def record_transaction_failure(self): """트랜잭션 실패 기록""" self.stats["failed_transactions"] += 1 def record_rollback(self): """롤백 기록""" self.stats["rollback_count"] += 1 def record_savepoint(self): """세이브포인트 기록""" self.stats["savepoint_count"] += 1 def get_stats(self) -> dict: """통계 반환""" total = self.stats["total_transactions"] if total > 0: self.stats["success_rate"] = (self.stats["successful_transactions"] / total) * 100 self.stats["failure_rate"] = (self.stats["failed_transactions"] / total) * 100 else: self.stats["success_rate"] = 0 self.stats["failure_rate"] = 0 return self.stats.copy() def reset_stats(self): """통계 초기화""" for key in self.stats: if key not in ["success_rate", "failure_rate"]: self.stats[key] = 0 # 전역 트랜잭션 통계 인스턴스 transaction_stats = TransactionStats()