Use Case: Reliable Question-Answering with Llama-3

Large-language models (LLMs) commonly fail on Question-Answering (QA) tasks due to challenges such as out-of-domain data, ambiguity, and hallucinations. Our uncertainty-aware variant of the large language model (LLM) is capable of estimating uncertainty with every prediction, significantly improving accuracy in a selective question-answering setting.

In this tutorial, we will demonstrate how to wrap Llama 3, show quantitative results, and present a chat interface demo.

How to Wrap

We briefly show how to call the risk aware version of Llama3 with different wrappers.

First we define RiskAwareLlama that adapts original Llama model with our wrapper.


class RiskAwareLlama(pl.LightningModule):
    
    
    def __init__(self,
                 wrapper, 
                 auth_token,
                 model_name="meta-llama/Meta-Llama-3-8B", 
                 lora_config=None,
                 lr = 3e-4,
                 dist = None
                 ):

    
        super().__init__()


        if model_name:
            self.model_name = model_name
            try:
                self.hf_model = AutoModelForCausalLM.from_pretrained(model_name,use_auth_token=auth_token)
                self.tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=auth_token, use_fast=False,add_eos_token=True,add_bos_token=True)
            except Exception as e:
                raise e
            
        self.wrapper = wrapper
        self.wrapper_name = get_wrapper_name(wrapper)
        self.lr = lr
        self.is_LoRA = lora_config is not None
        self.lora_config = lora_config

        self.prepare_tokenizer()
        self.prepare_model()


        self.vocab_size = self.hf_model.config.vocab_size

        self.dist = dist

        del self.hf_model
        torch.cuda.empty_cache()

        self.accuracy = torchmetrics.Accuracy(task='multiclass',num_classes=128256 + 1)

    
    def prepare_tokenizer(self):
        
        self.tokenizer.add_special_tokens({"pad_token":"[PAD]"})
        self.tokenizer.padding_side = "left"


    def prepare_model(self):
        
        pad_token_id = self.hf_model.model.embed_tokens.num_embeddings
 

        self.embed_tokens = self.hf_model.resize_token_embeddings(self.hf_model.model.embed_tokens.num_embeddings+1)


        for param in self.embed_tokens.parameters():
            param.requires_grad = False

        self.hf_model.config.pad_token_id = pad_token_id
        self.padding_idx = pad_token_id

        if self.is_LoRA:
            self.hf_model = get_peft_model(self.hf_model, self.lora_config)
            self.model = Llama(self.hf_model.model.model,self.hf_model.model.lm_head)
        else:
            self.model = Llama(self.hf_model.model,self.hf_model.lm_head)



    def forward(self,**kwargs):

        return self.wrapped_model(**kwargs)



    def training_step(self, batch, batch_idx):
        data = batch

        input_ids = data["input_ids"]
        shift_labels = data["shift_labels"]
        attention_mask = data["attention_mask"]
        one_hot_labels = data["one_hot_labels"][:,0,...]
        shift_response_mask = data["response_mask"][:,0,...][..., 1:].contiguous().bool()
        shift_result_mask = data["result_mask"][:,0,...][..., 1:].contiguous().bool()


        input_embeds = self.embed_tokens(input_ids)

        if self.wrapper_name == "vote":
            final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False,tile_and_reduce=False)
        elif self.wrapper_name == "sculpt":
            y_pred,y_risk = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=True)
            final_y_pred = self.dist.sample((y_pred, y_risk))
        elif self.wrapper_name == "sample":
            final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)
        elif self.wrapper_name == "none":
            final_y_pred = self.model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...])


        shift_logits = final_y_pred[..., :-1, :].contiguous()

        loss = F.cross_entropy(shift_logits[shift_response_mask], one_hot_labels[shift_response_mask])
        
        self.log("train/loss", loss, prog_bar=True)

        return loss


    def validation_step(self, batch, batch_idx):
        data = batch

        input_ids = data["input_ids"]
        shift_labels = data["shift_labels"]
        attention_mask = data["attention_mask"]
        one_hot_labels = data["one_hot_labels"][:,0,...]
        shift_result_mask = data["result_mask"][:,0,...][..., 1:].contiguous().bool()
        shift_response_mask = data["response_mask"][:,0,...][..., 1:].contiguous().bool()

        input_embeds = self.embed_tokens(input_ids)


        if self.wrapper_name == "vote":
            final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False,tile_and_reduce=False)
        elif self.wrapper_name == "sculpt":
            y_pred,y_risk = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=True)
            final_y_pred = self.dist.sample((y_pred, y_risk))
        elif self.wrapper_name == "sample":
            final_y_pred = self.wrapped_model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)
        elif self.wrapper_name == "none":
            final_y_pred = self.model(inputs_embeds = input_embeds[:,0,...],attention_mask = attention_mask[:,0,...])

        shift_logits = final_y_pred[..., :-1, :].contiguous()

        loss = F.cross_entropy(shift_logits[shift_response_mask], one_hot_labels[shift_response_mask])
        
        self.log("val/loss", loss, prog_bar=True,on_epoch=True)

        return loss



