More Thinking Solves Harder Problems AI Can Learn From Simple Tasks to Solve Hard Problems

Reading time
2 min read
Animated charts showing how AI can learn from simple tasks to harder versions of the same task

In machine learning, an easy task and a more difficult version of the same task — say, a maze that covers a smaller or larger area — often are learned separately. A new study shows that recurrent neural networks can generalize from one to the other.

What’s new: Avi Schwarzschild and colleagues at the University of Maryland showed that, at inference, boosting recurrence to a neural network — sending the output of a portion of the network back through the same block repeatedly before allowing it to move through the rest of the network — can enable it to perform well on a harder version of a task it was trained to do.

Key insight: A network’s internal representation of input data should improve incrementally each time it passes through a recurrent block. With more passes, the network should be able to solve more difficult versions of the task at hand.

How it works: The authors added recurrence to ResNets prior to training by duplicating the first residual block and sharing its weights among all residual blocks. (As non-recurrent baselines, they used ResNets of equivalent or greater depth without shared weights.) They trained and tested separate networks on each of three tasks:

  • Mazes: The network received an image of a two-dimensional maze and generated an image that highlighted the path from start to finish. The authors trained a network with 20 residual blocks on 9x9 grids and tested it on 13x13 grids.
  • Chess: The network received an image of chess pieces on a board and generated an image that showed the origin and destination squares of the best move. The authors trained a network with 20 residual blocks on chess puzzles with standardized difficulty ratings below 1,385, then tested it on those with ratings above that number.
  • Prefix strings: The network received a binary string and generated a binary string of equal length in which each bit was the cumulative sum of the input, modulo two (for example, input 01011, output 01101). The authors trained a network with 10 residual blocks on 32-bit strings and tested it on 44-bit strings.

Results: In tests, the recurrent networks generally improved their performance on the more complex problems with each pass through the loop — up to a limit — and outperformed the corresponding nonrecurrent networks. The authors presented their results most precisely for prefix strings, in which the recurrent networks achieved 24.96 percent accuracy with 9 residual blocks, 31.02 percent with 10 residual blocks, and 35.22 percent with 11 residual blocks. The nonrecurrent networks of matching depth achieved 22.17 percent, 24.78 percent, and 22.79 percent accuracy respectively. The performance improvement was similar on mazes and chess.

Why it matters: Forcing a network to re-use blocks can enhance its performance on harder versions of a task. This work also opens an avenue for interpreting recurrent neural networks by increasing the number of passes through a given block and studying changes in the output.

We’re thinking: Many algorithms in computing use iteration to refine a representation, such as belief propagation in probabilistic graphical models. It’s exciting to find that this algorithm learns weights in a similarly iterative way, computing a better representation with each pass through the loop.


Subscribe to The Batch

Stay updated with weekly AI News and Insights delivered to your inbox