-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_classification.py
More file actions
322 lines (302 loc) · 14.8 KB
/
run_classification.py
File metadata and controls
322 lines (302 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import argparse
import os
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data
from datasets import PlainClsDataset, MetaClsDataset
from models.clsnet import ClassificationNet
from models.maml import MetaClsLearner
parser = argparse.ArgumentParser(description='run experiment for few-shot classification.')
# data-related arguments
parser.add_argument('--seed', type=int, default=2023, help='random seed.')
parser.add_argument('--dataset', type=str, default='', help='name of corpus.')
parser.add_argument('--data_path', type=str, default='', help='path to load dataset.')
parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training.')
# optimization-related arguments
parser.add_argument('--strategy', type=str, default='proto', help='could be one of [maml, proto, ft].')
parser.add_argument('--arch', type=str, default='conv', help='could be one of [conv, mlp].')
parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train.')
parser.add_argument('--optimizer', type=str, default='adam', help='choice of optimizer.')
parser.add_argument('--lr', type=float, default=0.005, help='learning rate.')
parser.add_argument('--grad_clip', type=float, default=20.0, help='gradient clipping.')
parser.add_argument('--weight_decay', type=float, default=1.2e-6, help='some l2 regularization.')
# few-shot setting arguments
parser.add_argument('--meta_batch_size', type=int, help='number of tasks processed at each update', default=5)
parser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=3e-4)
parser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=3e-3)
parser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
parser.add_argument('--update_step_test', type=int, help='update steps for fine-tuning', default=10)
parser.add_argument('--num_ways', type=int, default=5, help='number of classes for each N-way-K-shot task.')
parser.add_argument('--num_shots', type=int, default=10, help='number of support samples in each N-way-K-shot task.')
parser.add_argument('--num_queries', type=int, default=15, help='number of query samples in each N-way-K-shot task.')
args = parser.parse_args()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(args.seed)
# data loading pipeline
test_set = MetaClsDataset(
dataset_name=args.dataset,
data_path=args.data_path,
mode='test',
num_ways=args.num_ways,
num_shots=args.num_shots
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=1, shuffle=False, num_workers=2
)
# define the directory to save models
if args.arch == 'conv':
save_dir = './checkpoints/classification/CNN'
elif args.arch == 'mlp':
save_dir = './checkpoints/classification/MLP'
else:
save_dir = f'./checkpoints/classification/{args.arch}'
os.makedirs(save_dir, exist_ok=True)
# three typical few-shot learning algorithms
if args.strategy == 'maml':
train_set = MetaClsDataset(
dataset_name=args.dataset,
data_path=args.data_path,
mode='train',
num_ways=args.num_ways,
num_shots=args.num_shots
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=args.meta_batch_size, shuffle=True, num_workers=4
)
model = ClassificationNet(
input_dim=len(train_set.vocab),
num_classes=args.num_ways,
arch=args.arch
)
model.to(device)
print('\nModel : {}'.format(model))
tmp = filter(lambda x: x.requires_grad, model.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print('Total trainable tensors:', num)
learner = MetaClsLearner(args, model, device)
print('\nApplied few-shot algorithm: model-agnostic meta-learning')
print('\n===> Meta-training stage ===<')
for epoch in range(8):
print('\n===> Epoch: {} <==='.format(epoch))
train_loss = []
train_acc = []
model.train()
for idx, (sup_batches, qry_batches) in enumerate(tqdm(train_loader)):
sup_batches, qry_batches = sup_batches.to(device), qry_batches.to(device)
sup_batches = sup_batches.view(args.meta_batch_size, args.num_ways * args.num_shots, -1)
qry_batches = qry_batches.view(args.meta_batch_size, args.num_ways * args.num_queries, -1)
loss_qry, acc_qry = learner(sup_batches, qry_batches)
train_loss.append(loss_qry)
train_acc.append(acc_qry)
if (idx + 1) % 100 == 0:
print('Avg Train Loss: {}, Avg Train Acc: {}'.format(np.mean(train_loss), np.mean(train_acc)))
# save model weights once per epoch
torch.save(
learner.model.state_dict(),
os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_epoch{epoch+1}_model.pth'
)
)
print('\n===> Meta-testing stage ===<')
ckpt_path = os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_epoch8_model.pth'
)
learner.model.load_state_dict(torch.load(ckpt_path, map_location=device))
accs_all_test = []
for (sup_batch, qry_batch) in tqdm(test_loader):
sup_batch = sup_batch.squeeze(0).to(device)
qry_batch = qry_batch.squeeze(0).to(device)
sup_batch = sup_batch.view(args.num_ways * args.num_shots, -1)
qry_batch = qry_batch.view(args.num_ways * args.num_queries, -1)
acc = learner.finetunning(sup_batch, qry_batch)
accs_all_test.append(acc.item())
mu_acc, std_acc = np.array(accs_all_test).mean(), np.array(accs_all_test).std()
print("The average accuracy on {} test tasks: {:.6f} \u00B1 {:.6f}".format(
len(test_loader), mu_acc, std_acc))
elif args.strategy == 'proto':
train_set = MetaClsDataset(
dataset_name=args.dataset,
data_path=args.data_path,
mode='train',
num_ways=args.num_ways,
num_shots=args.num_shots
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=1, shuffle=True, num_workers=0
)
model = ClassificationNet(
input_dim=len(train_set.vocab),
num_classes=args.num_ways,
arch=args.arch
)
model.to(device)
print('\nModel : {}'.format(model))
tmp = filter(lambda x: x.requires_grad, model.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print('Total trainable tensors:', num)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
print('\nApplied few-shot algorithm: prototypical network')
print('\n===> Meta-training stage ===<')
best_train_acc = 0.
for epoch in range(5):
print('\n===> Epoch: {} <==='.format(epoch))
train_loss = []
train_acc = []
model.train()
qry_labels = torch.arange(args.num_ways).unsqueeze(1).repeat(1, args.num_queries).view(-1).to(device)
for idx, (sup_batch, qry_batch) in enumerate(tqdm(train_loader)):
sup_batch = sup_batch.squeeze(0).to(device)
qry_batch = qry_batch.squeeze(0).to(device)
sup_feats, _ = model(sup_batch.view(args.num_ways * args.num_shots, -1))
sup_proto = sup_feats.view(args.num_ways, args.num_shots, -1).mean(1)
qry_feats, _ = model(qry_batch.view(args.num_ways * args.num_queries, -1))
qry_dists = -(qry_feats.unsqueeze(1) - sup_proto.unsqueeze(0)).pow(2).sum(-1)
loss = criterion(qry_dists, qry_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
num_correct = torch.eq(qry_dists.argmax(-1), qry_labels).float().sum()
acc = num_correct / (args.num_ways * args.num_queries)
train_acc.append(acc.item())
if (idx + 1) % 500 == 0:
if np.mean(train_acc) > best_train_acc:
best_train_acc = np.mean(train_acc)
torch.save(
model.state_dict(),
os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_best_model.pth'
)
)
print('Avg Train Loss: {}, Avg Train Acc: {}'.format(np.mean(train_loss), np.mean(train_acc)))
train_acc.clear()
print('\n===> Meta-testing stage ===<')
ckpt_path = os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_best_model.pth'
)
model.load_state_dict(torch.load(ckpt_path, map_location=device))
qry_labels = torch.arange(args.num_ways).unsqueeze(1).repeat(1, args.num_queries).view(-1).to(device)
model.eval()
accs_all_test = []
for (sup_batch, qry_batch) in tqdm(test_loader):
sup_batch = sup_batch.squeeze(0).to(device)
qry_batch = qry_batch.squeeze(0).to(device)
with torch.no_grad():
sup_feats, _ = model(sup_batch.view(args.num_ways * args.num_shots, -1))
qry_feats, _ = model(qry_batch.view(args.num_ways * args.num_queries, -1))
sup_proto = sup_feats.view(args.num_ways, args.num_shots, -1).mean(1)
qry_dists = (qry_feats.unsqueeze(1) - sup_proto.unsqueeze(0)).pow(2).sum(-1)
num_correct = torch.eq(qry_dists.argmin(-1), qry_labels).float().sum()
acc = num_correct / (args.num_ways * args.num_queries)
accs_all_test.append(acc.item())
mu_acc, std_acc = np.array(accs_all_test).mean(), np.array(accs_all_test).std()
print("The average accuracy on {} test tasks: {:.6f} \u00B1 {:.6f}".format(
len(test_loader), mu_acc, std_acc))
elif args.strategy == 'ft':
train_set = PlainClsDataset(args.dataset, args.data_path, 'train')
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=4
)
model = ClassificationNet(
input_dim=len(train_set.vocab),
num_classes=len(train_set.class_id),
arch=args.arch
)
model.to(device)
print('\nModel : {}'.format(model))
tmp = filter(lambda x: x.requires_grad, model.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print('Total trainable tensors:', num)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer,
gamma=0.2,
step_size=30
)
criterion = nn.CrossEntropyLoss()
print('\nApplied few-shot algorithm: baseline')
print('\n===> Pre-training stage ===<')
best_train_acc = 0.
for epoch in tqdm(range(args.epochs)):
train_loss = []
train_acc = []
model.train()
for idx, (batch_data, batch_labels) in enumerate(train_loader):
batch_data = batch_data.to(device)
batch_labels = batch_labels.to(device)
_, output = model(batch_data)
loss = criterion(output, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
num_correct = torch.eq(output.argmax(-1), batch_labels).float().sum()
acc = num_correct / batch_labels.shape[0]
train_acc.append(acc.item())
print('Epoch: {}, Avg Train Loss: {}, Avg Train Acc: {}'.format(
epoch, np.mean(train_loss), np.mean(train_acc)))
if np.mean(train_acc) > best_train_acc:
best_train_acc = np.mean(train_acc)
torch.save(
model.state_dict(),
os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_best_model.pth'
)
)
scheduler.step()
print("\n===> Fine-tuning on meta-test set ===<")
ckpt_path = os.path.join(
save_dir,
f'{args.dataset}_{args.num_shots}shot_{args.strategy}_best_model.pth'
)
model.load_state_dict(torch.load(ckpt_path, map_location=device))
sup_labels = torch.arange(args.num_ways).unsqueeze(1).repeat(1, args.num_shots).view(-1).to(device)
qry_labels = torch.arange(args.num_ways).unsqueeze(1).repeat(1, args.num_queries).view(-1).to(device)
model.eval()
accs_all_test = []
for (sup_batch, qry_batch) in tqdm(test_loader):
sup_batch = sup_batch.squeeze(0).to(device)
qry_batch = qry_batch.squeeze(0).to(device)
with torch.no_grad():
sup_feats, _ = model(sup_batch.view(args.num_ways * args.num_shots, -1))
qry_feats, _ = model(qry_batch.view(args.num_ways * args.num_queries, -1))
# temp_clf = nn.Linear(128, args.num_ways).to(device) # for MLP arch
temp_clf = nn.Linear(2984, args.num_ways).to(device) # for CNN arch
temp_optim = torch.optim.Adam(temp_clf.parameters())
for _ in range(100):
sup_outputs = temp_clf(sup_feats)
loss = criterion(sup_outputs, sup_labels)
temp_optim.zero_grad()
loss.backward()
temp_optim.step()
qry_outputs = temp_clf(qry_feats)
num_correct = torch.eq(qry_outputs.argmax(-1), qry_labels).float().sum()
acc = num_correct / (args.num_ways * args.num_queries)
accs_all_test.append(acc.item())
del temp_optim, temp_clf
mu_acc, std_acc = np.array(accs_all_test).mean(), np.array(accs_all_test).std()
print("The average accuracy on {} test tasks: {:.6f} \u00B1 {:.6f}".format(
len(test_loader), mu_acc, std_acc))
else:
raise NotImplementedError(f"The '{args.strategy}' strategy has not been implemented!")
if __name__ == '__main__':
main()