在使用 PyTorch 自定义神经网络时,我们经常会看到如下代码:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# 模型结构定义你是否想过:
super().__init__()?nn.Module.__init__(self) 有什么区别?本文将从 Python 面向对象的机制、PyTorch 模型构造原理、代码实践和错误示例等多个角度,带你深入理解这一行代码背后的意义。
super() 是什么?super() 是 Python 提供的内置函数,用于调用父类方法。它常用于类的初始化过程中,尤其在继承链中有多个父类时。
在 Python 中,类的继承是通过“方法解析顺序(MRO)”决定的,super() 会按照这个顺序自动调用合适的父类方法,而不用我们手动指定。
class Base:
def __init__(self):
print("Base init")
class Sub(Base):
def __init__(self):
super().__init__() # 推荐
print("Sub init")super().__init__()?在 PyTorch 中,我们的模型类通常继承自 nn.Module。这是因为 PyTorch 的很多功能(如参数注册、模块嵌套、模型保存等)都依赖于 nn.Module 的初始化机制。
super().__init__() 的后果:model.parameters() 正确识别;to(device)、eval())可能失效。super()):class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)super()):class MyModel(nn.Module):
def __init__(self):
self.linear = nn.Linear(10, 1) # 会报错或行为异常运行这段错误代码可能导致:
AttributeError: 'MyModel' object has no attribute '_parameters'因为父类 nn.Module 的初始化没有被执行,底层所需的属性未创建。
super() vs nn.Module.__init__(self)写法 | 优点 | 缺点 |
|---|---|---|
super().__init__() | 简洁、支持多继承、推荐写法 | 无 |
nn.Module.__init__(self) | 写法直观 | 不支持多继承,容易出错 |
推荐:始终使用 super()。
是的!super() 是 Python 的通用特性,不仅限于 PyTorch。在任何需要继承和初始化父类的场景下都应该使用它,特别是在涉及多个父类时。
super().__init__() 是初始化父类的标准做法,保证继承机制正常工作;nn.Module 的初始化,否则模型行为异常;super() 代替硬编码父类名,更灵活、安全、可维护;nn.Module 的类都应该使用 super().__init__();nn.Module.__init__(self);super() 的原理,对写出可维护、高质量的代码至关重要。