In this article, we will see how to train the LLaMA model that we built in the previous article.
Previous Article :- Building a LLaMA Model from Scratch
from typing import Optional import torch import time from pathlib import Path import json from sentencepiece import SentencePieceProcessor from tqdm import tqdm from model import ModelArgs, llamaModel
Import all the necessary libraries also import the model.py file that has code that we implemented in the previous article.
class LLaMA: def __init__(self, model: llamaModel, tokenizer: SentencePieceProcessor, model_args: ModelArgs): self.model = model self.tokenizer = tokenizer self.args = model_args
In the class LLaMA create the initialize method with parameters:
model :- An instance of the llamaModel class, which represents the language model.
tokenizer :- An instance of the SentencePieceProcessor class, which handles tokenization and detokenization of text.
model_args :- An instance of the ModelArgs class, which contains configuration and hyperparameters for the model.
@staticmethod def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str): prev_time = time.time() if load_model: checkpoints = sorted(Path(checkpoints_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}" ckpt_path = checkpoints[0] print(f'Loading checkpoint "{ckpt_path}"') checkpoint = torch.load(ckpt_path, map_location="cpu") print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s") prev_time = time.time() with open(Path(checkpoints_dir) / "params.json", "r") as f: params = json.loads(f.read()) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, device=device, **params ) tokenizer = SentencePieceProcessor() tokenizer.load(tokenizer_path) model_args.vocab_size = tokenizer.vocab_size() if device == "cuda": torch.set_default_tensor_type(torch.cuda.HalfTensor) else: torch.set_default_tensor_type(torch.BFloat16Tensor) model = llamaModel(model_args).to(device) if load_model: # The only unmatched key in the checkpoint is rope.freqs. Remove it del checkpoint['rope.freqs'] model.load_state_dict(checkpoint, strict=True) print(f"Loaded state dict in {time.time() - prev_time:.2f}s") return LLaMA(model, tokenizer, model_args)
This is the second function inside the LLaMA
class. The @staticmethod
: Indicates that this method is a static method, which does not depend on class or instance-specific data.
Checkpoint Loading: If load_model is True, it searches for .pth files (PyTorch model checkpoints) in the specified directory. This ensures that there is at least one checkpoint file. Then loads the first checkpoint found into memory, mapping it to the CPU. And rints the time taken to load the checkpoint.
Parameter Loading: Reads model parameters from a params.json file in the checkpoints directory. Constructs a ModelArgs object with the loaded parameters, along with the provided max_seq_len, max_batch_size, and device.
Tokenizer Loading: Initializes a SentencePieceProcessor tokenizer and loads it from the specified path and sets the vocabulary size in model_args based on the loaded tokenizer.
Tensor Type Configuration: This sets the default tensor type based on the specified device (either CUDA half-precision or BFloat16 for CPUs).
Model Initialization: Instantiates the llamaModel with the given model_args and moves it to the specified device. If load_model is True, removes the rope.freqs key from the checkpoint (to avoid mismatches) and loads the state dictionary into the model.
Return Statement: Creates and returns an instance of the LLaMA class, initialized with the model, tokenizer, and model arguments.
During inference time we only pass one token at a time to reduce the unnecessary computation. We need to find a strategy to find the next token from the vocabulary and this is called as logits.
Suppose a sentence “Love is _____” and we have to fill the last word, we can come up with many different words like: kind, eternal, painful, pure, unconditional, etc. The choice of the next token depends on our knowledge, education, and experience.
LLMs also face the same problem, predicting the next token depends on their training and the strategy that they use to predict the next token like: Greedy Strategy, Beam Search, Temperature, Random Sampling, Top K, Top P, etc.
The output of the self-attention is a sequence, in case of KV Cache it is only a single token. Then after normalization we pass it through a linear layer which will transform the embedding that is output from the self-attention into a list of numbers that represents the kind of the probability of that token in the vocabulary. If the vocabulary size is 1000, we get a list of 1000 numbers. After applying Softmax these numbers will become probabilities of that token being the next probable token.
Now out of these many probabilites how to choose which one will be the next token. For this we apply strategies.
Greedy Strategy → At every step we select the token with the highest probability which is appended to the input to geenrate the next token. If the initial tokens happens to be the wrong ones, it is very likely that the next one will be wrong as well. Hence poor performance.
Beam Search → At every step we keep alive the top K paths and all the other are killed. This increases infernce time, since at every step must explore K possible options. Generally, performs better than the Greedy Strategy.
Temperature → The idea is to scale the logits before applying the softmax. A low temperature makes the model more confident (gap between low and high probabilites increases). A high temperature makes the model less confident (gap between low and high probabilites reduces).
Random Sampling → We sample from the random distribution that is output form the Softmax.
logits = torch.Tensor([-2.5, -3, -0.6]) distribution = torch.softmax(logits, dim=0) distribution #OUTPUT --> tensor([0.1206, 0.0731, 0.8063])
The first token will be choosen with a probability of 12.06% the second with a probability of 7.31% and the third with probability 80.63%. The higher the probability, the more likely the probability of being choosen. The problem is with very little probability it may happen that we choose tokens that are total nonsense.
Top K → With this we keep only the top K highest probabilities, so that tokens with low probabilities will never be choosen. Problem with this is given teh following distributions the low- probability tokens can still make their way into the top K tokens (K=2).
Distribution 1 :- 0.5, 0.4, 0.05, 0.025, 0.025
Distribution 2:- 0.9, 0.05, 0.025, 0.020, 0.005
Top P → With this we keep only those tokens with highest probability such that their cumulative probability is greater than or equal to the parameter P. This way we get more tokens for distributions that are more “flat” and less tokens for distributions with more prominent mode.
In LLaMA, top P strategy was implemented so we will do the same.
def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None): if max_gen_len is None: max_gen_len = self.args.max_seq_len - 1 prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts] batch_size = len(prompt_tokens) assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}" max_prompt_len = max(len(prompt) for prompt in prompt_tokens) assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}" total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id() tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device) for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) eos_reached = torch.tensor([False] * batch_size, device=device) prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise cur_iterator = tqdm(range(1, total_len), desc="Generating tokens") for cur_pos in cur_iterator: with torch.no_grad(): logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = self._sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id) if all(eos_reached): break out_tokens = [] out_text = [] for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()): if self.tokenizer.eos_id in current_prompt_tokens: eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id) current_prompt_tokens = current_prompt_tokens[:eos_idx] out_tokens.append(current_prompt_tokens) out_text.append(self.tokenizer.decode(current_prompt_tokens)) return (out_tokens, out_text)
Tokenizing Prompts: Each prompt is tokenized using the tokenizer, including the beginning-of-sequence (BOS) token but excluding the end-of-sequence (EOS) token.
Ensuring Batch and Prompt Size Constraints: Then ensure that the batch size does not exceed the maximum allowed batch size. Also that the prompt length does not exceed the maximum sequence length. Calculate the total length of the sequence to be generated.
Initializing Token Matrix: Initialize a matrix filled with padding tokens. Populate the matrix with the prompt tokens.
Token Generation Loop: Iterate over each position in the sequence to be generated. Obtain the logits for the current position. Apply temperature scaling and samples the next token using top-p sampling or greedy sampling if temperature is 0. Update the token matrix with the generated tokens. Check for EOS tokens to determine if generation can be stopped early.
Processing Output Tokens: Converts the generated token matrix to a list of tokens. Truncates each sequence at the EOS token if present. Decodes the token sequences into text. Returns the tokens and corresponding text completions.
def _sample_top_p(self, probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # (B, vocab_size) probs_sum = torch.cumsum(probs_sort, dim=-1) # (B, vocab_size) mask = probs_sum - probs_sort > p # (B, vocab_size) probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token
First the probabilities are sorted in descending order. Then the original indices of the sorted probabilities. This is used to map back to the original token indices after sampling.
Cumulative sum of the sorted probabilities along the last dimension is calculated. A boolean mask that marks the tokens whose cumulative probability (excluding the current token) exceeds the threshold P.
After setting some probabilities to zero, the remaining probabilities are re-normalized so that they sum up to 1.
The next_token
samples one token from the re-normalized probabilities and uses the gathered indices to get the original token indices corresponding to the sampled next token.
if __name__ == '__main__': torch.manual_seed(0) allow_cuda = False device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu' prompts = [ "Simply put, the theory of relativity states that ", "If Google was an Italian company founded in Milan, it would", # Few shot promt """Translate English to French: sea otter => loutre de mer peppermint => menthe poivrée plush girafe => girafe peluche cheese =>""", # Zero shot prompt """Tell me if the following person is actually Doraemon disguised as human: Name: Umar Jamil Decision: """ ] model = LLaMA.build( checkpoints_dir='llama-2-7b/', tokenizer_path='tokenizer.model', load_model=True, max_seq_len=1024, max_batch_size=len(prompts), device=device ) out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64)) assert len(out_texts) == len(prompts) for i in range(len(out_texts)): print(f'{out_texts[i]}') print('-' * 50)
The script sets up the environment, initializes the LLaMA model, generates text completions for a set of input prompts, and prints the resulting texts. The key components involve configuring the device, defining input prompts, building the model, and handling the output generation and display.
Download the download.sh
file from my GitHub to download the weights of the LLaMA Models.
There are no datasets linked
There are no datasets linked