diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index 7435fa822..e8e3e6a51 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .hooks import DumpSubnetHook, EstimateResourcesHook +from .hooks import (DistillationLossDetachHook, DumpSubnetHook, + EstimateResourcesHook) from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, @@ -12,5 +13,5 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'DistillationLossDetachHook' ] diff --git a/mmrazor/engine/hooks/__init__.py b/mmrazor/engine/hooks/__init__.py index d25c7c993..9364a2700 100644 --- a/mmrazor/engine/hooks/__init__.py +++ b/mmrazor/engine/hooks/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .distillation_loss_detach_hook import DistillationLossDetachHook from .dump_subnet_hook import DumpSubnetHook from .estimate_resources_hook import EstimateResourcesHook from .visualization_hook import RazorVisualizationHook -__all__ = ['DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook'] +__all__ = [ + 'DumpSubnetHook', 'EstimateResourcesHook', 'RazorVisualizationHook', + 'DistillationLossDetachHook' +] diff --git a/mmrazor/engine/hooks/distillation_loss_detach_hook.py b/mmrazor/engine/hooks/distillation_loss_detach_hook.py new file mode 100644 index 000000000..26e36b08e --- /dev/null +++ b/mmrazor/engine/hooks/distillation_loss_detach_hook.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmrazor.registry import HOOKS + + +@HOOKS.register_module() +class DistillationLossDetachHook(Hook): + + priority = 'LOW' + + def __init__(self, detach_epoch) -> None: + self.detach_epoch = detach_epoch + + def before_train_epoch(self, runner) -> None: + if runner.epoch >= self.detach_epoch: + model = runner.model + # TODO: refactor after mmengine using model wrapper + if is_model_wrapper(model): + model = model.module + assert hasattr(model, 'distill_loss_detach') + + runner.logger.info('Distillation stop now!') + model.distill_loss_detach = True diff --git a/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py index 9d87d9139..6d433b0dd 100644 --- a/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py @@ -49,9 +49,11 @@ def loss( batch_inputs, data_samples, mode='loss') losses.update(add_prefix(student_losses, 'student')) - # Automatically compute distill losses based on `loss_forward_mappings` - # The required data already exists in the recorders. - distill_losses = self.distiller.compute_distill_losses() - losses.update(add_prefix(distill_losses, 'distill')) + if not self.distill_loss_detach: + # Automatically compute distill losses based on + # `loss_forward_mappings`. + # The required data already exists in the recorders. + distill_losses = self.distiller.compute_distill_losses() + losses.update(add_prefix(distill_losses, 'distill')) return losses diff --git a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py index edcc2a898..50fbe643e 100644 --- a/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/single_teacher_distill.py @@ -67,9 +67,9 @@ def __init__(self, self.set_module_inplace_false(teacher, 'self.teacher') if teacher_ckpt: - # avoid loaded parameters be overwritten - self.teacher.init_weights() _ = load_checkpoint(self.teacher, teacher_ckpt) + # avoid loaded parameters be overwritten + self.teacher._is_init = True self.teacher_trainable = teacher_trainable if not self.teacher_trainable: for param in self.teacher.parameters(): @@ -89,6 +89,9 @@ def __init__(self, self.distiller.prepare_from_student(self.student) self.distiller.prepare_from_teacher(self.teacher) + # may be modified by distill loss scheduler hook + self.distill_loss_detach = False + @property def student(self) -> nn.Module: """Alias for ``architecture``.""" @@ -135,10 +138,12 @@ def loss( _ = self.student( batch_inputs, data_samples, mode='loss') - # Automatically compute distill losses based on `loss_forward_mappings` - # The required data already exists in the recorders. - distill_losses = self.distiller.compute_distill_losses() - losses.update(add_prefix(distill_losses, 'distill')) + if not self.distill_loss_detach: + # Automatically compute distill losses based on + # `loss_forward_mappings`. + # The required data already exists in the recorders. + distill_losses = self.distiller.compute_distill_losses() + losses.update(add_prefix(distill_losses, 'distill')) return losses diff --git a/mmrazor/models/architectures/connectors/__init__.py b/mmrazor/models/architectures/connectors/__init__.py index dd7f77f01..fd4c91e77 100644 --- a/mmrazor/models/architectures/connectors/__init__.py +++ b/mmrazor/models/architectures/connectors/__init__.py @@ -5,11 +5,13 @@ from .factor_transfer_connectors import Paraphraser, Translator from .fbkd_connector import FBKDStudentConnector, FBKDTeacherConnector from .mgd_connector import MGDConnector +from .norm_connector import NormConnector from .ofd_connector import OFDTeacherConnector from .torch_connector import TorchFunctionalConnector, TorchNNConnector __all__ = [ 'ConvModuleConnector', 'Translator', 'Paraphraser', 'BYOTConnector', 'FBKDTeacherConnector', 'FBKDStudentConnector', 'TorchFunctionalConnector', - 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector' + 'CRDConnector', 'TorchNNConnector', 'OFDTeacherConnector', 'MGDConnector', + 'NormConnector' ] diff --git a/mmrazor/models/architectures/connectors/norm_connector.py b/mmrazor/models/architectures/connectors/norm_connector.py new file mode 100644 index 000000000..5d65da7dc --- /dev/null +++ b/mmrazor/models/architectures/connectors/norm_connector.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + +import torch +from mmcv.cnn import build_norm_layer + +from mmrazor.registry import MODELS +from .base_connector import BaseConnector + + +@MODELS.register_module() +class NormConnector(BaseConnector): + + def __init__(self, in_channels, norm_cfg, init_cfg: Optional[Dict] = None): + super(NormConnector, self).__init__(init_cfg) + _, self.norm = build_norm_layer(norm_cfg, in_channels) + + def forward_train(self, feature: torch.Tensor) -> torch.Tensor: + return self.norm(feature) diff --git a/mmrazor/models/distillers/configurable_distiller.py b/mmrazor/models/distillers/configurable_distiller.py index acfc81d3b..74c99250f 100644 --- a/mmrazor/models/distillers/configurable_distiller.py +++ b/mmrazor/models/distillers/configurable_distiller.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings from inspect import signature -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union +import torch from mmengine.model import BaseModel from torch import nn @@ -139,15 +140,24 @@ def prepare_from_teacher(self, model: nn.Module) -> None: def build_connectors( self, - connectors: Optional[Dict[str, Dict]] = None, + connectors: Optional[Union[Dict[str, List], Dict[str, Dict]]] = None, ) -> nn.ModuleDict: """Initialize connectors.""" distill_connecotrs = nn.ModuleDict() if connectors: for connector_name, connector_cfg in connectors.items(): - connector = MODELS.build(connector_cfg) - distill_connecotrs[connector_name] = connector + if isinstance(connector_cfg, dict): + connector = MODELS.build(connector_cfg) + distill_connecotrs[connector_name] = connector + else: + assert isinstance(connector_cfg, list) + module_list = [] + for cfg in connector_cfg: + connector = MODELS.build(cfg) + module_list.append(connector) + distill_connecotrs[connector_name] = nn.Sequential( + *module_list) return distill_connecotrs @@ -204,12 +214,20 @@ def compute_distill_losses(self) -> LossResults: losses = dict() for loss_name, forward_mappings in self.loss_forward_mappings.items(): forward_kwargs = dict() + is_empty = False for forward_key, record in forward_mappings.items(): - forward_var = self.get_record(**record) + try: + forward_var = self.get_record(**record) + except AssertionError: + is_empty = True + break forward_kwargs[forward_key] = forward_var loss_module = self.distill_losses[loss_name] - loss = loss_module(**forward_kwargs) # type: ignore + if not is_empty: + loss = loss_module(**forward_kwargs) # type: ignore + else: + loss = torch.tensor(0.) # add computed loss result. losses[loss_name] = loss diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 8c779d15a..5844afb33 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -8,6 +8,7 @@ from .decoupled_kd import DKDLoss from .factor_transfer_loss import FTLoss from .fbkd_loss import FBKDLoss +from .fgd_loss import FGDLoss from .kd_soft_ce_loss import KDSoftCELoss from .kl_divergence import KLDivergence from .l1_loss import L1Loss @@ -22,5 +23,6 @@ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', - 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss' + 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss', + 'FGDLoss' ] diff --git a/mmrazor/models/losses/cwd.py b/mmrazor/models/losses/cwd.py index 875a3944b..3c8ca1195 100644 --- a/mmrazor/models/losses/cwd.py +++ b/mmrazor/models/losses/cwd.py @@ -17,11 +17,7 @@ class ChannelWiseDivergence(nn.Module): loss_weight (float): Weight of loss. Defaults to 1.0. """ - def __init__( - self, - tau=1.0, - loss_weight=1.0, - ): + def __init__(self, tau=1.0, loss_weight=1.0): super(ChannelWiseDivergence, self).__init__() self.tau = tau self.loss_weight = loss_weight diff --git a/mmrazor/models/losses/fgd_loss.py b/mmrazor/models/losses/fgd_loss.py new file mode 100644 index 000000000..b1601f95a --- /dev/null +++ b/mmrazor/models/losses/fgd_loss.py @@ -0,0 +1,242 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import constant_init, kaiming_init + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class FGDLoss(nn.Module): + """PyTorch version of 'Focal and Global Knowledge Distillation for + Detectors'. + + + Args: + in_channels (int): Channels of the input feature map. + temp (float, optional): Temperature coefficient. Defaults to 0.5. + alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001. + beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005. + gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001. + lambda_fgd (float, optional): Weight of relation_loss. + Defaults to 0.000005. + """ + + def __init__( + self, + in_channels, + temp=0.5, + alpha_fgd=0.001, + beta_fgd=0.0005, + gamma_fgd=0.001, + lambda_fgd=0.000005, + ): + super(FGDLoss, self).__init__() + self.temp = temp + self.alpha_fgd = alpha_fgd + self.beta_fgd = beta_fgd + self.gamma_fgd = gamma_fgd + self.lambda_fgd = lambda_fgd + + self.conv_mask_s = nn.Conv2d(in_channels, 1, kernel_size=1) + self.conv_mask_t = nn.Conv2d(in_channels, 1, kernel_size=1) + self.channel_add_conv_s = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 2, kernel_size=1), + nn.LayerNorm([in_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)) + self.channel_add_conv_t = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 2, kernel_size=1), + nn.LayerNorm([in_channels // 2, 1, 1]), nn.ReLU(inplace=True), + nn.Conv2d(in_channels // 2, in_channels, kernel_size=1)) + + self.reset_parameters() + + def forward(self, preds_S, preds_T, gt_info, batch_img_metas): + """Forward function. + + Args: + preds_S (Tensor): Bs*C*H*W, student's feature map. + preds_T (Tensor): Bs*C*H*W, teacher's feature map. + gt_info (Tensor): Bs*nt*4, ground truth bounding boxes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + """ + assert preds_S.shape[-2:] == preds_T.shape[-2:] + N, C, H, W = preds_S.shape + gt_bboxes = gt_info[:, :, 1:] # xyxy + + spatial_attention_t, channel_attention_t = self.get_attention( + preds_T, self.temp) + spatial_attention_s, channel_attention_s = self.get_attention( + preds_S, self.temp) + + mask_fg = torch.zeros_like(spatial_attention_t) + mask_bg = torch.ones_like(spatial_attention_t) + wmin, wmax, hmin, hmax = [], [], [], [] + for i in range(N): + new_boxx = torch.ones_like(gt_bboxes[i]) + new_boxx[:, 0] = gt_bboxes[i][:, 0] / batch_img_metas[i][ + 'batch_input_shape'][1] * W + new_boxx[:, 2] = gt_bboxes[i][:, 2] / batch_img_metas[i][ + 'batch_input_shape'][1] * W + new_boxx[:, 1] = gt_bboxes[i][:, 1] / batch_img_metas[i][ + 'batch_input_shape'][0] * H + new_boxx[:, 3] = gt_bboxes[i][:, 3] / batch_img_metas[i][ + 'batch_input_shape'][0] * H + + wmin.append(torch.floor(new_boxx[:, 0]).int()) + wmax.append(torch.ceil(new_boxx[:, 2]).int()) + hmin.append(torch.floor(new_boxx[:, 1]).int()) + hmax.append(torch.ceil(new_boxx[:, 3]).int()) + + height = hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1) + width = wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1) + area = 1.0 / height.float() / width.float() + + for j in range(len(gt_bboxes[i])): + mask_fg[i][hmin[i][j]:hmax[i][j]+1, + wmin[i][j]:wmax[i][j]+1] = \ + torch.max(mask_fg[i][hmin[i][j]:hmax[i][j]+1, + wmin[i][j]:wmax[i][j]+1], area[0][j]) + + mask_bg[i] = torch.where(mask_fg[i] > 0, + torch.zeros_like(mask_bg[i]), + torch.ones_like(mask_bg[i])) + if torch.sum(mask_bg[i]): + mask_bg[i] /= torch.sum(mask_bg[i]) + + fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, mask_fg, + mask_bg, channel_attention_t, + spatial_attention_t) + mask_loss = self.get_mask_loss(channel_attention_s, + channel_attention_t, + spatial_attention_s, + spatial_attention_t) + rela_loss = self.get_rela_loss(preds_S, preds_T) + + loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \ + + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss + + return loss + + def get_attention(self, preds, temp): + """Calculate spatial and channel attention. + + Args: + preds (Tensor): Model prediction with shape (N, C, H, W). + temp (float): Temperature coefficient. + """ + N, C, H, W = preds.shape + + value = torch.abs(preds) + # Bs*W*H + fea_map = value.mean(axis=1, keepdim=True) + spatial_attention = (H * W * F.softmax( + (fea_map / temp).view(N, -1), dim=1)).view(N, H, W) + + # Bs*C + channel_map = value.mean( + axis=2, keepdim=False).mean( + axis=2, keepdim=False) + channel_attention = C * F.softmax(channel_map / temp, dim=1) + + return spatial_attention, channel_attention + + def get_fea_loss(self, preds_S, preds_T, mask_fg, mask_bg, + channel_attention_t, spatial_attention_t): + loss_mse = nn.MSELoss(reduction='sum') + + mask_fg = mask_fg.unsqueeze(dim=1) + mask_bg = mask_bg.unsqueeze(dim=1) + + channel_attention_t = channel_attention_t.unsqueeze(dim=-1).unsqueeze( + dim=-1) + spatial_attention_t = spatial_attention_t.unsqueeze(dim=1) + + fea_t = torch.mul(preds_T, torch.sqrt(spatial_attention_t)) + fea_t = torch.mul(fea_t, torch.sqrt(channel_attention_t)) + fea_t_fg = torch.mul(fea_t, torch.sqrt(mask_fg)) + fea_t_bg = torch.mul(fea_t, torch.sqrt(mask_bg)) + + fea_s = torch.mul(preds_S, torch.sqrt(spatial_attention_t)) + fea_s = torch.mul(fea_s, torch.sqrt(channel_attention_t)) + fea_s_fg = torch.mul(fea_s, torch.sqrt(mask_fg)) + fea_s_bg = torch.mul(fea_s, torch.sqrt(mask_bg)) + + loss_fg = loss_mse(fea_s_fg, fea_t_fg) / len(mask_fg) + loss_bg = loss_mse(fea_s_bg, fea_t_bg) / len(mask_bg) + + return loss_fg, loss_bg + + def get_mask_loss(self, channel_attention_s, channel_attention_t, + spatial_attention_s, spatial_attention_t): + + mask_loss = torch.sum( + torch.abs( + (channel_attention_s - + channel_attention_t))) / len(channel_attention_s) + torch.sum( + torch.abs( + (spatial_attention_s - + spatial_attention_t))) / len(spatial_attention_s) + + return mask_loss + + def spatial_pool(self, x, is_student_input): + batch, channel, width, height = x.size() + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + if is_student_input: + context_mask = self.conv_mask_s(x) + else: + context_mask = self.conv_mask_t(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = F.softmax(context_mask, dim=2) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + + return context + + def get_rela_loss(self, preds_S, preds_T): + loss_mse = nn.MSELoss(reduction='sum') + + context_s = self.spatial_pool(preds_S, is_student_input=True) + context_t = self.spatial_pool(preds_T, is_student_input=False) + + out_s = preds_S + out_t = preds_T + + channel_add_s = self.channel_add_conv_s(context_s) + out_s = out_s + channel_add_s + + channel_add_t = self.channel_add_conv_t(context_t) + out_t = out_t + channel_add_t + + rela_loss = loss_mse(out_s, out_t) / len(out_s) + + return rela_loss + + def last_zero_init(self, m): + if isinstance(m, nn.Sequential): + constant_init(m[-1], val=0) + else: + constant_init(m, val=0) + + def reset_parameters(self): + kaiming_init(self.conv_mask_s, mode='fan_in') + kaiming_init(self.conv_mask_t, mode='fan_in') + self.conv_mask_s.inited = True + self.conv_mask_t.inited = True + + self.last_zero_init(self.channel_add_conv_s) + self.last_zero_init(self.channel_add_conv_t) diff --git a/mmrazor/models/losses/kl_divergence.py b/mmrazor/models/losses/kl_divergence.py index d79d74c49..ab59a5c97 100644 --- a/mmrazor/models/losses/kl_divergence.py +++ b/mmrazor/models/losses/kl_divergence.py @@ -2,9 +2,26 @@ import torch.nn as nn import torch.nn.functional as F +from mmrazor.models.losses.utils import weighted_loss from mmrazor.registry import MODELS +@weighted_loss +def kl_div(preds_S, preds_T, tau: float = 1.0): + """Calculate the KL divergence between `preds_S` and `preds_T`. + + Args: + preds_S (torch.Tensor): The student model prediction with shape (N, C). + preds_T (torch.Tensor): The teacher model prediction with shape (N, C). + tau (float): Temperature coefficient. + """ + softmax_pred_T = F.softmax(preds_T / tau, dim=1) + logsoftmax_preds_S = F.log_softmax(preds_S / tau, dim=1) + loss = (tau**2) * F.kl_div( + logsoftmax_preds_S, softmax_pred_T, reduction='none') + return loss + + @MODELS.register_module() class KLDivergence(nn.Module): """A measure of how one probability distribution Q is different from a @@ -45,7 +62,12 @@ def __init__( f'but gets {reduction}.' self.reduction = reduction - def forward(self, preds_S, preds_T): + def forward(self, + preds_S, + preds_T, + weight=None, + avg_factor=None, + reduction_override=None): """Forward computation. Args: @@ -53,14 +75,27 @@ def forward(self, preds_S, preds_T): shape (N, C, H, W) or shape (N, C). preds_T (torch.Tensor): The teacher model prediction with shape (N, C, H, W) or shape (N, C). + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean", "sum" and "batchmean". Return: torch.Tensor: The calculated loss value. """ + assert reduction_override in (None, 'none', 'mean', 'sum', 'batchmean') + reduction = ( + reduction_override if reduction_override else self.reduction) if self.teacher_detach: preds_T = preds_T.detach() - softmax_pred_T = F.softmax(preds_T / self.tau, dim=1) - logsoftmax_preds_S = F.log_softmax(preds_S / self.tau, dim=1) - loss = (self.tau**2) * F.kl_div( - logsoftmax_preds_S, softmax_pred_T, reduction=self.reduction) + loss = kl_div( + preds_S, + preds_T, + tau=self.tau, + weight=weight, + reduction=reduction, + avg_factor=avg_factor) return self.loss_weight * loss diff --git a/mmrazor/models/losses/utils.py b/mmrazor/models/losses/utils.py new file mode 100644 index 000000000..72c336ed2 --- /dev/null +++ b/mmrazor/models/losses/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def reduce_loss(loss: Tensor, reduction: str) -> Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + # special case for batchmean in kl_div + if reduction == 'batchmean': + return loss.sum() / loss.size()[0] + + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss: Tensor, + weight: Optional[Tensor] = None, + reduction: str = 'mean', + avg_factor: Optional[float] = None) -> Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Optional[Tensor], optional): Element-wise weights. + Defaults to None. + reduction (str, optional): Same as built-in losses of PyTorch. + Defaults to 'mean'. + avg_factor (Optional[float], optional): Average factor when + computing the mean of losses. Defaults to None. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func: Callable) -> Callable: + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: str = 'mean', + avg_factor: Optional[int] = None, + **kwargs) -> Tensor: + """ + Args: + pred (Tensor): The prediction. + target (Tensor): Target bboxes. + weight (Optional[Tensor], optional): The weight of loss for each + prediction. Defaults to None. + reduction (str, optional): Options are "none", "mean" and "sum". + Defaults to 'mean'. + avg_factor (Optional[int], optional): Average factor that is used + to average the loss. Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py index 15c361e38..e832331c2 100644 --- a/mmrazor/models/task_modules/delivery/function_outputs_delivery.py +++ b/mmrazor/models/task_modules/delivery/function_outputs_delivery.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import functools from types import FunctionType from typing import Callable @@ -152,7 +153,7 @@ def wrap_func(*args, **kwargs): assert len(self.data_queue) < self.data_queue.maxlen,\ 'push into an full queue' outputs = origin_func(*args, **kwargs) - self.data_queue.append(outputs) + self.data_queue.append(copy.deepcopy(outputs)) return outputs return wrap_func diff --git a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py index fa9f6c4a4..f18c6e8dd 100644 --- a/mmrazor/models/task_modules/delivery/method_outputs_delivery.py +++ b/mmrazor/models/task_modules/delivery/method_outputs_delivery.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import functools from types import FunctionType, ModuleType from typing import Callable @@ -149,7 +150,7 @@ def wrap_method(*args, **kwargs): assert len(self.data_queue) < self.data_queue.maxlen,\ 'push into an full queue' outputs = origin_method(*args, **kwargs) - self.data_queue.append(outputs) + self.data_queue.append(copy.deepcopy(outputs)) return outputs return wrap_method diff --git a/tests/test_engine/test_hooks/test_distillation_loss_detach_hook.py b/tests/test_engine/test_hooks/test_distillation_loss_detach_hook.py new file mode 100644 index 000000000..eefdcc7cb --- /dev/null +++ b/tests/test_engine/test_hooks/test_distillation_loss_detach_hook.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import Mock + +from mmrazor.engine import DistillationLossDetachHook + + +class TestDistillationLossDetachHook(TestCase): + + def setUp(self): + self.hook = DistillationLossDetachHook(detach_epoch=5) + runner = Mock() + runner.model = Mock() + runner.model.distill_loss_detach = False + + runner.epoch = 0 + # runner.max_epochs = 10 + self.runner = runner + + def test_before_train_epoch(self): + max_epochs = 10 + target = [False] * 5 + [True] * 5 + for epoch in range(max_epochs): + self.hook.before_train_epoch(self.runner) + self.assertEquals(self.runner.model.distill_loss_detach, + target[epoch]) + self.runner.epoch += 1 diff --git a/tests/test_models/test_architectures/test_connectors/test_connectors.py b/tests/test_models/test_architectures/test_connectors/test_connectors.py index 1f7137022..a2a2dcadc 100644 --- a/tests/test_models/test_architectures/test_connectors/test_connectors.py +++ b/tests/test_models/test_architectures/test_connectors/test_connectors.py @@ -5,7 +5,7 @@ from mmrazor.models import (BYOTConnector, ConvModuleConnector, CRDConnector, FBKDStudentConnector, FBKDTeacherConnector, - MGDConnector, Paraphraser, + MGDConnector, NormConnector, Paraphraser, TorchFunctionalConnector, TorchNNConnector, Translator) @@ -143,3 +143,11 @@ def test_mgd_connector(self): assert s_output1.shape == torch.Size([1, 16, 8, 8]) assert s_output2.shape == torch.Size([1, 32, 8, 8]) + + def test_norm_connector(self): + s_feat = torch.randn(2, 3, 2, 2) + norm_cfg = dict(type='BN', affine=False, track_running_stats=False) + norm_connector = NormConnector(3, norm_cfg) + output = norm_connector.forward_train(s_feat) + + assert output.shape == torch.Size([2, 3, 2, 2]) diff --git a/tests/test_models/test_distillers/test_configurable_distill.py b/tests/test_models/test_distillers/test_configurable_distill.py index ae454389e..62e0292d5 100644 --- a/tests/test_models/test_distillers/test_configurable_distill.py +++ b/tests/test_models/test_distillers/test_configurable_distill.py @@ -2,13 +2,31 @@ import copy from unittest import TestCase +import torch +import torch.nn as nn from mmengine import ConfigDict from mmrazor.models import ConfigurableDistiller +from mmrazor.registry import MODELS + + +class ToyDistillLoss(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, arg1, arg2): + return arg1 + arg2 class TestConfigurableDistiller(TestCase): + def setUp(self): + MODELS.register_module(module=ToyDistillLoss, force=True) + + def tearDown(self): + MODELS.module_dict.pop('ToyDistillLoss') + def test_init(self): recorders_cfg = ConfigDict( @@ -65,3 +83,33 @@ def test_init(self): with self.assertRaisesRegex(TypeError, 'from_student should be a bool'): _ = ConfigurableDistiller(**distiller_kwargs_) + + def test_connector_list(self): + recorders_cfg = ConfigDict( + conv=dict(type='ModuleOutputs', source='conv')) + norm_cfg = dict(type='BN', affine=False, track_running_stats=False) + + distiller_kwargs = ConfigDict( + student_recorders=recorders_cfg, + teacher_recorders=recorders_cfg, + distill_losses=dict(loss_toy=dict(type='ToyDistillLoss')), + loss_forward_mappings=dict( + loss_toy=dict( + arg1=dict( + from_student=True, + recorder='conv', + connector='loss_1_sfeat'), + arg2=dict(from_student=False, recorder='conv'), + )), + connectors=dict(loss_1_sfeat=[ + dict( + type='ConvModuleConnector', + in_channel=3, + out_channel=4, + act_cfg=None), + dict(type='NormConnector', norm_cfg=norm_cfg, in_channels=4) + ])) + + distiller = ConfigurableDistiller(**distiller_kwargs) + connectors = distiller.connectors + self.assertIsInstance(connectors['loss_1_sfeat'], nn.Sequential) diff --git a/tests/test_models/test_losses/test_distillation_losses.py b/tests/test_models/test_losses/test_distillation_losses.py index 77233b81f..a15b78c8e 100644 --- a/tests/test_models/test_losses/test_distillation_losses.py +++ b/tests/test_models/test_losses/test_distillation_losses.py @@ -1,16 +1,89 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Union from unittest import TestCase import torch +from mmdet.models.utils import unpack_gt_instances +from mmdet.testing import demo_mm_inputs from mmengine.structures import BaseDataElement +from torch import Tensor from mmrazor import digit_version from mmrazor.models import (ABLoss, ActivationLoss, ATLoss, CRDLoss, DKDLoss, - FBKDLoss, FTLoss, InformationEntropyLoss, + FBKDLoss, FGDLoss, FTLoss, InformationEntropyLoss, KDSoftCELoss, MGDLoss, OFDLoss, OnehotLikeLoss, PKDLoss) +# copied from mmyolo +def gt_instances_preprocess(batch_gt_instances: Union[Tensor, Sequence], + batch_size: int) -> Tensor: + """Split batch_gt_instances with batch size, from [all_gt_bboxes, 6] to. + + [batch_size, number_gt, 5]. If some shape of single batch smaller than + gt bbox len, then using [-1., 0., 0., 0., 0.] to fill. + + Args: + batch_gt_instances (Sequence[Tensor]): Ground truth + instances for whole batch, shape [all_gt_bboxes, 6] + batch_size (int): Batch size. + + Returns: + Tensor: batch gt instances data, shape [batch_size, number_gt, 5] + """ + if isinstance(batch_gt_instances, Sequence): + max_gt_bbox_len = max( + [len(gt_instances) for gt_instances in batch_gt_instances]) + # fill [0., 0., 0., 0., 0.] if some shape of + # single batch not equal max_gt_bbox_len + batch_instance_list = [] + for index, gt_instance in enumerate(batch_gt_instances): + bboxes = gt_instance.bboxes + labels = gt_instance.labels + batch_instance_list.append( + torch.cat((labels[:, None], bboxes), dim=-1)) + + if bboxes.shape[0] >= max_gt_bbox_len: + continue + + fill_tensor = bboxes.new_full( + [max_gt_bbox_len - bboxes.shape[0], 5], 0) + batch_instance_list[index] = torch.cat( + (batch_instance_list[index], fill_tensor), dim=0) + + return torch.stack(batch_instance_list) + else: + # faster version + # format of batch_gt_instances: + # [img_ind, cls_ind, x1, y1, x2, y2] + + # sqlit batch gt instance [all_gt_bboxes, 6] -> + # [batch_size, max_gt_bbox_len, 5] + assert isinstance(batch_gt_instances, Tensor) + if len(batch_gt_instances) > 0: + gt_images_indexes = batch_gt_instances[:, 0] + max_gt_bbox_len = gt_images_indexes.unique( + return_counts=True)[1].max() + # fill [0., 0., 0., 0., 0.] if some shape of + # single batch not equal max_gt_bbox_len + batch_instance = torch.zeros((batch_size, max_gt_bbox_len, 5), + dtype=batch_gt_instances.dtype, + device=batch_gt_instances.device) + + for i in range(batch_size): + match_indexes = gt_images_indexes == i + gt_num = match_indexes.sum() + if gt_num: + batch_instance[i, :gt_num] = batch_gt_instances[ + match_indexes, 1:] + else: + batch_instance = torch.zeros((batch_size, 0, 5), + dtype=batch_gt_instances.dtype, + device=batch_gt_instances.device) + + return batch_instance + + class TestLosses(TestCase): @classmethod @@ -211,3 +284,18 @@ def test_mgd_loss(self): feats_S, feats_T = torch.rand(2, 256, 4, 4), torch.rand(2, 256, 4, 4) loss = mgd_loss(feats_S, feats_T) self.assertTrue(loss.numel() == 1) + + def test_fgd_loss(self): + fgd_loss = FGDLoss(in_channels=3) + packed_inputs = demo_mm_inputs(2, [[3, 320, 128], [3, 125, 320]]) + batch_gt_instances, batch_gt_instances_ignore, batch_img_metas =\ + unpack_gt_instances(packed_inputs['data_samples']) + gt_info = gt_instances_preprocess(batch_gt_instances, 2) + for meta in batch_img_metas: + meta.update({'batch_input_shape': meta['img_shape']}) + + preds_S = torch.rand(2, 3, 80, 32) + preds_T = torch.rand(2, 3, 80, 32) + + loss = fgd_loss(preds_S, preds_T, gt_info, batch_img_metas) + self.assertTrue(loss.numel() == 1)