Then we specify the model with specific argument for different wrappers:

import json
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.trainer import Trainer
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.strategies import SingleDeviceStrategy
from lightning.pytorch.callbacks import ModelCheckpoint,LearningRateMonitor
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

from datasets import load_dataset
from utils import Dataset
from peft import LoraConfig, TaskType
from config import *

import wandb
import time

import torch
import torch.nn.functional as F

from capsa_torch import sample,vote,sculpt
from capsa_transformers import RiskAwareLlama




seed = int(time.time())
print("Seed:",seed)
torch.manual_seed(seed)



dataset = load_dataset(DATASET_NAME,DATASET_SPLIT)

train_dataset = dataset["train"]
val_dataset = dataset["test"]
test_dataset = dataset["test"]


if WRAPPER == 'sculpt':
    dist = sculpt.Normal
    wrapper = sculpt.Wrapper(symbolic_trace=SYMBOLIC_TRACE,n_layers=N_LAYERS,verbose=2,distribution=dist)
    lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
    model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,dist=dist,lora_config=lora_config)
    model.wrap()

elif WRAPPER == 'vote':
    wrapper = vote.Wrapper(param_filter=PARAM_FILTER,symbolic_trace=SYMBOLIC_TRACE,n_voters=N_VOTERS,finetune=FINETUNE,verbose=2,weight_noise=WEIGHT_NOISE,alpha=ALPHA)
    model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE)
    model.wrap()

elif WRAPPER == 'sample':
    wrapper = sample.Wrapper(symbolic_trace=SYMBOLIC_TRACE,n_samples=N_SAMPLES,verbose=2)
    lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
    model = RiskAwareLlama(wrapper=wrapper,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,lora_config=lora_config)
    model.wrap()

else:
    lora_config = LoraConfig(r=R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT)
    model = RiskAwareLlama(wrapper=None,model_name=MODEL_NAME,auth_token="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",lr=LEARNING_RATE,lora_config=lora_config)




train_dataloader = torch.utils.data.DataLoader(dataset=Dataset(train_dataset,tokenizer=model.tokenizer,max_length=MAX_LENGTH), batch_size=BATCH_SIZE,shuffle=True,num_workers=BATCH_SIZE,drop_last=True)
val_dataloader = torch.utils.data.DataLoader(dataset=Dataset(val_dataset,tokenizer=model.tokenizer,max_length=MAX_LENGTH), batch_size=BATCH_SIZE, shuffle=False,num_workers=BATCH_SIZE,drop_last=True)

cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")

model = model.to(cuda_device)

if WRAPPER != 'none':
    with torch.no_grad():

        batch = next(iter(train_dataloader))

        input_ids = batch["input_ids"].to(cuda_device)
        attention_mask = batch["attention_mask"].to(cuda_device)

        input_embeds = model.embed_tokens(input_ids)
        logits = model.wrapped_model(inputs_embeds= input_embeds[:,0,...],attention_mask = attention_mask[:,0,...],return_risk=False)

model = model.to(cpu_device)


if WRAPPER == 'vote':
    for name,param in model.named_parameters():
        if "be_bias" in name or "be_alpha" in name or "be_gamma" in name or "model.model.norm.weight" in name or "model.lm_head.weight" in name:
            param.requires_grad = True
            print(f"Grad enabled for {name}")
        else:
            param.requires_grad = False



# Callbacks
checkpoint_callback = ModelCheckpoint(dirpath=os.getenv("HOME")+"/repos/capsa-transformers/checkpoints/"+str(WRAPPER), save_top_k=1, monitor="val/loss",save_on_train_epoch_end=True,mode="min")
lr_monitor = LearningRateMonitor(logging_interval='epoch',log_weight_decay=True)
wandb_logger = WandbLogger(project="Tree-of-thoughts",log_model = False)
wandb_logger.experiment.config.update(config)

