PyTorch/PyTorch Geometric 使用 MemoryDataset 加快数据读取速度

PyTorch/PyTorch Geometric 使用 MemoryDataset 加快数据读取速度

PyTorch Geometric Library (简称 PyG) 是一个基于 PyTorch 的图神经网络库,地址是:https://github.com/rusty1s/pytorch_geometric。它包含了很多 GNN 相关论文中的方法实现和常用数据集,并且提供了简单易用的接口来生成图,因此对于复现论文来说也是相当方便。用法大多数和 PyTorch 很相近。

在我之前的一个实验中,我发现 PyTorch Geometric 的运行速度很慢,为了提升速度,尝试了许多常用的数据预加载等方法均无效。后来才发现 PyTorch Geometric 提供了 InMemoryDataset 来提前把所有数据一次性加载到内存中。

PyG 提供 2 种不同的Dataset

  • InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
  • Dataset: 使用这个Dataset每次加载一个数据到内存中,比较常用。

其中后者 Dataset,我们已经在前文图神经网络 PyTorch Geometric 入门教程 中介绍过,下面介绍 InMemoryDataset 的使用方法。

我们需要继承 InMemoryDataset,实现 4 个方法:

  • raw_file_names():用于返回 raw_dir 文件夹里的文件列表。

  • processed_file_names():用于返回 processed_dir 文件夹里的文件列表,我们这里返回只有一个元素的列表即可,如 ['data.pt']

  • download():下载数据到 raw_dir 文件夹中,一般写pass

  • process():处理raw_dir 的文件,并保存到processed_dir 。这里需要定义一个 list,把处理后的数据添加到 list 中,把这些数据保存到 self.processed_paths[0] 中。

    1
    2
    data, slices = self.collate(data_list)
    torch.save((data, slices), self.processed_paths[0])

    同理,在 __init__()函数中需要加载 data.pt里的所有数据

    1
    self.data, self.slices = torch.load(self.processed_paths[0])

例子如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class Graph_2D_Memory_Dataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(Graph_2D_Memory_Dataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])

# return file list of self.raw_dir
@property
def raw_file_names(self):
all_filenames = glob.glob(os.path.join(self.raw_dir, "*.mat"))
# get all file names
file_names = [f.split(os.sep)[-1] for f in all_filenames]
return file_names

# get all file names in self.processed_dir
@property
def processed_file_names(self):
return ['data.pt']


def download(self):
pass


# convert the mat files of self.raw_dir to torch_geometric.Data format, save the result files in self.processed_dir
# this method will only execute one time at the first running.
def process(self):
data_list = []
for raw_path in self.raw_paths:

content = sio.loadmat(raw_path)
feature = torch.tensor(content["feature"]).float()
edge_index = torch.tensor(
np.array(content["edges"]).astype(np.int32), dtype=torch.long
)
# 构建 2D Graph
pos = torch.tensor(content["pseudo"], dtype=torch.float32)[:, 1:3]
# pos = torch.tensor(np.array(content["pseudo"]), dtype=torch.float32)
label_idx = torch.tensor(int(content["label"]), dtype=torch.long)
data = Data(
x=feature, edge_index=edge_index, pos=pos, y=label_idx.unsqueeze(0)
)

if self.pre_filter is not None and not self.pre_filter(data):
continue

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])

评论