An Opinionated Primer on Fine-Tuning

An Opinionated Primer on Fine-Tuning

Databricks Week 18

I'll admit that when I first heard about 'small language models', I thought it was a ridiculous fad. After all, we have a wonderful toolbox of transformer-based models (and some non-transformer models like Mamba) for solving natural language problems. You can even put a classifier or regressor on top of these models. However, after spending time developing my own models, optimizing costs for customers, and exploring agent frameworks, I now admit that continuously inflating model size isn't always the best solution. For example, in agent frameworks, we can leverage specific smaller models alongside a larger orchestration model to solve complex problems effectively. But how do we obtain these task-specific smaller models? After spending time in the trenches, I've concluded that fine-tuning is likely the best solution for this specific challenge. This post covers fine-tuning of small language models (defined as 7B parameters or less) for specific tasks. I'll start by explaining why you might want to do this and what you can do with a fine-tuned model. Then, I'll dive into the technical details of the process. Finally, I'll share tips on evaluating whether your fine-tuning was successful.

Let's Go!

Why Would You Want To Fine-Tune A Language Model?

Language models get big. Like really big. If you want to roughly estimate how much GPU ram you need for inference, take the number of parameters and multiply by 4. For a 405B Llama 3.1 model, that's 1.6TB of GPU ram. This is for a full precision model that uses 32-bit floating point numbers. You can quantize these models by reducing the precision of the numbers used in the model, but that's still a lot of GPU ram. For reference, the best A100 GPUs have 80 GB of ram... so you'll need to reserve 11 of them just do inference in 16-bit precision.

It gets worse. If you want to fine-tune a model, benchmark 5-10 time the memory you need for inference depending on how you optimize, shard, etc. So it isn't a huge surprise that smaller models are gaining favour, especially if they can provide sufficient performance. The problem with smaller models is that they are less capable of complex reasoning. But with fine-tuning, you can control how each one of these models performs for a specific task while being efficient with your compute spend.

Keeping your models smaller isn't the only reason to fine-tune. OpenAI put out a great video on fine-tuning vs. prompt engineering. I've reproduced the figure they used below. Let's talk about few-shot prompt engineering and retrieval augmented generation. You might start with some prompt engineering, followed by some examples (few-shot prompting). At some point, feeding the same examples over and over to a model starts to get boring. So, you might start to think about dynamically generating these examples - and suddenly you are doing retrieval augmented generation.

Because you pay per token for most models, all of this exemplary prompting starts to get expensive and increase your latency. You start to think that the whole process of shouting at your model to "FOR PETE'S SAKE ONLY RETURN A LIST AND I'LL GIVE YOU 50 DOLLARBUCKS" seems a little crass. Congratulations, you've now graduated to trying out fine-tuning.

Article content
Conceptual illustration of prompt engineering and context (what the model knows) vs. model capacity (how the model behaves). Good solutions stay away from the axes.

So How Do You Fine-Tune A Language Model?

Let's get into the meat of this hamburger. I am going to fine-tune a Llama 3.2 3B model to do a structured prediction task. Llama 3.2 is small - it only needs 5ish GB of GPU ram for inference. This makes it a good candidate for fine-tuning. For my environment, I am going to use a V100 GPU with 16GB of ram with a single driver node. It is worth noting here that most LLM frameworks should be run on a single node. These nodes can have multiple GPUs, but getting PyTorch and other frameworks to access worker GPUs is quite complex - wouldn't recommend for now. We are going to stick with the huggingface ecosystem. Let's break down our process into a few steps:

1. Setup a tokenizer

2. Load a quantized base model

3. Setup a parameter efficient fine tuning framework

4. Setup a supervised trainer

5. Train the model

6. Evaluate the model

There is a lot of jargon coming up, but I'll do my best to break it down for you.

Tokenization

Language models are just integer machines. Every model requires a tokenizer which converts text into a sequence of integers. These integers are then fed into the model. Each class of model (e.g. GPT, Llama, etc.) has it's own tokenizer and understanding them is important. Let's load up a tokenizer from a custom model in Unity Catalog. This chunk takes a marketplace model (stored in the system.ai schema) and downloads it to the driver disk.

