2525
2626namespace infini_train ::utils {
2727
28+ static std::unordered_map<const nn::Module *, std::string> g_module_name_map;
29+
2830namespace {
2931
3032// Simple MD5 implementation
@@ -263,7 +265,7 @@ void SaveNpy(const std::shared_ptr<Tensor> &tensor, const std::string &name, int
263265 const auto &output_path = PrecisionCheckEnv::Instance ().GetOutputPath ();
264266 std::string dir = output_path + " /rank_" + std::to_string (rank);
265267 std::filesystem::create_directories (dir);
266- std::string filename = dir + " /" + name + " _" + std::to_string (idx) + " _" + stage + " .npy" ;
268+ std::string filename = dir + " /" + name + (idx > 0 ? " _" + std::to_string (idx) : " " ) + " _" + stage + " .npy" ;
267269
268270 if (tensor->Dtype () == DataType::kFLOAT32 ) {
269271 tensor->SaveAsNpy (filename);
@@ -320,6 +322,9 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
320322 // Output to log
321323 auto &log_stream = GetLogStream ();
322324
325+ // Format: name[_idx]_forward/backward (match .npy filename format)
326+ std::string log_name = name + (idx > 0 ? " _" + std::to_string (idx) : " " ) + " _" + stage_short;
327+
323328 if (global_config.format == " md5" ) {
324329 // MD5 format
325330 std::string md5;
@@ -338,7 +343,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
338343 // Original precision MD5
339344 md5 = ComputeMD5 (cpu_tensor->DataPtr (), byte_size);
340345 }
341- log_stream << context_key << " " << name << " _ " << idx << " _ " << stage << " tensor[" << i << " ]: "
346+ log_stream << context_key << " " << log_name << " tensor[" << i << " ]: "
342347 << " dtype=" << DataTypeToString (cpu_tensor->Dtype ()) << " "
343348 << " shape=" << FormatShape (cpu_tensor->Dims ()) << " "
344349 << " md5=" << md5 << std::endl;
@@ -350,7 +355,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string
350355 = (config.check_nan && stats.nan_count > 0 ) || (config.check_inf && stats.inf_count > 0 );
351356 const std::string error_marker = has_error ? " <- ERROR" : " " ;
352357
353- log_stream << context_key << " " << name << " _ " << idx << " _ " << stage << " tensor[" << i << " ]: "
358+ log_stream << context_key << " " << log_name << " tensor[" << i << " ]: "
354359 << " dtype=" << DataTypeToString (cpu_tensor->Dtype ()) << " "
355360 << " shape=" << FormatShape (cpu_tensor->Dims ()) << " "
356361 << " min=" << stats.min_val << " "
@@ -388,17 +393,46 @@ void PrecisionChecker::Init(const PrecisionCheckConfig &global_config, const Con
388393 GlobalModuleHookRegistry::Instance ().RegisterModuleForwardHook (
389394 [config](nn::Module *module , const std::vector<std::shared_ptr<Tensor>> &inputs,
390395 const std::vector<std::shared_ptr<Tensor>> &outputs) {
391- CheckTensors (" Forward Output" , module ->type (), outputs, config);
396+ auto it = g_module_name_map.find (module );
397+ const std::string &name = (it != g_module_name_map.end ()) ? it->second : module ->type ();
398+ CheckTensors (" Forward Output" , name, outputs, config);
392399 });
393400
394401 // Register global module full backward hook (checks gradients on every backward)
395402 GlobalModuleHookRegistry::Instance ().RegisterModuleFullBackwardHook (
396403 [config](nn::Module *module , const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
397404 const std::vector<std::shared_ptr<Tensor>> &grad_inputs) {
398- CheckTensors (" GradOutputs" , module ->type (), grad_outputs, config);
405+ auto it = g_module_name_map.find (module );
406+ const std::string &name = (it != g_module_name_map.end ()) ? it->second : module ->type ();
407+ CheckTensors (" GradOutputs" , name, grad_outputs, config);
399408 });
400409}
401410
411+ static inline bool ShouldSkipNameMap (std::string_view name) {
412+ return name.rfind (" __pp" , 0 ) == 0 ; // starts_with("__pp")
413+ }
414+
415+ void PrecisionChecker::BuildNameMap (nn::Module *root_model) {
416+ const auto &global_config = PrecisionCheckEnv::Instance ().GetConfig ();
417+ if (global_config.level == PrecisionCheckLevel::OFF || root_model == nullptr ) {
418+ return ;
419+ }
420+
421+ auto named = root_model->NamedModules (/* memory=*/ nullptr , /* prefix=*/ " " , /* remove_duplicate=*/ false );
422+ g_module_name_map.clear ();
423+ g_module_name_map.reserve (named.size ());
424+
425+ for (const auto &[name, module ] : named) {
426+ if (name.empty ()) {
427+ continue ; // skip root
428+ }
429+ if (ShouldSkipNameMap (name)) {
430+ continue ; // skip PP internal tree
431+ }
432+ g_module_name_map[module .get ()] = name; // keep InfiniTrain path directly
433+ }
434+ }
435+
402436void PrecisionChecker::RegisterForFunction (autograd::Function *func, const std::string &name, const Config &config) {
403437 const std::string func_name = name.empty () ? " Function" : name;
404438
0 commit comments