Coding/TIL & 배운것들

사전 학습 & 파인 튜닝

코딩짜는 머글 2024. 11. 11. 16:11

사전 학습(Pre-training) 이란?

대구모의 텍스트 데이터셋을 사용해 모델이 일반적인 언어 이해 능력을 학습하는 과정이다. 

 

▼ 특징 

  • 대규모 데이터셋 사용 : 인터넷에서 수집한 방대한 양의 텍스트 데이터로 모델을 학습시킨다. 예를 들어, BERT 는 수십억 개의 문장으로 사전 학습되어있다. 
  • 일반적인 언어 이해 : 모델은 텍스트 내 단어의 의미, 문장 구조, 문맥 등 언어의 전반적인 특징을 학습한다.
  • 작업 비특화 : 특정 작업에 맞춰진 학습이 아닌, 전반적인 언어 이해에 초점을 맞춘다. 

 

▼ 목적

사전 학습을 통해 모델은 다양한 텍스트에서 언어의 기본적인 규칙을 배우고, 이후에 특정 작업에 빠르게 적응할 수 있는 기반을 다진다. Hugging Face에서 제공하는 대부분의 모델들은 이 단계까지 완료된 상태로 제공된다. 

 

 

BERT의 사전 학습 예시 

Masked Language Modeling (MLM) 
문장의 일부 단어를 마스킹(masking)한 후, 이를 예측하도록 모델을 학습시킨다. 이 과정을 통해 BERT는 문맥을 양방향으로 이해할 수 있다.
Nest Sentence Prediction (NSP)
두 문장이 주어졌을 때, 두 번째 문장이 첫 번째 문장 뒤에 자연스럽게 이어지는지를 예측한다. 이를 통해 문장 간의 관계를 이해하는 능력을 학습한다. 

 

 

 

파인 튜닝(Fine-tuning)이란?

사전 학습된 모델을 특정 작업에 맞게 추가로 학습시키는 과정이다. 예를 들어, BERT 모델을 감정 분석에 사용하려면, BERT의 사전 학습된 가중치를 유지하면서 감정 분석 작업에 맞게 모델의 가중치를 조정한다.

 

 

▼ 특징 

  • 작업 특화 : 파인 튜닝은 특정 작업(텍스트 분류, 번역, 질의 응답 등)에 맞춰 모델을 최적화하는 과정이다.
  • 사전 학습 가중치 활용 : 사전 학습된 모델의 언어 이해 능력을 바탕으로, 새로운 작업에 적응할 수 있도록 일부 가중치만 조정한다.
  • 적은 데이터로도 가능 : 사전 학습 덕분에, 파인 튜닝은 비교적 적은 양의 데이터로도 효과적인 학습이 가능하다.

 

▼ 목적

특정 작업에서 최상의 성능을 발휘하도록 모델을 조정하는 과정이다. 사전 학습 덕분에, 파인 튜닝은 더 빠르고 적은 데이터로 이루어질 수 있다. 대부분의 NLP 모델은 파인 튜닝 과정을 거쳐 실제 애플리케이션에서 사용된다. 

 

 

 

 

▼ IMDB 데이터셋을 활용한 BERT 파인 튜닝 실습

# 라이브러리 설치
pip install transformers datasets torch
pip install accelerate -U
# 필요 라이브러리
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch
# 모델 평가를 위한 라이브러리
import numpy as np
from sklearn.metrics import accuracy_score
# IMDb 데이터셋 로드
dataset = load_dataset("imdb")

# 훈련 및 테스트 데이터셋 분리
train_dataset = dataset['train'].shuffle(seed=42).select(range(1000))  # 1000개 샘플로 축소
test_dataset = dataset['test'].shuffle(seed=42).select(range(500))  # 500개 샘플로 축소

# BERT 토크나이저 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 데이터셋 토크나이징 함수 정의
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

 

→ examples['text'] : 데이터셋에서 리뷰 부분을 가져오는 작업

→ tokenize : BERT모델이 이해할 수 있는 형태, 즉 텍스트를 토큰 형태로 변환해줌

→ padding='max_length' : 모든 텍스트를 최대길이로 padding 해준다. BERT모델에서 기본 최대 길이는 512토큰 이므로 padding된 토큰은 모델에서 무시되고 모든 입력이 동일한 길이를 갖도록 할 수 있다. 

→ truncation : 텍스트의 길이가 최대 길이를 초과한 경우 초과된 부분을 잘라내는 작업 

# 데이터셋 토크나이징 적용
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

# 모델 입력으로 사용하기 위해 데이터셋 포맷 설정
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

 

→ batched : batch 단위로 함수를 적용하여 여러 샘플을 한 번에 처리. 속도를 높이고 메모리를 최적화

→ set_format : 데이터셋을 특정 형식으로 변환. 여기서는 PyTorch tensor로 변환

→ columns : BERT 모델에 필요한 데이터컬럼만 선택

# BERT 모델 로드
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 훈련 인자 설정
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    save_steps=10_000,
    save_total_limit=2,
)

 

→ output_dir : 훈련 중 생성된 모델 파일 및 로그를 저장할 디렉토리

→ num_train_epochs : 훈련 에포크설정

per_device_train_batch_size : 각 디바이스에서 훈련할때 사용할 batch 크기

→ evaluation_strategy : 에폭 단위로 평가를 해서 에폭이 끝날 때무다 평가를 진행

→ save_steps : 몇 스텝마다 모델을 저장할지를 설정 

→ save_total_limit : 저장할 최대 체크 포인트의 개수  

# 트레이너 설정
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

# 모델 훈련
trainer.train()
trainer.evaluate()

 

→ args : 학습을 위한 설정값

# 평가 지표 함수 정의
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)  # 예측된 클래스
    labels = p.label_ids  # 실제 레이블
    acc = accuracy_score(labels, preds)  # 정확도 계산
    return {'accuracy': acc}

# 이미 훈련된 트레이너에 compute_metrics를 추가하여 평가
trainer.compute_metrics = compute_metrics

# 모델 평가 및 정확도 확인
eval_result = trainer.evaluate()
print(f"Accuracy: {eval_result['eval_accuracy']:.4f}")  # .4f: 소수점 4번째 자리까지 표현

# 정확도
Accuracy: 0.8960