In a previous blog post, we presented the results of our fine-tuned Mixtral response generation model called Mixtral FT RG, as well as our fine-tuned Mixtral TruthChecker model called Mixtral FT TC .
Mixtral FT RG delivered a respectable 4.5% base hallucination rate, which is about half of raw Mixtral-Instruct’s hallucination rate. Meanwhile, Mixtral FT TC achieved 96.0% accuracy in truth-checking, which is a significant improvement over raw Mixtral-Instruct's accuracy of 62.5%. These results show the impact of fine-tuning models for specific tasks. In this blog, we would like to share our insights from the process of fine-tuning Mixtral.
Fine-tuning Libraries
To avoid building from scratch, we wanted to find a suitable fine-tuning library that supports Mixtral. We initially tested Transformers and TRL by the HuggingFace team. However, we were unable to obtain good results, as the fine-tuned models’ performance was often worse than the raw pre-trained Mixtral model. We later found success with the Axolotl library, which offers built-in support for various models including Mixtral, as well as useful features such as quantization, LoRA, DeepSpeed, and Flash Attention.
Fine-tuning Mixtral for response generation
Data preprocessing
To fine-tune for the response generation task, we started by compiling a high-quality Q&A dataset. This dataset for response generation consists of 2400 high-quality questions and answers based on a given knowledge base. These questions and answers were either synthetically generated by experts or taken from live customer conversations, while the knowledge base snippets are extracted by our RAG pipeline. We randomly sampled 2100 examples for training and held out the remaining 300 examples for testing.
To format the training examples, we applied a built-in prompt in Axolotl called context_qa.load_v2, which combines in-context articles with a question and answer. This template is simple and does not use a system message. Based on our past experience with fine-tuning language models for a specific task, we found that a simple prompt usually works well. The context_qa.load_v2 prompt template has the following structure:
prompt = (
"Context: " + sample["context"]
+ "\nQuestion: " + sample["question"]
+ "\nAnswer: " + sample["answer"]
)
Configure training parameters
Axolotl provided a comprehensive configuration file of Mixtral fine-tuning parameters. We mainly kept these default parameters with minor modifications such as num_train_epochs, micro_batch_size, and eval_steps for our own use case. We tried both Mixtral and Mixtral-Instruct as base models for fine-tuning, and found Mixtral-Instruct to be slightly better.
Some important parameters to note are
adapter=qlora: using QLoRA significantly reduces the amount of GPU memory required for training while maintaining performance
sequence_len=4096: the majority of our training examples fit in a context window of 4096 tokens
sample_packing=true: combines multiple short examples in the same input sequence to increase training efficiency
train_on_inputs=false: masks out the input prompt from the training labels
Training time and compute
For fine-tuning, we used one A100 80Gb GPU. Training with QLoRA and 4-bit precision required about 48Gb of GPU memory with train_batch_size of 4. We fine-tuned the model for 5 epochs, which took about 6.5 hours. Overall, the training process was stable as the loss was steadily decreasing, as shown in the graph.
To reduce GPU memory needed, one could fit training in one A100 40Gb GPU by using train_batch_size of 2 and gradient_accumulation_steps of 2.
To speed up training time, one could use LoRA instead of QLoRA. However, that would require more GPU memory to train.
At this stage, we obtained Mixtral FT RG, a fine-tuned model which specializes in providing relevant and accurate responses to the user based on a knowledge base.
Fine-tuning Mixtral for truth-checking
Data preprocessing
To compile the dataset for hallucination detection, we used Mixtral FT RG to produce responses on its test set, as well as a subset of most recent live conversations. This resulted in 2100 examples of user messages and corresponding Mixtral responses based on the extracted knowledge base snippets. According to Autoeval, fine-tuned Mixtral has a 4.5% base hallucination rate on this dataset.
We applied AutoEval’s JSON output format to fine-tune Mixtral for truth-checking, which includes fields to indicate whether any relevance or groundedness mistakes were found and if so, what the mistakes were. By using this output format, we trained Mixtral to generate labels and reasoning, while simultaneously identifying and correcting any potential mistakes in the response.
Data augmentation
Since Mixtral FT RG’s base hallucination rate is low, there are not many hallucination examples for training. Therefore, we augmented the training set with about 2000 responses from GPT-3.5-Turbo which has a higher hallucination rate of 21.7%. This provides additional hallucination cases that the model can learn from during training.
Training configuration
We applied the same training parameters for TruthChecker as for the response generation model, except for max_sequence_length. Since the prompt for TruthChecker is longer, we increased max_sequence_length to 8192 tokens. We trained for 5 epochs using 4-bit QLoRA, which required one A100 80Gb GPU and took about 14 hours.
The resulting TruthChecker model, Mixtral FT TC, showed a considerable improvement in accuracy over the raw Mixtral-Instruct model. At a 4.5% rejection rate, it is able to reduce the net hallucination rate to 2.0%, which is comparable to GPT-4 in factual accuracy.
Conclusion
We have successfully fine-tuned and integrated Mixtral models into our products. As there are still limited resources, documentation and understanding available on how to fine-tune Mixtral, we shared our knowledge and experience here for the community to learn from. Our experiments show the improvement that fine-tuning can bring to an already performant base model. By fine-tuning Mixtral for response generation, we were able to reduce our automated chatbot’s hallucination rate from 8.0% to 4.5%. Moreover, by fine-tuning Mixtral for truth-checking, we further decreased the residual hallucination rate to 2.0%. We hope this result will encourage developers to continue fine-tuning open-source models that can achieve similar performance to some of the best closed models.