Loss functions¶
-
class
compressors.distillation.losses.AttentionLoss(p: int = 2)¶ -
forward(s_hidden_states: Tuple[torch.FloatTensor], t_hidden_states: Tuple[torch.FloatTensor]) → torch.FloatTensor¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
compressors.distillation.losses.CRDLoss(student_dim: int, teacher_dim: int, n_data: int, feature_dim: int = 128, nce_k: int = 16384, nce_t: float = 0.07, nce_m: float = 0.5)¶ CRD Loss function includes two symmetric parts: (a) using teacher as anchor, choose positive and negatives over the student side (b) using student as anchor, choose positive and negatives over the teacher side
- Parameters
student_dim – the dimension of student’s feature
teacher_dim – the dimension of teacher’s feature
feature_dim – the dimension of the projection space
nce_k – number of negatives paired with each positive
nce_t – the temperature
nce_m – the momentum for updating the memory buffer
n_data – the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
-
forward(f_s, f_t, idx, contrast_idx=None)¶ Forward pass.
- Parameters
f_s – the feature of student network, size [batch_size, s_dim]
f_t – the feature of teacher network, size [batch_size, t_dim]
idx – the indices of these positive samples in the dataset, size [batch_size]
contrast_idx – the indices of negative samples, size [batch_size, nce_k]
- Returns
The contrastive loss
-
class
compressors.distillation.losses.KLDivLoss(temperature: float = 1.0)¶ -
forward(s_logits, t_logits)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
compressors.distillation.losses.MSEHiddenStatesLoss(normalize: bool = False, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None, num_layers: Optional[int] = None)¶ -
forward(s_hidden_states: Union[torch.FloatTensor, Tuple[torch.FloatTensor]], t_hidden_states: Union[torch.FloatTensor, Tuple[torch.FloatTensor]]) → torch.FloatTensor¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
compressors.distillation.losses.kl_div_loss(s_logits: torch.FloatTensor, t_logits: torch.FloatTensor, temperature: float = 1.0) → torch.FloatTensor¶ KL-devergence loss
- Parameters
s_logits (FloatTensor) – output for student model.
t_logits (FloatTensor) – output for teacher model.
temperature (float, optional) – Temperature for teacher distribution. Defaults to 1.
- Returns
Divergence between student and teachers distribution.
- Return type
FloatTensor
-
compressors.distillation.losses.mse_loss(s_hidden_states: Tuple[torch.FloatTensor], t_hidden_states: Tuple[torch.FloatTensor], normalize: bool = False) → torch.FloatTensor¶ mse loss for hidden states
- Parameters
s_hidden_states (Tuple[FloatTensor]) – student hiddens
t_hidden_states (Tuple[FloatTensor]) – teacher hiddens
normalize (bool, optional) – normalize embeddings. Defaults to False.
- Returns
loss
- Return type
FloatTensor
-
compressors.distillation.losses.pkt_loss(s_hidden_states: torch.FloatTensor, t_hidden_states: torch.FloatTensor, eps: float = 1e-07) → torch.FloatTensor¶ Loss between distributions over features similarity with cosine similarity kernel.
- Parameters
s_hidden_states (FloatTensor) – student hidden states
t_hidden_states (FloatTensor) – teacher hidden states
eps (float, optional) – small value. Defaults to 1e-7.
- Returns
loss
- Return type
FloatTensor