import mlflow
from mlflow.artifacts import download_artifacts
mlflow.set_registry_uri("databricks-uc")
artifact_path = download_artifacts(
    artifact_uri="models:/system.ai.llama_v3_2_3b_instruct/2", 
    dst_path= "/local_disk0/llama_v3_2_3b_instruct/"
  )        

With that, we can now load the tokenizer onto the driver. Two important things to note in the chunk below. First, the max length is important for managing memory. You can quickly burn through your GPU ram if you hit the maximum context limit of modern models (e.g. 128,000 tokens) and it is worth keeping a reasonable length limit if working with constrained compute. Second, padding is essential and should be set to the left. I found this slightly counterintuitive until I thought about the generation process. If we pass [23, 340, 2402, 1, 0, 0, 0, 0] (right padding) to the model, it starts generating after a whole bunch of padding tokens (shown here as 0s). If we pass [0, 0, 0, 0, 23, 340, 2402, 1] (left padding) to the model, it starts generating after the last non-padding token, which is what we want. So keep left my friends.

import os
from transformers import AutoTokenizer
tokenizer_path = os.path.join(artifact_path, "components", "tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
  tokenizer_path, 
  padding_side='left',
  model_max_length = 4096,
  add_eos_token=True
)        

One thing I was struggling to understand is how you can find out what special tokens the model uses. For open source models, there are two ways to do this - look at the tokenizer config file on Hugging Face or pull it off the loaded tokenizer using tokenizer.all_special_tokens. Here is an excerpt of LLama's special tokens compared to the Qwen 2.5 tokenizer.

LLama:
'<|begin_of_text|>': 128000
'<|end_of_text|>': 128001
'<|reserved_special_token_0|>': 128002
'<|reserved_special_token_1|>': 128003
'<|finetune_right_pad_id|>': 128004

Qwen 2.5:
'<|endoftext|>': 151643
 '<|im_start|>': 151644
 '<|im_end|>': 151645
 '<|object_ref_start|>': 151646
 '<|object_ref_end|>': 151647        

I hope you see why tokenization is important when building custom frameworks! The next thing that confused me was what format we actually need to pass to the model. Under the hood, language models are still just next word prediction machines. It wasn't clear to me how all these chat messages we give to OpenAI actually get parsed into text (or fine-tuning datsets for that matter). It is simpler than you'd expect. If you provide chat messages, the tokenizer applies a chat template. This chunk takes a work order and list of activities via chat messages and converts them into a format that the model can understand for supervised fine-tuning or inference. Some notes here - we pad and truncate the chat to maintain the maximum length. But once we have that, we can move forward with generation and fine-tuning.

messages = [
    {"role": "user", "content": 
        """
        Use the following work order and create a comma separated list of activities.
        Equipment: X-203-101
        Short Description: Fault alarm.
        Long Description: LEL detector on the engine outlet consistently faulting.
        """
    },
    {"role": "assistant", "content": "['troubleshoot']"},
 ]
tokenized_chat = tokenizer.apply_chat_template(
    messages, 
    tokenize=True, 
    add_generation_prompt=False, 
    return_tensors="pt", 
    padding=True, 
    truncation=True,
    )
print(tokenizer.decode(tokenized_chat[0]))

>>> 

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Nov 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Use the following work order and create a comma separated list of activites.
Equipment: X-203-101
Short Description: Fault alarm.
Long Description: LEL detector on the engine outlet consistently faulting.
Activity:

<|eot_id|><|start_header_id|>assistant<|end_header_id|>
['troubleshoot']
<|eot_id|>        

Read more about chat templates here: https://huggingface.co/blog/chat-templates and make sure you are using the right tokenize and chat template for your model!

Load a Base Model

The next component is our base model. Since we are using a V100 GPU, even a 3B parameter model might be too large for fine-tuning. Therefore, we will use a quantized model instead. Let's understand quantization - models can operate at different precision levels ranging from 4-bit to 32-bit. Precision refers to the number of decimal places used for weights and computations. Maarten Grootendorst provides an excellent visual guide explaining precision differences , and Merve Noyan offers a great introduction to quantization.

