Callbacks

Logits difference callbacks

Callbacks which use difference between probabilities distribution over last layers. Useful for classification task.

class compressors.distillation.callbacks.logits_diff.KLDivCallback(output_key: str = 'kl_div_loss', temperature: float = 1.0, student_logits_key: str = 's_logits', teacher_logits_key: str = 't_logits')

Wrappers

Wrappers is useful when your callback inputs something different then hidden states or logits, but you don’t want to modify batch in your runner.

class compressors.distillation.callbacks.wrappers.LambdaWrapperCallback(base_callback: catalyst.core.callback.Callback, lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])

Wraps input for your callback with specified function.

Parameters
  • base_callback (Callback) – Base callback.

  • lambda_fn (Callable) – Function to apply.

  • keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function. Defaults to [“s_hidden_states”, “t_hidden_states”].

Raises

TypeError – When keys_to_apply is not str or list.

Preprocessors

Inplace analogs of wrappers. Preprocess your runner.batch before applying callbacks.

class compressors.distillation.callbacks.LambdaPreprocessCallback(lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])

Filters output with your lambda function. Inplace analog of LambdaWrapper.

Parameters
  • lambda_fn (Callable) – Function to apply.

  • keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function. Defaults to [“s_hidden_states”, “t_hidden_states”].

Raises

TypeError – When keys_to_apply is not str or list.

class compressors.distillation.callbacks.HiddenStatesSelectCallback(layers: Union[int, List[int]], hiddens_key: str = 't_hidden_states')

Hiddens states

class compressors.distillation.callbacks.hidden_states.AttentionHiddenStatesCallback(output_key: str = 'attention_loss', exclude_first_and_last: bool = True, p: int = 2)
Parameters
  • output_key – name for loss. Defaults to attention_loss.

  • exclude_first_and_last – If set to True doesn’t take first and last hidden states. Usually attention loss applied in this way. Defaults to True.

class compressors.distillation.callbacks.hidden_states.CosineHiddenStatesCallback(output_key: str = 'cosine_loss', last_only: bool = True, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None)

Cosine loss for difference between hidden states of teacher and student model.

Parameters
  • output_key – name for loss. Defaults to cosine_loss.

  • last_only – If set to True takes only last hidden state. Usually cosine loss applied in this way. Defaults to True.

class compressors.distillation.callbacks.hidden_states.HiddenStatesSelectCallback(layers: Union[int, List[int]], hiddens_key: str = 't_hidden_states')
class compressors.distillation.callbacks.hidden_states.LambdaPreprocessCallback(lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])

Filters output with your lambda function. Inplace analog of LambdaWrapper.

Parameters
  • lambda_fn (Callable) – Function to apply.

  • keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function. Defaults to [“s_hidden_states”, “t_hidden_states”].

Raises

TypeError – When keys_to_apply is not str or list.

class compressors.distillation.callbacks.hidden_states.MSEHiddenStatesCallback(output_key: str = 'mse_loss', normalize: bool = False, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None, num_layers: Optional[int] = None)

MSE loss aka Hint loss for difference between hidden states of teacher and student model.

Parameters

output_key – name for loss. Defaults to mse_loss.

class compressors.distillation.callbacks.hidden_states.PKTHiddenStatesCallback(output_key: str = 'pkt_loss', last_only: bool = True)

Probabilistic Knowlewdge Transfer loss for difference between hidden states of teacher and student model. Proposed in https://arxiv.org/abs/1803.10837.

Parameters
  • output_key – name for loss. Defaults to mse_loss.

  • last_only – If set to True takes only last hidden state. Usually pkt loss applied in this way. Defaults to True.