欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 幼教 > 【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

2025/9/20 18:18:33 来源:https://blog.csdn.net/weixin_43935696/article/details/144942305  浏览:    关键词:【Python实现连续学习算法】复现2018年ECCV经典算法RWalk

Python实现连续学习Baseline 及经典算法RWalk

在这里插入图片描述

1 连续学习概念及灾难性遗忘

连续学习(Continual Learning)是一种模拟人类学习过程的机器学习方法,它旨在让模型在面对多个任务时能够连续学习,而不会遗忘已学到的知识。然而,大多数深度学习模型在连续学习多个任务时会出现“灾难性遗忘”(Catastrophic Forgetting)现象。灾难性遗忘指模型在学习新任务时会大幅度遗忘之前学到的任务知识,这是因为模型参数在新任务的训练过程中被完全覆盖。

解决灾难性遗忘问题是连续学习研究的核心。目前已有多种方法被提出,包括正则化方法、回放、架构等等的方法,其中EWC(Elastic Weight Consolidation)是一种经典的正则化方法。

2 PermutdMNIST数据集及模型

PermutedMNIST是连续学习领域的一种经典测试数据集。它通过对MNIST数据集中的像素进行随机置换生成不同的任务。每个任务都是一个由置换规则决定的分类问题,但所有任务共享相同的标签空间。

对于模型的选择,通常采用简单的全连接神经网络。网络结构可以包含若干个隐藏层,每个隐藏层具有一定数量的神经元,并使用ReLU作为激活函数。网络的输出层与标签类别数一致。

模型在训练每个任务时需要调整参数,研究灾难性遗忘问题的严重程度,并在引入算法时测试其对连续学习能力的改善效果。

import random
import torch
from torchvision import datasets
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falseclass PermutedMNIST(datasets.MNIST):def __init__(self, root="./data/mnist", train=True, permute_idx=None):super(PermutedMNIST, self).__init__(root, train, download=True)assert len(permute_idx) == 28 * 28if self.train:self.data = torch.stack([img.float().view(-1)[permute_idx] / 255for img in self.data])else:self.data = torch.stack([img.float().view(-1)[permute_idx] / 255for img in self.data])def __getitem__(self, index):if self.train:img, target = self.data[index], self.train_labels[index]else:img, target = self.data[index], self.test_labels[index]return img.view(1, 28, 28), targetdef get_sample(self, sample_size):random.seed(2024)sample_idx = random.sample(range(len(self)), sample_size)return [img.view(1, 28, 28) for img in self.data[sample_idx]]
def worker_init_fn(worker_id):# 确保每个 worker 的随机种子一致random.seed(2024 + worker_id)np.random.seed(2024 + worker_id)
def get_permute_mnist(num_task, batch_size):random.seed(2024)train_loader = {}test_loader = {}root_dir = './data/permuted_mnist'os.makedirs(root_dir, exist_ok=True)for i in range(num_task):permute_idx = list(range(28 * 28))random.shuffle(permute_idx)train_dataset_path = os.path.join(root_dir, f'train_dataset_{i}.pt')test_dataset_path = os.path.join(root_dir, f'test_dataset_{i}.pt')if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):train_dataset = torch.load(train_dataset_path)test_dataset = torch.load(test_dataset_path)else:train_dataset = PermutedMNIST(train=True, permute_idx=permute_idx)test_dataset = PermutedMNIST(train=False, permute_idx=permute_idx)torch.save(train_dataset, train_dataset_path)torch.save(test_dataset, test_dataset_path)train_loader[i] = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,#  num_workers=1,worker_init_fn=worker_init_fn,pin_memory=True)test_loader[i] = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,#  num_workers=1,worker_init_fn=worker_init_fn,pin_memory=True)return train_loader, test_loaderclass MLP(nn.Module):def __init__(self, input_size=28 * 28, num_classes_per_task=10, hidden_size=[400, 400, 400]):super(MLP, self).__init__()self.hidden_size = hidden_sizeself.input_size = input_size# 初始化类别计数器self.total_classes = num_classes_per_taskself.num_classes_per_task = num_classes_per_task# 定义网络结构self.fc1 = nn.Linear(input_size, hidden_size[0])self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])self.fc_before_last = nn.Linear(hidden_size[1], hidden_size[2])self.fc_out = nn.Linear(hidden_size[2], self.total_classes)def forward(self, input, task_id=-1):x = F.relu(self.fc1(input))x = F.relu(self.fc2(x))x = F.relu(self.fc_before_last(x))x = self.fc_out(x)return x

