pysyft学习

pysyft 学习专栏

1.环境搭建

pysyft 目前不是稳定的框架,因此选择了经典的0.2.9版本

因为0.2.x 版本有完善的github文档

(1)python-3.7

pysyft最新0.2.x(0.2.9),网上经常说的3.6 不行,会报其他依赖需要python>=3.7.x

1
conda create -n env_pysyft python=3.7

(2) pytorch-1.4.0

1
2
conda activate env_pysyft
conda install pytorch==1.4.0 torchvision==0.5.0 cpuonly -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/

(3) syft-0.2.9

1
pip install syft==0.2.9  -i https://pypi.tuna.tsinghua.edu.cn/simple

(4) protobuf-3.20.3

1
2
3
4
# syft 需要低版本 protobuf, 但是pytorch-1.4.0 用的是高版本
# 需要 3.19.0 <= protobuf <=3.20.x
pip uninstall protobug
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple protobuf==3.20.3

至此,安装完 pysyft-0.2.9 环境

2. demo 验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys

import torch
from torch.nn import Parameter
import torch.nn as nn
import torch.nn.functional as F

import syft as sy
# 创建本地hook,用于连接远程存放数据的服务器
hook = sy.TorchHook(torch)
# 创建一个名为bob的存放数据的服务器
bob = sy.VirtualWorker(hook, id="bob")

x = torch.tensor([1,2,3,4,5])
y = torch.tensor([1,1,1,1,1])
# 操作的是 x的指针,x数据在 bob 机器上
x_ptr = x.send(bob)
# 操作的是 y 的指针,y数据在 bob 机器上
y_ptr = y.send(bob)
# 定义需要优化的目标函数
z = x_ptr + x_ptr
# 获取目标函数运算后的结果
print(z.get())
# 发出反向传播命令
z.backward()
# 获取x 的梯度
x = x.get()
print(x.grad)

3. 联邦学习简单案例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torch import nn, optim
import syft as sy
from syft.federated.floptimizer import Optims

hook = sy.TorchHook(torch)
# create a couple workers
workers = ['bob', 'alice']
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")


# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

# get pointers to training data on each worker by
# sending some training data to bob and alice
data_bob = data[0:2]
target_bob = target[0:2]

data_alice = data[2:]
target_alice = target[2:]

# Iniitalize A Toy Model
model = nn.Linear(2,1)

data_bob = data_bob.send(bob)
data_alice = data_alice.send(alice)
target_bob = target_bob.send(bob)
target_alice = target_alice.send(alice)

# organize pointers into a list
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

optims = Optims(workers, optim=optim.Adam(params=model.parameters(),lr=0.1))


def train():
# Training Logic
for iter in range(10):

# NEW) iterate through each worker's dataset
for data, target in datasets:
# NEW) send model to correct worker
model.send(data.location)

# Call the optimizer for the worker using get_optim
opt = optims.get_optim(data.location.id)
# print(data.location.id)

# 1) erase previous gradients (if they exist)
opt.zero_grad()

# 2) make a prediction
pred = model(data)

# 3) calculate how much we missed
loss = ((pred - target) ** 2).sum()

# 4) figure out which weights caused us to miss
loss.backward()

# 5) change those weights
opt.step()

# NEW) get model (with gradients)
model.get()

# 6) print our progress
print(loss.get().data) # NEW) slight edit... need to call .get() on loss\

if __name__ == '__main__':
train()

4. FedAvg

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import syft as sy
import copy
hook = sy.TorchHook(torch)
from torch import nn, optim

# create a couple workers

bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")


# A Toy Dataset
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

# get pointers to training data on each worker by
# sending some training data to bob and alice
bobs_data = data[0:2].send(bob)
bobs_target = target[0:2].send(bob)

alices_data = data[2:].send(alice)
alices_target = target[2:].send(alice)

# Iniitalize A Toy Model
model = nn.Linear(2,1)

bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)

bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)
alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)

iterations = 10
worker_iters = 5

for a_iter in range(iterations):

bobs_model = model.copy().send(bob)
alices_model = model.copy().send(alice)

bobs_opt = optim.SGD(params=bobs_model.parameters(), lr=0.1)
alices_opt = optim.SGD(params=alices_model.parameters(), lr=0.1)

for wi in range(worker_iters):
# Train Bob's Model
bobs_opt.zero_grad()
bobs_pred = bobs_model(bobs_data)
bobs_loss = ((bobs_pred - bobs_target) ** 2).sum()
bobs_loss.backward()

bobs_opt.step()
bobs_loss = bobs_loss.get().data

