Meta has recently released Llama 3.1, including their 405 billion parameter model which is the most capable open model to date and the first open model on the same level as GPT 4. We took this opportunity to create a synthetic dataset for classifying political spam, which we generated with Llama 3.1 405B. Now, we are going to train Llama 3.1 8B on that dataset to classify political spam accurately.
This is the third of a four-part blog series. In our last blog, we filtered the data generated in the first post to deduplicate the synthetic data and check the diversity of the data. Now that we've got the data, the next step is fine-tuning a model on the data, which means training it to predict whether a text message is spam.
The Series:
- Create Your Own Synthetic Data With Only 5 Political Spam Texts
- How to De-duplicate and Clean Synthetic Data
- This blog post!
- Evaluating the trained model
Download the Code!
Just like in the previous blog posts, the code and data are available in the Oxen repo:
Picking a Model
With the recent release of the Llama 3.1 models, we used Llama 3.1 8B. However, smaller models might work just as well. Testing that may be a fun side project for the reader:)
Picking a Fine-tuning Method
To fine-tune the model, we chose to use ReFT for its efficiency. We demonstrated ReFT in a previous blog, and we also had one of the paper's authors on our Arxiv Dive on the topic, so check those out if you're interested.
ReFT Hyperparameters
After trying several settings, the hyperparameter configuration we settled on was this:
- Positions to intervene on: first 8 and last 8 tokens in the sequence
- Layers to intervene on: 8, 16, 20, 24, and 31
- Low rank dimension: 4
- Weight sharing enabled between interventions on the same layer
If you don't know what that means, I'd recommend watching the Arxiv Dive on the topic, which goes through the paper and explains how it works.
Loading the Model with ReFT
We proceeded to add the ReFT intervention to set the model up for fine tuning.
# Load the base model
model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='cpu', attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ReFT Hyperparameters
share_weights=True
positions = "f8+l8"
reft_config = ReftConfig(
representations=[{
"layer": l,
"component": "block_output",
"low_rank_dimension": 4,
"intervention": LoreftIntervention(
embed_dim=model.config.hidden_size, low_rank_dimension=4
)
} for l in [8, 16, 20, 24, 31]]
)
# Add the fine tuning intervention
reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
reft_model.print_trainable_parameters()
# trainable parameters: 163,860
# trainable %: 0.002%
Loading the Data
We generated and filtered synthetic data for detecting political spam in the last two blog posts from this Oxen repo. Now, we load the data, apply a chat template to it, and prepare the data for fine-tuning.
# use the same seed when evaluating
random_seed = 256
texts = pd.read_parquet('filtered_texts.parquet')
eval_texts = texts.sample(n=50, random_state=random_seed)
train_texts = texts.drop(eval_texts.index)
def create_prompt(text):
return tokenizer.apply_chat_template([
{"role": "system", "content": "Detect whether the following text is spam or not."},
{"role": "user", "content": text}
], add_generation_prompt=True, tokenize=False)
train_prompts = [create_prompt(prompt) for prompt in train_texts['message'].tolist()]
train_responses = ["Spam" if text == True else "Ham" for text in train_texts['spam'].tolist()]
# pyreft module
train_data_module = make_multiple_position_supervised_data_module(
tokenizer, model, train_prompts, train_responses,
positions=positions, num_interventions=len(reft_config.representations), share_weights=share_weights
)
Fine-tuning the Model
After setting up ReFT and loading the data, we fine tuned the model for 8 epochs and with a learning rate of 4e-3. The training itself took only 11 minutes and 15 seconds!
Evaluating the Model
The natural next step after fine-tuning a model is to test its performance. If you want to see how this model performed, stay tuned for our next blog post where we evaluate the model on how well it can tell apart spam from regular messages!
Why Oxen?
Oxen.ai makes building, iterating on, and collaborating on machine learning datasets easy.
At its core Oxen is a lightning fast data version control tool optimized for large unstructured datasets. On top of that are features that make working with data easier such as data diffs, natural language queries for tabular files, workspaces, rendering images in tables, and more. We're constantly pushing out new features to make things easier for you. Oh yeah, and it's open source.
If you would like to learn more, star us on GitHub and head to Oxen.ai and create an account.