Avoiding Your Teacher's Mistakes: Training Neural Networks with Controlled Weak Supervision

This post is about the project I've done in collaboration with Aliaksei Severyn, Sascha Rothe, and Jaap Kamps, during my internship at Google Research.

Deep neural networks have shown impressive results in a lot of tasks in computer vision, natural language processing, and information retrieval. However, their success is conditioned on the availability of exhaustive amounts of labeled data, while for many tasks such a data is not available.  Hence, unsupervised and semi-supervised methods are becoming increasingly attractive.

Using weak or noisy supervision is a straightforward approach to increase the size of the training data.  In one of my previous post, I've talked about how to beat your teacher, which provides an insight on how to train a neural network model using only the output of a heuristic model as supervision signal which eventually works better than that heuristic model. Assuming that most of the time, besides a lot of unlabeled (or weakly labeled) data there is a small amount of training data with strong (true) labels, i.e. a semi-supervised setup, here I'll talk about how to learn from a weak teacher and avoid his mistakes.

This is usually done by pre-training the network on weak data and fine-tuning it with true labels.  However, these two independent stages do not leverage the full capacity of information from true labels.  For instance, in the pre-training stage, there is no handle to control the extent to which the data with weak labels contribute in the learning process, while they can be of different quality.

In this post, I'm going to talk about our proposed idea which is a semi-supervised method that leverages a small amount of data with true labels along with a large amount of data with weak labels.  Our proposed method has three main components:

  • A weak annotator, which can be a heuristic model, a weak classifier, or even human via crowdsourcing and it is employed to annotate massive amount of unlabeled data.
  • A target network which uses a large set of weakly annotated instances by weak annotator to learn the main task
  • A confidence network which is trained on a small human-labeled set to estimate confidence scores for instances annotated by the weak annotator. We train the target network and confidence in a multi-task fashion.

In a joint learning process, target network and confidence network try to learn a suitable representation of the data and this layer is shared between them as a two-way communication channel.  The target network tries to learn to predict the label of the given input under the supervision of the weak annotator. In the same time, the output of the confidence network, which are the confidence scores, define the magnitude of the weight updates to the target network with respect to the loss computed based on labels from weak annotator, during the back-propagation phase of the target network. This way, the confidence network helps the target network to avoid mistakes of her teacher, i.e.weak annotator, by down-weighting the weight updates from weak labels that do not look reliable to the confidence network.

Our setup requires running a weak annotator to label a large amount of unlabeled data, which is done at pre-processing time. For many tasks, it is possible to use a simple heuristic, or implicit human feedback to generate weak labels. This set is then used to train the target network.
In contrast, a small expert-labeled set is used to train the confidence network, which estimates how good the weak annotations are, i.e. controls the effect of weak labels on updating the parameters of the target network. Our method allows learning different types of neural architectures and different tasks, where a meaningful weak annotator is available.

Model Architecture:

The high-level representation of the model is shown in these figures:

Faded parts of the network are disabled during the training in the corresponding mode. Red-dotted arrows show gradient propagation. Parameters of the parts of the network in red frames get updated in the backward pass, while parameters of the network in blue frames are fixed during the training.

More formally, the goal of the weak annotator is to provide weak labels \tilde{y}_i for all the instances \tau_i \in U \cup V. We have this assumption that \tilde{y}_i provided by the weak annotator are imperfect estimates of true labels y_i, where y_i are available for set V, but not for set U.

The goal of the confidence network is to estimate the confidence score \tilde{c}_j of training instances. It is learned on triplets from training set V: input \tau_j, its weak label \tilde{y}_j, and its true label y_j. The score \tilde{c}_j is then used to control the effect of weakly annotated training instances on updating the parameters of the target network in its backward pass during backpropagation. The target network is in charge of handling the main task we want to learn, or in other words, approximating the underlying function that predicts the correct labels.
Given the data instance, \tau_i and its weak label \tilde{y}_i from the training set U, the target network aims to predict the label \hat{y}_i.  The target network parameter updates are based on noisy labels assigned by the weak annotator, but the magnitude of the gradient update is based on the output of the confidence network.

