Skip to content

Resume from ckpt#135

Draft
kevssim wants to merge 22 commits intomodelscope:mainfrom
kevssim:resume_from_ckpt
Draft

Resume from ckpt#135
kevssim wants to merge 22 commits intomodelscope:mainfrom
kevssim:resume_from_ckpt

Conversation

@kevssim
Copy link
Copy Markdown
Collaborator

@kevssim kevssim commented Mar 31, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Write the detail information belongs to this PR.

Experiment results

Paste your experiment result here(if needed).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a comprehensive "Strict Resume" feature for Transformers models, enabling the restoration of full training state including optimizer, scheduler, scaler, RNG states, and data progress. Key changes involve implementing load_training_state and read_training_progress across the model, server, and client layers, alongside dataloader enhancements to support sample-level skipping for map-style datasets. Feedback highlights several critical improvements: ensuring deterministic RNG in distributed settings by avoiding unseeded random states, addressing the deprecated use of StopIteration in generators, improving security by using weights_only=True during checkpoint loading, and removing an accidental BOM character in the client generator. Additionally, a more robust approach for re-initializing the dataloader is suggested to avoid modifying private PyTorch attributes.

if emitted >= target_total:
return

for idx in np.random.RandomState().permutation(len(self.dataset)).tolist():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using np.random.RandomState() without a seed creates a new, non-deterministic random state. In a distributed training environment, this can lead to different ranks generating different permutations during the retry/backfill phase, causing data divergence across ranks and potentially breaking the training process. It is recommended to use a seeded random state or the global np.random state if it has been properly seeded during initialization.

Comment on lines +137 to +139
self.dataloader.__initialized = False
self._rebuild_sampler_stack()
self.dataloader.__initialized = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing and modifying the private attribute __initialized of torch.utils.data.DataLoader is brittle and relies on internal implementation details of PyTorch that could change. A safer approach to update the sampler stack after the dataloader has been created is to simply re-instantiate the underlying self.dataloader using the stored self.dataloader_params.

Suggested change
self.dataloader.__initialized = False
self._rebuild_sampler_stack()
self.dataloader.__initialized = True
if self.dataloader is not None:
self.dataloader = None
self._lazy_init_dataloader()

@@ -39,12 +45,11 @@ def __iter__(self):
else:
raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Manually raising StopIteration inside a generator (a function using yield) is deprecated since PEP 479 and will be converted into a RuntimeError in Python 3.7+. Since this represents an error condition (max retries exceeded), it is better to raise a RuntimeError or ValueError directly to provide a clear error message to the user.

Suggested change
raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.')
raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.')

if hasattr(self.strategy, 'load_optimizer_checkpoint'):
self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path)
else:
state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using weights_only=False when loading checkpoints via torch.load can be a security risk if the checkpoint file is untrusted, as it allows the execution of arbitrary code during unpickling. Since these are standard state dictionaries (optimizer, scheduler, RNG), they should be compatible with weights_only=True in modern PyTorch versions. This applies to lines 998, 1007, and 1029 as well.

Suggested change
state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False)
state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant