Finetuning Atlas Model on Custom Dataset for MCQ Answering

This project aimed at taking the pretrained Atlas Model published by Meta and finetuning it on a custom dataset. The custom dataset was Chapters 4,5,6 in a question answer format. The answers were in the form of multiple choice questions. This was fed into models of varying complexity for finetuning subject to hyperparameter tuning as well.

Overview of Steps Taken

  1. Prepared Dataset consisting of 400 questions from Chapters 2,3,4 of James et al. "An Introduction to Statistical Learning with Applications in Python"
  2. Dataset split into train-validation-test as 75%-10%-15%
  3. Train set further split into different sample sizes
  4. Base and Large models experimentally fine tuned on different sample sizes
  5. Hyperparameters such as total_steps,temperature ,text_maxlength,dropout,n_context,learning rate also experimentally tested
  6. Best Accuracy at 85.37% on Large Model

Dataset Preperation

Using the textbook as a reference, we created a dataset of this format with one file with QnA and another with the passage and its source

[
  {
    "id": 35,
    "contributed by": "group 5",
    "question": "What is term used to refer to the two unknown constants that represent slope and intercept in the mathematical representation of the linear model",
    "options": {
      "A": "coefficients",
      "B": "variables",
      "C": "weights",
      "D": "bounds"
    },
    "answer": "A",
    "contributed_by": "group 5"
  }
]
[
  {
    "id": "35",
    "title": "Simple Linear Regression",
    "section": "3.1",
    "text": "β0 and β1 are two unknown constants that represent the intercept and slope terms in the linear model. Together, β0 and β1 are intercept slope known as the model coefficients or parameters."
  }
]

Atlas Model Training

The Atlas model was cloned into the HPC environment and finetuned on the dataset Prepared

git clone https://github.com/facebookresearch/atlas.git

After the datasets are prepared, the model was finetuned on the dataset varying the hyperparameters such as

1. Total Steps

2. Temperature

3. Text MaxLength

4. Dropout

5. N-Context

6. Learning Rate

Additionally, different models were also experimented with including base and large bert models with different sample sizes ranging from 15 to 300 as the dataset allowed it.

#!/bin/bash

# SLURM job settings
#SBATCH --mail-user=abraham.mathew@sjsu.edu
#SBATCH --mail-user=/dev/null
#SBATCH --mail-type=BEGIN,END,FAIL
#SBATCH --job-name=gpuTest_016018990
#SBATCH --output=gpuTest_%j.out
#SBATCH --error=gpuTest_%j.err
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=1
#SBATCH --time=48:00:00
##SBATCH --mem-per-cpu=2000
##SBATCH --gres=gpu:p100:1
#SBATCH --partition=gpu

# Load Python and Slurm modules
module load python-3.10.8-gcc-11.2.0-c5b5yhp slurm

# Set up proxy
export http_proxy=http://172.16.1.2:3128
export https_proxy=http://172.16.1.2:3128

# Change to project directory
cd /home/016018990/CMPE259/project/mcqs/atlas

# Activate Python virtual environment
source /home/016018990/CMPE259/project/scripts/atlas-2/bin/activate

# Set up directories and files
DATA_DIR=/scratch/cmpe259-fa23/016018990/sub_dataset
EXPERIMENTS_DIR=/home/016018990/CMPE259/project/mcqs/atlas/experiments
TRAIN_FILE="${DATA_DIR}/15/train_15.jsonl"
EVAL_FILE="${DATA_DIR}/15/eval_15.jsonl"
PASSAGES_FILE="${DATA_DIR}/unified_passages.jsonl"
SAVE_DIR=${DATA_DIR}/output/
EXPERIMENT_NAME="experiment-train-15"

# Set up training parameters
TRAIN_STEPS=30
PRECISION="bf16"
port=$(shuf -i 15000-16000 -n 1)

# Run training script
python3 /home/016018990/CMPE259/project/mcqs/atlas/train.py \
  --shuffle \
  --train_retriever \
  --gold_score_mode ppmean \
  --use_gradient_checkpoint_reader \
  --use_gradient_checkpoint_retriever \
  --precision ${PRECISION} \
  --shard_optim \
  --shard_grads \
  --temperature_gold 0.1 \
  --temperature_score 0.1 \
  --refresh_index -1 \
  --target_maxlength 16 \
  --reader_model_type google/t5-base-lm-adapt \
  --dropout 0.1 \
  --lr 5e-5 \
  --lr_retriever 1e-5 \
  --scheduler linear \
  --weight_decay 0.01 \
  --text_maxlength 512 \
  --model_path "/home/016018990/atlas/atlas_data/models/atlas/base/" \
  --train_data ${TRAIN_FILE} \
  --eval_data ${EVAL_FILE} \
  --per_gpu_batch_size 1 \
  --n_context 30 \
  --retriever_n_context 30 \
  --name ${EXPERIMENT_NAME} \
  --checkpoint_dir ${SAVE_DIR} \
  --eval_freq 30 \
  --log_freq 4 \
  --total_steps 2000 \
  --warmup_steps 50 \
  --save_freq ${TRAIN_STEPS} \
  --main_port $port \
  --write_results \
  --task multiple_choice \
  --multiple_choice_train_permutations all \
  --multiple_choice_eval_permutations cyclic \
  --index_mode flat \
  --passages "${PASSAGES_FILE}" \
  --query_side_retriever_training \
  --save_index_path ${SAVE_DIR}/${EXPERIMENT_NAME}/saved_index \
  --save_index_n_shards 1

Result Discussion

15304575105150180225300Sample Size70%75%80%85%90%
ParameterSInitialBest
Total Steps300600
Temperature gold /score0.10.1
Text max length512512
N context3050
Learning rate1e-51e-5
Dropout0.10.1

After finetuning the large model on 300 samples with the above parameters, the highest accuracy we were able to obtain was 85.37%

Key Insights

Context Matters: Providing more background information (N_Context) significantly boosted model accuracy, highlighting the importance of context for optimal performance.

Hyperparameter Tweaking: Lower dropout and learning rates led to consistently better results across model sizes. Balanced temperature scores proved optimal, while extreme values negatively impacted accuracy. The ideal text length varied depending on the dataset size and model. Increasing total steps, representing computation time, had inconsistent effects on different experiments.

Beyond Higher Values: Exceeding certain parameter thresholds, like dropout rate and learning rate, didn't always lead to better performance. Identifying the optimal balance for these parameters was crucial.

Automation for Efficiency and Reliability: Automated processes, such as data splitting, ensured consistency and eliminated errors, enhancing result reliability.

Streamlined Evaluation: An automated post-processing approach facilitated extracting evaluation metrics from data files, streamlining evaluation and enabling rapid performance analysis.