-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathddp_trainer.py
More file actions
159 lines (144 loc) · 6.16 KB
/
ddp_trainer.py
File metadata and controls
159 lines (144 loc) · 6.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Trainer Module to assist with Distributed Data Parallel Training (DDP).
"""
import os
import torch
from time import time
from typing import Callable
from torch.utils.data import DataLoader
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP
from math import sqrt
class Trainer:
"""Trainer Class to assist with DDP training
"""
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
valid_data: DataLoader,
loss_func: Callable, # from torch.nn.functional.*
optimizer: torch.optim.Optimizer,
max_run_time: float,
snapshot_name: str,
) -> None:
# Torchrun assigns many environment variables
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
self.model = model.to(self.local_rank)
self.train_data = train_data
self.valid_data = valid_data
self.loss_func = loss_func
self.optimizer = optimizer
# Hours to seconds, training will stop at this time
self.max_run_time = max_run_time * 60**2
self.save_path = "training_saves/" + snapshot_name
self.epochs_run = 0 # current epoch tracker
self.run_time = 0.0 # current run_time tracker
self.train_loss_history = list()
self.valid_loss_history = list()
self.epoch_times = list()
self.lowest_loss = np.Inf
self.train_loss = np.Inf
self.valid_loss = np.Inf
# Loading in existing training session if the save destination already exists
if os.path.exists(self.save_path):
print("Loading snapshot")
self._load_snapshot(self.save_path)
if self.train_loss_history:
self.train_loss = self.train_loss_history[-1]
self.valid_loss = self.valid_loss_history[-1]
# Key DDP Wrapper
self.model = DDP(self.model, device_ids=[self.local_rank])
def _load_snapshot(self, snapshot_path):
loc = f"cuda:{self.local_rank}"
snapshot = torch.load(snapshot_path, map_location=loc)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
self.run_time = snapshot['RUN_TIME']
self.train_loss_history = snapshot['TRAIN_HISTORY']
self.valid_loss_history = snapshot['VALID_HISTORY']
self.epoch_times = snapshot['EPOCH_TIMES']
self.lowest_loss = snapshot['LOWEST_LOSS']
print(f"Resuming training from save at Epoch {self.epochs_run}")
def _calc_validation_loss(self, source, targets) -> float:
self.model.eval()
output = self.model(source)
loss = self.loss_func(output, targets)
self.model.train()
return float(loss.item())
def _run_batch(self, source, targets) -> float:
self.optimizer.zero_grad()
output = self.model(source)
loss = self.loss_func(output, targets)
loss.backward()
self.optimizer.step()
return float(loss.item())
def _run_epoch(self):
b_sz = len(next(iter(self.train_data))[0])
print(
f"\n[GPU{self.global_rank}] Epoch: {self.epochs_run} | Batch_SZ: {b_sz} ", end="")
print(
f"| Steps: {len(self.train_data)} ", end="")
print(
f"| T_loss: {self.train_loss:.3f} | V_loss: {self.valid_loss:.3f}")
self.train_data.sampler.set_epoch(self.epochs_run)
train_loss = 0
valid_loss = 0
# Train Loop
for source, targets in self.train_data:
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
train_loss += self._run_batch(source, targets)
# Calculating Validation loss
for source, targets in self.valid_data:
source = source.to(self.local_rank)
targets = targets.to(self.local_rank)
valid_loss += self._calc_validation_loss(source, targets)
# Update loss history
self.train_loss_history.append(train_loss/len(self.train_data))
self.valid_loss_history.append(valid_loss/len(self.valid_data))
self.train_loss, self.valid_loss = self.train_loss_history[-1], self.valid_loss_history[-1]
def _save_snapshot(self):
snapshot = {
"MODEL_STATE": self.model.module.state_dict(),
"EPOCHS_RUN": self.epochs_run,
"RUN_TIME": self.run_time,
"TRAIN_HISTORY": self.train_loss_history,
"VALID_HISTORY": self.valid_loss_history,
"EPOCH_TIMES": self.epoch_times,
"LOWEST_LOSS": self.lowest_loss
}
torch.save(snapshot, self.save_path)
print(f"Training snapshot saved at {self.save_path}")
def train(self):
for _ in range(self.epochs_run, self.epochs_run + 1000):
start = time()
self._run_epoch()
elapsed_time = time() - start
self.run_time += elapsed_time
self.epoch_times.append(elapsed_time)
start = time()
self.epochs_run += 1
if self.valid_loss_history[-1] < self.lowest_loss and self.local_rank == 0:
self.lowest_loss = self.valid_loss_history[-1]
self._save_snapshot()
elapsed_time = time() - start
self.run_time += elapsed_time
self.epoch_times[-1] += elapsed_time
print(
f'\nCurrent Train Time: {self.run_time//60**2} hours & {((self.run_time%60.0**2)/60.0):.2f} minutes')
if (self.run_time > self.max_run_time):
print(
f"Training completed -> Total train time: {self.run_time:.2f} seconds")
break
# Saving import metrics to analyze training on local machine
if (self.global_rank == 0):
train_metrics = {
"EPOCHS_RUN": self.epochs_run,
"RUN_TIME": self.run_time,
"TRAIN_HISTORY": self.train_loss_history,
"VALID_HISTORY": self.valid_loss_history,
"EPOCH_TIMES": self.epoch_times,
"LOWEST_LOSS": self.lowest_loss
}
torch.save(train_metrics, self.save_path[:-3] + "_metrics.pt")