Use Case: Reliable Question-Answering with Llama-3¶
Large-language models (LLMs) commonly fail on Question-Answering (QA) tasks due to challenges such as out-of-domain data, ambiguity, and hallucinations. Our uncertainty-aware variant of the large language model (LLM) is capable of estimating uncertainty with every prediction, significantly improving accuracy in a selective question-answering setting.
In this tutorial, we will demonstrate how to wrap Llama 3, show quantitative results, and present a chat interface demo.
How to Wrap¶
We briefly show how to call the risk aware version of Llama3 with different wrappers.
First we define RiskAwareLlama
that adapts original Llama
model with our wrapper.
class RiskAwareLlama(pl.LightningModule):
def __init__(self,
wrapper,
auth_token,
model_name="meta-llama/Meta-Llama-3-8B",
lora_config=None,
lr = 3e-4,
dist = None
):
super().__init__()
if model_name:
self.model_name = model_name
try:
self.hf_model = AutoModelForCausalLM.from_pretrained(model_name,use_auth_token=auth_token)
self.tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=auth_token, use_fast=False,add_eos_token=True,add_bos_token=True)
except Exception as e:
raise e
self.wrapper = wrapper
self.wrapper_name = get_wrapper_name(wrapper)
self.lr = lr
self.is_LoRA = lora_config is not None
self.lora_config = lora_config
self.prepare_tokenizer()
self.prepare_model()
self.vocab_size = self.hf_model.config.vocab_size
self.dist = dist
del self.hf_model
torch.cuda.empty_cache()
self.accuracy = torchmetrics.Accuracy(task='multiclass',num_classes=128256 + 1)
def prepare_tokenizer(self):
self.tokenizer.add_special_tokens({"pad_token":"[PAD]"})
self.tokenizer.padding_side = "left"
def prepare_model(self):
pad_token_id = self.hf_model.model.embed_tokens.num_embeddings
self.embed_tokens = self.hf_model.resize_token_embeddings(self.hf_model.model.embed_tokens.num_embeddings+1)
for param in self.embed_tokens.parameters():
param.requires_grad = False
self.hf_model.config.pad_token_id = pad_token_id
self.padding_idx = pad_token_id
if self.is_LoRA:
self.hf_model = get_peft_model(self.hf_model, self.lora_config)
self.model = Llama(self.hf_model.model.model,self.hf_model.model.lm_head)
else:
self.model = Llama(self.hf_model.model,self.hf_model.lm_head)
def forward(self,**kwargs):
return self.wrapped_model(**kwargs)
def training_step(self, batch, batch_idx):
data = batch
input_ids = data["input_ids"]
shift_labels = data["shift_labels"]
attention_mask = data["attention_mask"]
one_hot_labels = data["one_hot_labels"][:,0,...]
shift_response_mask = data["response_mask"][:,0,...][..., 1:].contiguous().bool()
shift_result_mask = data["result_mask"][:,0,...][..., 1:].contiguous().bool()
input_embeds = self.embed_tokens(input_ids)
if self.wrapper_name == "vote":
final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False,tile_and_reduce=False)
elif self.wrapper_name == "sculpt":
y_pred,y_risk = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=True)
final_y_pred = self.dist.sample((y_pred, y_risk))
elif self.wrapper_name == "sample":
final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)
elif self.wrapper_name == "none":
final_y_pred = self.model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...])
shift_logits = final_y_pred[..., :-1, :].contiguous()
loss = F.cross_entropy(shift_logits[shift_response_mask], one_hot_labels[shift_response_mask])
self.log("train/loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
data = batch
input_ids = data["input_ids"]
shift_labels = data["shift_labels"]
attention_mask = data["attention_mask"]
one_hot_labels = data["one_hot_labels"][:,0,...]
shift_result_mask = data["result_mask"][:,0,...][..., 1:].contiguous().bool()
shift_response_mask = data["response_mask"][:,0,...][..., 1:].contiguous().bool()
input_embeds = self.embed_tokens(input_ids)
if self.wrapper_name == "vote":
final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False,tile_and_reduce=False)
elif self.wrapper_name == "sculpt":
y_pred,y_risk = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=True)
final_y_pred = self.dist.sample((y_pred, y_risk))
elif self.wrapper_name == "sample":
final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)
elif self.wrapper_name == "none":
final_y_pred = self.model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...])
shift_logits = final_y_pred[..., :-1, :].contiguous()
loss = F.cross_entropy(shift_logits[shift_response_mask], one_hot_labels[shift_response_mask])
self.log("val/loss", loss, prog_bar=True,on_epoch=True)
return loss
Then we specify the model with specific argument for different wrappers:
import json
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.trainer import Trainer
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.strategies import SingleDeviceStrategy
from lightning.pytorch.callbacks import ModelCheckpoint,LearningRateMonitor
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from datasets import load_dataset
from utils import Dataset
from peft import LoraConfig, TaskType
from config import *
import wandb
import time
import torch
import torch.nn.functional as F
from capsa_torch import sample,vote,sculpt
from capsa_transformers import RiskAwareLlama
seed = int(time.time())
print("Seed:",seed)
torch.manual_seed(seed)
dataset = load_dataset(DATASET_NAME,DATASET_SPLIT)
train_dataset = dataset["train"]
val_dataset = dataset["test"]
test_dataset = dataset["test"]
if WRAPPER == 'sculpt':
dist = sculpt.Normal
wrapper = sculpt.Wrapper(symbolic_trace=SYMBOLIC_TRACE,n_layers=N_LAYERS,verbose=2,distribution=dist)
lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,dist=dist,lora_config=lora_config)
model.wrap()
elif WRAPPER == 'vote':
wrapper = vote.Wrapper(param_filter=PARAM_FILTER,symbolic_trace=SYMBOLIC_TRACE,n_voters=N_VOTERS,finetune=FINETUNE,verbose=2,weight_noise=WEIGHT_NOISE,alpha=ALPHA)
model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE)
model.wrap()
elif WRAPPER == 'sample':
wrapper = sample.Wrapper(symbolic_trace=SYMBOLIC_TRACE,n_samples=N_SAMPLES,verbose=2)
lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,lora_config=lora_config)
model.wrap()
else:
lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
model = RiskAwareLlama(wrapper=None,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,lora_config=lora_config)
train_dataloader = torch.utils.data.DataLoader(dataset=Dataset(train_dataset,tokenizer=model.tokenizer,max_length=MAX_LENGTH), batch_size=BATCH_SIZE,shuffle=True,num_workers=BATCH_SIZE,drop_last=True)
val_dataloader = torch.utils.data.DataLoader(dataset=Dataset(val_dataset,tokenizer=model.tokenizer,max_length=MAX_LENGTH), batch_size=BATCH_SIZE, shuffle=False,num_workers=BATCH_SIZE,drop_last=True)
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
model = model.to(cuda_device)
if WRAPPER != 'none':
with torch.no_grad():
batch = next(iter(train_dataloader))
input_ids = batch["input_ids"].to(cuda_device)
attention_mask = batch["attention_mask"].to(cuda_device)
input_embeds = model.embed_tokens(input_ids)
logits = model.wrapped_model(inputs_embeds= input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)
model = model.to(cpu_device)
if WRAPPER == 'vote':
for name,param in model.named_parameters():
if "be_bias" in name or "be_alpha" in name or "be_gamma" in name or "model.model.norm.weight" in name or "model.lm_head.weight" in name:
param.requires_grad = True
print(f"Grad enabled for {name}")
else:
param.requires_grad = False
# Callbacks
checkpoint_callback = ModelCheckpoint(dirpath=os.getenv("HOME")+"/repos/capsa-transformers/checkpoints/"+str(WRAPPER), save_top_k=1, monitor="val/loss",save_on_train_epoch_end=True,mode="min")
lr_monitor = LearningRateMonitor(logging_interval='epoch',log_weight_decay=True)
wandb_logger = WandbLogger(project="Tree-of-thoughts",log_model = False)
wandb_logger.experiment.config.update(config)
# Training
trainer = Trainer(callbacks=[checkpoint_callback,lr_monitor],accelerator="gpu",max_epochs=EPOCH,precision=PRECISION,enable_progress_bar=True,strategy=STRATEGY,logger=wandb_logger,gradient_clip_val=GRADIENT_CLIP_VAL,enable_checkpointing=True,val_check_interval=0.25,num_sanity_val_steps=0)
trainer.fit(model, train_dataloader,val_dataloader)#,ckpt_path=os.getenv("HOME")+"/repos/capsa-transformers/checkpoints/")
exit()
Where configs can be define in the config.py file
# General hyperparameters
DATASET_NAME = "mosaicml/dolly_hhrlhf"#"gsm8k"#"MU-NLPC/Calc-gsm8k"
DATASET_SPLIT = None #"main"
WRAPPER = "sculpt"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B" #"meta-llama/Llama-2-7b-chat-hf"
LEARNING_RATE = 2e-4
EPOCH = 5
MAX_LENGTH = 400
BATCH_SIZE = 5
SYMBOLIC_TRACE = False
PRECISION= "bf16-mixed"
GRADIENT_CLIP_VAL = 1.0
STRATEGY = "deepspeed_stage_2"
# LoRA specific hyperparameters
R = 64
LORA_ALPHA = 64
LORA_DROPOUT = 0
# Vote specific hyperparameters
ALPHA = 8
FINETUNE = False
N_VOTERS = 5
WEIGHT_NOISE = 0.2
PARAM_FILTER = "q_proj|v_proj"
# Sculpt specific hyperparameters
N_LAYERS = 2
# Sample specific hyperparameters
N_SAMPLES = 5
config={
"wrapper": WRAPPER,
"learning_rate": LEARNING_RATE,
"model_name": MODEL_NAME,
"dataset": DATASET_NAME,
"epochs": EPOCH,
"batch_size": BATCH_SIZE,
"symbolic_trace": SYMBOLIC_TRACE,
"precision": PRECISION,
"gradient_clip_val": GRADIENT_CLIP_VAL
}
# track hyperparameters and run metadata
if WRAPPER == 'sculpt':
config_specific={
"n_layers": N_LAYERS,
"r": R,
"lora_alpha": LORA_ALPHA,
"lora_dropout": LORA_DROPOUT
}
elif WRAPPER == 'vote':
config_specific={
"weight_noise": WEIGHT_NOISE,
"finetune": FINETUNE,
"n_voters": N_VOTERS,
"gradient_clip_val": GRADIENT_CLIP_VAL,
"param_filter": PARAM_FILTER
}
elif WRAPPER == 'sample':
config_specific={
"n_samples": N_SAMPLES,
"r": R,
"lora_alpha": LORA_ALPHA,
"lora_dropout": LORA_DROPOUT
}
else:
config_specific={}
config.update(config_specific)
Now the Llama
model is wrapped and good to go.
Results¶
We presented our work Uncertainty-aware Language Modeling for Selective Question Answering at 2024 AAAI workshop. The following results are on the TruthfulQA question answering benchmark and use Llama 2-Chat 7B.
In general, we show that increasing values of logit probability do not correspond to increased question answering ability – despite being often misconceived as a measure of confidence. Our methods report a reliable measure of confidence – with increased confidence corresponding to increased accuracy.
As we see from the figure, using baseline logit probability to measure confidence leads to a maximum increase in accuracy of only 4.7% for questions in the 4th lowest percentile. In comparison, our uncertainty-aware models are able to attain accuracy rates of 100%, +90%, +80%, +70% when answering 1%, 8%, 9%, and 35% of the questions, respectively, and result in higher accuracy across all confidence percentiles. We further observe that after generating 10 candidate answers for each question in the benchmark, answers with highest uncertainty were consistently incorrect. We also note that the converted model is able to output correct answers if it repeatedly generates predictions until one in the 99% confidence percentile is found.
Demo¶
Here we present a demo build with Gradio that features real time chat inferences. This system outputs both the answer, as you would expect from any chatbot, and also provides a risk score associated with each token.
Link coming soon to try it yourself!
Conclusion¶
Our wrapping process modifies the LLM to output uncertainty estimates along with predictions, without needing external models or major architecture changes.
On question answering benchmarks, the uncertainty-aware models can achieve very high accuracy (>90%) by answering only the most confident subset of questions.
Using the uncertainty estimates performs much better than relying solely on the raw model probabilities, which often turn out to be unreliable confidence scores.