-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsweep.py
More file actions
120 lines (96 loc) · 4.32 KB
/
sweep.py
File metadata and controls
120 lines (96 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from omegaconf import OmegaConf, ListConfig
from argparse import ArgumentParser
import optuna
import torch
import time
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.loggers import CSVLogger
from lightning import Trainer, seed_everything
# from data_provider.data_factory import data_provider
from src.exp.exp_basic import model_dict, exp_dict, datamodule_dict
from src.utils.utils_lightning import get_callbacks, get_logger
SAMPLER = {
'TPESampler': optuna.samplers.TPESampler,
}
if __name__ == "__main__":
parser = ArgumentParser(description='sweep')
parser.add_argument('--config', type=str, default= '/home/dynamical_embedding/configs/lorenz/lstm/direct/lorenz_BiLSTM_init3_lead1_tune.yaml')
parser.add_argument('--sweep_config', type=str, default= '/home/dynamical_embedding/configs/lorenz/sweep/lstm/lorenz_BiLSTM_init3_lead1_sweep.yaml')
parser.add_argument('--device', nargs="*", type=int, default= [0])
args = parser.parse_args()
# process config
cfg = OmegaConf.load(args.config)
cfg.general.output_dir = "{}_{}_{}_{}".format(
cfg.general.task_name,
cfg.data.data_name,
cfg.model.model_name,
cfg.general.cust_name,
)
sweep_cfg = OmegaConf.load(args.sweep_config)
seed_everything(cfg.general.seed, workers=True)
def objective(trial: optuna.trial.Trial) -> float:
hparams_dict = {}
# we tune model, data(batch), and train
# model
if hasattr(sweep_cfg, 'model'):
for key, value in sweep_cfg.model.items():
cfg.model[key] = trial.suggest_categorical(f'{key}', value)
hparams_dict[key] = cfg.model[key]
# data
if hasattr(sweep_cfg, 'data'):
for key, value in sweep_cfg.data.items():
cfg.data[key] = trial.suggest_categorical(f'{key}', value)
hparams_dict[key] = cfg.data[key]
# train
if hasattr(sweep_cfg, 'train'):
for key, value in sweep_cfg.train.hparams.items():
cfg.train.hparams[key] = trial.suggest_categorical(f'{key}', value)
hparams_dict[key] = cfg.train.hparams[key]
# for key, value in sweep_cfg.model.items():
# cfg.model[key] = trial.suggest_categorical(f'{key}', value)
# hparams_dict[key] = cfg.model[key]
print(f'Current hparams: {hparams_dict}')
# instantiating datamodule
datamodule = datamodule_dict[cfg.general.task_name](cfg.data)
optimizer = getattr(torch.optim, cfg.train.optimizer)
# model
model = exp_dict[cfg.general.task_name](
net=model_dict[cfg.model.model_name].Model(cfg.model),
optimizer=optimizer,
scheduler=None,
compile=False,
**cfg.train.hparams
)
# train
callbacks = get_callbacks(cfg.train.callbacks)
trainer = Trainer(
max_epochs=cfg.train.max_epochs,
min_epochs=cfg.train.min_epochs,
accelerator=cfg.train.accelerator,
devices=args.device,
deterministic=True,
default_root_dir=cfg.general.output_dir,
callbacks=list(callbacks.values()),
logger=get_logger(cfg.general.output_dir),
)
trainer.logger.log_hyperparams(hparams_dict)
st = time.time()
trainer.fit(model, datamodule=datamodule)
print("Training time: ", time.time()-st)
return trainer.callback_metrics["val_loss"].item()
# optuna setting
sampler = SAMPLER[sweep_cfg.sampler.sampler_name](seed = cfg.general.seed, n_startup_trials=sweep_cfg.sampler.n_startup_trials)
study_name = sweep_cfg.study_name
study = optuna.create_study(direction=sweep_cfg.direction,
sampler=sampler,
study_name=study_name,
storage=f"sqlite:///{cfg.general.output_dir}_sweep.db",
load_if_exists=True,)
study.optimize(objective, n_trials=sweep_cfg.n_trials)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print("Value: {}".format(trial.value))
print("Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))