Published
Reading time
2 min read
Charts showing benchmark on medium-sized datasets

While neural networks perform well on image, text, and audio datasets, they fall behind decision trees and their variations for tabular datasets. New research looked into why.

What’s new: Léo Grinsztajn, Edouard Oyallon, and Gaël Varoquaux at France’s National Institute for Research in Digital Science and Technology and Sorbonne University trained a variety of neural networks and tree models on tabular datasets. Performance on their tabular data learning benchmark revealed dataset characteristics that favor each class of models.

Key insight: Previous work found that no single neural network architecture performed best on a variety of tabular datasets, but a tree-based approach performed better than any neural network on most of them. Training and testing different models on many permutations of the data can reveal principles to guide the choice of architecture for any given dataset.

How it works: The authors compiled datasets, trained a variety of models (using a variety of hyperparameters), and evaluated their performance. Then they applied transformations to the data, retrained the models, and tested them again to see how the transformations affected model performance.

  • The authors collected 45 tabular datasets useful for both classification problems like predicting increase/decrease in electricity prices and regression problems such as estimating housing prices. Each dataset comprised more than 3,000 real-world examples and resisted simple modeling (that is, logistic or linear regression models trained on them performed 5 percent worse than a ResNet or gradient boosting trees).
  • The authors trained tree-based models (random forests, gradient boosting machines, XGBoost, and various ensembles) and deep-learning-based models (vanilla neural network, ResNet, and two Transformer-based models). They trained each model 400 times, searching randomly through a predefined hyperparameter space. They evaluated classification performance according to test-set accuracy and regression models according to R2, which measures how well a model estimates the ground-truth data.
  • In one transformation of the data, they used a random forest model to rank the importance of a dataset’s features and trained models on various proportions of informative versus uninformative features. In another, they smoothed labels like 0 or 1 into labels like .2 or .8.

Results: Averaged across all tasks, the best tree models performed 20 percent to 30 percent better than the best deep learning models. ResNets fell even farther behind trees and transformers as the number of uninformative features rose. In another experiment, training on smoothed labels degraded the performance of trees more than that of neural networks, which suggests that tree-based methods are better at learning irregular mapping of training data to labels.

Why it matters: Deep learning isn’t the best approach to all datasets and problems. If you have tabular data, give trees a try!

We’re thinking: The authors trained their models on datasets of 10,000 or 50,000 training examples. Smaller or larger datasets may have yielded different results.

Share

Subscribe to The Batch

Stay updated with weekly AI News and Insights delivered to your inbox