Large Language Models (LLMs) represent a key category within Generative AIβan advanced class of language models capable of generating coherent text, writing code, performing abstractive summarization, translating languages, and executing a wide range of natural language processing (NLP) tasks. These models have significantly transformed and modernized the NLP landscape, emerging as intelligent assistants such as OpenAIβs ChatGPT, Googleβs Gemini, DeepSeekβs R1, and NVIDIAβs Megatron. They play a critical role in augmenting human productivity across diverse applications.
The foundation of modern LLMs lies in the Transformer architecture, first introduced in the seminal 2017 paper βAttention Is All You Needβ [1]. This architecture enabled highly parallelized training and long-range context modelling, catalyzing the rapid advancement of LLM capabilities. Since then, Transformer-based models have evolved into various architectural forms and training paradigms, leveraging vast corpora of internet-scale text data. These models are trained using a combination of pretraining and fine-tuning strategies, improving generalization across a wide array of language understanding and generation tasks.
A significant innovation arising from LLM development is the emergence of AI agentsβsystems that integrate LLMs with external tools and memory modules to autonomously perform complex tasks alongside humans [2]. At the core of many recent advancements in LLM efficiency and scalability is a specialized architecture module known as Mixture of Experts (MoE).
Although popularized by Mistral's Mixtral of Experts model in 2024 [3], the theoretical underpinnings of MoE date back to 1991 [4], where it was first proposed as a theoretical framework for modular and sparse computation in neural networks. The MoE paradigm gained renewed interest with the increasing availability of high-performance computing resources and large-scale, high-quality text datasets. In 2022, Google introduced the Switch Transformer [5], which further demonstrated the scalability benefits of sparse expert models. These architectures activate only a subset of model parameters during inferenceβdramatically improving computational efficiency while preserving model capacity.
The release of the Mixtral 8x7B [3] model reignited enthusiasm in the AI community for MoE-based architectures, showcasing how expert sparsity can be effectively utilized to scale models while maintaining or improving performance. In the following sections, we will delve deeper into the architecture and mechanisms behind Mixture of Experts, examining how it enables efficient scaling and robust performance in state-of-the-art language models.
So what is MoE,
Mixture of Experts (MoE) is an ensemble learning paradigm introduced to scale the language model for handling large-scale data and to achieve higher accuracy. The main idea of this technique is to introduce multiple specialized models to learn different subsets of the data using gating mechanism. This approach improves overall accuracy and efficiency by leveraging the strengths of specialized models.
The key difference between Vanilla Transformer and MoE based transformer is that, In Vanilla transformer, there will be a dense feedforward NN that utilizes every parameters for learning all the data. In MoE based transformer, there will be multiple feedforward NN that learns different parts of the data.
Sparse Mixture of Experts (Sparse MoE) is a type of MoE, where only a subset of the model's experts, or specialized sub-models, are active for each input. This contrasts with dense MoEs, where all experts are used for all inputs. Sparse MoE reduces computational cost and memory usage by selectively activating a smaller number of experts based on the input, making it more efficient for scaling up large models.
LaMoE is a scratch implementation of a sparse MoE-based decoder model tailored for text generation tasks. To gain a deeper understanding of the operational dynamics of Mixture-of-Experts (MoE) architectures, I integrated MoE modules directly into the LLaMA-2 [6] model by replacing its standard feed-forward network (FFN) components. This approach allowed me to study the behavior and performance of MoE within the context of a state-of-the-art transformer model, rather than treating it as an isolated component."
As a part of this project, I have used the open datasets used for training LaMoE.
LaMoE is built with a robust set of tools and libraries:
Tokens are primary and basic units for large language models, that converts text into tokens allows transformer models to understand and process language effectively.
Subword-based tokenization is a bridge between word and character-based tokenization. The main idea is to solve the issues faced by word-based tokenization (very large vocabulary size, large number of OOV tokens, and different meaning of very similar words) and character-based tokenization (very long sequences and less meaningful individual tokens).
Byte-pair encoding is one of the popular among the sub-word based tokenization. It was first introduced in 1994 [7] as a simple data compression technique by iteratively replacing the most frequent pair of bytes in a sequence with a single, unused byte. It has been adapted as tokenization algorithm and used in most of language models like BERT and GPT.
Below snippet shows the custom implementation of BPE.
class BPE: def create_tokenizer(self, Text: list, vocab_size: int = 30_000, max_iterations: int = 100, tokenizer_file_name: Optional[str] = "Tokenizer", pad_token: Optional[str] = "<pad>", special_tokens: Optional[list] = []) -> None: r""" Function to create byte-pair encoding tokenizer in json. Parameters: Text (list): List of texts vocab_size (int): Maximum size of vocabulary (default: 30000) max_iterations (int): Maximum number of iterations to refine the vocabulary (default: 100) tokenizer_file_name (str): Name of the json to be saved (default: "Tokenizer") pad_token (str): Pad token (default: "\<pad\>") special_tokens (list): List of special tokens to be added to the vocabulary (default: []) """ word_freq, vocab_count_dic = create_word_freq_dict(Text) splits = {word: word.split() for word in word_freq.keys()} # 1 pair_freqs = compute_pair_freqs(word_freq, splits) merges = {} i = 0 while i < max_iterations: # For max iterations pair_freqs = compute_pair_freqs(word_freq, splits) # 2 if not pair_freqs: break best_pair = max(pair_freqs, key = pair_freqs.get) # 3 key = best_pair[0] + ":" + best_pair[1] merges[key] = best_pair[0] + best_pair[1] # 4 # 5 max_freq = pair_freqs[best_pair] vocab_count_dic[best_pair[0] + best_pair[1]] = max_freq vocab_count_dic[best_pair[0]] -= max_freq vocab_count_dic[best_pair[1]] -= max_freq if vocab_count_dic[best_pair[0]] == 0: vocab_count_dic.pop(best_pair[0]) if vocab_count_dic[best_pair[1]] == 0: vocab_count_dic.pop(best_pair[1]) splits = merge_pair(*best_pair, word_freq, splits) if len(vocab_count_dic) == vocab_size: break i += 1 vocab = list(vocab_count_dic.keys()) # Creating final vocabulary with <unk>, special and pad tokens. vocab.append("<unk>") vocab = vocab + special_tokens vocab.sort() vocab = [pad_token] + vocab vocab_dict = {key: i for i, key in enumerate(vocab)} Tokenizer = {"vocab_dict": vocab_dict, "merges": merges, "max_iter": max_iterations} with open(os.path.join('Saved', f'{tokenizer_file_name}.json'), 'w') as f: json.dump(Tokenizer, f) print(f"{os.path.join('Saved', f'{tokenizer_file_name}.json')} is successfully created.") del vocab del vocab_count_dic return Tokenizer
This repository shows the brief implementation of BPE.
This figure shows the main architecture of official LLaMA-2 model developed by Meta.
![]() |
---|
Architecture of LLaMA - credits(Umar Jamil) |
Rotary positional Encodings[RPoE] is a relative positional encoding applied between two tokens, which indicates the intensity of relationship between them, in terms of Distance parameter.
RPoE are only applied to the Query and the Keys, but not the Values. It is applied after the vector q and k are multiplied with respective W matrices in the attention mechanism.
Multi-Head Self-Attention which employs Grouped Multi Query Attention that provides the good compromise between Quality and Speed. The main objective of Grouped Multi Query Attention is to minimize the memory access/transfer in the GPU.
For Inference, attention mechanism uses KV-Cache technique. At every step of the inference, we are only interested in the last token output by the model, because we already have previous tokens. However, the model needs access to all the previous tokens to decide on which token to output, since it constitute its context. This KV cache is a solution to make the model do less computation on the token it has already seen during inference.
Following is the snippet of implementation of MHA
q = self.Wq(x) # (batch_size, seq_len, head_dim * n_heads_q) k = self.Wk(x) # (batch_size, seq_len, head_dim * n_heads_kv) v = self.Wv(x) # (batch_size, seq_len, head_dim * n_heads_kv) q = q.view(batch_size, seq_len, self.n_heads_q, self.head_dim) # (batch_size, seq_len, n_heads_q, head_dim) k = k.view(batch_size, seq_len, self.n_heads_kv, self.head_dim) # (batch_size, seq_len, n_heads_kv, head_dim) v = v.view(batch_size, seq_len, self.n_heads_kv, self.head_dim) # (batch_size, seq_len, n_heads_kv, head_dim) q = apply_rotary_embeddings(q, freqs_complex, x.device) # (batch_size, seq_len, n_heads_q, head_dim) k = apply_rotary_embeddings(k, freqs_complex, x.device) # (batch_size, seq_len, n_heads_kv, head_dim) if self.cache: assert start_pos is not None, "Start position is not given. Give the start position." self.cache_k[: batch_size, start_pos : start_pos + seq_len] = k self.cache_v[: batch_size, start_pos : start_pos + seq_len] = v keys = self.cache_k[: batch_size, : start_pos + seq_len] # (batch_size, seq_len, n_heads_kv, head_dim) values = self.cache_v[: batch_size, : start_pos + seq_len] # (batch_size, seq_len, n_heads_kv, head_dim) else: keys, values = k, v keys = repeat_kv(keys, self.repeat, 2) # (batch_size, seq_len, n_heads, head_dim) values = repeat_kv(values, self.repeat, 2) # (batch_size, seq_len, n_heads, head_dim) """ Actual calculation: query (batch_size, seq_len, n_heads, head_dim), keys (batch_size, kv_seq_len, n_heads, head_dim) || (Transpose) query (batch_size, n_heads, seq_len, head_dim), keys (batch_size, n_heads, kv_seq_len, head_dim) || query (batch_size, n_heads, seq_len, head_dim), keys (batch_size, n_heads, head_dim, kv_seq_len) || (Matrix Multiplication, Softmax) Attention scores (batch_size, n_heads, seq_len, kv_seq_len) """ attn_scores = torch.einsum("bshd, bthd -> bhst", q, keys) * self.softmax_scale # (batch_size, n_heads, seq_len, kv_seq_len) attn_scores = attn_scores.masked_fill(self.mask[:, :, : seq_len, : seq_len] == 0, float("-inf")) attn_scores = F.softmax(attn_scores, dim = -1) """ Actual calculation: values (batch_size, kv_seq_len, n_heads, head_dim) || (Transpose) values (batch_size, n_heads, kv_seq_len, head_dim) || attn_scores (batch_size, n_heads, seq_len, kv_seq_len), values (batch_size, n_heads, kv_seq_len, head_dim) || (Matrix Multiplication) Attention values (batch_size, n_heads, seq_len, head_dim) || (Transpose, reshape) Attention values (batch_size, seq_len, n_heads * head_dim) """ out = torch.einsum("bhst, bthd -> bshd", attn_scores, values).contiguous().view(batch_size, seq_len, self.n_heads_q * self.head_dim) # (batch_size, seq_len, n_heads * head_dim) out = self.Wo(out) # (batch_size, seq_len, dim)
Root Mean Square Normalization focuses on re-scaling invariance and regularizes the summed inputs simply according to Root Mean Square.
eps = eps weight = nn.Parameter(torch.ones(dim)) norm = (x / (torch.sqrt(x.pow(2).mean(-1, keepdim = True)) + self.eps)).type_as(x) out = weight * norm
This represents a significant modification to the LLaMA architecture, wherein the traditional fully connected feedforward neural network has been replaced or altered.
An Expert is a FeedForward Neural Network with SwiGLU activation.
class Expert(nn.Module): def __init__(self, args: dataclass): super().__init__() r""" An Expert is a FeedForward Neural Network with SwiGLU activation. Args: args (dataclass): Model arguments. Returns: out (torch.Tensor): Output of expert. """ self.w1 = nn.Linear(args.dim, args.ffn_hidden_dim, bias = False) self.w2 = nn.Linear(args.ffn_hidden_dim, args.dim, bias = False) self.w3 = nn.Linear(args.dim, args.ffn_hidden_dim, bias = False) def forward(self, x: torch.Tensor) -> torch.Tensor: # (batch_size, seq_len, dim) -> (batch_size, seq_len, hidden_dim) x_W = self.w1(x) # (batch_size, seq_len, dim) -> (batch_size, seq_len, hidden_dim) x_V = self.w3(x) # (batch_size, seq_len, hidden_dim) * (batch_size, seq_len, hidden_dim) -> (batch_size, seq_len, hidden_dim) out = F.silu(x_W) * x_V # (batch_size, seq_len, hidden_dim) -> (batch_size, seq_len, dim) out = self.w2(out) return out
The router (or gating network) is itself a feedforward neural network, responsible for selecting the appropriate expert based on a given input. It produces a probability distribution over the available experts, which is then used to determine the most suitable expert for the input.
To introduce load balancing in the router, trainable (gaussian) noise is used in the router that creates sparse gate to route the tokens to the experts, which prevent the same experts from always being picked.
class NoisyTopkRouter(nn.Module): def __init__(self, args: dataclass) -> None: super(NoisyTopkRouter, self).__init__() r""" A noisy router that creates sparse gate to route the tokens to the experts. Args: args (dataclass): Model arguments Returns: out (tuple): Output of expert """ self.topk = args.k self.gate = nn.Linear(args.dim, args.num_experts) self.noisy_gate = nn.Linear(args.dim, args.num_experts) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: gate_logits = self.gate(x) # (batch_size, seq_len, num_experts) noisy_logits = self.noisy_gate(x) # (batch_size, seq_len, num_experts) noise = torch.randn_like(noisy_logits) * F.softplus(noisy_logits) noisy_logits = noise + gate_logits topklogits, topexperts = noisy_logits.topk(self.topk, dim = -1) # (batch_size, seq_len, topk) zeros = torch.full_like(noisy_logits, float('-inf')) sparse_logits = zeros.scatter(-1, topexperts, topklogits) router_output = F.softmax(sparse_logits, dim = -1) # (batch_size, seq_len, num_experts) return router_output, topexperts
This block includes a noisy gating mechanism and a set of expert networks, wherein input tokens are dynamically routed to the top-k experts based on the routerβs output.
It also includes Auxiliary Loss (also called load balancing loss), which is another method introducing load balancing, was added to the networkβs regular loss. It adds a constraint that forces experts to have equal importance. The first component of this auxiliary loss involves summing the router outputs for each expert across the entire batch. This yields an importance score for each expert, reflecting the overall likelihood of that expert being selected independent of specific input instances. These importance scores are then used to compute the coefficient of variation (CV), which quantifies the dispersion or imbalance in expert utilization.
class MoE(nn.Module): def __init__(self, args: dataclass) -> None: super(MoE, self).__init__() self.topk = args.k self.router = NoisyTopkRouter(args) self.experts = nn.ModuleList([Expert(args) for _ in range(args.num_experts)]) self.aux_loss = args.aux_loss self.aux_loss_coeff = args.aux_loss_coeff def forward(self, x: torch.Tensor) -> torch.Tensor: gate_scores, top_experts = self.router(x) out = torch.zeros_like(x) # Reshape inputs for batch processing flat_x = x.view(-1, x.size(-1)) flat_gate_score = gate_scores.view(-1, gate_scores.size(-1)) # Process each expert in parallel for i, expert in enumerate(self.experts): # Create a mask for the inputs where the current expert is in top-k expert_mask = (top_experts == i).any(dim = -1) flat_expert_mask = expert_mask.view(-1) if flat_expert_mask.any(): expert_input = flat_x[flat_expert_mask] expert_output = expert(expert_input) # Extract and apply gating scores gating_scores = flat_gate_score[flat_expert_mask, i].unsqueeze(1) weighted_output = expert_output * gating_scores # Update final output additively by indexing and adding out[expert_mask] += weighted_output.squeeze(1) # Auxiliary Loss if self.aux_loss: imp = gate_scores.sum(1) cv = imp.var() / (imp.mean() ** 2) cv *= self.aux_loss_coeff return out, cv else: return out, None
This constitutes the main decoder block, which integrates Multi-Head Attention (MHA) mechanisms alongside Mixture of Experts (MoE) layers.
class Block(nn.Module): def __init__(self, args: ModelArgs) -> None: super(Block, self).__init__() r""" Individual Decoder layer. Args: args (ModelArgs): Model arguments Returns: out (tuple): Output tensor from layer, auxloss of MoE """ self.attention = MHA(args) self.moe = MoE(args) self.attention_norm = RMSNorm(args.dim, args.norm_eps) self.ffn_norm = RMSNorm(args.dim, args.norm_eps) def forward(self, x: torch.Tensor, freqs_complex: torch.Tensor, start_pos: Optional[int] = None) -> torch.Tensor: residue = x out = self.attention_norm(x) # RMS normalization before activation out = residue = self.attention(out, freqs_complex, start_pos)[0] + residue # MHA and Addition out = self.ffn_norm(out) # RMS normalization before feedforward neural network out, loss = self.moe(out) # Feedforward neural network (MoE) out = out + residue # Addition return out, loss
This illustrates the full model stack of LaMoE, encompassing all constituent layers and components.
class Transformer(nn.Module): def __init__(self, args: ModelArgs) -> None: super(Transformer, self).__init__() r""" Sparse-MoE based Decoder. Args: args (ModelArgs): Model arguments Returns: out (tuple): Output tensor from decoder, total auxloss of MoE """ assert args.vocab_size != -1, "Vocab size must be set." self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = nn.ModuleList() for _ in range(args.n_layers): self.layers.append(Block(args)) self.norm = RMSNorm(args.dim, args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias = False) self.head_dim = args.dim // args.n_heads self.freqs_complex = precompute_theta_pos_frequencies(self.head_dim, args.max_seq_length * 2, device = args.device) self.aux = args.aux_loss self.inference = args.inference def forward(self, x: torch.Tensor, start_pos: Optional[int] = None) -> torch.Tensor: batch_size, seq_len = x.shape out = self.tok_embeddings(x) aux_loss = [] if self.inference: # During Inference assert start_pos is not None, "Start position is not given. Give the start position." assert seq_len == 1, "Only one token will be processed during inference." freqs_complex = self.freqs_complex[start_pos : start_pos + seq_len] else: # During Training seq_lens = [x[i].size(0) for i in range(batch_size)] seq_lens_metadata = [SimpleInputMetadata.from_seqlens(seq_lens, x.device) for i in range(len(self.layers))] positions = seq_lens_metadata[0].positions freqs_complex = self.freqs_complex[positions] for layer in self.layers: out, loss = layer(out, freqs_complex, start_pos) if self.aux and loss is not None: aux_loss.append(loss) out = self.norm(out) out = self.output(out).float() if self.aux: # During training total_aux_loss = sum(aux_loss) / len(aux_loss) return out, total_aux_loss else: # During inference return out, None
The project's GitHub repository provides clear instructions for setting up and experimenting with LaMoE:
Clone this repository.
git clone https://github.com/harishhirthi/LaMoE.git
cd LaMoE
Create conda environment using environment.yml and activate this environment.
conda env create -f environment.yml
To train lamoe using script.
cd scripts
python train_eval.py
For inferencing using script.
cd scripts
python inference.py
Ensure that you have the necessary datasets in the data/
directory or modify the script to point to your dataset.
These are the sample screenshots of training and inference using scripts.
Config for the model used for this experiment.
class ModelArgs: dim: int = 512 ffn_hidden_dim: int = 4 * dim n_layers: int = 4 n_heads: int = 8 n_kv_heads: Optional[int] = 4 vocab_size: int = -1 norm_eps: float = 1e-5 num_experts: int = 8 k: int = 2 eos: str = "<eos>" pad: str = "<pad>" unk: str = "<unk>" aux_loss: Optional[bool] = True aux_loss_coeff: Optional[float] = 1e-2 inference: Optional[bool] = False cache: bool = field(init = False) max_batch_size: int = 32 max_seq_length: int = 300 device: str = get_default_device() def __post_init__(self): self.cache = True if self.inference else False
Training:
Training using CLI |
Inference:
Inference using CLI |
Loss:
Visualization of loss in MLflow |
User: Science and Technology are
Generating Text , ...
Model:
Science and Technology are also used in the United States customary landowners and the Netherlands as well as the number of new classes dropped markedly with only a year with just two screens including Germany and Ireland in their first half of the year for the first time in the UK for the first time since the summer of 2003 the World Cup is held in
User: Once, upon a time
Generating Text , ...
Model:
Once upon a time scale it is a great number of times it is a positive correlation between the two points at which the observer is the time derivative of the electric field E is related to the field strength at the point where the permittivity is measured in radians per meter.
User: Physics
Generating Text , ....
Model:
Physics for example the second is the 367th greatest single quarter of the same period of 629 seconds long in 1572 seconds.
User: exit
Note: The above sample output is a result of initial training of the model for 1000 iterations, without any additional strategies like early stopping, learning rate scheduler etc. Also, this training setup begins to overfit from 300 iterations which has to be taken care. The important motive of this implementation is to get the glimpse of processing large corpus of text and using them to train the large language model that are scaled using MoE architecture.
You can explore the code, and experiment with the model on the official GitHub repository:
LaMoE GitHub Repository
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is All You Need. Advances in Neural Information Processing Systems, 30. https://arxiv.org/abs/1706.03762
Yao, S., Zhao, J., Yu, D., Du, N., Shafran, I., Narasimhan, K., & Cao, Y. (2023). ReAct: Synergizing reasoning and acting in language models. arXiv. https://arxiv.org/abs/2210.03629
Mistral AI. (2024). Mixtral of Experts: A Sparse Mixture of Experts Model. https://mistral.ai/news/mixtral-of-experts/
Jacobs, R. A., Jordan, M. I., Nowlan, S. J., & Hinton, G. E. (1991). Adaptive Mixtures of Local Experts. Neural Computation, 3(1), 79β87. https://doi.org/10.1162/neco.1991.3.1.79
Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Transactions on Machine Learning Research. https://arxiv.org/abs/2101.03961
Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M.-A., Lacroix, T., ... & Jegou, H. (2023). LLaMA: Open and Efficient Foundation Language Models. https://arxiv.org/abs/2302.13971
Gage, P. (1994). A new algorithm for data compression. C Users Journal, 12(2), 23β38. https://dl.acm.org/doi/10.5555/177910.177914