git-subtree-dir: integrations/document-ai git-subtree-split: 9093611c9629c0de3db760878ec9929f50add5ed
97 lines
3.2 KiB
Python
Executable File
97 lines
3.2 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
NLLB 모델 테스트 (수정된 버전)
|
|
"""
|
|
|
|
import torch
|
|
import time
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
def test_nllb_fixed():
|
|
print("🧪 NLLB 모델 테스트 (수정된 버전)")
|
|
|
|
model_name = "facebook/nllb-200-3.3B"
|
|
|
|
# 모델 로드
|
|
print("📥 모델 로딩 중...")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="models/nllb-200-3.3B")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
model_name,
|
|
cache_dir="models/nllb-200-3.3B",
|
|
torch_dtype=torch.float16,
|
|
low_cpu_mem_usage=True
|
|
)
|
|
|
|
# Apple Silicon 최적화
|
|
if torch.backends.mps.is_available():
|
|
device = torch.device("mps")
|
|
model = model.to(device)
|
|
print("🚀 Apple Silicon MPS 가속 사용")
|
|
else:
|
|
device = torch.device("cpu")
|
|
print("💻 CPU 모드 사용")
|
|
|
|
# NLLB 언어 코드 (수정된 방식)
|
|
def get_lang_id(tokenizer, lang_code):
|
|
"""언어 코드를 토큰 ID로 변환"""
|
|
return tokenizer.convert_tokens_to_ids(lang_code)
|
|
|
|
def translate_text(text, src_lang, tgt_lang, description):
|
|
print(f"\n📝 {description}:")
|
|
print(f"원문: {text}")
|
|
|
|
start_time = time.time()
|
|
|
|
# 입력 텍스트 토큰화
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
|
|
|
# 번역 생성 (수정된 방식)
|
|
with torch.no_grad():
|
|
generated_tokens = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=get_lang_id(tokenizer, tgt_lang),
|
|
max_length=200,
|
|
num_beams=4,
|
|
early_stopping=True,
|
|
do_sample=False,
|
|
pad_token_id=tokenizer.pad_token_id
|
|
)
|
|
|
|
# 결과 디코딩
|
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
|
|
|
translation_time = time.time() - start_time
|
|
print(f"번역: {result}")
|
|
print(f"소요 시간: {translation_time:.2f}초")
|
|
|
|
return result, translation_time
|
|
|
|
print("\n" + "="*60)
|
|
|
|
try:
|
|
# 1. 영어 → 한국어
|
|
en_text = "Artificial intelligence is transforming the way we work and live."
|
|
translate_text(en_text, "eng_Latn", "kor_Hang", "영어 → 한국어 번역")
|
|
|
|
# 2. 일본어 → 한국어
|
|
ja_text = "人工知能は私たちの働き方と生活を変革しています。"
|
|
translate_text(ja_text, "jpn_Jpan", "kor_Hang", "일본어 → 한국어 번역")
|
|
|
|
# 3. 기술 문서 테스트
|
|
tech_text = "Machine learning algorithms require large datasets for training and validation."
|
|
translate_text(tech_text, "eng_Latn", "kor_Hang", "기술 문서 번역")
|
|
|
|
print(f"\n✅ 모든 테스트 성공!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ 테스트 중 오류: {e}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
if test_nllb_fixed():
|
|
print("\n🎉 NLLB 모델 테스트 완료!")
|
|
print("📝 다음 단계: KoBART 요약 모델 설치")
|
|
else:
|
|
print("\n❌ 테스트 실패")
|