Better Than Trees for Tabular Data Transformers can outperform decision trees at predicting unlabeled spreadsheet cells

Published
Reading time
3 min read
TabPFN neural network diagram showing synthetic training, prediction on real-world tabular data, and attention layers.
Loading the Elevenlabs Text to Speech AudioNative Player...

If you have a collection of variables that represent, say, a cancer patient and you want to classify the patient’s illness as likely cancer or not, algorithms based on decision trees, such as gradient-boosted trees, typically perform better than neural networks. A transformer tailored to tabular data could change this situation.

What’s new: Noah Hollmann, Samuel Müller, and colleagues at University of Freiburg, Berlin Institute of Health, Prior Labs, and ELLIS Institute introduced Tabular Prior-data Fitted Network (TabPFN), a transformer that, given a tabular dataset, beats established decision-tree methods on classification and regression tasks. You can download the code and weights under a license based on Apache 2.0 that allows noncommercial and commercial uses.

Key insight: In a typical supervised learning process, a model given one example at a time learns to recognize patterns in a dataset. If each example is an entire dataset, it learns to recognize patterns across all those datasets. Trained in this way on enough datasets, it can generalize to new ones. Applying this idea to tabular data, a transformer — unlike a decision tree — can learn to perform classification and regression on any dataset without further training; that is, without further updating the model weights.

How it works: The authors generated 100 million datasets and used them to pretrain two small transformers (around 7 million and 11 million parameters respectively) to perform classification or regression. Given a dataset of rows (say, patient data labeled diagnoses or real-estate data labeled with prices) and one final row that’s unlabeled, the models learned to generate the missing label or value. Each dataset consisted of up to 2,048 rows (examples) and up to 160 columns (features).

  • To generate a dataset, the authors sampled hyperparameters, such as the number of rows and columns, and produced a graph in which each node is a potential column, and each edge describes how one column is related to another mathematically. They sampled the mathematical relationships randomly; for example, one column might be the sum of a second column with the sine of a third. They selected a subset of nodes at random, creating columns, and propagated random noise through them to fill the columns with values. To simulate real-world imperfections, they removed some values and added noise at random.
  • The authors modified the transformer’s attention mechanism. Where a typical transformer block contains an attention layer and a fully connected layer, the authors included a feature attention layer (in which each cell attended to other cells in its column), an example attention layer (in which each cell attended to other cells in its row), and a fully connected layer.
  • The authors trained the model to estimate the missing label in each synthetic dataset. At inference, given a dataset (with labels) and an unlabeled example, the model predicted the label.

Results: The authors tested the system on 29 classification datasets and 28 regression datasets from the AutoML benchmark and OpenML-CTR23. Each dataset contained up to 10,000 rows, 500 columns, and 10 classes. They compared TabPFN to the popular gradient-boosted tree approaches CatBoost, LightGBM, and XGBoost.

  • To evaluate classification, the authors measured area under the curve (AUC, higher is better) and normalized the scores across the datasets to range from 0 (worst) to 1 (best). TabPFN performed best across the datasets tested, achieving an average 0.939 normalized AUC, while the best contender, CatBoost, achieved an average 0.752 normalized AUC.
  • To evaluate regression, the authors measured root mean squared error (RMSE). They normalized the resulting scores to range from 0 (worst) to 1 (best). TabPFN achieved 0.923 normalized RMSE, while the next-best method, Catboost, achieved 0.872 normalized RMSE.

Yes, but: The authors’ method is slower than decision tree methods with respect to inference. To process a 10,000-row dataset, TabPFN required 0.2 seconds while CatBoost took 0.0002 seconds.

Why it matters: Transformers trained on large datasets of text or images can perform tasks they weren’t specifically trained for and generalize to novel datasets when performing tasks they were trained for. But when it comes to tabular data, they haven’t been competitive with decision trees. This work bridges the gap, unlocking a wide variety of new use cases for transformers. Not only does it process tabular data as well as popular tree-based methods, it doesn’t require additional training to process novel datasets.

We’re thinking: Decision trees date back to Aristotle and remain extremely useful. But a transformer-based approach could open the processing of tabular data to benefit from the ongoing innovation in transformers.

Share

Subscribe to The Batch

Stay updated with weekly AI News and Insights delivered to your inbox