Skip to main content
Advertisement

< Back to Article

Fig 1.

Bottom-up and top-down perception.

One classical view of perception is as a primarily bottom-up process, where sensory data x is transformed into perceptual representations z through a cascade of feedforward feature detectors. In contrast, predictive coding suggests that the brain solves perception by modelling how perceptual representations z generate sensory data x, which is a fundamentally top-down process. Although there is bottom-up information flow [in PC] in terms of errors, this bottom-up information conveys errors (transformed by the synaptic weights). However, in our hybrid predictive coding models, bottom-up information conveys both predictions and errors (along with top-down information). In HPC, sensory data x predicts perceptual representations at fast, amortized time scales, and perceptual representations z predict sensory data x at slow, iterative time scales. Our “fast and slow” model casts this integration of bottom-up and top-down signals in a common framework, allowing derivation of a testable process theory.

More »

Fig 1 Expand

Fig 2.

Hybrid predictive coding combines two phases of inference as follows.

(A) At stimulus onset, data x is propagated up the hierarchy in a feedforward manner, utilising the amortised functions fϕ(⋅). These predictions set the initial conditions for μ, which parameterise posterior beliefs about the sensory data. These predictions are associated with error units that track the difference between variables at one level and the variables at the level above under transforms fϕ(⋅). These errors are not utilised for inference but are used to update the amortised parameters ϕ during learning (weight updates). (B) The initial values for μ are then used to predict the activity at the layer below, transformed by the generative functions fθ(⋅). These predictions incur prediction errors ε, which are then used to update beliefs μ. This process is repeated N times, after which perceptual inference is complete.

More »

Fig 2 Expand

Fig 3.

Simultaneous classification and generation.

(A) Classification accuracy on the MNIST dataset for hybrid predictive coding, standard predictive coding and amortised inference. Each line is the average classification accuracy across three seeds; the shaded area corresponds to the standard deviation. The x-axis denotes the number of batches. (B) Generative loss. The panel shows the averaged mean-squared error between the lowest level of the hierarchy (which is fixed to the sensory data during testing) and the top-down predictions from the superordinate layer, plotted against batches, for HPC and standard PC. This metric provides a measure of how well each model is able to generate data. The seeds used are the same as those used in panel (A) (i.e., the data is from the same run). (C) Illustrative samples taken from HPC at the end of learning. These images are generated by activating a single nodes in the highest layer (corresponding to a single digit), and performing top-down predictions in a layer-wise fashion. The images correspond to the predicted nodes at the lowest layer. (D) As in (C) but for standard predictive coding.

More »

Fig 3 Expand

Fig 4.

Fast inference.

(A) Classification accuracy of the hybrid predictive coding model and the bottom-up, amortised predictions as a function of number of batches. The asymptotic convergence demonstrates that placing an uncertainty-aware threshold on the number of iterations has no influence on (asymptotic) model performance. Plotted are average accuracies over 5 seeds and shaded regions are the standard deviation. (B) Average number of iterations (for iterative inference) as a function of test batch. Amortised predictions provide increasingly accurate estimates of model variables, reducing the need for costly iterative inference.

More »

Fig 4 Expand

Fig 5.

Classification accuracy under fixed iterations.

(A) 10 iterations. The accuracy of HPC and the amortised predictions is mostly unaffected by the reduced number of iterations, whereas standard predictive coding fails to classify at all. (B) 25 iterations. The classification accuracy of standard predictive coding slowly decreases over batches, illustrating a common pathology observed in these simulations. (C) 50 iterations. Standard predictive coding approximately matches the performance of hybrid predictive coding, but begins to decline later in training. (D) 100 iterations. There are no significant differences between the accuracies of hybrid and standard predictive coding. Together, these results demonstrate that hybrid predictive coding enables effective inference and maintains higher performance with a substantially fewer amount of inference iterations required than standard predictive coding. Plotted are mean accuracies over 5 random network initializations. Shaded areas are the standard deviation.

More »

Fig 5 Expand

Fig 6.

Accuracy as a function of dataset size.

(A) 100 examples. The accuracy of hybrid predictive coding is lower than with the full dataset, but still high given the relatively small amount of data the network has been exposed to (0.17 percent). The accuracy of the amortised predictions is significantly worse (B) 500 examples (C) 1000 examples. (D) 5000 examples. Together, these results demonstrate that bottom-up, amortised inference is far more sensitive to the time spent training, compared to the full hybrid architecture. Importantly, the poor performance of amortised inference at the start of learning does not negatively impact the speed at which iterative inference learns. Plotted are the mean accuracies over 5 seeds. Shaded areas represent the standard deviation.

More »

Fig 6 Expand

Fig 7.

Additional properties of the HPC model.

(A) Example evolution of the label entropy over the course of an inference phase. The initial amortized guess has relatively high entropy (uncertainty over labels) which progressively reduces during iterative inference. This is consistent with the viewpoint that the iterative inference phase refines the initial amortized guesses. (B) The number of inference steps required over an example training run. Due to the superior initialization provided by the amortized connections, far fewer iterative inference steps are required. (C) Adaptive computation time based on task difficulty. On a well learned task, the number of inference iterations required decays towards 0. However, when there is a change in data distribution, additional iterative inference iterations are adaptively utilized to classify the new, more challenging, stimuli.

More »

Fig 7 Expand