pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

admin2024-08-20  7

在 PyTorch 中,.pt 模型文件通常包含以下几类数据:

        模型参数:

                存储模型的权重和偏置参数。

        优化器状态:

                包含优化器的状态信息,以便在恢复训练时能够从中断的地方继续。

        训练状态:

                一些训练过程中的信息,例如当前的 epoch 数和训练进度。

        其他元数据:

                包括模型的配置、训练时使用的超参数等。

        在讲解pytorch pt(pth)文件中保存了什么内容之前,需要先了解pt在保存时保存了那些参数。

以YOLO系列pt保存代码来介绍说明:

1. 模型保存代码:

 def save_model(self):
        ckpt = {
            'epoch': self.epoch, #
            'best_fitness': self.best_fitness,
            'model': deepcopy(de_parallel(self.model)).half(),
            'ema': deepcopy(self.ema.ema).half(),
            'updates': self.ema.updates,
            'optimizer': self.optimizer.state_dict(),
            'train_args': vars(self.args),  # save as dict
            'date': datetime.now().isoformat(),
            'version': __version__}
        # Use dill (if exists) to serialize the lambda functions where pickle does not do this
        try:
            import dill as pickle
        except ImportError:
            import pickle
        # Save last, best and delete
        torch.save(ckpt, self.last, pickle_module=pickle)
        if self.best_fitness == self.fitness:
            torch.save(ckpt, self.best, pickle_module=pickle)
        if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
            torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
        del ckpt

参数说明:

        'epoch': 当前的训练轮次数。

        'best_fitness': 最佳性能指标的数值。

        'model': 深拷贝(deepcopy)并将模型参数进行半精度(half)转换后的模型。

        'ema': 深拷贝并将指数移动平均模型参数进行半精度转换后的指数移动平均模型。

        'updates': 指数移动平均模型的更新次数。

        'optimizer': 优化器的状态字典(state_dict)。

        'train_args': 训练参数的字典表示,使用vars(self.args)将self.args对象转换为字典。

        'date': 当前的日期和时间,使用datetime.now().isoformat()获取。

        'version': 代码的版本号,通过__version__获取。

        其中:model中保存的模型的结构,train_args中保存训练时的一些参数(超参数)。

通过上述功能函数可以看到pytorch保存的pt文件中的内容。

补充说明:

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

        torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

2. 模型加载介绍

下面通过Debug来详解pt中的具体内容:

首先加载模型,代码如下:

import sys
import argparse
import os
import struct
import torch
pt_file = "./yolov8n.pt"
wts_file = "./yolov8n.wts"
# Initialize
device = 'cpu'
# Load model
modelAll = torch.load(pt_file, map_location=device)
model = modelAll['model'].float()  # load to FP32
#model = torch.load(pt_file, map_location=device)['model'].float()  # load to FP32

anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchors')
model.to(device).eval()
with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

 Debug结果如下所示,分别对应save_model()中保存的内容

pytorch训练后pt模型中保存内容详解(yolov8n.pt为例),第1张

其中model(model = modelAll['model'].float())中内容如下:

pytorch训练后pt模型中保存内容详解(yolov8n.pt为例),第2张       model的类型为DetectionModel,里面包含了模型结构(model.model)以及参数信息(model.args)及构造网络时的配置参数信息(model.yaml)以及目标类别及个数、stride等信息。 

pytorch训练后pt模型中保存内容详解(yolov8n.pt为例),第3张

3. 模型权重解析保存

        model.state_dict()是一个字典,键是参数的名称,值是对应的 tensor。

        其中保存着模型的权重(Weights)和偏置值(Biases)以及运行均值和方差(例如,Batch Normalization 层的 running_mean 和 running_var,用于推理时)等信息。

        权重解析保存代码如下:

with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

代码功能介绍:

  1. 使用写模式打开一个文件 wts_file,以便保存模型的参数。
  2. 将模型参数的数量写入文件。
  3. 循环遍历每个参数的键名 k 和对应的值 v。
  4. 将参数 v 重塑为一维数组,并将其从 GPU 移动到 CPU(如果适用),然后转换为 NumPy 数组。
  5. 写入参数的名称和长度。
    for vv in vr:
        f.write(' ')
        f.write(struct.pack('>f', float(vv)).hex())

        遍历每个参数值,使用大端格式(‘>’)将其转换为浮点数并写入文件.

pt解包后保存后的文件内容如下:

pytorch训练后pt模型中保存内容详解(yolov8n.pt为例),第4张

上述代码可以将pt格式模型,转化为Nvidia TensorRT部署需要的文件。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明原文出处。如若内容造成侵权/违法违规/事实不符,请联系SD编程学习网:675289112@qq.com进行投诉反馈,一经查实,立即删除!