Examples

Minimal Example

Imports

from itertools import chain

from catalyst.callbacks import AccuracyCallback
from catalyst.contrib.datasets import MNIST
import torch
from torch.utils.data import DataLoader

from compressors.distillation.runners import EndToEndDistilRunner
from compressors.models import MLP
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp

teacher = MLP(num_layers=4)
student = MLP(num_layers=3)

datasets = {
    "train": Wrp(MNIST("./data", train=True, download=True)),
    "valid": Wrp(MNIST("./data", train=False)),
}

loaders = {
    dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
    for dl_key, dataset in datasets.item()
}

optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters()))

runner = EndToEndDistilRunner(hidden_state_loss="mse", num_train_teacher_epochs=5)

runner.train(
    model={"teacher": teacher, "student": student},
    loaders=loaders,
    optimizer=optimizer,
    num_epochs=4,
    callbacks=[AccuracyCallback(input_key="logits", target_key="targets")],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    logdir="./logs",
)

Minimal Complex Example

First of all imports:

import torch
from torch import nn
from torch.utils.data import DataLoader

from catalyst.contrib.datasets import MNIST
from catalyst.callbacks import AccuracyCallback, CriterionCallback
from catalyst.runners import SupervisedRunner

from compressors.utils.data import TorchvisionDatasetWrapper as Wrp
from compressors.models import BaseDistilModel
from compressors.distillation.callbacks import MSEHiddenStatesCallback, HiddenStatesSelectCallback, KLDivCallback, MetricAggregationCallback
from compressors.distillation.runners import DistilRunner

Now we can create tiny model class. The main and the only difference from ordinary pytorch model is that forward should also supports output_hidden_states and return_dict args.

If output_hidden_states is set to True model should also output tuple of hidden states.

If return_dict is set to True model should be type of dict.

class ExampleModel(BaseDistilModel):
    def __init__(self, num_layers: int = 4, hidden_dim: int=128):
        layers = []
        self.num_layers = num_layers
        for layer_idx in range(num_layers):
            if layer_idx == 0:
                layers.append(nn.Linear(28*28, hidden_dim))
            elif layer_idx == num_layers-1:
                layers.append(nn.Linear(hidden_dim, 10))
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))

        self.layers = nn.ModuleList(*layers)

    def forward(
        self,
        inp,
        output_hidden_states: bool=False,
        return_dict: bool = False
    ):
        cur_hidden = inp
        if output_hidden_states:
            hiddens = []

        for layer_idx, layer in enumerate(self.layers):
            cur_hidden = layer(cur_hidden)
            if output_hidden_states:  #  accumulate hidden states
                hiddens.append(cur_hidden)

            if layer_idx != self.num_layers - 1:  # last layer case
                cur_hidden = torch.relu(cur_hidden)

        logits = cur_hidden
        if return_dict:
            output = {"logits": logits}
            if output_hidden_states:
                output["hidden_states"] = tuple(hiddens)
            return output

        if output_hidden_states:
            return logits, tuple(hiddens)

        return logits

Now we are all-set. Let’s begin and define our teacher and student models.

teacher = ExampleModel(num_layers=4)
student = ExampleModel(num_layers=3)

Here is data preprocessing:

datasets = {
    "train": Wrp(MNIST("./data", train=True, download=True)),
    "valid": Wrp(MNIST("./data", train=False))
}

loaders = {
    dl_key: DataLoader(dataset, shuffle=dl_key=="train", batch_size=32) for dl_key, dataset in datasets.item()
}

Now we are ready to train our teacher. This is just simple supervised learning pipeline.

optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-2)

runner = SupervisedRunner()

runner.train(
    model=teacher,
    loaders=loaders,
    optimizer=optimizer,
    criterion=nn.CrossEntropyLoss(),
    callbacks=[AccuracyCallback(input_key="logits", target_key="targets")],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    num_epochs=5,
)

Here begins distillation code.

First of all let’s define our losses: in addition to CrossEntropyLoss we will count MSELoss between hidden states of teacher and student model and KLDivLoss between output distributions of class probabilities.

But our teacher model has more layers then student model and has more hidden states. Therefore we will took only last two hidden states of teacher model. We can do it with HiddenStatesSelectCallback and set layers=[2, 3].

select_last_hidden_states = HiddenStatesSelectCallback(layers=[2, 3])

Now we can simply initialize MSEHiddenStatesCallback for MSE loss and KLDivCallback for KL-divergence loss

mse_callback = MSEHiddenStatesCallback()
kl_callback = KLDivCallback()

Here we can initialize our DistilRunner and set output_hidden_states=True as we are using hidden_states in loss

runner = DistilRunner(output_hidden_states=True)

We can provide only students parameters to optimizer.

optimizer = torch.optim.Adam(student.parameters(), lr=1e-2)

Now we can run distillation! We also add MetricAggregationCallback to callbacks, as we need final loss to be sum of the several losses.

\[L_\text{loss} = w_1\cdot L_\text{task loss} + w_2\cdot L_\text{KL} + w_3\cdot L_\text{MSE}\]

We are also setting weights to losses.

runner.train(
    model={"teacher": teacher, "student": student},
    loaders=loaders,
    optimier=optimizer,
    criterion=nn.CrossEntropyLoss(),
    callbacks=[
        AccuracyCallback(input_key="s_logits", target_key="targets"),
        CriterionCallback(input_key="s_logits"),
        mse_callback,
        select_last_hidden_states,
        kl_callback,
        MetricAggregationCallback({
            "kl_loss": 0.2,
            "mse_loss": 0.2,
            "loss": 0.6
        })
    ],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    num_epochs=5,
)

NLP

CV