在使用 PyTorch 自定义神经网络时,我们经常会看到如下代码:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# 模型结构定义
你是否想过:
super().__init__()
?nn.Module.__init__(self)
有什么区别?本文将从 Python 面向对象的机制、PyTorch 模型构造原理、代码实践和错误示例等多个角度,带你深入理解这一行代码背后的意义。
super()
是什么?super()
是 Python 提供的内置函数,用于调用父类方法。它常用于类的初始化过程中,尤其在继承链中有多个父类时。
在 Python 中,类的继承是通过“方法解析顺序(MRO)”决定的,super()
会按照这个顺序自动调用合适的父类方法,而不用我们手动指定。
class Base:
def __init__(self):
print("Base init")
class Sub(Base):
def __init__(self):
super().__init__() # 推荐
print("Sub init")
super().__init__()
?在 PyTorch 中,我们的模型类通常继承自 nn.Module
。这是因为 PyTorch 的很多功能(如参数注册、模块嵌套、模型保存等)都依赖于 nn.Module
的初始化机制。
super().__init__()
的后果:model.parameters()
正确识别;to(device)
、eval()
)可能失效。super()
):class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
super()
):class MyModel(nn.Module):
def __init__(self):
self.linear = nn.Linear(10, 1) # 会报错或行为异常
运行这段错误代码可能导致:
AttributeError: 'MyModel' object has no attribute '_parameters'
因为父类 nn.Module
的初始化没有被执行,底层所需的属性未创建。
super()
vs nn.Module.__init__(self)
写法 | 优点 | 缺点 |
---|---|---|
super().__init__() | 简洁、支持多继承、推荐写法 | 无 |
nn.Module.__init__(self) | 写法直观 | 不支持多继承,容易出错 |
推荐:始终使用 super()
。
是的!super()
是 Python 的通用特性,不仅限于 PyTorch。在任何需要继承和初始化父类的场景下都应该使用它,特别是在涉及多个父类时。
super().__init__()
是初始化父类的标准做法,保证继承机制正常工作;nn.Module
的初始化,否则模型行为异常;super()
代替硬编码父类名,更灵活、安全、可维护;nn.Module
的类都应该使用 super().__init__()
;nn.Module.__init__(self)
;super()
的原理,对写出可维护、高质量的代码至关重要。