mmdetection阅读笔记:OptimizerConstructor

2024-03-04 13:05 栏目: 技术学堂 查看()

其余内容见:

mmdetection源码阅读笔记:概览

optimizer构造起来就相对比较复杂了,来看一下config文件中optimizer的配置optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),mmdetecion还是用pytorch的Optimizer类作为优化器,所以我们要用注册器机制将pytorch中的SGD、Adam等都注册了。下面我们先回忆一下pytorch.Optimizer的构造参数,比如:{'params':[Tensor], 'lr': 1, 'momentum': 0, 'weight_decay': 0.1} ,指定了需要优化的tensor以及它对应的学习率、动量等。

最简单的情况下,构造参数为{'params':model.parameters(),'lr':1} ,即对整个model中的所有参数采用同样的学习率。可是往往我们需要对model中的一些组件进行特殊处理,比如特殊指定backbone的学习率、batchnorm层的weight_decay等。这些都在mmcv中的DefaultOptimizerConstructor实现

所以关键弄懂如何注册pytorch中的SGD等优化器、DefaultOptimizerConstructor类如何构造optimizer

# mmcv/runner/optimier/builder.py
import inspect
import torch
from mmcv import Registry, build_from_cfg

OPTIMIZERS=Registry('optimizer') # 定义一个注册器类,用来注册pytorch中的优化器

def register_torch_optimizers():
    for module_name in dir(torch.optim): # 遍历torch.optim中的类
        if module_name.startswith('__'): # '__'开头,如'__name__'、'__path__'等,表示特殊类跳过
            continue
        _optim=getattr(torch.optim, module_name) # torch.optim本质是模块,python万物皆对象,它也可以用attr属性
        if inspect.isclass(_optim) and issubclass(_optim,torch.optim.Optimizer): # 判断是否是优化器
            OPTIMIZERS.register_module(module=_optim) # 这才是注册
            
register_torch_optimizers() # 导入builder.py时,就会执行

一顿操作后,成功注册了pytorch中的优化器SGD等。可以通过dict=(type='SGD')的方式来builder optimer了。

DefaultOptimizerConstructor实现的很巧妙,官方给出了两个例子:

Example 1:
        >>> model=torch.nn.modules.Conv1d(1, 1, 1)
            # optimizer_cfg就是设置优化器类型,lr、momentum的默认配置
        >>> optimizer_cfg=dict(type='SGD', lr=0.01, momentum=0.9,weight_decay=0.0001)
            # paramwise_cfg是关键,可以精确实现对每个param的配置,如果不指定那么默认是optimizer_cfg
        >>> paramwise_cfg=dict(norm_decay_mult=0.)
        >>> optim_builder=DefaultOptimizerConstructor(optimizer_cfg, paramwise_cfg)
        >>> optimizer=optim_builder(model)

 Example 2:
        >>> # assume model have attribute model.backbone and model.cls_head
        >>> optimizer_cfg=dict(type='SGD', lr=0.01, weight_decay=0.95)
            # 如果params的name中匹配到了custom_keys中的key'.backbone',那么该params的lr为0.01*0.1
            # 同理weight_decay
            # 值得注意的是,如果params的name与多个custom_keys中的key匹配,将采用最长子字符串,如果长度还一样,按字母排序
        >>> paramwise_cfg=dict(custom_keys={'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
        >>> optim_builder=DefaultOptimizerConstructor(
        >>>     optimizer_cfg, paramwise_cfg)
        >>> optimizer=optim_builder(model)
        >>> # Then the `lr` and `weight_decay` for model.backbone is
        >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
        >>> # model.cls_head is (0.01, 0.95).

有两个关键方法:

  • add_params方法,就是获得pytorch.Optimizer的构造参数。特别的,将由paramwise_cfg确定params对应的学习率、动量等
  • 重载了__call__方法,调用self.__call__(model)将构造并返回一个pytorch.Optimizer优化器,该优化器的构造参数由add_params方法确定,该优化器负责优化model的参数。
from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS

@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor:
"""
主要有两个参数optimizer_cfg和paramwise_cfg:
    optimizer_cfg确定优化器type、默认的lr、momentum等。其中属性base_lr、base_wd就是optimizer_cfg中的lr与weight_decay。
    paramwise_cfg确定个别模块的lr、momentum等。

以optimizer_cfg=dict(type='SGD', lr=0.01, momentum=0.9,weight_decay=0.0001)为例
如果不指定paramwise_cfg,那么调用self.__call__(model),很简单就是返回
SGD({'params':model.parameters,lr='0.01',momentum='0.9',weight_decay='0.0001'})

特别情况,需要用paramwise_cfg来指定个别模块的lr、momentum、weight_decay。比如DCN、depthwise conv、batchnorm。
所以关键的介绍paramwise_cfg,它是一个dict,包括以下key-value
   - 'custom_keys' (dict): 它的key值是字符串类型,如果custom_keys中的一个key值是一个params的name的子字符串,
      那么该params的lr将由custom_keys[key]['lr_mult']与base_lr相乘来计算,同理weight_decay。
      值得注意的是,如果params的name与多个custom_keys中的key匹配,将采用最长子字符串,如果长度还一样,按字母排序。
      此外,它的value值还是dict字典,可能包括lr_mult和decay_mult字段,同下。
    - 'bias_lr_mult'(float): 所有的bias参数(如conv.bias)的lr等于base_lr*bias_lr_mult。
      注意,norm的bias参数、DCN的offset层!不由bias_lr_mult指定。
    - 'bias_decay_mult' (float): 同上,所有的bias参数(如conv.bias)的weight_decay等于base_wd*bias_decay_mult。 
      注意,norm的bias参数、DCN的offset层、depthwise conv的bias参数!不由bias_lr_mult指定。
    - 'norm_decay_mult' (float): 确定norm的weight和bias参数的weight_decay。
    - 'dwconv_decay_mult' (float): 确定depthwise conv的weight和bias参数的weight_decay。
    - 'dcn_offset_lr_mult'(float): 确定DCN的offset层的学习率。
    - 'bypass_duplicate' (bool): 如果为True,重复的params不会被添加在optimizer

    Note:
        1.  'dcn_offset_lr_mult'会重载'bias_lr_mult'
        2.  custom_keys有最高优先级,会覆盖其他参数
    """

    def __init__(self, optimizer_cfg, paramwise_cfg=None):
        if not isinstance(optimizer_cfg, dict):
            raise TypeError('optimizer_cfg should be a dict',
                            f'but got{type(optimizer_cfg)}')
        self.optimizer_cfg = optimizer_cfg
        self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
        self.base_lr = optimizer_cfg.get('lr', None)
        self.base_wd = optimizer_cfg.get('weight_decay', None)
        …………

    def _is_in(self, param_group, param_group_list):
        assert is_list_of(param_group_list, dict)
        param = set(param_group['params'])
        param_set = set()
        for group in param_group_list:
            param_set.update(set(group['params']))

        return not param.isdisjoint(param_set)

    def add_params(self, params, module, prefix='', is_dcn_module=None):
        """ 
        根据paramwise_cfg,将moduel中的参数放入params中
        参数:
            params (list[dict]): A list of param groups, it will be modified
                in place.
            module (nn.Module): The module to be added.
            prefix (str): The prefix of the module
            is_dcn_module (int|float|None):当前的module是否是DCN的子module
        """
        # get param-wise options
        custom_keys = self.paramwise_cfg.get('custom_keys', {})
        # first sort with alphabet order and then sort with reversed len of str
        sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)

        bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
        dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
        bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
        dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)

        # special rules for norm layers and depth-wise conv layers
        is_norm = isinstance(module,(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
        is_dwconv = (isinstance(module, torch.nn.Conv2d)
            and module.in_channels == module.groups)

        for name, param in module.named_parameters(recurse=False): # recurse为Fasle,不再递归遍历子module
            param_group = {'params': [param]}
            if not param.requires_grad:
                params.append(param_group)
                continue
            if bypass_duplicate and self._is_in(param_group, params):
                warnings.warn(f'{prefix} is duplicate. It is skipped since '
                              f'bypass_duplicate={bypass_duplicate}')
                continue
            # if the parameter match one of the custom keys, ignore other rules
            is_custom = False
            for key in sorted_keys:
                if key in f'{prefix}.{name}': # 如果key是name的子字符串!注意sorted_keys是按长度,再按字母排序
                    is_custom = True
                    lr_mult = custom_keys[key].get('lr_mult', 1.)
                    param_group['lr'] = self.base_lr * lr_mult
                    if self.base_wd is not None:
                        decay_mult = custom_keys[key].get('decay_mult', 1.)
                        param_group['weight_decay'] = self.base_wd * decay_mult
                    break # 找到一个就break

            if not is_custom:
                # bias_lr_mult affects all bias parameters
                # except for norm.bias dcn.conv_offset.bias
                if name == 'bias' and not (is_norm or is_dcn_module):
                    param_group['lr'] = self.base_lr * bias_lr_mult

                if (prefix.find('conv_offset') != -1 and is_dcn_module
                        and isinstance(module, torch.nn.Conv2d)):
                    # deal with both dcn_offset's bias & weight
                    param_group['lr'] = self.base_lr * dcn_offset_lr_mult

                # apply weight decay policies
                if self.base_wd is not None:
                    # norm decay
                    if is_norm:
                        param_group['weight_decay'] = self.base_wd * norm_decay_mult
                    # depth-wise conv
                    elif is_dwconv:
                        param_group['weight_decay'] = self.base_wd * dwconv_decay_mult
                    # bias lr and decay
                    elif name == 'bias' and not is_dcn_module:
                        # TODO: current bias_decay_mult will have affect on DCN
                        param_group['weight_decay'] = self.base_wd * bias_decay_mult
            params.append(param_group)

        if check_ops_exist():
            from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
            is_dcn_module = isinstance(module,(DeformConv2d, ModulatedDeformConv2d))
        else:
            is_dcn_module = False
        for child_name, child_mod in module.named_children():
            child_prefix = f'{prefix}.{child_name}' if prefix else child_name
            self.add_params(
                params,
                child_mod,
                prefix=child_prefix,
                is_dcn_module=is_dcn_module)

    def __call__(self, model):
        if hasattr(model, 'module'): # 如果有module属性,说明是被DataParallel封装后的,需要取出module
            model = model.module

        optimizer_cfg = self.optimizer_cfg.copy()
        # if no paramwise option is specified, just use the global setting
        if not self.paramwise_cfg: 
            optimizer_cfg['params'] = model.parameters()
            return build_from_cfg(optimizer_cfg, OPTIMIZERS)

        # set param-wise lr and weight decay recursively
        params = []
        self.add_params(params, model) # 获得优化器的构造参数,确定model参数的lr、momentum等,保存在params中
        optimizer_cfg['params'] = params

        # OPTIMIZERS就是定义在mmcv/runner/optimier/builder.py中,注册了pytorch中的优化器
        return build_from_cfg(optimizer_cfg, OPTIMIZERS) 

最后我们只需要实例化optimizer_constructor=DefaultOptimizerConstructor(……) ,调用self.__call__方法就可获得优化器optimizer=optimizer_constructor(model)

最后,OptimizerConstructor也可以注册,这样方便拓展,完全可以自定义构造优化器的方法。

扫二维码与项目经理沟通

我们在微信上24小时期待你的声音

解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流

郑重申明:某某网络以外的任何单位或个人,不得使用该案例作为工作成功展示!

平台注册入口