BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning

The BatchBALD acquisition function builds on the regular BALD approach by formulating the problem over a set of data points as opposed to singleton points. The BALD acquisition function is as follows:

I(y;θx,Dtrain)=H(yx,Dtrain)Ep(θDtrain)[H(yx,θ,Dtrain)]\color{green}\mathbb{I}(y; \theta| x, \mathcal{D}_\text{train}) \color{clear}= \color{red}\mathbb{H}(y | x, \mathcal{D}_\text{train}) \color{clear}- \color{blue}\mathbb{E}_{p(\theta|\mathcal{D}_\text{train})}[\mathbb{H}(y|x, \theta, \mathcal{D}_\text{train})]

On the right side of the equation we have the conditional entropy of the model’s prediction yy for the input data point xx (and when trained on Dtrain\mathcal{D}_\text{train}), which is high when the trained model is uncertain about its prediction. The second term on the right side is an expectation over the conditional entropy of the model prediction under the posterior of the model parameters. That is, this term gives a measure of the expected uncertainty of the prediction yy under all possible models trained on Dtrain\mathcal{D}_\text{train}, weighted by the likelihood of those particular model parameters p(θDtrain)p(\theta|\mathcal{D}_\text{train}). This term is low when the model is relatively certain about the prediction yy across all (weighted) possible parameter values. Thus, the left term is high only when we have relative uncertainty about the prediction yy for input xx, but overall the model tends to be able to explain the data with many possible parameter draws from the posterior i.e. the posterior draws disagree among themselves about which is best for the data. This puts together an intuitive explanation for the left term for Bayesian Active Learning by Disagreement (BALD), which estimates the mutual information between model predictions and parameters: it thinks the most informative xx currently have a highly uncertain prediction while also having many draws of model parameters that disagree about which is the best way to explain the prediction for xx. Labeling this point will both reduce the uncertainty about the prediction of the point as well as illuminate which model parameters should be considered “best”.