Distillation

Philosophy and main idea

Method proposed in Distilling the Knowledge in a Neural Network paper. We can train one teacher model for a task and then transfer knoweledge to the other student network. In classical approach teacher is a big and slow network and student is a smaller model and we need student to perform as better as a teacher.

Let’s take a look on the simple image classification example:

../_images/distillation.png

The most simple way is to look on the probability distribution of the teacher model and the student model and force student to produce same distribution as a teacher.

../_images/distributions.png

The most common way to measure difference between distribution called KL-divergence.

\[KL = \sum_x p_1(x) \log\left(\frac{p_1(x)}{p_2(x)}\right)\]

You can use KLDivLoss for applying this loss. It tooks student’s and teacher’s logits.

class compressors.distillation.losses._kl_loss.KLDivLoss(temperature: float = 1.0)

Or just add KLDivCallback to DistilRunner if you are using Catalyst.

class compressors.distillation.callbacks.KLDivCallback(output_key: str = 'kl_div_loss', temperature: float = 1.0, student_logits_key: str = 's_logits', teacher_logits_key: str = 't_logits')

We also can add difference between hidden states.

../_images/flatten_hiddens.png

For example we can use \(L_2\)-norm for this vectors with MSEHiddenStatesLoss

class compressors.distillation.losses.MSEHiddenStatesLoss(normalize: bool = False, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None, num_layers: Optional[int] = None)

or callback

class compressors.distillation.callbacks.MSEHiddenStatesCallback(output_key: str = 'mse_loss', normalize: bool = False, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None, num_layers: Optional[int] = None)

MSE loss aka Hint loss for difference between hidden states of teacher and student model.

Parameters

output_key – name for loss. Defaults to mse_loss.

But we can have not only vectors in hidden states. We can also have tensors like here with shapes \((C\times H\times W)\).

../_images/hiddens.png

You can also apply \(L_2\)-norm here, but experiments shows that you can use something more complex like Attention Transfer.

In compressors this method implemented in

class compressors.distillation.losses.AttentionLoss(p: int = 2)

and

class compressors.distillation.callbacks.AttentionHiddenStatesCallback(output_key: str = 'attention_loss', exclude_first_and_last: bool = True, p: int = 2)
Parameters
  • output_key – name for loss. Defaults to attention_loss.

  • exclude_first_and_last – If set to True doesn’t take first and last hidden states. Usually attention loss applied in this way. Defaults to True.

Submodules