当前位置:网站首页>Some thoughts on super in pytorch, combined with code

Some thoughts on super in pytorch, combined with code

2022-04-23 20:47:00 NuerNuer

class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()#inherit fa class's init method
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
# This class must define member functions , Otherwise, in the TransformerEncoderBlock Calling the forward Method will report an error .

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

# here  FeedForwardBlock The parent of a class is nn.Sequential, And through super().__init__() Displays the initial class that called the parent class 
# The initial method , Therefore, it inherits the attributes and methods of classification , So when you use it , Even if FeedForwardBlock The definition method function is not displayed 
# Count , You can also call the... Of the parent class forward Method .

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

# It's a little complicated here , Mainly ResidualAdd Inside nn.Sequential It's confusing , there nn.Sequential And 
#TransformerEncorderBlock The parent class of doesn't have much relevance meaning , When used here , Initialize one first ResidualAdd object ,
# then , Because here TransformerEncorderBlock It still shows that the initialization method of the parent class is called , therefore , Inherits the properties of the parent class 
# And methods , When in use, it will call nn.Squential Inside forward Method .


class Sequential(Module):

    def __init__(self, *args: Any):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input):
        for module in self:
            input = module(input)
        return input

# About parent classes nn.Sequential The main part of the code , We can see that , Its initialization parameters are instantiated objects , So every instantiated object 
# All classes must have forward function , Otherwise, in the input = module(input) It's a mistake , That's why ResidualAdd have to 
# Must have forward Why 

Reference resources :https://blog.csdn.net/ZEdwin/article/details/117296675
            Give Way python Class is called directly _csldh The blog of -CSDN Blog _python Call the class directly
           https://blog.csdn.net/a__int__/article/details/104600972

版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545522862.html