Memory Layers for More-Factual Output Meta researchers build Llama-style models that recall details without needing more computing resources

Published
Reading time
3 min read
Dual line graphs showing factual QA accuracy and NLL against memory size for NQ and TQA datasets in AI models.
Loading the Elevenlabs Text to Speech AudioNative Player...

Improving a large language model’s factual accuracy typically requires making it bigger, which in turn, involves more computation. Researchers devised an architecture that enables models to recall relevant details without significantly increasing the amount of computation required.

What’s new: Vincent-Pierre Berges, Barlas Oğuz, and colleagues at Meta augmented transformers with trainable memory layers that efficiently store and retrieve information related to a prompt. The training code is available under a CC BY-NC license, which permits noncommercial uses.

Memory layer basics: Memory layers were introduced in 2015 and were applied to transformers a few years later. They compute vectors, which may capture details like names or dates that were learned through training, and retrieve them according to a given input. Computing the output of a memory layer is similar to computing that of a self-attention layer. Both describe vectors that represent queries, keys, and values, and both compute the similarity between queries and keys and then weight the values by that similarity. However, while a self-attention layer computes queries, keys, and values from linear transformations of the input, a memory layer (which computes queries the same way) learns keys and a corresponding value for each key through training.

Key insight: Memory layers can be scaled to millions of keys, but computing the similarity between a query and so many keys is computationally expensive. One solution is to represent each key as a combination of two half-keys drawn from two much smaller sets. For example, two sets of 1,000 half-keys each can represent 1 million possible keys. Comparing a query to these smaller sets is much more efficient, making it practical to scale up memory layers dramatically.

How it works: The authors pretrained Llama-style models of several sizes (from 134 million to 8 billion parameters) on data similar to Llama 2’s and Llama 3’s pretraining datasets. They replaced the fully connected layers with memory layers in three transformer blocks. These layers shared parameters and held up to 16 million values (an extra 64 billion parameters total). The memory layers performed these steps:

  • Given a query (a prompt that has been embedded by preceding transformer layers), split it into two vectors half the size.
  • Compute similarity scores between each half-query to and each half-key drawn from two sets of half keys. Identify the k highest-scoring half-keys. 
  • Concatenate the highest-scoring half keys to produce k2 full keys. 
  • Sum the similarity scores of the two half keys that make up each full key. Choose the k highest-scoring full keys.
  • Compute the index of each full key based on the indices of the corresponding half-keys.
  • Retrieve the values that correspond to the full keys.
  • Output the summed values weighted by the similarity scores.

Results: The authors compared a model (8 billion parameters) with memory layers to a similar model without memory layers, both trained on 1 trillion tokens.

  • They used nine question-answering datasets for evaluation. The model with memory layers achieved higher performance on seven of them. For example, on MMLU, the memory model achieved 63.04 percent accuracy, while the unmodified transformer achieved 59.68 percent accuracy. 
  • In general, the memory model performed worse than Llama 3.1 8B trained on 15 trillion tokens. For example, Llama 3.1 8B achieved 66 percent accuracy on MMLU.

Why it matters: Memory layers didn’t catch on in the early days of large language models (LLMs), but they can improve the output of today’s much bigger models. LLMs outfitted with memory layers require less data and computation for pretraining than conventional models to achieve the same result, at least with respect to answering factual questions.

We’re thinking: While retrieval-augmented generation can help LLMs deliver more-factual output by retrieving facts from a database, the authors add trainable parameters for this purpose.

Share

Subscribe to The Batch

Stay updated with weekly AI News and Insights delivered to your inbox