# Train Alice's Model
alices_opt.zero_grad()
alices_pred = alices_model(alices_data)
alices_loss = ((alices_pred - alices_target) ** 2).sum()
alices_loss.backward()

alices_opt.step()
alices_loss = alices_loss.get().data

alices_model.move(secure_worker)
bobs_model.move(secure_worker)
with torch.no_grad():
model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())
model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())

print("Bob:" + str(bobs_loss) + " Alice:" + str(alices_loss))



preds = model(data)
loss = ((preds - target) ** 2).sum()


print(preds)
print(target)
print(loss.data)

5.使用fedsgd训练 CNN 网络

与单机版训练的区别

  1. dataloader

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 实际是将minist 拆分成两份 ,一份给 bob,一份给 alice
    # <-- this is now a FederatedDataLoader
    federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST('./data', train=True, download=True,
    transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ]))
    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers,
  2. 模型训练前,分发给各个客户端,在客户端上进行实际训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#  for batch_idx, (data, target)
# 这一段相当于对两份Part-minist整合成一个Big-minist
# 然后按 batch_size 遍历 Big-minist
# 因此实际一个全局 epoch 训练流程如下 :
# epoch_1 && batch_id_1 && model.send(bob) && loss.backward()
# epoch_1 && batch_id_2 && model.send(bob) && loss.backward()
# ...
# epoch_1 && batch_id_1 && model.send(alice) && loss.backward()
# epoch_1 && batch_id_2 && model.send(alice) && loss.backward()
# ...
# <-- now it is a distributed dataset
for batch_idx, (data, target) in enumerate(federated_train_loader):
model.send(data.location) # <-- NEW: send the model to the right location
data, target = data.to(device), target.to(device)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import syft as sy # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob") # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice") # <-- NEW: and alice
epochs = 10

class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = epochs
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)



class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)



def train(args, model, device, federated_train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset
model.send(data.location) # <-- NEW: send the model to the right location
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
model.get() # <-- NEW: get the model back
if batch_idx % args.log_interval == 0:
loss = loss.get() # <-- NEW: get the loss back
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * args.batch_size, len(federated_train_loader) * args.batch_size,
100. * batch_idx / len(federated_train_loader), loss.item()))


def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

for epoch in range(1, args.epochs + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(args, model, device, test_loader)

if (args.save_model):
torch.save(model.state_dict(), "mnist_cnn.pt")

6.使用fedavg 训练CNN网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

import syft as sy # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch) # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning
bob = sy.VirtualWorker(hook, id="bob") # <-- NEW: define remote worker bob
alice = sy.VirtualWorker(hook, id="alice") # <-- NEW: and alice

workers = {}
workers['bob'] = bob
workers['alice'] = alice

secure_worker = sy.VirtualWorker(hook, id="secure_worker")

epochs = 10
local_epochs = 1

class Arguments():
def __init__(self):
self.batch_size = 64
self.test_batch_size = 1000
self.epochs = epochs
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
self.seed = 1
self.log_interval = 30
self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)



class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)



def train(args, models, device, federated_train_loader, epoch):
for client_id in federated_train_loader.workers:
# 将全局模型发给 当前client
model = models['global_model'].copy().send(workers[client_id])
model.train()
optimizer =optim.SGD(params=model.parameters(), lr=0.1)
# 构造 当前client的数据集
one_client_train_loader = federated_train_loader.federated_dataset[client_id]
dataset = sy.BaseDataset(one_client_train_loader.data, one_client_train_loader.targets)
dataset = sy.FederatedDataset([dataset])
one_client_train_loader = sy.FederatedDataLoader(dataset, batch_size=32, shuffle=False, drop_last=False)
# 当前client开始本地训练
for local_epoch in range(local_epochs):
loss_per_local_epoch = []
for batch_idx, (data, target) in enumerate(one_client_train_loader): # <-- now it is a distributed dataset
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
loss = loss.get()
loss_per_local_epoch.append(loss.item())
# 每个client 每个,本地epoch 打印一次
print('client: {} , Train Global_Epoch: {}, Local_Epoch: {}, avg_loss: {} '.format(
client_id, epoch, local_epoch,np.mean(loss_per_local_epoch)))
models[client_id] = model.copy().get()
return models



def fedAvg(arg,models,federated_train_loader):
global_model = models['global_model']
global_state_dict = global_model.state_dict()

for key in global_state_dict.keys():
one_layer_weight_or_bias = global_state_dict[key].zero_()

for client_id in federated_train_loader.workers:
model = models[client_id]
one_layer_weight_or_bias += model.state_dict()[key].data.clone()

global_state_dict[key] = one_layer_weight_or_bias / len(federated_train_loader.workers)

global_model.load_state_dict(global_state_dict)
models['global_model'] = global_model
return models




def test(args, models, device, test_loader):
model = models['global_model']
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))


global_model = Net().to(device)
global_optimizer = optim.SGD(global_model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment
bobs_model = global_model.copy().send(bob)
alices_model = global_model.copy().send(alice)
models = {}
models['bob'] = bobs_model
models['alice'] = alices_model
models['global_model'] = global_model


for epoch in range(1, args.epochs + 1):
models = train(args, models, device, federated_train_loader, epoch)
models = fedAvg(args,models,federated_train_loader)
test(args, models, device, test_loader)

if (args.save_model):
torch.save(global_model.state_dict(), "mnist_cnn.pt")

7. 使用安全多方计算 (SMPC)加密算法 实现 FedAVG

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader



class Parser:
"""Parameters for training"""

def __init__(self):
self.epochs = 10
self.lr = 0.001
self.test_batch_size = 8
self.batch_size = 8
self.log_interval = 10
self.seed = 1
self.download_boston_housing = False


args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

# 下载数据集
if args.download_boston_housing:
import tensorflow as tf
import pickle

boston_housing = tf.keras.datasets.boston_housing
pickle_data = boston_housing.load_data()
pickle.dump(pickle_data, open('./data/BostonHousing/boston_housing.pickle', 'wb'))

# 加载数据集
with open('./data/BostonHousing/boston_housing.pickle','rb') as f:
((X, y), (X_test, y_test)) = pickle.load(f)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()
# preprocessing
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0. # the feature at column 3 is binary,
dev[:, 3] = 1. # so we don't standardize it
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)



class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(13, 32)
self.fc2 = nn.Linear(32, 24)
self.fc3 = nn.Linear(24, 1)

def forward(self, x):
x = x.view(-1, 13)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


bobs_model = Net()
alices_model = Net()

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)

models = [bobs_model, alices_model]
params = [list(bobs_model.parameters()), list(alices_model.parameters())]
optimizers = [bobs_optimizer, alices_optimizer]





import syft as sy

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
james = sy.VirtualWorker(hook, id="james")

compute_nodes = [bob, alice]

remote_dataset = (list(),list())

# 将训练集分发到远程节点bob,alice
for batch_idx, (data,target) in enumerate(train_loader):
worker_i_index = batch_idx % len(compute_nodes)
worker_i = compute_nodes[worker_i_index]
data = data.send(worker_i)
target = target.send(worker_i)
remote_dataset[worker_i_index].append((data, target))


# 单个客户端上, 使用 单个 batch_size 大小的数据集进行一次训练
def update(data, target, model, optimizer):
model.send(data.location)
optimizer.zero_grad()
pred = model(data)
loss = F.mse_loss(pred.view(-1), target)
loss.backward()
optimizer.step()
return model



# 完整训练
def train(epoch):
# 对单个 batch_size 大小的数据集进行一次训练
for data_index in range(len(remote_dataset[0])-1):
# update remote models
# 对每个远程节点,获取 单个 batch_size 大小的数据集,进行一次训练
for remote_index in range(len(compute_nodes)):
data, target = remote_dataset[remote_index][data_index]
models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

# encrypted aggregation
new_params = list()
# 对于每个参数,进行加密聚合
for param_i in range(len(params[0])):
spdz_params = list()
# 聚合 每个节点上,对应 param_key 里面的参数值
for remote_index in range(len(compute_nodes)):
# 对每个节点上的参数,进行加密
spdz_params.append(params[remote_index][param_i].fix_precision().share(bob, alice, crypto_provider=james).get())

# 此时 spdz_params 是一个列表,列表里面有两个元素,每个元素是一个加密的参数值
new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
new_params.append(new_param)

# cleanup
with torch.no_grad():
for model in params:
for param in model:
param *= 0

for model in models:
model.get()

# 将聚合好的参数,赋值给每个节点的模型
for remote_index in range(len(compute_nodes)):
for param_index in range(len(params[remote_index])):
params[remote_index][param_index].set_(new_params[param_index])


def test():
models[0].eval()
test_loss = 0
for data, target in test_loader:
output = models[0](data)
test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability

test_loss /= len(test_loader.dataset)


print('Test set: Average loss: {:.4f}\n'.format(test_loss))


import time
t = time.time()

for epoch in range(args.epochs):
print(f"Epoch {epoch + 1}")
train(epoch)
test()

total_time = time.time() - t
print('Total', round(total_time, 2), 's')


test()