# Training
trainer = Trainer(callbacks=[checkpoint_callback,lr_monitor],accelerator="gpu",max_epochs=EPOCH,precision=PRECISION,enable_progress_bar=True,strategy=STRATEGY,logger=wandb_logger,gradient_clip_val=GRADIENT_CLIP_VAL,enable_checkpointing=True,val_check_interval=0.25,num_sanity_val_steps=0)
trainer.fit(model, train_dataloader,val_dataloader)#,ckpt_path=os.getenv("HOME")+"/repos/capsa-transformers/checkpoints/")

exit()

Where configs can be define in the config.py file

# General hyperparameters
DATASET_NAME = "mosaicml/dolly_hhrlhf"#"gsm8k"#"MU-NLPC/Calc-gsm8k"
DATASET_SPLIT = None #"main"
WRAPPER = "sculpt"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B" #"meta-llama/Llama-2-7b-chat-hf" 
LEARNING_RATE = 2e-4
EPOCH = 5
MAX_LENGTH = 400
BATCH_SIZE = 5
SYMBOLIC_TRACE = False
PRECISION= "bf16-mixed"
GRADIENT_CLIP_VAL = 1.0
STRATEGY = "deepspeed_stage_2"

# LoRA specific hyperparameters
R = 64
LORA_ALPHA = 64
LORA_DROPOUT = 0

# Vote specific hyperparameters
ALPHA = 8
FINETUNE = False
N_VOTERS = 5
WEIGHT_NOISE = 0.2
PARAM_FILTER = "q_proj|v_proj"

# Sculpt specific hyperparameters
N_LAYERS = 2

# Sample specific hyperparameters
N_SAMPLES = 5

config={
    "wrapper": WRAPPER,
    "learning_rate": LEARNING_RATE,
    "model_name": MODEL_NAME,
    "dataset": DATASET_NAME,
    "epochs": EPOCH,
    "batch_size": BATCH_SIZE,
    "symbolic_trace": SYMBOLIC_TRACE,
    "precision": PRECISION,
    "gradient_clip_val": GRADIENT_CLIP_VAL
    }

# track hyperparameters and run metadata
if WRAPPER == 'sculpt':
    config_specific={
    "n_layers": N_LAYERS,
    "r": R,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": LORA_DROPOUT
    }
elif WRAPPER == 'vote':
    config_specific={
    "weight_noise": WEIGHT_NOISE,
    "finetune": FINETUNE,
    "n_voters": N_VOTERS,
    "gradient_clip_val": GRADIENT_CLIP_VAL,
    "param_filter": PARAM_FILTER
    }
elif WRAPPER == 'sample':
    config_specific={
    "n_samples": N_SAMPLES,
    "r": R,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": LORA_DROPOUT
    }
else:
    config_specific={}

config.update(config_specific)


Now the Llama model is wrapped and good to go.

Results

We presented our work Uncertainty-aware Language Modeling for Selective Question Answering at 2024 AAAI workshop. The following results are on the TruthfulQA question answering benchmark and use Llama 2-Chat 7B.

In general, we show that increasing values of logit probability do not correspond to increased question answering ability – despite being often misconceived as a measure of confidence. Our methods report a reliable measure of confidence – with increased confidence corresponding to increased accuracy.

../_images/llm.png

As we see from the figure, using baseline logit probability to measure confidence leads to a maximum increase in accuracy of only 4.7% for questions in the 4th lowest percentile. In comparison, our uncertainty-aware models are able to attain accuracy rates of 100%, +90%, +80%, +70% when answering 1%, 8%, 9%, and 35% of the questions, respectively, and result in higher accuracy across all confidence percentiles. We further observe that after generating 10 candidate answers for each question in the benchmark, answers with highest uncertainty were consistently incorrect. We also note that the converted model is able to output correct answers if it repeatedly generates predictions until one in the 99% confidence percentile is found.

Demo

Here we present a demo build with Gradio that features real time chat inferences. This system outputs both the answer, as you would expect from any chatbot, and also provides a risk score associated with each token.

../_images/chat.png

Link coming soon to try it yourself!

Conclusion

  • Our wrapping process modifies the LLM to output uncertainty estimates along with predictions, without needing external models or major architecture changes.

  • On question answering benchmarks, the uncertainty-aware models can achieve very high accuracy (>90%) by answering only the most confident subset of questions.

  • Using the uncertainty estimates performs much better than relying solely on the raw model probabilities, which often turn out to be unreliable confidence scores.