Ref
https://youtu.be/gPOsqI58iF0?si=SrsRyyBT5FIeTqLj
https://youtu.be/aI8cyr-gH6M?si=tocxBJSAWsIR2CHQ
https://youtu.be/pzh2oc6shic?si=rp5moj9sS7XfGzYx
PEFT
#️⃣PEFT(parameter-efficient fine tuning)란
PEFT는 사전 학습된 모델을 특정 작업에 맞게 조정하는 과정인 미세 조정(fine tuning)의 효율성을 높이는 방법
Pretrained LLM의 대부분의 파라미터가 frozen인 상태에서 일부만 fine tuning (전이학습처럼)
방법
- 재매개변수화(Reparameterization): 모델의 가중치를 새로운 매개변수로 매핑하는 함수를 학습. 이 함수를 사용하여 모델의 일부 매개변수만 업데이트할 수 있음
- 부분 미세 조정(Partial fine tuning): 모델의 일부 매개변수만 업데이트하는 방법을 사용 이 방법은 재매개변수화를 사용하는 방법보다 더 간단하지만, 정확도가 약간 떨어질 수 있음
Reinforcement Learning
#️⃣ Reinforcement Learning (RL, 강화학습)
강화학습은 모델이 환경과 인터랙트하게하고, reward나 penalty의 형태로 피드백을 받아 학습하는 것
모델은 누적된 보상이 최대화 되도록 결정을 내리게 학습됨
- 에이전트가 취한 액션이 desired outcome의 경우 rewards를 받음
- 에이전트가 취한 액션이 undesired outcome의 경우에는 penalties를 받음
- 에이전트는 rewards가 최대화 되도록 학습
Model | Data | RL |
---|---|---|
fine-tuned model | REWARD model | |
pre-trained model | Dataset for fine-tuning | |
Transformer model | Dataset for pre-training |
#️⃣ Reward Model
에이전트에게 취한 액션에 따라 rewards를 주는 모델
에이전트의 학습 프로세스에 가이드를 주기 위함
- 에이전트는 액션을 취함으로써 환경과 인터랙션
- 시간이 지남에 따라 sum of rewards가 최대가 되도록 선택하는 전략인 "policy"를 학습하기 위한 것
- 텍스트 생성 모델 학습의 경우, 어떻게 생성 단어/문장에 reward를 assign할지 명학하지 않기때문에 rewarde model이 필요한것
Reward Model의 학습 방법
reward model은 주어진 state-action 쌍에 대한 reward를 예측하도록 학습됨
state: 현재 문장
action: 생성될 다음 단어
reward: 생성된 문장이 얼마나 좋은지 측정 (새문장이 나머지 텍스트들과 얼마나 어울리는지)
#️⃣ 강화학습 과정
- 베이스모델을 fine tuning하고, reward model 학습
- 강화학습으로 베이스 모델을 further optimization
- (이것이 Proximal Policy Optimization(PPO)가 Transformer Reinforcement Learning(TRL)하는 과정임)
- Policy Rollout : 베이스모델이 action을 생성
- Reward Calculation : 생성된 action이 reward 모델로 전달되고 각 액션에 대한 reward를 계산
- Policy Optimization : reward가 베이스모델을 업데이트하는데 사용됨
- Repeat : 위의 과정이 반복됨. 베이스모델은 이어서 action을 생성하고, reward모델은 reward를 계산하고, 베이스모델은 계속 기대 reward가 maximize되도록 업데이트를 반복
#️⃣ Transformer Reinforcement Learning (TRL)
Supervised Fine-tuning (SFT):
지도 학습으로 라벨링된 large dataset으로 모델을 학습시키는 과정
pre-trained 모델을 파인튜닝해서 특정 태스크에 잘 동작하도록 하는 것이 목표
이것은 강화학습으로 further optimization을 하기 위한 기반이 됨
# imports
from datasets import load_dataset
from trl import SFTTrainer
# get dataset
dataset = load_dataset("imdb", split="train")
# get trainer
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)
# train
trainer.train()
Reward Modeling (RM):
모델 파인튜닝 이후 reward 모델 생성 과정
Reward 모델은 강화학습 과정동안 메인모델에 피드백을 주도록 사용됨
메인 모델에 의해 생성된 output의 퀄리티를 예측하도록 학습됨
# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer
# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# load trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
# train
trainer.train()
- RM은 function이 될 수도 있고, BERT모델을 사용해서 reward를 계산할 수도 있음
Proximal Policy Optimization (PPO):
PPO는 모델을 further optimization하는데 사용되는 강화학습 알고리즘
reward 모델의 피드백으로 메인모델의 파라미터를 업데이트함
기대 reward를 증가시키는 방향으로 policy(모델의 behaviour)를 개선시키는 것이 목표
# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
# initialize trainer
ppo_config = PPOConfig(
batch_size=1,
)
# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
# get model response
response_tensor = respond_to_batch(model, query_tensor)
# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
SFT 단계에서는 최초 policy를 제공하고, RM 단계에서는 policy의 퀄리티를 평가하는 방법을 제공, 마지막으로 PPO 단계에서는 reward 모델의 피드백으로 policy를 최적화함
#️⃣ Reward Function
reward function은 주어진 state s에 서로다른 액션의 desirability를 quantify하고, 학습 과정을 가이드함
#️⃣ Policy
policy는 에이전트가 어떤 액션을 취할지 결정하는 전략
즉 state에서 action으로의 매핑을 의미함
policy의 목적은 expected cumulative reward가 maximize되는 optimal policy를 찾는 것
각 state에서 가능한 각 action의 reward를 예측
트랜스포머는 강화학습에서 policy를 최적화하는데 사용됨
(sequential data와 long-term dependenciy를 다룰 수 있기 때문)
#️⃣ Non-Markovian Rewards
non-markovian rewards는 state의 sequence나 여러 액션에 의존하는 reward
- 트랜스포머 구조는 특히 non-Markkovian reward에 이점이 있음
- episode동안 마주치는 state의 시퀀스에 대해 지연과 의존성을 특징으로 가지기 때문
Marcov Decision Process
강화학습은 이전 state의 지식 없이 미래 state는 오직 현재 state에 의해 예측된다는 Markov property 문제가 있음
=> 현재 state는 미래를 예측하기위한 모든 필요 정보를 압축
하지만 현실 데이터는 단지 현재 state만이 아니라, 일련의 과거에 의존해 미래 state가 결정되곤함
그래서 non-Markovian rewards가 등장
Transformers in Reinforcement Learning: A Survey (arxiv.org)
Direct Preference Optimization (DPO)
#️⃣ DPO (direct preference optimization)란
Direct Preference Optimization: Your Language Model is Secretly a Reward Model (Link)
- Explicit reward estimation
- Reinforcement Learning
single maximum likelihood objective를 사용해 policy를 학습
RLHF vs. DPO 차이
RLHF | DPO |
---|---|
reward 학습 후 RL로 최적화 | reward modeling을 건너뛰고, 바로 preference data로 모델을 최적화 |
DPO의 핵심은 reward function의 analytical mapping을 최적 policy로 만드는 것
이는 reward function에 대한 loss function를 policy에 대한 loss function로 변환할 수 있게 함
즉 DPO는 강화학습 없이 Preference로 언어 모델을 학습시키는 심플한 학습 패러다임
PPO에 기반한 DPO는 RLHF 알고리즘과 비슷하거나 나은 퍼포먼스를 보임
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
# 6. train
dpo_trainer.train()
Reinforcement Learning from Human Feedback (RLHF)
#️⃣ RLHF (reinforcement learning from human feedback)란
강화 학습(RLHF)은 에이전트의 행동에 대한 인간의 피드백을 사용하여 보상 함수를 학습하는 방법
학습 과정
- 처음에 에이전트는 human feedback 없이 학습
- 액션에 대해 human feedback을 받음
- 에이전트가 이 피드백을 reward function을 갱신하는데 사용
- 이 갱신된 reward function으로 에이전트가 재학습
'AI > LLM' 카테고리의 다른 글
Hugging Face StackLLaMA - RLHF로 LLaMA 모델 훈련 (기술 블로그 정리) (0) | 2024.01.06 |
---|