Squashed 'integrations/document-ai/' content from commit 9093611
git-subtree-dir: integrations/document-ai git-subtree-split: 9093611c9629c0de3db760878ec9929f50add5ed
This commit is contained in:
96
src/test_nllb_fixed.py
Executable file
96
src/test_nllb_fixed.py
Executable file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NLLB 모델 테스트 (수정된 버전)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import time
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
def test_nllb_fixed():
|
||||
print("🧪 NLLB 모델 테스트 (수정된 버전)")
|
||||
|
||||
model_name = "facebook/nllb-200-3.3B"
|
||||
|
||||
# 모델 로드
|
||||
print("📥 모델 로딩 중...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models/nllb-200-3.3B")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name,
|
||||
cache_dir="models/nllb-200-3.3B",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Apple Silicon 최적화
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
model = model.to(device)
|
||||
print("🚀 Apple Silicon MPS 가속 사용")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("💻 CPU 모드 사용")
|
||||
|
||||
# NLLB 언어 코드 (수정된 방식)
|
||||
def get_lang_id(tokenizer, lang_code):
|
||||
"""언어 코드를 토큰 ID로 변환"""
|
||||
return tokenizer.convert_tokens_to_ids(lang_code)
|
||||
|
||||
def translate_text(text, src_lang, tgt_lang, description):
|
||||
print(f"\n📝 {description}:")
|
||||
print(f"원문: {text}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 입력 텍스트 토큰화
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
||||
|
||||
# 번역 생성 (수정된 방식)
|
||||
with torch.no_grad():
|
||||
generated_tokens = model.generate(
|
||||
**inputs,
|
||||
forced_bos_token_id=get_lang_id(tokenizer, tgt_lang),
|
||||
max_length=200,
|
||||
num_beams=4,
|
||||
early_stopping=True,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# 결과 디코딩
|
||||
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||||
|
||||
translation_time = time.time() - start_time
|
||||
print(f"번역: {result}")
|
||||
print(f"소요 시간: {translation_time:.2f}초")
|
||||
|
||||
return result, translation_time
|
||||
|
||||
print("\n" + "="*60)
|
||||
|
||||
try:
|
||||
# 1. 영어 → 한국어
|
||||
en_text = "Artificial intelligence is transforming the way we work and live."
|
||||
translate_text(en_text, "eng_Latn", "kor_Hang", "영어 → 한국어 번역")
|
||||
|
||||
# 2. 일본어 → 한국어
|
||||
ja_text = "人工知能は私たちの働き方と生活を変革しています。"
|
||||
translate_text(ja_text, "jpn_Jpan", "kor_Hang", "일본어 → 한국어 번역")
|
||||
|
||||
# 3. 기술 문서 테스트
|
||||
tech_text = "Machine learning algorithms require large datasets for training and validation."
|
||||
translate_text(tech_text, "eng_Latn", "kor_Hang", "기술 문서 번역")
|
||||
|
||||
print(f"\n✅ 모든 테스트 성공!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 테스트 중 오류: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
if test_nllb_fixed():
|
||||
print("\n🎉 NLLB 모델 테스트 완료!")
|
||||
print("📝 다음 단계: KoBART 요약 모델 설치")
|
||||
else:
|
||||
print("\n❌ 테스트 실패")
|
||||
Reference in New Issue
Block a user