Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a comprehensive "Strict Resume" feature for Transformers models, enabling the restoration of full training state including optimizer, scheduler, scaler, RNG states, and data progress. Key changes involve implementing load_training_state and read_training_progress across the model, server, and client layers, alongside dataloader enhancements to support sample-level skipping for map-style datasets. Feedback highlights several critical improvements: ensuring deterministic RNG in distributed settings by avoiding unseeded random states, addressing the deprecated use of StopIteration in generators, improving security by using weights_only=True during checkpoint loading, and removing an accidental BOM character in the client generator. Additionally, a more robust approach for re-initializing the dataloader is suggested to avoid modifying private PyTorch attributes.
| if emitted >= target_total: | ||
| return | ||
|
|
||
| for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): |
There was a problem hiding this comment.
Using np.random.RandomState() without a seed creates a new, non-deterministic random state. In a distributed training environment, this can lead to different ranks generating different permutations during the retry/backfill phase, causing data divergence across ranks and potentially breaking the training process. It is recommended to use a seeded random state or the global np.random state if it has been properly seeded during initialization.
| self.dataloader.__initialized = False | ||
| self._rebuild_sampler_stack() | ||
| self.dataloader.__initialized = True |
There was a problem hiding this comment.
Accessing and modifying the private attribute __initialized of torch.utils.data.DataLoader is brittle and relies on internal implementation details of PyTorch that could change. A safer approach to update the sampler stack after the dataloader has been created is to simply re-instantiate the underlying self.dataloader using the stored self.dataloader_params.
| self.dataloader.__initialized = False | |
| self._rebuild_sampler_stack() | |
| self.dataloader.__initialized = True | |
| if self.dataloader is not None: | |
| self.dataloader = None | |
| self._lazy_init_dataloader() |
| @@ -39,12 +45,11 @@ def __iter__(self): | |||
| else: | |||
| raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') | |||
There was a problem hiding this comment.
Manually raising StopIteration inside a generator (a function using yield) is deprecated since PEP 479 and will be converted into a RuntimeError in Python 3.7+. Since this represents an error condition (max retries exceeded), it is better to raise a RuntimeError or ValueError directly to provide a clear error message to the user.
| raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') | |
| raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.') |
| if hasattr(self.strategy, 'load_optimizer_checkpoint'): | ||
| self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) | ||
| else: | ||
| state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) |
There was a problem hiding this comment.
Using weights_only=False when loading checkpoints via torch.load can be a security risk if the checkpoint file is untrusted, as it allows the execution of arbitrary code during unpickling. Since these are standard state dictionaries (optimizer, scheduler, RNG), they should be compatible with weights_only=True in modern PyTorch versions. This applies to lines 998, 1007, and 1029 as well.
| state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=False) | |
| state_dict = torch.load(optimizer_path, map_location='cpu', weights_only=True) |
PR type
PR information
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).