Fig 1.
(A) Illustration of the general connectivity task and its global nature. (B) Example of an input image and correct labelling for that image in the edge-connected pixel task. (C) Implementation of the single-hidden layer architecture for detecting whether the center pixel is connected to the edge. Each hidden neuron checks whether a certain pattern connecting the pixel to the edge is in the image. (D) Tag propagation solution in a recurrent network for detecting whether the center pixel is connected to the edge. We start with any on pixels connected to the edge tagged as connected. At subsequent time steps, the tag is passed to any neighboring pixel which is also on. (E) Schematic of the setup for the analytical solution of implementation of the tag propagation algorithm in neural network weights. The same setup is repeated at all pixels in the image. (F) Output of the analytical tag-propagation network at progressive time steps for several example images.
Fig 2.
Feedforward network performance on the edge-connected pixel task.
(A) Training procedure for neural networks. Weights are learned iteratively. At each iteration the weights are updated using the gradient of the loss function calculated via backpropagation. The network illustrated here is a deep feedforward network. The image is given as input to the first layer and the activation of units at the last layer is the network’s output. (B) Schematic of the two-part procedure used to obtain good solutions despite the stochasticity of learning and the large parameter spaces. First, hyperoptimization is performed over the training parameters for each model and layer combination (top). Each line corresponds to a network being trained, colored differently to indicate the different hyperparameters used. The worst performing models are eliminated at regular intervals (see Methods for full details). Second (bottom), 50 models are trained from random initialization, all using the optimal parameters found in the hyperoptimization procedure. (C) Performance of the best trained solution for the deep feedforward network across layers vs. the recurrent tag propagation solution. X-axis corresponds to number of layers in the network for the feedforward solution, and number of time-steps the network is allowed to run for the recurrent solutions, which have only one layer of recurrently connected neurons. (D) Illustration of the splitting of pixels, and associated errors, into three groups: path, distractor, and off. (E) Breakdown of errors for two naïve solutions: all on, outputting all the on-pixels as connected, and all off, outputting no pixels as connected. (F-G) Decomposition of error by pixel type for each model. (F) Tag propagation implemented via recurrent network. X-axis corresponds to the number of time-steps the network is allowed to run. Blue dots and line correspond to errors on path pixels, red dots and line correspond to errors on distractor pixels. Inset shows same data with larger axis range. (G). X-axis corresponds to number of layers in the network. Blue dots and line correspond to errors on path pixels, red dots and line correspond to errors on distractor pixels. Dashed line corresponds to tag propagation error on path pixels for reference. (H) Schematic of input augmented architectures where the full image is added as input at each layer, allowing the tag propagation solution to be part of parameter space. (I) Decomposition of error by pixel type for the input augmented network. Same plotting convention as G.
Fig 3.
Masked recurrent networks performance on the edge-connected pixel task.
(A) Schematic of the addition of constraints on the feedforward parameter space generating an increasingly restrictive parameter space that still contains the efficient tag-propagation solution. From left to right: weight sharing across the layers creates an unrolled recurrent network, masking of shared weights (i.e., enforcing locality in the operations) to a grid around each pixel creating unrolled recurrent networks with sparse weights (middle), masking to just the nearest neighbor of each pixel (right). (B) Decomposition of error by pixel type for each model. X-axis corresponds to number of layers in the network. Blue dots and line correspond to errors on path pixels, red dots and line correspond to errors on distractor pixels. Dashed line corresponds to tag propagation error on path pixels for reference. Dotted blue and red lines indicate input augmented architecture error on path (blue) and distractor pixels (red). Each subpanel corresponds to a different network architecture. From left to right: unrolled recurrent, square mask, neighbors only mask.
Fig 4.
Performance comparison of uniform and hybrid networks on the edge-connected pixel task.
(A) Schematic of a general hybrid network combining a recurrent and feedforward model. For any time-step t, the switching network combines the current output of the recurrent network yr and the output of the feedforward network yf to form the hybrid output yh. Note that the feedforward output can only be used by the switching network when t is greater than or equal to the number of layers in the model. (B) Performance of the best solution after hyperoptimization type plotted against the number of neurons to motivate why combining these two architectures into a hybrid networks can provide better performance across a range of computational times. The feedforward architecture (input augmented) provides a superior neuron/performance tradeoff at low computational time while the recurrent architecture (masked neighbors) provides a superior neuron/performance tradeoff at high computational time. Colors correspond to different network architectures. Small circles correspond to recurrent networks run for five timesteps and feedforward networks with five layers. Large circles correspond to architectures that use 25 timesteps or layers. Solid colored lines connect models of the same architecture with different layers or time steps. Dotted lines with arrows highlight the masked and feedforward architectures with equivalent layers or timesteps for ease of comparison. (C) Performance of hybrid architecture designed to be able to switch on-the-fly based on how many time steps are available to output a labelling. Here we consider the simplest switching network which can choose to either to output yr or yh at a given time step. The model combines the tag propagation solution and trained input augmented networks; the budget for each curve is the number of neurons the full model requires. The blue curve is the error profile for tag propagation alone while the orange curve shows the result of combining tag propagation with the 2-layer feedforward (abbreviated FF) model. The green curve shows tag propagation combined with the 10-layer feedforward network and the purple curve shows the combination with both the 2-layer and 10-layer networks. Note that the figure only shows path error; the feedforward solutions will also have some error on the distractor pixels. (D) Three distributions over computation time: the blue distribution is evenly split over short, medium, and long computation times (2, 5, 10, and 30 steps) and the orange distribution is evenly split over medium and long computation times (10 and 30 steps). (E) Performance of hybrid networks on three distributions of allowed computation time illustrated in panel D, meaning the fraction of runs the network is limited to a certain number of steps. The x-axis corresponds to the number of neurons allowed when constructing the hybrid solution. The first point on all curves allows only a fully recurrent architecture and thus is not hybrid. This is indicated by a square marker. The rest of the x-axis corresponds to a neuron budget allowing hybrid solutions with increased feed-forward composition. Additional hybrid models were only included in the plots if they improved the performance over models with fewer neurons (see Methods). (F) Per-pixel switching hybrid network learned from observing the output of the four networks (tag propagation; input augmented with 2, 5, and 10 layers) on 10,000 test samples. The heatmap shows the percent of times a given network correctly classified a pixel.
Fig 5.
(A) Randomly generated examples of the competitive foraging task. (B) The animal’s and competitors’ propagation networks. Each one implements the tag-propagation algorithm with open space pixel corresponding to on pixels and barriers corresponding to off pixels. Unlike the edge-connected pixel task, the source pixels change in every sample to correspond to the location of the animal and its competitors respectively. See Methods for further implementation details. (C) Architecture for the trained decision network. The time series of animal and competitor ranges are concatenated with the food pixel and then are used as the input to a recurrent network. The decision traces shown in panel E are the projection of the trained network onto the two-dimensional readout. (D) Sample outputs for the trained propagation networks. These are the best networks trained with 7, 10, and 12 layers respectively. Errors in the trained network are marked in red. We show results only for the competitors’ network as the propagation task is the same for both the animal and its competitors; only the input which corresponds to the initial location changes. (E) The generalized tag propagation implements a correct version of the decision trace and is described in the Methods. For comparison, we show the decision trace outputted by the trained decision network. Propagation is shown after ten time steps, and each example is labelled with the correct decision after ten steps: “stay,” “run,” or “out of range.” Note that if the food location is in the range of both groups of animals, the decision is based on which animal can reach the food first.
Table 1.
We trained four U-Net architectures of increasing size by each time doubling the number of channels in each layer. For each model, we report the best performance across 10 trained models on 10,240 test samples.
Table 2.
U-Net architecture used for the experiments in Table 1. D specifies the base level of channels; we considered experiments with D = 2, 4, 8, 16.
Table 3.
Hybrid models considered in Fig 4C. Models were only included in the plot if adding neurons decreased the loss. The models were assumed to be given the allowed computational time at initialization, enabling them to switch optimally.