Callbacks
Logits difference callbacks
Callbacks which use difference between probabilities distribution over last layers.
Useful for classification task.
-
class
compressors.distillation.callbacks.logits_diff.KLDivCallback(output_key: str = 'kl_div_loss', temperature: float = 1.0, student_logits_key: str = 's_logits', teacher_logits_key: str = 't_logits')
Wrappers
Wrappers is useful when your callback inputs something different then hidden states or logits,
but you don’t want to modify batch in your runner.
-
class
compressors.distillation.callbacks.wrappers.LambdaWrapperCallback(base_callback: catalyst.core.callback.Callback, lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])
Wraps input for your callback with specified function.
- Parameters
base_callback (Callback) – Base callback.
lambda_fn (Callable) – Function to apply.
keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function.
Defaults to [“s_hidden_states”, “t_hidden_states”].
- Raises
TypeError – When keys_to_apply is not str or list.
Preprocessors
Inplace analogs of wrappers. Preprocess your runner.batch before applying callbacks.
-
class
compressors.distillation.callbacks.LambdaPreprocessCallback(lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])
Filters output with your lambda function. Inplace analog of LambdaWrapper.
- Parameters
lambda_fn (Callable) – Function to apply.
keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function.
Defaults to [“s_hidden_states”, “t_hidden_states”].
- Raises
TypeError – When keys_to_apply is not str or list.
-
class
compressors.distillation.callbacks.HiddenStatesSelectCallback(layers: Union[int, List[int]], hiddens_key: str = 't_hidden_states')
Hiddens states
-
class
compressors.distillation.callbacks.hidden_states.AttentionHiddenStatesCallback(output_key: str = 'attention_loss', exclude_first_and_last: bool = True, p: int = 2)
- Parameters
-
-
class
compressors.distillation.callbacks.hidden_states.CosineHiddenStatesCallback(output_key: str = 'cosine_loss', last_only: bool = True, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None)
Cosine loss for difference between hidden states of teacher and student model.
- Parameters
-
-
class
compressors.distillation.callbacks.hidden_states.HiddenStatesSelectCallback(layers: Union[int, List[int]], hiddens_key: str = 't_hidden_states')
-
class
compressors.distillation.callbacks.hidden_states.LambdaPreprocessCallback(lambda_fn: Callable, keys_to_apply: Union[List[str], str] = ['s_hidden_states', 't_hidden_states'])
Filters output with your lambda function. Inplace analog of LambdaWrapper.
- Parameters
lambda_fn (Callable) – Function to apply.
keys_to_apply (Union[List[str], str], optional) – Keys in batch dict to apply function.
Defaults to [“s_hidden_states”, “t_hidden_states”].
- Raises
TypeError – When keys_to_apply is not str or list.
-
class
compressors.distillation.callbacks.hidden_states.MSEHiddenStatesCallback(output_key: str = 'mse_loss', normalize: bool = False, need_mapping: bool = False, teacher_hidden_state_dim: Optional[int] = None, student_hidden_state_dim: Optional[int] = None, num_layers: Optional[int] = None)
MSE loss aka Hint loss for difference between hidden
states of teacher and student model.
- Parameters
output_key – name for loss. Defaults to mse_loss.
-
class
compressors.distillation.callbacks.hidden_states.PKTHiddenStatesCallback(output_key: str = 'pkt_loss', last_only: bool = True)
Probabilistic Knowlewdge Transfer loss for difference between hidden states
of teacher and student model.
Proposed in https://arxiv.org/abs/1803.10837.
- Parameters
-