-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
95 lines (81 loc) · 2.57 KB
/
main.py
File metadata and controls
95 lines (81 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
from datetime import datetime
from loguru import logger
from configs._base_.yaml_parser import Config, load_config
from metric.tensorboard_logger import MyLogger
from utils.prepration import build_trainer
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description="Model Parameter Config")
# 允许传递超参数
parser.add_argument(
"--config",
type=str,
default="./configs/base_config.yaml",
help="Path to config file",
)
parser.add_argument("--lr", type=float, help="Learning rate")
parser.add_argument("--batch_size", type=int, help="Batch size")
parser.add_argument("--device", type=str, help="Training device (cuda/cpu)")
parser.add_argument(
"--name",
default="resnet18",
choices=["resnet18", "DESSN"],
type=str,
help="Backbone name",
)
parser.add_argument(
"--loss_type",
default="ce",
choices=["ce", "focal"],
type=str,
help="Configure loss",
)
parser.add_argument("--epoch", default=20, type=int)
args = parser.parse_args()
# 读取 YAML 配置
config_dict = load_config(args.config)
config = Config(**config_dict)
# 用命令行参数覆盖 YAML 参数
if args.lr:
config.train["lr"] = args.lr
if args.batch_size:
config.train["batch_size"] = args.batch_size
if args.device:
config.train["device"] = args.device
if args.name:
config.model["name"] = args.name
if args.loss_type:
config.train["loss_type"] = args.loss_type
if args.epoch:
config.train["epoch"] = args.epoch
return config
if __name__ == "__main__":
config = parse_args()
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d_%H%M")
os.makedirs("tensorboard_log", exist_ok=True)
logger.add(
"logs/{time}"
+ "-"
+ config.model["name"]
+ config.train["loss_type"]
+ "-"
+ config.data["dataset"]
+ "-"
+ config.train["optimizer"]
+ ".log",
rotation="50 MB",
level="DEBUG",
)
logger.info(config)
writer = MyLogger(logdir=os.path.join("tensorboard_log", time_str))
### GET TRAINER ###
if config.model["name"] == "resnet18":
# Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
trainer = build_trainer(config, writer=writer)
else:
raise NotImplementedError("Not Implement Model")
### START TRAIN ###
trainer.start_train()