Language models like DialoGPT are highly effective in generating conversations but suffer from a well-known issue called Catastrophic Forgetting. This occurs when a model, after being fine-tuned on a new dataset, completely forgets previously learned information.
This experiment aims to test the use of Elastic Weight Consolidation (EWC) on DialoGPT to:
EWC Explained
Imagine you are learning to draw a tree.
You've already learned to draw a nice trunk, but now you want to add leaves.
If you erase everything to make only the leaves, you lose the trunk!
EWC is like a friend who says, "Wait! The trunk is important, don't erase it completely while adding the leaves!"
Now Back to AI Language
The model has already learned something useful (e.g., "My favorite color is blue").
When we train it on new things, without precautions, it might forget everything else.
EWC says, "Wait, some weights are important! Let's protect them!"
How Does It Do That?
It observes which weights are most important to the model.
It marks them as "critical" using the Fisher information matrix.
The Fisher information matrix measures how sensitive a model is to small changes in its parameters, indicating which weights are most important for predictions. In simpler terms, it quantifies how much information the model's parameters carry about the data.
During new training, it prevents these weights from being changed too much.
Result?
The model learns new things but doesn't completely forget the old ones.
To test this hypothesis, we followed these steps:
First, we load the DialoGPT-small model, which has been pre-trained on conversational data.
from transformers import AutoModelForCausalLM, AutoTokenizer import torch import matplotlib.pyplot as plt # Load DialoGPT model_name = "microsoft/DialoGPT-small" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name)
We created a small dataset with questions and answers focused on key concepts like memory retention of favorite color.
conversations = [ {"input": "Hello, how are you?", "response": "I'm good, thanks for asking! How about you?"}, {"input": "What's your favorite color?", "response": "I like blue. It's a calming color."}, {"input": "Do you remember your favorite color?", "response": "Yes, I like blue."}, {"input": "Tell me about yourself.", "response": "I'm an AI designed to chat with you."}, {"input": "What do you think about machine learning?", "response": "Machine learning is a powerful tool for AI development."}, ] # Function to create input-output pairs def create_training_example(input_text, response_text): input_ids = tokenizer.encode(input_text + " ", return_tensors="pt") response_ids = tokenizer.encode(response_text, return_tensors="pt") return input_ids, response_ids # Create the dataset dataset = [create_training_example(conv["input"], conv["response"]) for conv in conversations]
To apply Elastic Weight Consolidation (EWC), we computed the Fisher Information Matrix, which identifies the critical weights in the model.
def compute_fisher_information(model, dataset, max_length=50, pad_token_id=tokenizer.pad_token_id): fisher_info = {name: torch.zeros_like(param) for name, param in model.named_parameters()} model.eval() for input_ids, output_ids in dataset: model.zero_grad() # Apply padding to standardize sequence lengths input_ids = torch.nn.functional.pad(input_ids, (0, max_length - input_ids.size(1)), value=pad_token_id) output_ids = torch.nn.functional.pad(output_ids, (0, max_length - output_ids.size(1)), value=pad_token_id) outputs = model(input_ids, labels=output_ids) loss = outputs.loss loss.backward() for name, param in model.named_parameters(): fisher_info[name] += param.grad ** 2 fisher_info = {name: fisher / len(dataset) for name, fisher in fisher_info.items()} return fisher_info
We implemented a loss function with the EWC penalty, limiting changes to critical weights.
def ewc_loss(model, outputs, fisher_info, prev_params, lambda_penalty=0.05): base_loss = outputs.loss # Standard loss ewc_penalty = 0 for name, param in model.named_parameters(): if name in fisher_info: penalty = fisher_info[name] * (param - prev_params[name]) ** 2 ewc_penalty += penalty.sum() return base_loss + lambda_penalty * ewc_penalty
Now, we run 10 training epochs.
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) losses = [] model.train() for epoch in range(10): total_loss = 0 for input_ids, output_ids in dataset: optimizer.zero_grad() # Apply padding to match sequence lengths max_length = max(input_ids.size(1), output_ids.size(1)) input_ids = torch.nn.functional.pad(input_ids, (0, max_length - input_ids.size(1)), value=tokenizer.pad_token_id) output_ids = torch.nn.functional.pad(output_ids, (0, max_length - output_ids.size(1)), value=tokenizer.pad_token_id) outputs = model(input_ids, labels=output_ids) loss = ewc_loss(model, outputs, fisher_info, prev_params, lambda_penalty=0.05) loss.backward() optimizer.step() total_loss += loss.item() avg_loss = total_loss / len(dataset) losses.append(avg_loss) print(f"Epoch {epoch+1}, Loss: {avg_loss}") # Plot loss graph plt.plot(range(1, len(losses) + 1), losses, marker='o') plt.title('Model Loss During Training') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid(True) plt.show()
We tested the model with a set of key questions to verify memory retention.
test_questions = [ "Hello, how are you?", "What is your favorite color?", "Do you remember your favorite color?", "Tell me about yourself.", "What do you think about machine learning?", ] def test_model(questions): results = {} for question in questions: response = chat_with_model(question) results[question] = response print(f"User: {question}") print(f"Model: {response}") print("-" * 50) return results # Run the tests test_results = test_model(test_questions)
The experiment demonstrated that EWC helps retain memory in the model.
The model remembers "blue" as its favorite color even after fine-tuning.
Catastrophic Forgetting has been significantly reduced.
The model still generates fragmented and incoherent sentences, suggesting EWC might need further balancing.
My intuition was correct: EWC can be used as a hybrid approach to stabilize memory in LLMs.
The model successfully retains key information while continuing to generate responses.
However, optimizing EWC further is necessary to improve coherence and fluency.
temperature
, top_k
, etc.) for better coherence.This experiment confirms that EWC is a promising strategy for memory improvement in conversational AI models!
There are no datasets linked
There are no datasets linked