3 Baseline代码

没有任何连续学习算法的Baseline代码实现仅仅是将任务逐个训练。具体过程为:依次加载每个任务的数据集,独立训练模型,而不考虑模型对前一个任务的记忆能力。


class Baseline:def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):self.num_classes_per_task = num_classes_per_taskself.num_tasks = num_tasksself.batch_size = batch_sizeself.epochs = epochsself.neurons = neuronsself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.input_size = 28 * 28# Initialize modelself.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)self.criterion = nn.CrossEntropyLoss()# Get datasetself.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)def evaluate(self, test_loader, task_id):self.model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:# Move data to GPU in batchesimages = images.view(-1,self.input_size)images = images.to(self.device, non_blocking=True)labels = labels.to(self.device, non_blocking=True)outputs = self.model(images, task_id)predicted = torch.argmax(outputs, dim=1)correct += (predicted == labels).sum().item()total += labels.size(0)return 100.0 * correct / totaldef train_task(self, train_loader,optimizer, task_id):self.model.train()for images, labels in train_loader:images = images.view(-1,self.input_size)images = images.to(self.device, non_blocking=True)labels = labels.to(self.device, non_blocking=True)optimizer.zero_grad()outputs = self.model(images, task_id)loss = self.criterion(outputs, labels)loss.backward()optimizer.step()def run(self):all_avg_acc = []for task_id in range(self.num_tasks):train_loader = self.train_loaders[task_id]self.model = self.model.to(self.device)optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)for epoch in range(self.epochs):self.train_task(train_loader,optimizer, task_id)task_acc = []for eval_task_id in range(task_id + 1):accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)task_acc.append(accuracy)mean_avg = np.round(np.mean(task_acc), 2)print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")all_avg_acc.append(mean_avg)avg_acc = np.mean(all_avg_acc)print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")if __name__ == '__main__':print('Baseline'+"=" * 50)random.seed(2024)torch.manual_seed(2024)np.random.seed(2024)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsebaseline = Baseline(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)baseline.run()

Baseline==================================================

Task 0: Task Acc = [96.78],AVG=96.78

Task 1: Task Acc = [85.19, 97.0],AVG=91.1

Task 2: Task Acc = [52.66, 89.14, 97.27],AVG=79.69

Task AVG Acc: [96.78, 91.1, 79.69],AVG = 89.19

可以看到模型在学习新任务后,旧任务的准确率在下降,在学习完Task2后,第一个任务的准确率只有52.66,第二个任务的准确率只有89.14。

4 MAS算法

4.1 算法原理

RWalk算法是一种增量学习框架,它通过结合Fisher信息矩阵和优化路径上参数重要性的累积来平衡对旧任务的记忆保持(避免灾难性遗忘)和新任务的学习能力(减少固执性)。

