본문으로 바로가기

논문 구현을 하다보면 backbone model로 pre-trained model을 사용하는 경우가 많다. 방대한 양의 데이터를 처음부터 직접 학습시키는데 많은 시간이 소요되기 때문이다. 따라서 기존의 학습된 모델에 논문 저자의 아이디어를 적용시키는 경우가 많다.

모든 코드를 직접 작성해 모델의 architecture를 구성하고 그 위에 weight parameter를 입히는 방법도 있다. 하지만 block module과 같이 새로운 layer를 간단하게 추가하여 실험 환경을 빠르게 구축하는 방법도 있다.

 

import torch
from torch import nn

# 학습된 모델 불러오기
qwer = torch.hub.load('facebookresearch/pytorchvideo', model = 'slow_r50', pretrained = True)

위와 같이 모델을 불러오게 되면 architecture 구성을 볼 수 있다. 각 block, layer 등에 접근하는 방법은 아래와 같다.

새로운 layer를 삽입하기 위해 구글링을 하며 다양한 방법을 사용해 보았는데 성공한 방법은 아래와 같다.

import gcblock
from gcblock import ContextBlock

qwer.blocks[1] = nn.Sequential(qwer.blocks[1], ContextBlock(inplanes = 256, ratio = 4))

선언해준 block을 내가 원하는 block이나 layer에 직접 접근하여 sequential로 묶어주는 방법이다.

기존의 weight parameter 값이 사라졌는지 확인했는데 무사히 남아있었고 정상적으로 작동하긴 한다. 다만, 이 방법이 옳거나 효율적인지는 모르겠지만 자주 사용할 것 같아서 글로 남겨본다.


실패 사례

class로 선언하여 시도하였는데 다 밑으로 몰리는 결과가 나온다...

import gcblock
from gcblock import ContextBlock

pretrained = torch.hub.load('facebookresearch/pytorchvideo', model = 'slow_r50', pretrained = True)

class plz(nn.Module):
    def __init__(self, pretrain_model, ContextBlock):
        super().__init__()
        self.pre_block_1 = pretrain_model.blocks[0]
        self.pre_block_2 = pretrain_model.blocks[1]
        self.pre_block_3 = pretrain_model.blocks[2]
        self.pre_block_4 = pretrain_model.blocks[3]
        self.pre_block_5 = pretrain_model.blocks[4]
        self.pre_block_6 = pretrain_model.blocks[5]
        
        self.GC_block_1 = ContextBlock(inplanes = 256, ratio = 4)
        self.GC_block_2 = ContextBlock(inplanes = 512, ratio = 4)
        self.GC_block_3 = ContextBlock(inplanes = 1024, ratio = 4)
        
    def forward(self, x):
        x = self.pre_block_1(x)
        x = self.pre_block_2(x)
        x = self.GC_block_1(x)
        x = self.pre_block_3(x)
        x = self.GC_block_2(x)
        x = self.pre_block_4(x)
        x = self.GC_block_3(x)
        x = self.pre_block_5(x)
        x = self.pre_block_6(x)
        
        return x
        
asdf = plz(pretrained, ContextBlock)