Search

Lite-HRNet: A Lightweight High-Resolution Network

리뷰완료날짜
리뷰어
카테고리
인공지능
출판
CVPR 2021
파일
https://arxiv.org/abs/2104.06403
태그
backbone
본 연구는 백본 네트워크에 관한 연구로써 적은 파라미터 수로 좋은 성능을 내는 네트워크에 대해 소개한다. 대표적으로 conditional channel weighting이라는 효율적인 유닛을 소개하며 이는 shuffle block의 pointwise (1x1) convolution을 대체한다. 가령 다중 해상 피쳐인 64x64x40과 32x31x80에 대해 conditional channel weighting unit은 기존의 shuffle block의 전체 계산 복잡도를80%까지 줄일 수 있다.
일반적인 컨볼루션 커널의 가중치의 경우 모델의 파라미터로써 학습을 진행하지만 이와 달리 제안된 schema weight의 경우 lightweight unit을 통해 입력 맵에서 conditioned 되고 채널 전체에서 계산된다. 그러므로 모든 채널맵에 대한 정보를 모두 포함하고 있고 채널 weight를 통해 정보를 교환하는 bridge로써 활용된다. 게다가 병렬 다중 해상 채널 맵으로 부터 가중치가 계산되기 때문에 HRNet이 강건하고 가중치에 많은 정보를 포함할 수 있도록 한다. 본 연구에서는 이 네트워크를 Lite-HRNet이라 이름 붙였다.
요약
HRNet에 shuffle block을 적용하여 경량 네트워크에서 좋은 성능을 보이는 naive Lite-HRNet을 보였다. 이에 대한 증명으로 MobileNet과 ShuffleNet, 그리고 Small HRNet과 비교 실험을 진행하였다.
shuffle block 내에 있는 1x1 convolution을 대체하기 위해 conditional channel weighting unit을 적용하였다.
Lite-HRNet은 COCO와 MPII 데이터에서 복잡도와 정확도 면에서 가장 높은 성능을 보였다.

접근방식

Naive Lite-HRNet

Shuffle blocks

ShuffleNet V2에 포함된 shuffle block은 먼저 채널을 두개의 파티션으로 분리한다. 첫번째 파티션은 1x1 convolution과 3x3 depth-wise convolution, 그리고 1x1 convolution으로 구성되고 마지막 출력에서 다른 파티션과 합쳐지는 방식으로 네트워크가 설계되었다. 그리고 마지막으로 합쳐진 채널이 셔플되었다.
class ShuffleUnit(nn.Module): """InvertedResidual block for ShuffleNetV2 backbone. Args: in_channels (int): The input channels of the block. out_channels (int): The output channels of the block. stride (int): Stride of the 3x3 convolution layer. Default: 1 conv_cfg (dict): Config dict for convolution layer. Default: None, which means using conv2d. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): Config dict for activation layer. Default: dict(type='ReLU'). with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ def __init__(self, in_channels, out_channels, stride=1, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_cp=False): super().__init__() self.stride = stride self.with_cp = with_cp branch_features = out_channels // 2 self.branch2 = nn.Sequential( ConvModule( in_channels if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1, groups=branch_features, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None), ConvModule( branch_features, branch_features, kernel_size=1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) def forward(self, x): x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) out = channel_shuffle(out, 2) return out
Python
복사

HRNet

HRNet은 첫번째 단계에서 high-resolution의 컨볼루션 stem으로 시작하여 새로운 단계로써 high-resolution의 steam을 하나씩 추가하는 방식으로 설계되어 있다. 각 단계에서 resolution간의 정보가 반복적으로 교환된다. 본 연구에서는 Small HRNet 디자인을 따랐고 본 네트워크에 몇개의 레이어를 활용하였다. Small HRNet의 레이어 구조는 아래와 같다.

Simple combination

본 연구에서는 Small HRNet의 stem 안에 있는 두 개의 3x3 convolution와 모든 residual block을 shuffle block로 대체하였다. 또한 multi-resolution fusion의 일반적인 컨볼루션들은 separable convolution으로 대체하여 naive Lite-HRNet을 구성하였다.

Lite-HRNet

1x1 convolution is costly

1x1 convolution은 각 포지션에 대해 벡터곱을 진행한다. 1x1 convolution은 shuffle operation으로써 채널 사이의 정보를 교환하는 역할을 하며 이에 반헤 depthwise convolution은 채널 사이 정보에 대해 영향을 미치지 않는다.
1x1 convolution의 경우 Θ(C2)\Theta(C^2)의 시간이 걸리며 CC는 채널 수를 의미한다. 이에 반해 depthwise convolution은 Θ(9C)\Theta(9C)의 복잡도를 갖는다. shuffle block의 경우 두 개의 1x1 convolution을 사용하기 때문에 Θ(2C2)>Θ(9C)\Theta(2C^2)>\Theta(9C)로 1x1 convolution이 depthwise convolution보다 훨씬 많은 계산량을 요구하게 된다.

Conditional channel weighting

naive Lite-HRNet에 포함된 1x1 convolution을 대체하기 위해 본 연구에서 제안한 element-wise weighting operation이며 s번째 단계에서는 s개의 브랜치를 가짐.
Ys=WsXs\mathsf{Y}_s = \mathsf{W}_s \odot \mathsf{X}_s
element-wise 곱의 경우 Θ(C)\Theta(C)의 시간 복잡도를 가지며 이는 shuffle block안의 1x1 convolution보다 훨씬 적은 수치이다.

Cross-resolution weight computation

s번째 단계에선 s개의 병렬 resolution이 있고 이에 대응하는 s개의 weight map이 있다. 본 연구에서는 resolution사이의 모든 채널로 부터 s 개의 weight map을 계산하는 함수를 H()\mathcal{H}(\cdot)라 표기하였다.
(X1,X2,...,Xs)Conv.ReLUConv.sigmoid(W1,W2,...,Ws)(\mathsf{X'_1, X'_2,...,X_s}) \rightarrow Conv. \rightarrow ReLU \rightarrow Conv. \rightarrow sigmoid \rightarrow (\mathsf{W'_1, W'_2, ..., W'_s})

Spatial weight computation

Connection