论文《Chaudhry A, Dokania P K, Ajanthan T, et al. Riemannian walk for incremental learning: Understanding forgetting and intransigence[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 532-547.》Riemannian Walk for Incremental Learning (RWalk) 算法中,计算重要性权重和损失函数的公式如下:

  1. 重要性权重的计算:

    • Fisher 信息矩阵的更新:
      F t θ = α F t θ + ( 1 − α ) F t − 1 θ F_t^\theta = \alpha F_t^\theta + (1 - \alpha) F_{t-1}^\theta Ftθ=αFtθ+(1α)Ft1θ
      其中, F t θ F_t^\theta Ftθ 是在第 t t t 次迭代时的 Fisher 信息矩阵, α \alpha α 是一个超参数。

    • 参数重要性得分的累积:
      s t 2 t 1 ( θ i ) = ∑ t = t 1 t 2 Δ L t t + Δ t ( θ i ) 1 2 F t θ i Δ θ i ( t ) 2 + ϵ s_{t_2}^{t_1}(\theta_i) = \sum_{t=t_1}^{t_2} \frac{\Delta L_t^{t+\Delta t}(\theta_i)}{\frac{1}{2} F_t^{\theta_i} \Delta \theta_i(t)^2 + \epsilon} st2t1(θi)=t=t1t221FtθiΔθi(t)2+ϵΔLtt+Δt(θi)

      其中, Δ L t t + Δ t ( θ i ) \Delta L_t^{t+\Delta t}(\theta_i) ΔLtt+Δt(θi) 是参数 θ i \theta_i θi 从时间步 t t t t + Δ t t + \Delta t t+Δt 的损失变化, F t θ i F_t^{\theta_i} Ftθi 是第 t t t 次迭代时 θ i \theta_i θi 的 Fisher 信息, Δ θ i ( t ) = θ i ( t + Δ t ) − θ i ( t ) \Delta \theta_i(t) = \theta_i(t + \Delta t) - \theta_i(t) Δθi(t)=θi(t+Δt)θi(t) ϵ \epsilon ϵ 是一个正的常数。

  2. 损失函数的计算:

    • 最终目标函数 (RWalk):
      L ~ k ( θ ) = L k ( θ ) + λ ∑ i = 1 P ( F k − 1 θ i + s t 0 t k − 1 ( θ i ) ) ( θ i − θ k − 1 i ) 2 \tilde{L}_k(\theta) = L_k(\theta) + \lambda \sum_{i=1}^P \left( F_{k-1}^{\theta_i} + s_{t_0}^{t_{k-1}}(\theta_i) \right) (\theta_i - \theta_{k-1}^i)^2 L~k(θ)=Lk(θ)+λi=1P(Fk1θi+st0tk1(θi))(θiθk1i)2

    其中, L k ( θ ) L_k(\theta) Lk(θ) 是第 k k k 个任务的损失函数, λ \lambda λ 是一个超参数, F k − 1 θ i F_{k-1}^{\theta_i} Fk1θi 是第 k − 1 k-1 k1 个任务结束时 θ i \theta_i θi 的 Fisher 信息, s t 0 t k − 1 ( θ i ) s_{t_0}^{t_{k-1}}(\theta_i) st0tk1(θi) 是从第 t 0 t_0 t0 次迭代到第 t k − 1 t_{k-1} tk1 次迭代 θ i \theta_i θi 的重要性得分, θ k − 1 i \theta_{k-1}^i θk1i 是第 k − 1 k-1 k1 个任务结束时 θ i \theta_i θi 的值。

4.2 代码实现


import torch
import torch.nn as nn
import random
import warnings
import numpy as np
import warnings
warnings.filterwarnings("ignore")# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True  # Enable for GPU efficiencyclass RWalk:def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):self.num_classes_per_task = num_classes_per_taskself.num_tasks = num_tasksself.batch_size = batch_sizeself.epochs = epochsself.neurons = neuronsself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.input_size = 28 * 28self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)self.criterion = nn.CrossEntropyLoss()self.scaler = torch.cuda.amp.GradScaler()  # Enable mixed precisionself.importance_dict = {}self.previous_params = {}self.path_integral = {}self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)self.update_params()def evaluate(self, test_loader, task_id):self.model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images = images.view(-1,self.input_size)images = images.to(self.device, non_blocking=True)labels = labels.to(self.device, non_blocking=True)outputs = self.model(images, task_id)predicted = torch.argmax(outputs, dim=1)correct += (predicted == labels).sum().item()total += labels.size(0)return 100.0 * correct / totaldef train_task(self, train_loader,optimizer, task_id):self.model.train()for images, labels in train_loader:images = images.view(-1,self.input_size)images = images.to(self.device, non_blocking=True)labels = labels.to(self.device, non_blocking=True)optimizer.zero_grad()outputs = self.model(images, task_id)if task_id > 0:loss = self.rwalk_multi_objective_loss(outputs, labels)else:loss = self.criterion(outputs, labels)loss.backward()optimizer.step()def compute_importance(self, data_loader, task_id):# EWC++ importance_dict = {name: torch.zeros_like(param, device=self.device) for name, param in self.model.named_parameters() if 'task' not in name}self.model.eval()for images, labels in data_loader:images = images.view(-1,self.input_size)images = images.to(self.device, non_blocking=True)labels = labels.to(self.device, non_blocking=True)self.model.zero_grad()outputs = self.model(images, task_id=task_id)loss = self.criterion(outputs, labels)loss.backward()for name, param in self.model.named_parameters():if name in importance_dict and param.requires_grad:importance_dict[name] += param.grad ** 2 / len(data_loader)# 移动平均更新Fisher Matrixfor name in importance_dict:if name in self.importance_dict:self.importance_dict[name] = 0.9 * self.importance_dict[name] + 0.1 * importance_dict[name]else:self.importance_dict[name] = importance_dict[name]def update_path_integral(self):# 计算累计重要性for name, param in self.model.named_parameters():if name in self.path_integral:self.path_integral[name] += (param.detach() - self.previous_params[name]) ** 2else:self.path_integral[name] = (param.detach() - self.previous_params[name]) ** 2def update_params(self):for name, param in self.model.named_parameters():self.previous_params[name] = param.clone().detach()def update(self, dataset, task_id):self.compute_importance(dataset, task_id)self.update_path_integral()self.update_params()def rwalk_multi_objective_loss(self, outputs, labels, lambda_=100):regularization_loss = 0.0for name, param in self.model.named_parameters():if name in self.importance_dict and name in self.previous_params and name in self.path_integral:fisher_importance = self.importance_dict[name]path_penalty = self.path_integral[name]previous_param = self.previous_params[name]regularization_loss += ((fisher_importance + path_penalty) * (param - previous_param).pow(2)).sum()loss = self.criterion(outputs, labels)total_loss = loss + lambda_ * regularization_lossreturn total_lossdef run(self):all_avg_acc = []for task_id in range(self.num_tasks):train_loader = self.train_loaders[task_id]self.model = self.model.to(self.device)optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)for epoch in range(self.epochs):self.train_task(train_loader,optimizer, task_id)self.update(train_loader, task_id)task_acc = []for eval_task_id in range(task_id + 1):accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)task_acc.append(accuracy)mean_avg = np.round(np.mean(task_acc), 2)all_avg_acc.append(mean_avg)print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")avg_acc = np.mean(all_avg_acc)print(f"Task AVG Acc: {all_avg_acc}, AVG = {avg_acc}")if __name__ == '__main__':print('RWalk' + "=" * 50)for _ in range(1):random.seed(2024)torch.manual_seed(2024)np.random.seed(2024)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falserwalk = RWalk(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)rwalk.run()

RWalk==================================================
Task 0: Task Acc = [96.78],AVG=96.78
Task 1: Task Acc = [94.91, 95.73],AVG=95.32
Task 2: Task Acc = [86.88, 89.66, 93.76],AVG=90.1
Task AVG Acc: [96.78, 95.32, 90.1], AVG = 94.06666666666666

在学习完每个任务后,旧任务的准确率只是轻微的下降,说明该算法有效的缓解了灾难性遗忘。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词