[📕 Paper] [🤗 Alignment dataset & Harmful dataset]
Fine-tuning-as-a-service allows users to upload data to service provider (e.g., OpenAI) for fine-tuning the base model. The mode The fine-tuend model is then deployed in the server and serve customized user need. Such a procedure usually contains two sequential stages: i) safety alignment stage-- the model is safety aligned with safety data. ii) fine-tuning stage-- the aligned model produced by the first stage is fine-tuned on user provided data.
However, such scenario expose serious safety issue, because the users might intentionally/unintentionally upload harmful data to break down the safety alignment of the victim LLMs. Specifically, the model suffers from harmful fine-tuning attack, the customized LLM forget the alignment knowledge and exhbit harmful behavior after fine-tuning on partial harmful data. See the following figure for an illustration.
Tcell is the proposed alignment stage defense against harmful fine-tuning attack. Tcell strenghten the aligned model's robustness by sufficiently exploiting alignment/harmful dataset. The high level idea is to align the gradient over the alignment dataset with the gradient over harmful dataset (in short, align harmful gradient and safety gradient)
We implement a cusomized trainer (TcellTrainer) on top of the original HuggingFace Trainer. To achieve Booster, we append several forward/backdward passes according to the psedo-agorithm.
Specifically, in trainer_step(), we use the following logistic:
# first backward gradient for safety dataset
with self.compute_loss_context_manager():
loss1 = self.compute_loss(model, alignment_inputs)
if self.use_apex:
with amp.scale_loss(loss1, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss1)
alignment_grads = {name: param.grad.data.clone() for name, param in model.named_parameters() if param.requires_grad}
# then backward gradient for harmful dataset
with self.compute_loss_context_manager():
# correct one
loss2 = self.compute_loss(model, harmful_inputs)
if self.use_apex:
with amp.scale_loss(loss2, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss2)
harmful_grads = {name: param.grad.data.clone() for name, param in model.named_parameters() if param.requires_grad}
# take step with the safety gradient
with torch.no_grad():
for name, param in model.named_parameters():
if param.requires_grad:
# correct one
param.data -= self.args.beta* alignment_grads[name]/grad_norm
# compute the safety gradient over the safety perturb model
with self.compute_loss_context_manager():
loss3= self.compute_loss(model, alignment_inputs)
if self.use_apex:
with amp.scale_loss(loss3, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss3)
perturb_safety_grads = {name: param.grad for name, param in model.named_parameters() if param.requires_grad}
# take step with the harmful gradient
with torch.no_grad():
for name, param in model.named_parameters():
if param.requires_grad:
# correct one
param.data -= self.args.beta* harmful_grads[name]/grad_norm2
del harmful_grads
# compute the loss of the gradient alignment reguarlizer
with self.compute_loss_context_manager():
loss4 = self.compute_loss(model, alignment_inputs)
similarity_loss = self.args.rho*torch.norm(loss3.detach()-loss4)**2
self.accelerator.backward(similarity_loss)
# take step with the harmful gradient
with torch.no_grad():
for name, param in model.named_parameters():
if param.requires_grad:
# correct one
param.data -= self.args.beta* harmful_grads[name]/grad_norm2
del harmful_grads
# compute the loss of the gradient alignment reguarlizer
with self.compute_loss_context_manager():
loss4 = self.compute_loss(model, alignment_inputs)
similarity_loss = self.args.rho*torch.norm(loss3.detach()-loss4)**2
self.accelerator.backward(similarity_loss)
# Finally, sum the grad
for name, param in model.named_parameters():
if param.requires_grad:
param.grad.data = alignment_grads[name]+perturb_safety_grads[name] + param.grad.data
Please leave an issue if you encounter any issues for reproducing.
The package requirement is listed in requirements.txt. Run the following code to install the packages with anaconda and pip.
conda create -n hts python=3.12.12
source activate hts
pip install -r requirements.txt
For safety alignment, please download the safety alignment dataset from this link, and put the json file under \data directory.
For finetuning task, we first need to run the following scripts to prepare the sueprvised finetuning data.
cd eval/sst2
python build_dataset.py
cd eval/gsm8k
python build_dataset.py
cd eval/agnews
python build_dataset.py
cd ..
Llama3-8B is a gated repo, which need a formal request to get access to the model. Check out https://huggingface.co/meta-llama/Meta-Llama-3-8B.
After applying permission from meta, you should be able to access the model, but you first need to enter your token in the file huggingface_token.txt.
We prepare scripts for re-producing all the experiments in the paper (check out the script directory). We recommend to use Slurm to reproduce the results as the logging file will be automatically organized into the script directory (if you don't use Slurm, just replace sbatch with bash in our example).
We first run SFT to produce the instruction-tuned model
cd script/alignment
sbatch IFT_new.sh
cd ../../
Then we run tcell to produce the safety aligned model
cd script/alignment
sbatch Tcell.sh
Then we finetune the model using 10% of harmful data with a total number of 1000 samples from SST2 dataset.
cd ../tcell_finetune
sbatch hft.sh
If you want to re-produce all the results of the papers, we prepare all the commands under script directory, i.e., those files start with run_. All you need is to copy paste the commands into your terminal.
