+-
告别代码复制粘贴,傻瓜式提取 PyTorch 中间层特征
首页 专栏 人工智能 文章详情
0

告别代码复制粘贴,傻瓜式提取 PyTorch 中间层特征

超神经HyperAI 发布于 3 月 16 日
内容导读:特征提取是图像处理过程中常需要用到的一种方法,其效果好坏对模型的泛化能力有至关重要的影响。

本文首发自微信公众号「PyTorch 开发者社区」。

特征提取(Feature extraction)在机器学习、模式识别和图像处理中应用广泛。

它从初始的一组测量数据开始,建构出提供信息且不冗余的派生值,即特征值,从而促进后续的学习和泛化步骤。

在使用 PyTorch 进行模型训练的过程中,经常需要提取模型中间层的特征。解决这个问题可以用到 3 种方法。

对中间层进行特征提取的 3 大方法

1、借助模型类的属性传递

方法:修改 forward 函数,通过添加一行代码将 feature 赋值给 self 变量,即 _self.feature_map = feature_,然后打印输出即可。

备注:适用于仅提取中间层特征,不需要提取梯度的情况。

代码示例:

# Define a Convolutional Neural Network class Net(nn.Module):        def __init__(self, kernel_size=5, n_filters=16, n_layers=3):        xxx    def forward(self, x):        x = self.body(self.head(x))        self.featuremap1 = x.detach() # 核心代码        return F.relu(self.fc(x)) model_ft = Net() train_model(model_ft) feature_output1 = model_ft.featuremap1.transpose(1,0).cpu()

2、借助 hook 机制

hook 是一个可调用对象,它可以在不修改主代码的前提下插入业务。PyTorch 中的 hook 包括三种:

torch.autograd.Variable.register_hook

torch.nn.Module.register_backward_hook

torch.nn.Module.register_forward_hook

第一个是针对 Variable 对象的,后两个是针对 nn.Module 对象的。

方法:在调用阶段对 Module 使用 forward_hook 函数,可以获得所需梯度或特征。

备注:较为复杂、功能完善,需要对 PyTorch 有一定程度的了解。

3、借助 torchextractor

torchextractor 是一个独立 Python 包,具有跟 nn.Module 功能类似的提取器,只需提供模块名称,就可以在 PyTorch 中对中间层进行特征提取。

与使用 forward_hook 进行中间层特征提取相比,torchextractor 更像是一个包装程序(wrapper),不像 torchvision IntermediateLayerGetter 有那么多的 _assumption_。

在功能方面 torchextractor 主要优势在于支持嵌套模块(nested module)、自定义缓存操作,而且与 ONNX 兼容。

torchextractor 极大简化了在 PyTorch 中进行特征提取的流程,这避免了大量代码的粘贴复制,也不需要重写 forward 函数,它对初学者更友好,可用性也更强。

torchextractor 上手实践

安装

pip install torchextractor # stable pip install git+https://github.com/antoinebrl/torchextractor.git # latest

要求

Python 3.6 及以上版本

Torch 1.4.0 及以上版本

用法

import torch import torchvision import torchextractor as tx model = torchvision.models.resnet18(pretrained=True) model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"]) dummy_input = torch.rand(7, 3, 224, 224) model_output, features = model(dummy_input) feature_shapes = {name: f.shape for name, f in features.items()} print(feature_shapes) # { #   'layer1': torch.Size([1, 64, 56, 56]), #   'layer2': torch.Size([1, 128, 28, 28]), #   'layer3': torch.Size([1, 256, 14, 14]), #   'layer4': torch.Size([1, 512, 7, 7]), # }

完整文档请查看:

https://github.com/antoinebrl...

以上就是本期汇总的 3 个对中间层进行特征提取的方法,如果你有更好的解决思路,或者其他想要了解的 Pytorch 相关问题,欢迎在下方留言或发私信。

参考:

https://www.reddit.com/r/Mach...

https://www.zhihu.com/questio...

