在 PyTorch 中,nn.Module 是所有神经网络模型的基类,提供了许多重要的成员函数。以下是一些常用的成员函数及其功能:
1. __init__(self)
描述:初始化模块。在用户定义的模型中,通常用来定义层和其他模块。
示例:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = nn.Conv2d(1, 16, 3)
2. forward(self, *input)
描述:定义前向传播逻辑。必须实现此方法,用于定义如何通过模型进行推理。
示例:
def forward(self, x):
return self.conv(x)
3. parameters(self, recurse=True)
描述:返回模型中所有可学习参数的迭代器。可以选择是否递归到子模块。
示例:
for param in model.parameters():
print(param.shape)
4. named_parameters(self, recurse=True)
描述:与 parameters() 类似,但是返回一个包含参数名称和值的元组。
示例:
for name, param in model.named_parameters():
print(name, param.shape)
5. modules(self)
描述:返回模型中所有子模块的迭代器。
示例:
for module in model.modules():
print(module)
6. named_modules(self, memo=None, prefix='')
描述:返回一个包含模块名称和实例的迭代器,可以使用 memo 防止循环引用。
示例:
f = open("model_modules.txt","w")
for k, v in model.named_modules():
f.write("{}\n".format(k))
f.write("{}\n".format(v))
f.close()
保存内容(部分,yolov8n)如下:
7. train(self, mode=True)
描述:设置模块为训练模式或评估模式。训练模式会启用 Dropout 和 BatchNorm 等层的训练行为。
示例:
model.train() # 训练模式
model.eval() # 评估模式
8. to(self, *args, **kwargs)
描述:将模型及其参数移动到指定设备(如 GPU、CPU)或转换为指定数据类型。
示例:
model.to('cuda') # 移动到 GPU
9. load_state_dict(self, state_dict, strict=True)
描述:加载模型的状态字典。可以控制是否严格匹配参数名。
示例:
model.load_state_dict(torch.load('model.pth'))
10. state_dict(self)
描述:返回模型的状态字典,包含所有可学习参数和缓冲区的状态。
示例:
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
print("key={0}, v={1}".format(k,v))
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
保存内容(部分)如下:
11. apply(self, fn)
描述:递归地将函数 fn 应用到模块及其子模块中。常用于初始化参数或修改子模块的行为。
示例:
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=0.01)
model.apply(init_weights)
12. forward_hooks
描述:在前向传播过程中,可以注册钩子函数以在输入和输出之间修改数据。
使用示例:
def hook_fn(module, input, output):
print(f'Input: {input}, Output: {output}')
hook = model.conv.register_forward_hook(hook_fn)
13. backward_hooks
描述:在反向传播过程中,可注册钩子函数以修改梯度。
使用示例:
def backward_hook(module, grad_input, grad_output):
print(f'Grad Input: {grad_input}, Grad Output: {grad_output}')
hook = model.conv.register_backward_hook(backward_hook)
14. trainable
描述:可以通过设置 requires_grad 属性控制哪些参数参与训练。
示例:
for param in model.parameters():
param.requires_grad = False # 冻结参数
15. extra_repr(self)
描述:可以重写此方法以添加额外的模块描述信息。通常在调用 print(model) 时会显示。
示例:
def extra_repr(self):
return f"Input size: {self.input_size}, Output size: {self.output_size}"
16.高级用法
16.1 自定义损失函数:
通过继承 nn.Module 来定义自定义损失函数。
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, output, target):
return torch.mean((output - target) ** 2) # 均方误差
16.2 使用预训练模型:
可以利用 torchvision.models 中的预训练模型,并根据需求修改模型。
import torchvision.models as models
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes) # 替换最后一层
16.3 模型集成:
可以通过将多个模型结合在一起,创建一个更复杂的模型。
class EnsembleModel(nn.Module):
def __init__(self, model1, model2):
super(EnsembleModel, self).__init__()
self.model1 = model1
self.model2 = model2
def forward(self, x):
return (self.model1(x) + self.model2(x)) / 2 # 平均结果
16.4 序列模型:
通过 nn.Sequential 构建简单的线性网络。
model = nn.Sequential(
nn.Conv2d(1, 16, 3),
nn.ReLU(),
nn.Linear(16 * 6 * 6, 10)
)
总结:
nn.Module 提供了许多常用的功能,方便构建和管理神经网络模型。了解这些成员函数有助于更有效地使用 PyTorch 进行深度学习任务