Files
ai-server/src/test_nllb_fixed.py
hyungi 397efb86dc Squashed 'integrations/document-ai/' content from commit 9093611
git-subtree-dir: integrations/document-ai
git-subtree-split: 9093611c9629c0de3db760878ec9929f50add5ed
2025-08-13 08:38:41 +09:00

97 lines
3.2 KiB
Python
Executable File

#!/usr/bin/env python3
"""
NLLB 모델 테스트 (수정된 버전)
"""
import torch
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def test_nllb_fixed():
print("🧪 NLLB 모델 테스트 (수정된 버전)")
model_name = "facebook/nllb-200-3.3B"
# 모델 로드
print("📥 모델 로딩 중...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models/nllb-200-3.3B")
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir="models/nllb-200-3.3B",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# Apple Silicon 최적화
if torch.backends.mps.is_available():
device = torch.device("mps")
model = model.to(device)
print("🚀 Apple Silicon MPS 가속 사용")
else:
device = torch.device("cpu")
print("💻 CPU 모드 사용")
# NLLB 언어 코드 (수정된 방식)
def get_lang_id(tokenizer, lang_code):
"""언어 코드를 토큰 ID로 변환"""
return tokenizer.convert_tokens_to_ids(lang_code)
def translate_text(text, src_lang, tgt_lang, description):
print(f"\n📝 {description}:")
print(f"원문: {text}")
start_time = time.time()
# 입력 텍스트 토큰화
inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
# 번역 생성 (수정된 방식)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=get_lang_id(tokenizer, tgt_lang),
max_length=200,
num_beams=4,
early_stopping=True,
do_sample=False,
pad_token_id=tokenizer.pad_token_id
)
# 결과 디코딩
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translation_time = time.time() - start_time
print(f"번역: {result}")
print(f"소요 시간: {translation_time:.2f}")
return result, translation_time
print("\n" + "="*60)
try:
# 1. 영어 → 한국어
en_text = "Artificial intelligence is transforming the way we work and live."
translate_text(en_text, "eng_Latn", "kor_Hang", "영어 → 한국어 번역")
# 2. 일본어 → 한국어
ja_text = "人工知能は私たちの働き方と生活を変革しています。"
translate_text(ja_text, "jpn_Jpan", "kor_Hang", "일본어 → 한국어 번역")
# 3. 기술 문서 테스트
tech_text = "Machine learning algorithms require large datasets for training and validation."
translate_text(tech_text, "eng_Latn", "kor_Hang", "기술 문서 번역")
print(f"\n✅ 모든 테스트 성공!")
return True
except Exception as e:
print(f"❌ 테스트 중 오류: {e}")
return False
if __name__ == "__main__":
if test_nllb_fixed():
print("\n🎉 NLLB 모델 테스트 완료!")
print("📝 다음 단계: KoBART 요약 모델 설치")
else:
print("\n❌ 테스트 실패")