机器学习 人工智能 图像识别 tensorflow pytorch
阅读 658 发布于 3 月 16 日
收藏
分享
本作品系原创, 采用《署名-非商业性使用-禁止演绎 4.0 国际》许可协议
超神经HyperAI
AI 行业实验媒体,站在科技与人文的交叉口,看懂人工智能。微信公众号:HyperAI
关注专栏
avatar
超神经HyperAI
声望
1k 粉丝
关注作者
0 条评论
得票数 最新
提交评论
你知道吗?

注册登录
avatar
超神经HyperAI
声望
1k 粉丝
关注作者
宣传栏
目录
内容导读:特征提取是图像处理过程中常需要用到的一种方法,其效果好坏对模型的泛化能力有至关重要的影响。

本文首发自微信公众号「PyTorch 开发者社区」。

特征提取(Feature extraction)在机器学习、模式识别和图像处理中应用广泛。

它从初始的一组测量数据开始,建构出提供信息且不冗余的派生值,即特征值,从而促进后续的学习和泛化步骤。

在使用 PyTorch 进行模型训练的过程中,经常需要提取模型中间层的特征。解决这个问题可以用到 3 种方法。

对中间层进行特征提取的 3 大方法

1、借助模型类的属性传递

方法:修改 forward 函数,通过添加一行代码将 feature 赋值给 self 变量,即 _self.feature_map = feature_,然后打印输出即可。

备注:适用于仅提取中间层特征,不需要提取梯度的情况。

代码示例:

# Define a Convolutional Neural Network class Net(nn.Module):        def __init__(self, kernel_size=5, n_filters=16, n_layers=3):        xxx    def forward(self, x):        x = self.body(self.head(x))        self.featuremap1 = x.detach() # 核心代码        return F.relu(self.fc(x)) model_ft = Net() train_model(model_ft) feature_output1 = model_ft.featuremap1.transpose(1,0).cpu()

2、借助 hook 机制

hook 是一个可调用对象,它可以在不修改主代码的前提下插入业务。PyTorch 中的 hook 包括三种:

torch.autograd.Variable.register_hook

torch.nn.Module.register_backward_hook

torch.nn.Module.register_forward_hook

第一个是针对 Variable 对象的,后两个是针对 nn.Module 对象的。

方法:在调用阶段对 Module 使用 forward_hook 函数,可以获得所需梯度或特征。

备注:较为复杂、功能完善,需要对 PyTorch 有一定程度的了解。

3、借助 torchextractor

torchextractor 是一个独立 Python 包,具有跟 nn.Module 功能类似的提取器,只需提供模块名称,就可以在 PyTorch 中对中间层进行特征提取。

与使用 forward_hook 进行中间层特征提取相比,torchextractor 更像是一个包装程序(wrapper),不像 torchvision IntermediateLayerGetter 有那么多的 _assumption_。

在功能方面 torchextractor 主要优势在于支持嵌套模块(nested module)、自定义缓存操作,而且与 ONNX 兼容。

torchextractor 极大简化了在 PyTorch 中进行特征提取的流程,这避免了大量代码的粘贴复制,也不需要重写 forward 函数,它对初学者更友好,可用性也更强。

torchextractor 上手实践

安装

pip install torchextractor # stable pip install git+https://github.com/antoinebrl/torchextractor.git # latest

要求

Python 3.6 及以上版本

Torch 1.4.0 及以上版本

用法

import torch import torchvision import torchextractor as tx model = torchvision.models.resnet18(pretrained=True) model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"]) dummy_input = torch.rand(7, 3, 224, 224) model_output, features = model(dummy_input) feature_shapes = {name: f.shape for name, f in features.items()} print(feature_shapes) # { #   'layer1': torch.Size([1, 64, 56, 56]), #   'layer2': torch.Size([1, 128, 28, 28]), #   'layer3': torch.Size([1, 256, 14, 14]), #   'layer4': torch.Size([1, 512, 7, 7]), # }

完整文档请查看:

https://github.com/antoinebrl...

以上就是本期汇总的 3 个对中间层进行特征提取的方法,如果你有更好的解决思路,或者其他想要了解的 Pytorch 相关问题,欢迎在下方留言或发私信。

参考:

https://www.reddit.com/r/Mach...

https://www.zhihu.com/questio...