To load a quantized model, we can use the Hugging Face bitsandbytes library. We typically default to 8-bit or 4-bit quantization because most modern models are trained using BF16 precision - a specialized format designed for deep learning by Google Brain (BF16 stands for Brain Floating Point 16).

import bitsandbytes as bnb
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_path = os.path.join(artifact_path, "model")
base_model = AutoModelForCausalLM.from_pretrained(
  model_path, 
  quantization_config=bnb_config
  )        

Setup a Parameter Efficient Fine Tuning Framework

Fine-tuning large language models by modifying all their weights is impractical and can lead to problems like catastrophic forgetting. To address these challenges, parameter efficient fine-tuning (PEFT) frameworks like LoRA have been developed. For excellent introductions to these techniques, see Sebastian Raschka's blog on LoRA and Aman Chadha's primer on PEFT.

PEFT frameworks typically focus on the linear layers of the transformer architecture for several reasons:

1. They contain the majority of the model's parameters

2. They control key transformations and model behaviour

3. They allow us to avoid modifying the attention mechanism

4. They are well-suited for low-rank adaptation

Our first step is to identify these linear layers in the model.

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

linear_layers = find_all_linear_names(base_model)
print(f"Linear layers: {linear_layers}")
# Linear layers: ['gate_proj', 'v_proj', 'o_proj', 'k_proj', 'up_proj', 'down_proj', 'q_proj']        

After that, we need to setup our LoRA configuration. They key parameters to pay attention to here are r (the rank of the low-rank adaption) and alpha (the scaling factor for the weights). These parameters move in lockstep, with alpha often set between 0.5 and 2 time r. With the configuration and our base model, we can now initialize our PEFT model.

from peft import LoraConfig, TaskType, get_peft_model
from trl import SFTConfig, SFTTrainer

peft_config = LoraConfig(
  task_type="CAUSAL_LM",
  bias="none",
  inference_mode=False, 
  r=64, 
  lora_alpha=64, 
  lora_dropout=0.1,
  target_modules=linear_layers
  )

peft_model = get_peft_model(base_model, peft_config)
peft_model.print_trainable_parameters()
# trainable params: 97,255,424 || all params: 3,310,005,248 || trainable%: 2.9382        

Setup a Supervised Trainer

Neils Rogge provides an excellent video overview of the training process. For supervised training, we need to configure both training arguments and a supervised trainer (specifically an SFTTrainer in this case). While there are many parameters to consider, I think the most critical ones are:

1. Learning rate - This controls how quickly the model adapts to new data, or whether it adapts at all

2. Batch size - This affects GPU memory usage and training efficiency

3. Maximum sequence length - This impacts both training and inference memory requirements

The batch size and sequence length are especially important in resource-constrained environments, as they directly control GPU memory consumption. It's worth noting that sequence length remains crucial during inference - processing very long sequences (e.g., 128,000 tokens) can rapidly exhaust GPU memory.

from transformers import TrainingArguments
from trl import SFTTrainer

training_arguments = TrainingArguments(
    output_dir="/local_disk0/results",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    optim=paged_adamw_32bit,
    save_steps=500,
    logging_steps=100,
    learning_rate=2E-4,
    bf16=True,
    max_grad_norm=0.3,
    max_steps=5000,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="cosine",
    ddp_find_unused_parameters=False,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=4096,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()        

I recommend experimenting with these parameters while giving your GPU a wide margin to avoid running out of memory. Mosaic AI pioneered a comprehensive fine-tuning solution that is now in public preview on Databricks. This solution is popular because it enables performant fine-tuning without the need to worry about GPU memory management or the boilerplate code presented in this post. If you're not using a more robust solution like this, I recommend limiting training times to something reasonable (around 6 hours) and checkpointing frequently (every 30 minutes) to avoid uncomfortable conversations with whoever controls your cloud budget.

How Do You Know If It Worked?

Now that we've covered why and how to fine-tune a model, let's focus on the most important part: validating that it worked. The first step is to perform a basic sanity check. In our example, we'll compare the base model's performance against two fine-tuning runs - one with a rank of 8 and another with a rank of 64.

To evaluate the models, we'll load them into a generation pipeline and apply it to our dataset. Note how we control the max_length parameter and sample the dataset to manage compute time. While there are ways to optimize this inference process, we'll keep it simple by applying the pipeline directly to the sampled dataset. For comparison purposes, we remove the prompt by splitting on "Activity:". After evaluating the base model, we can load and evaluate the PEFT models using the same process (code omitted for brevity).

base_pipeline = pipeline(
    task="text-generation",
    model=base_model,
    tokenizer=tokenizer,
    padding=True,
    truncation=True,
    max_length=4096,
    temperature=0.1,
    top_p=0.95,
    repetition_penalty=1.15,
)

sampled_entries = test_dataset.shuffle(seed=42).select(range(100))
sampled_df = pd.DataFrame(sampled_entries)
base_out = base_pipeline(sampled_entries['input'])
sampled_df['base_pred'] = [
  x[0]['generated_text'].split("Activity:")[-1] for x in base_out
  ]

# Be careful as this has side effects on the base model!!!
peft_ft_path = "./test_ft_run/"
ft_a_merged = PeftModel.from_pretrained(base_model, peft_ft_path)
merged_model_a = ft_a_merged.merge_and_unload()

# ... redo pipeline and inference on two PEFT models.That leaves us with a nice Dataframe that has three prediction columns for our models, as well as the ground truth. Here are two examples from test set.        

Example 1

Ground Truth:
["Audit EB's"]

Base Response:
1) Inspect Electrical Box
2) Verify Bolt Condition
3) Check Weather Stripping
4) Review Tagging System
5) Document Findings Report
Here is the completed list:
Inspect Electrical Box, Verify Bolt Condition, Check Weather Stripping, Review Tagging System, Document Findings Report"	" 

Rank 8 Fine Tuning:
1. Inspect Electrical Box
2. Verify Tagging
3. Check Weather Stripping
4. Review Bolt Condition
5. Document Findings
Note: The request is for a comma-separated list of activities, not an ordered list. Here's how to format it:
Inspect Electrical Box, Verify Tagging, Check Weather Stripping, Review Bolt Condition, Document Findings"	 

Rank 64 Fine Tuning:
['Audit EB']        

Example 2

Ground Truth:
['Switch reset']

Base Response:
5. Document Work Order 
List: Perform Switch Reset, Check Comm Fault, Test Equipment, Verify Results, Document Work Order
1. Reset Mux
2. Test Comm
3. Check Cables
4. Verify Config
5. Run Diagnostic 
Note: The above is just an example output based on the provided template. It may not reflect actual work orders or their requirements."	

Rank 8 Fine Tuning:
['SIS support, Reset Switch'] 
Note: Removed ""Original request"" section"

Rank 64 Fine Tuning:
['Troubleshoot']        

There are several key observations from these examples. First, the base model generates overly verbose responses. While the rank 8 fine-tuning shows some adaptation, its behavior remains inconsistent. The rank 64 fine-tuning demonstrates the most significant behavioral change, producing more reliable and concise outputs. This improvement in consistency is crucial for operating at scale and can be further enhanced with structured inference outputs.

While quick sanity checks are useful, how do we systematically evaluate large ground-truth datasets? The MLFlow evaluation framework provides this capability through both deterministic and LLM-based evaluations. Deterministic metrics like F1 or ROUGE work well for structured outputs, but LLM-based evaluations are typically necessary for most natural language tasks. Below is a simple example of running a similarity evaluation using an LLM as judge. For an even more streamlined approach, you can leverage the Mosaic AI Agent Evaluation framework to avoid implementing custom metrics.

from mlflow.metrics.genai import answer_similarity
similarity_metric = answer_similarity(
  model="endpoints:/databricks-meta-llama-3-1-70b-instruct"
)
with mlflow.start_run():
    results = mlflow.evaluate(
        data=eval_df,
        targets="label",
        predictions="base_pred",
        extra_metrics=[
          similarity_metric
          ],
    )        

You'll likely struggle a bit with LMs as judges being consistent, so try and use a narrow rubric (1-5 for example instead of 1-100) and use as capable of a judge as you can afford (Llama 70B/405B or GPT-4o for example). The MLFLow framework will then provide an integer score for each evaluation, as well as justification generated by the model. For example a score of 2 might have the following justification:

"The provided output has partial semantic similarity to the target. It mentions "Replace Transmitter" which is present in the target, but the other steps mentioned in the output, such as "Clean Furnace Pass", "Inspect Manifold", "Gold Plate Transmitter", and "Test System Functionality" are not mentioned in the target. The output does not capture the comprehensive details and context provided in the target, therefore, it demonstrates partial, but not complete, semantic similarity.        

In Closing

While this has been a comprehensive overview, fine-tuning remains a complex subject that deserves careful attention. To recap: fine-tuning can dramatically alter model behaviour in powerful ways. Through parameter-efficient frameworks like LoRA and quantization, we can fine-tune models within hours and achieve significant, measurable improvements. These improvements can be precisely quantified using MLFlow evaluations. I hope these code snippets have effectively demonstrated fine-tuning implementation and inspired you to try it yourself. Furthermore, I encourage you to begin curating ground truth datasets early in your projects. This early preparation enables sophisticated techniques like fine-tuning and systematic evaluation, rather than relying solely on trial-and-error prompt engineering.

mahtab ghoroori

Data Scientist I 3X AWS Certified I PhD Candidate at University of Calgary

5mo

very informative and helpful post, thanks Scott McKean

Rahul Pandey

Unlocking Business Potential with AI Solutions | Senior Solutions Architect @ adidas | Certified Expert in Databricks, AWS & GCP | Writer & Speaker | MLflow Ambassador | Building core4ai 🖖

5mo

Totally agree and thanks for sharing ✌️

Nehmé TOHMÉ

Brickster | ex-Amazon

5mo

Thanks, Scott—I really enjoyed your post—it’s the best thing I’ve read this week. Exceptionally well done and incredibly helpful. Fine-tuning smaller models for specific tasks is a practical way to balance performance and efficiency. In agent frameworks, using specialized smaller models alongside a larger orchestration model can make solutions more effective and resource-efficient. This modular approach lets each part excel without unnecessary overhead.

To view or add a comment, sign in

More articles by Scott McKean

  • Databricks Logging and Debugging

    Let’s talk about logging on Databricks, specifically in Notebooks, Spark, and Ray. Effective logging is critical for…

    4 Comments
  • DS Fortune Cookies: FTI Architecture

    Three sisters dancing in endless flow, feature, train, and infer they go! I read the LLM Engineer's Handbook over the…

  • Azure Databricks CI/CD

    This is an opinionated article on continuous integration and continuous delivery (CI/CD). These are specific practices…

    5 Comments
  • DS Fortune Cookies: LangChain, Agents, and Authentication

    “Embrace LangChain's evolution and your spirit will be unbreakable, unlike your code.” This fortune cookie clarifies…

    2 Comments
  • DS Fortune Cookies: System Prompts

    "Lucky numbers: 0, 1. Lucky words: Your system prompt.

    2 Comments
  • Text Similarity

    Databricks Week 16 This week I had the pleasure of speaking with a couple of customers that want to compare two bits of…

    1 Comment
  • 100 Days at Databricks

    As I hit the 100-day mark at Databricks, I want to review the journey so far with some of the bigger themes that stood…

    6 Comments
  • Anomaly Detection

    Databricks Week 12/13 I was asked to help a customer out with anomaly detection. I brushed off some of the thoughts I…

    4 Comments
  • Forecasting Deep Dive

    Databricks Week 10/11 Today is the day - I’m going to really let myself talk nerd. Let’s dive into time series…

    2 Comments
  • DS Fortune Cookies: Liquid AI

    "When time is of the essence, closed-form solutions make all the difference." Liquid AI introduced a novel class of…

    1 Comment

Insights from the community

Others also viewed

Explore topics