PyTorch/PyTorch Geometric 使用 MemoryDataset 加快数据读取速度
PyTorch Geometric Library (简称 PyG) 是一个基于 PyTorch 的图神经网络库,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相关论文中的方法实现和常用数据集,并且提供了简单易用的接口来生成图,因此对于复现论文来说也是相当方便。用法大多数和 PyTorch 很相近。
在我之前的一个实验中,我发现 PyTorch Geometric 的运行速度很慢,为了提升速度,尝试了许多常用的数据预加载等方法均无效。后来才发现 PyTorch Geometric 提供了 InMemoryDataset 来提前把所有数据一次性加载到内存中。
PyG 提供 2 种不同的Dataset
:
- InMemoryDataset:使用这个
Dataset
会一次性把数据全部加载到内存中。 - Dataset: 使用这个
Dataset
每次加载一个数据到内存中,比较常用。
其中后者 Dataset
,我们已经在前文图神经网络 PyTorch Geometric 入门教程 中介绍过,下面介绍 InMemoryDataset
的使用方法。
我们需要继承 InMemoryDataset
,实现 4 个方法:
raw_file_names()
:用于返回raw_dir
文件夹里的文件列表。processed_file_names()
:用于返回processed_dir
文件夹里的文件列表,我们这里返回只有一个元素的列表即可,如['data.pt']
。download()
:下载数据到raw_dir
文件夹中,一般写pass
。process()
:处理raw_dir
的文件,并保存到processed_dir
。这里需要定义一个 list,把处理后的数据添加到 list 中,把这些数据保存到self.processed_paths[0]
中。1
2data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])同理,在
__init__()
函数中需要加载data.pt
里的所有数据1
self.data, self.slices = torch.load(self.processed_paths[0])
例子如下:
1 | class Graph_2D_Memory_Dataset(InMemoryDataset): |