Skip to content

Conversation

@kilinchange
Copy link
Collaborator

@kilinchange kilinchange commented Jan 14, 2026

device 注册设计文档:https://gxtctab8no8.feishu.cn/docx/F0CzdkVXCoaxRgxYc3AcFewtnOc?from=from_copylink

Device 现在是一个只维护 type/index 的简单类型,所有运行时方法都由 DeviceGuard/DeviceGuardImpl 提供。

本次 pr 涉及主要改动:

  • device 注册基建;
  • 所有通过 DeviceManager 拿 Device 的地方都改成直接构造
  • 所有使用 DeviceType 的地方改为 Device::DeviceType
  • 修改所有使用 Device* 的地方为 Device
  • Device type/index 方法命名修正(改成了小写开头的取值函数命名风格)
  • 所有 Device 的运行时调用 (SetDevice, Synchronize, Stream, CublasHandle) 修改为 DeviceGuard/DeviceGuardImpl 操作
  • 所有 #ifdef USE_CUDA 运行时操作使用 DeviceGuard(遗留 event、datatype 相关代码,待下次 pr 移除)

@kilinchange kilinchange changed the title tmp [WIP] feat: device registry Jan 14, 2026
@kilinchange kilinchange force-pushed the device_registry branch 6 times, most recently from ee4e27b to 2e2b4c3 Compare January 15, 2026 03:33
@kilinchange kilinchange changed the title [WIP] feat: device registry [WIP] feat: device registration Jan 15, 2026
@kilinchange kilinchange force-pushed the device_registry branch 8 times, most recently from c60d854 to ab3274a Compare January 18, 2026 13:42
@kilinchange kilinchange force-pushed the device_registry branch 4 times, most recently from f0b0494 to 2b5ad33 Compare February 6, 2026 05:12
@kilinchange
Copy link
Collaborator Author

image

@kilinchange kilinchange changed the title [WIP] feat: device registration feat: device registration Feb 9, 2026

// stream
Stream *CudaGuardImpl::GetStream(Device device) const {
std::call_once(device_stream_flags.at(device.index()), InitSingleStream, device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_once生命周期是进程级别的,这里可能有点风险(但不大,p4级修复)

- Drop legacy hardware-specific branching
- Convert DeviceGuardImpl base methods to fatal-only fallbacks
- Explicitly implement supported CPU runtime behaviors
- Validate CUDA device type and index bounds in CudaGuardImpl
- Widen DeviceCount return type to prevent truncation
streams.push_back(dynamic_cast<const CudaDevice *>(device)->Stream());
comms.push_back(device_comm_map_.at(device));
streams.push_back(dynamic_cast<infini_train::core::cuda::CudaStream *>(
core::GetDeviceGuardImpl(device.type())->GetStream(device))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

获取Stream的地方很多,是不是也封装一个function,可以放在和GetDeviceGuardImpl同一个位置

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants