目录

Python实例题

题目

问题描述

解题思路

关键代码框架

难点分析

扩展方向

Python实例题

题目

基于联邦学习的隐私保护 AI 系统(分布式学习、隐私计算)

问题描述

开发一个基于联邦学习的隐私保护 AI 系统,包含以下功能:

  • 联邦学习框架:支持多种机器学习模型的联邦训练
  • 隐私保护机制:差分隐私、同态加密等技术保护数据隐私
  • 模型聚合:安全聚合各参与方的模型参数
  • 客户端管理:管理和协调多个参与训练的客户端
  • 评估与部署:评估联邦模型性能并部署到生产环境

解题思路

  • 采用横向或纵向联邦学习架构
  • 实现安全聚合协议(如 FedAvg、FedProx)
  • 应用差分隐私或同态加密保护数据隐私
  • 设计客户端 - 服务器通信协议
  • 开发模型评估和部署工具

关键代码框架

# 联邦学习服务器端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import json
import logging
from typing import List, Dict, Any, Tuple
from cryptography.fernet import Fernet# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)class FedAvgServer:def __init__(self, model: nn.Module, clients: List[str], config: Dict[str, Any]):self.model = modelself.clients = clientsself.config = configself.global_round = 0self.client_models = {client: None for client in clients}self.client_weights = {client: 1.0 for client in clients}  # 客户端权重# 初始化加密密钥self.encryption_key = Fernet.generate_key()self.cipher_suite = Fernet(self.encryption_key)# 初始化优化器self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])def aggregate_models(self) -> None:"""聚合客户端模型"""logger.info(f"开始第 {self.global_round} 轮模型聚合")# 检查是否所有客户端都提交了模型for client, model_params in self.client_models.items():if model_params is None:logger.warning(f"客户端 {client} 未提交模型,跳过此轮")return# 计算总权重total_weight = sum(self.client_weights.values())# 初始化全局模型参数global_params = {}for name, param in self.model.named_parameters():global_params[name] = torch.zeros_like(param.data)# 加权聚合for client, model_params in self.client_models.items():weight = self.client_weights[client] / total_weightfor name, param in model_params.items():global_params[name] += param * weight# 更新全局模型with torch.no_grad():for name, param in self.model.named_parameters():param.data.copy_(global_params[name])# 增加全局轮次self.global_round += 1# 重置客户端模型self.client_models = {client: None for client in self.clients}logger.info(f"第 {self.global_round-1} 轮模型聚合完成")def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:"""加密模型参数"""# 将模型参数转换为numpy数组并序列化为JSONmodel_dict = {name: param.numpy().tolist() for name, param in model_params.items()}model_json = json.dumps(model_dict).encode('utf-8')# 加密encrypted_data = self.cipher_suite.encrypt(model_json)return encrypted_datadef decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:"""解密模型参数"""# 解密decrypted_data = self.cipher_suite.decrypt(encrypted_data)model_dict = json.loads(decrypted_data.decode('utf-8'))# 转换回PyTorch张量model_params = {name: torch.tensor(param) for name, param in model_dict.items()}return model_paramsdef receive_client_model(self, client_id: str, encrypted_model: bytes, client_weight: float) -> None:"""接收客户端模型"""if client_id not in self.clients:logger.warning(f"未知客户端: {client_id}")returntry:# 解密模型model_params = self.decrypt_model(encrypted_model)# 存储客户端模型self.client_models[client_id] = model_paramsself.client_weights[client_id] = client_weightlogger.info(f"收到客户端 {client_id} 的模型,权重: {client_weight}")except Exception as e:logger.error(f"接收客户端模型失败: {e}")def send_global_model(self, client_id: str) -> bytes:"""向客户端发送全局模型"""if client_id not in self.clients:logger.warning(f"未知客户端: {client_id}")return None# 获取当前全局模型参数model_params = {name: param.data for name, param in self.model.named_parameters()}# 加密并发送return self.encrypt_model(model_params)def evaluate_model(self, test_loader: DataLoader) -> Tuple[float, float]:"""评估模型性能"""self.model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for inputs, targets in test_loader:outputs = self.model(inputs)loss = nn.CrossEntropyLoss()(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()accuracy = 100.0 * correct / totalavg_loss = test_loss / len(test_loader)logger.info(f"模型评估结果: 准确率 = {accuracy:.2f}%, 平均损失 = {avg_loss:.4f}")return accuracy, avg_lossdef save_model(self, path: str) -> None:"""保存模型"""torch.save(self.model.state_dict(), path)logger.info(f"模型已保存到: {path}")
# 联邦学习客户端
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import json
import logging
from typing import Dict, Any, List, Tuple
from cryptography.fernet import Fernet# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)class FedAvgClient:def __init__(self, client_id: str, model: nn.Module, train_data: Dataset, config: Dict[str, Any]):self.client_id = client_idself.model = modelself.train_data = train_dataself.config = config# 创建数据加载器self.train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)# 初始化优化器self.optimizer = optim.SGD(self.model.parameters(), lr=config['learning_rate'])# 加密工具self.encryption_key = None  # 将从服务器接收self.cipher_suite = Nonedef set_encryption_key(self, key: bytes) -> None:"""设置加密密钥"""self.encryption_key = keyself.cipher_suite = Fernet(key)def encrypt_model(self, model_params: Dict[str, torch.Tensor]) -> bytes:"""加密模型参数"""if self.cipher_suite is None:raise ValueError("未设置加密密钥")# 将模型参数转换为numpy数组并序列化为JSONmodel_dict = {name: param.numpy().tolist() for name, param in model_params.items()}model_json = json.dumps(model_dict).encode('utf-8')# 加密encrypted_data = self.cipher_suite.encrypt(model_json)return encrypted_datadef decrypt_model(self, encrypted_data: bytes) -> Dict[str, torch.Tensor]:"""解密模型参数"""if self.cipher_suite is None:raise ValueError("未设置加密密钥")# 解密decrypted_data = self.cipher_suite.decrypt(encrypted_data)model_dict = json.loads(decrypted_data.decode('utf-8'))# 转换回PyTorch张量model_params = {name: torch.tensor(param) for name, param in model_dict.items()}return model_paramsdef update_model(self, encrypted_global_model: bytes) -> None:"""更新本地模型为全局模型"""try:# 解密全局模型global_params = self.decrypt_model(encrypted_global_model)# 更新本地模型with torch.no_grad():for name, param in self.model.named_parameters():param.data.copy_(global_params[name])logger.info(f"客户端 {self.client_id} 模型已更新")except Exception as e:logger.error(f"更新模型失败: {e}")def train(self, epochs: int) -> Tuple[Dict[str, torch.Tensor], float]:"""本地训练模型"""self.model.train()for epoch in range(epochs):epoch_loss = 0batches = 0for inputs, targets in self.train_loader:self.optimizer.zero_grad()outputs = self.model(inputs)loss = nn.CrossEntropyLoss()(outputs, targets)loss.backward()self.optimizer.step()epoch_loss += loss.item()batches += 1avg_loss = epoch_loss / batcheslogger.info(f"客户端 {self.client_id}, 轮次 {epoch+1}/{epochs}, 平均损失: {avg_loss:.4f}")# 获取训练后的模型参数model_params = {name: param.data for name, param in self.model.named_parameters()}# 返回模型参数和样本数量(作为权重)return model_params, len(self.train_data)def get_encrypted_model(self, epochs: int = 1) -> bytes:"""训练并返回加密的模型参数"""model_params, weight = self.train(epochs)encrypted_model = self.encrypt_model(model_params)return encrypted_model, weight
# 联邦学习主程序
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Subset
import numpy as np
from typing import List, Dict, Any# 定义简单的CNN模型
class SimpleCNN(nn.Module):def __init__(self, num_classes=10):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2)self.fc1 = nn.Linear(32 * 7 * 7, 128)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(-1, 32 * 7 * 7)x = self.relu3(self.fc1(x))x = self.fc2(x)return xdef split_dataset(dataset, num_clients: int, iid: bool = True) -> List[Subset]:"""分割数据集给多个客户端"""num_samples = len(dataset) // num_clientsclient_datasets = []if iid:# IID方式分割(随机分配)indices = list(range(len(dataset)))np.random.shuffle(indices)for i in range(num_clients):client_indices = indices[i * num_samples : (i + 1) * num_samples]client_datasets.append(Subset(dataset, client_indices))else:# 非IID方式分割(按标签排序后分配)# 这里简化处理,实际应用中可能需要更复杂的分割策略labels = np.array([dataset[i][1] for i in range(len(dataset))])indices = np.argsort(labels)for i in range(num_clients):client_indices = indices[i * num_samples : (i + 1) * num_samples]client_datasets.append(Subset(dataset, client_indices))return client_datasetsdef run_federated_learning(config: Dict[str, Any]):"""运行联邦学习过程"""# 设置随机种子torch.manual_seed(config['seed'])np.random.seed(config['seed'])# 加载数据集transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST('data', train=False, transform=transform)# 分割训练数据给客户端client_datasets = split_dataset(train_dataset, config['num_clients'], config['iid'])# 创建测试数据加载器test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)# 初始化服务器和客户端global_model = SimpleCNN()server = FedAvgServer(global_model, [f"client{i}" for i in range(config['num_clients'])], config)clients = []for i in range(config['num_clients']):client_model = SimpleCNN()# 初始时客户端模型与全局模型相同client_model.load_state_dict(global_model.state_dict())client = FedAvgClient(f"client{i}", client_model, client_datasets[i], config)clients.append(client)# 分发加密密钥给客户端for client in clients:client.set_encryption_key(server.encryption_key)# 联邦学习训练循环for round in range(config['global_rounds']):logger.info(f"===== 开始第 {round+1}/{config['global_rounds']} 轮联邦学习 =====")# 选择参与本轮的客户端selected_clients = np.random.choice(clients, size=min(config['clients_per_round'], len(clients)), replace=False)# 向客户端发送全局模型for client in selected_clients:encrypted_global_model = server.send_global_model(client.client_id)client.update_model(encrypted_global_model)# 客户端本地训练for client in selected_clients:encrypted_model, client_weight = client.get_encrypted_model(config['local_epochs'])server.receive_client_model(client.client_id, encrypted_model, client_weight)# 服务器聚合模型server.aggregate_models()# 评估全局模型if (round + 1) % config['eval_every'] == 0:accuracy, loss = server.evaluate_model(test_loader)logger.info(f"第 {round+1} 轮评估结果: 准确率 = {accuracy:.2f}%, 损失 = {loss:.4f}")# 保存最终模型server.save_model(config['model_save_path'])logger.info("联邦学习训练完成")# 配置参数
config = {'seed': 42,'num_clients': 10,'clients_per_round': 5,'global_rounds': 50,'local_epochs': 5,'batch_size': 64,'learning_rate': 0.01,'iid': True,  # 是否IID数据分布'eval_every': 5,  # 每多少轮评估一次'model_save_path': 'federated_model.pth'
}# 运行联邦学习
if __name__ == "__main__":run_federated_learning(config)

难点分析

  • 隐私保护与模型性能平衡:在保护隐私的同时保持模型准确性
  • 通信效率:减少客户端与服务器之间的通信开销
  • 异构设备处理:处理不同性能客户端的参与
  • 安全聚合协议:实现安全的模型参数聚合
  • 恶意参与者检测:识别和处理恶意参与方

扩展方向

  • 实现更高级的隐私保护技术(如差分隐私、同态加密)
  • 添加自适应学习率调整机制
  • 支持增量训练和持续学习
  • 开发联邦学习可视化监控界面
  • 实现跨平台联邦学习(移动端、边缘设备)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85949.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85949.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/85949.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

点点(小红书AI搜索):生活场景的智能搜索助手

1. 产品概述 点点是小红书于2024年12月正式推出的AI搜索助手,由上海生动诗章科技有限公司开发,定位为生活场景搜索工具,聚焦交通、美食、旅游、购物等日常需求,旨在通过即时信息和真实用户分享帮助用户“精准避坑”。 核心特点 …

软件工程概述:核心概念、模型与方法全解析

一、软件工程定义与诞生背景 定义 将系统化、规范化、可度量的方法应用于软件开发、运行和维护的过程(IEEE标准)。 核心目标:在可控成本下,生产高质量、可维护、满足需求的软件产品。 - 软件开发:需求 → 设计 → 编码…

LVS+Keepalived+nginx

LVSKeepalivednginx 1 安装依赖 sudo yum install ipvsadm keepalived -y 查询是否安装成功 rpm -q -a keepalived 2 配置虚拟IP并安装ipvsadm /etc/sysconfig/network-scripts cp ifcfg-ens33 ifcfg-ens33:1 修改里面配置文件 TYPE"Ethernet" PROXY_METHOD"n…

数据分析实操篇:京东淘宝商品实时数据获取与分析

在电商行业蓬勃发展的当下,数据已然成为驱动决策的核心要素。无论是商家精准把控市场需求、制定营销策略,还是消费者做出明智的购物抉择,都离不开对电商平台商品数据的深入剖析。京东和淘宝作为国内电商领域的两大巨头,汇聚了海量…

微信小程序扫码添加音频播放报错{errCode:10001, errMsg:“errCode:602,err:error,not found param“}

主要流程代码如下: let innerAudioContext wx.createInnerAudioContext() // 提示音 innerAudioContext.autoplay true innerAudioContext.src ../images/scan.mp3 innerAudioContext.onError(function(res){ console.log(onError 开始监听:,res) }) innerAudi…

SVN合并指南,从dev合并部分revision到release指南

dev合并到release 1.进入release的工作区,右击选择Merge 点击Next 2.确认merge来源分支和当前分支 点击Show Log,挑选需要合并的单号 3. 选择需要合并的commit 注意点击Hide no-mergeable revisions,来隐藏掉已经合并的commit 4.选择需…

《计算机网络:自顶向下方法(第8版)》Chapter 8 课后题

复习题 8.1节 R1. 机密性是攻击者截获原始明文消息的密文加密后无法确定原始明文的属性。消息完整性是接收方可以检测发送的消息(无论是否加密)在传输过程中是否又被更改的属性。 因此,这两者是不同的概念,可以独立存在。一个在传…

抖音小程序开发:ttml和传统html的区别

1 传统 Web 中 HTML 的角色 HyperText Markup Language&#xff1a;用来描述页面结构——标题、段落、图片、表单…… 只负责“放什么元素、排在什么层级”&#xff0c;真正的行为靠 JS&#xff0c;视觉靠 CSS。 <!-- 传统网页&#xff1a;结构 class 交给 CSS --> &…

Unity2D 街机风太空射击游戏 学习记录 #12QFramework引入

概述 这是一款基于Unity引擎开发的2D街机风太空射击游戏&#xff0c;笔者并不是游戏开发人&#xff0c;作者是siki学院的凉鞋老师。 笔者只是学习项目&#xff0c;记录学习&#xff0c;同时也想帮助他人更好的学习这个项目 作者会记录学习这一期用到的知识&#xff0c;和一些…

Proteus如何创建第一个工程

视频教程&#xff1a; [最详细]Proteus新建第一个工程与快捷键设置 操作步骤 1打开Proteus后&#xff0c;左上角点击文件然后点击新建工程。 2新建工程后&#xff0c;弹出以下界面&#xff0c;选择修改自己的工程名字&#xff0c;&#xff08;工程名的后缀“.pdspsj”不要修改…

现代浏览器剪贴板操作指南 + 示例页面 navigator.clipboard 详解与实战

在 Web 开发中&#xff0c;复制文本到剪贴板是一个常见需求。多年来&#xff0c;过去我们依赖 Flash 或 document.execCommand(copy) 来实现。它曾是我们的得力助手&#xff0c;但这些方法存在兼容性差、安全性低的问题。现已经被正式标记为废弃&#xff08;Deprecated&#xf…

OpenCV-Python学习笔记

2 OpenCV中的Gui特性 2-1 图像入门 目标 学习如何读取图像、显示图像和保存图像 学习api函数&#xff1a;cv.imread()、cv.imshow()、cv.imwrite() 学习使用Matplotlib显示图像 使用OpenCV 读取图像 在OpenCV中&#xff0c;使用函数cv.imread()读取图像。 img cv.imread(le…

2025年- H87-Lc195--287.寻找重复数(技巧,二分查找OR动态规划)--Java版

1.题目描述 2.思路 3.代码实现 public class H287 {public int findDuplicate(int[] nums) {// 重复数字可能的最小值int min1;// 重复数字可能的最大值&#xff0c;数组长度为 n&#xff0c;数字范围是 [1, n-1]int maxnums.length-1;while(min<max) {// 防止溢出&#xf…

PVE使用ubuntu-cloud-24.img创建虚拟机并制作模板

前言 在使用pve时,虽然它已经可以克隆虚拟机,已经极大提升了创建虚拟机速度,但创建完成后,不可避免还是要配置下网络,因为服务器要使用静态IP,克隆出的机器需要重新设置新的IP地址,有没有连这一步都省了方法呢?有,就是Cloud-Init 创建虚拟机模板 1. 下载ubuntu-clo…

LLM增强检索---GraphRAG + LangGraph项目实战

专栏&#xff1a;大模型垂直应用技术​ ​​​​ 个人主页:云端筑梦狮 大模型应用落地亟需解决的核心问题有一个是&#xff1a;如何与私域数据交互。私域数据主要的问题是&#xff1a;需要有效地将企业数据整合进大语言模型中。由于大模型的上下文处理能力有限&#xff0c;一…

修改SSH端口实战

本次实战主要涉及SSH端口的修改和配置。首先&#xff0c;对master、slave1和slave2三台云主机的SSH配置文件进行修改&#xff0c;指定新的SSH端口&#xff0c;并重启SSH服务。接着&#xff0c;通过FinalShell工具修改连接端口&#xff0c;实现SSH连接到三台云主机。然后&#x…

PyTorch中的permute, transpose, view, reshape和flatten函数详解(已解决)

1.permute permute函数用于重新排列张量的维度。它接受一个元组作为参数&#xff0c;表示新的维度顺序。例如&#xff0c;如果我们有一个形状为(2, 3)的二维张量&#xff0c;我们可以使用permute函数将其维度重新排列为(3, 2)&#xff0c;如下所示&#xff1a; >>> …

开疆智能ModbusTCP转EtherCAT网关连接甘纳数据采集系统配置案例

本案例是通过开疆智能研发的ModbusTCP转EtherCAT网关连接ModbusTCP主站与甘纳数据采集系统的配置案例&#xff0c;具体配置如下。 配置过程 首先设置ModbusTCP主站&#xff0c;这里以信捷PLC为例 IP设定 要走Modbus-TCP协议&#xff0c;要把设备IP设在同一网段且地址不同&am…

Neo4j常用语法-path

在 Neo4j 中&#xff0c;Path&#xff08;路径&#xff09; 是连接两个或多个节点的关系序列&#xff0c;是图查询的核心概念之一。理解 Path 的用法对于复杂图分析至关重要 关键特性&#xff1a; 有向性&#xff1a;路径中的关系具有方向 可变长度&#xff1a;路径可以包含 0 …

从 Cluely 融资看“AI 协同开发”认证:软件考试应该怎么升级?

AI 工具大爆发&#xff0c;软件考试却还停在“纯手写”时代&#xff1f; 2025 年 6 月&#xff0c;一个标语写着 “Cheat on Everything”&#xff08;对&#xff0c;意思就是“什么都能开挂”&#xff09;的 AI 初创公司——Cluely&#xff0c;正式宣布获得由 a16z 领投的 1 …