Data

Utils for preparing data.

class compressors.distillation.data.LogitsDataset(dataset: torch.utils.data.dataset.Dataset, model: torch.nn.modules.module.Module, batched: bool = True, get_logits_fn: Optional[Callable] = None, merge_logits_with_batch_fn: Optional[Callable] = None, **data_loader_kwargs)

Dataset wrapper for taking logits from trained model.

Parameters
  • dataset – base dataset

  • model – trained model

  • batched – flag. If true then getting logits with dataloader.

  • get_logits_fn – function for taking logits from model and batch.

  • merge_logits_with_batch_fn – function to merge data from dataset and logits.

  • **data_loader_kwargs – kwargs for dataloader. For exapmle, {“batch_size”: 32}.

compressors.distillation.data.label_smoothing.probability_shift(logits: torch.Tensor, labels: torch.Tensor)torch.Tensor

From “Preparing Lessons: Improve Knowledge Distillation with Better Supervision” https://arxiv.org/abs/1911.07471. Swaps argmax and correct label in logits.

Parameters
  • logits – logits from teacher model

  • labels – correct labels

Returns

smoothed labels