Both networks are trained in a multi-task fashion alternating between the \emph{full supervision} and the weak supervision mode.  In the full supervision mode, the parameters of the confidence network get updated using batches of instances from training set V.  As depicted on the left side of the figure, each training instance is passed through the representation layer mapping inputs to vectors. These vectors are concatenated with their corresponding weak labels \tilde{y}_j generated by the weak annotator. The confidence network then estimates \tilde{c}_j, which is the probability of taking data instance j into account for training the target network.

In the weak supervision, mode the parameters of the target network are updated using training set U. As shown in the right side of the figure, each training instance is passed through the same representation learning layer and is then processed by the supervision layer which is a part of the target network predicting the label for the main task.  We also pass the learned representation of each training instance along with its corresponding label generated by the weak annotator to the confidence network to estimate the confidence score of the training instance, i.e. \tilde{c}_i.  The confidence score is computed for each instance from set U. These confidence scores are used to weight the gradient updating target network parameters or in other words the step size during back-propagation.

It is noteworthy that the representation layer is shared between both networks, so besides the regularization effect of layer sharing which leads to better generalization, sharing this layer lays the ground for the confidence network to benefit from the largeness of set U and the target network to utilize the quality of set V.

Model Training:

Our optimization objective is composed of two terms: (1) the confidence network loss \mathcal{L}_c, which captures the quality of the output from the confidence network and (2) the target network loss \mathcal{L}_t, which expresses the quality for the main task.

Both networks are trained by alternating between the weak supervision and the full supervision mode. In the full supervision mode, the parameters of the confidence network are updated using training instance drawn from training set V. We use cross-entropy loss function for the confidence network to capture the difference between the predicted confidence score of instance j, i.e. \tilde{c}_j and the target score c_j:

(1)   \begin{equation*} % \nonumber \mathcal{L}_c = \sum_{j\in V} - c_j \log(\tilde{c}_j) - (1-c_j) \log(1-\tilde{c}_j), \end{equation*}

The target confidence score c_j is calculated based on the difference of the true and weak labels with respect to the main task. In the weak supervision mode, the parameters of the target network are updated using training instances from U. We use a weighted loss function, \mathcal{L}_t, to capture the difference between the predicted label \hat{y}_i by the target network and target label \tilde{y}_i:

(2)   \begin{equation*} % \nonumber \mathcal{L}_t = \sum_{i\in U} \tilde{c}_i \mathcal{L}_i, \end{equation*}

where \mathcal{L}_i is the task-specific loss on training instance i and \tilde{c}_i is the confidence score of the weakly annotated instance i, estimated by the confidence network. Note that \tilde{c}_i is treated as a constant during the weak supervision mode and there is no gradient propagation to the confidence network in the backward pass (as depicted on the right side of the figure).

We minimize two loss functions jointly by randomly alternating between full and weak supervision modes (for example, using a 1:10 ratio). During training and based on the chosen supervision mode, we sample a batch of training instances from V with replacement or from U without replacement (since we can generate as much train data for set U). Since in our setups usually |U| >> |V|, the training process oversamples the instance from V.
The key point here is that the ``main task'' and ``confidence scoring'' task are always defined to be close tasks and sharing representation will benefit the confidence network as an implicit data augmentation to compensate the small amount of data with true labels.
Besides, we noticed that updating the representation layer with respect to the loss of the other network acts as a regularization for each of these networks and helps generalization for both target and confidence network since we try to capture all tasks (which are related tasks) and less chance for overfitting.

We apply our semi-supervised method to two different tasks: document ranking and sentiment classification. Our experimental results suggest that the proposed method is more effective in leveraging large amounts of weakly labeled data compared to traditional fine-tuning in both tasks. We also show that explicitly controlling the weight updates in the target network with the confidence network leads to faster convergence since the filtered supervision signals are more solid and less noisy.

To learn more about our model and the results, please take a look at this paper:

  • Mostafa Dehghani, A. Severyn, S. Rothe, and J. Kamps. "Avoiding Your Teacher’s Mistake: Training Neural Networks with Controlled Weak Supervision", arXiv preprint arXiv:1711.00313 (2017).

One thought on “Avoiding Your Teacher's Mistakes: Training Neural Networks with Controlled Weak Supervision

Comments are closed.