Add items() method to nn.Module for state_dict iteration#1556
Add items() method to nn.Module for state_dict iteration#1556alinpahontu2912 wants to merge 1 commit intodotnet:mainfrom
Conversation
Add items() method to nn.Module that returns an enumerator of (name, tensor) tuples from the module's state_dict. This enables easy iteration over all parameters and persistent buffers, consistent with the existing items() pattern in ModuleDict and ParameterDict. This addresses the core request in issue dotnet#1474 by providing the items() API needed for model merging workflows (averaging parameters between models using state_dict + load_state_dict). Changes: - Add virtual items() method to Module class - Add 'new' keyword to ModuleDict.items() and ParameterDict.items() to properly hide the base class method (different return types) - Add tests for items() on simple and nested modules - Add test demonstrating the model merge pattern from the issue Closes dotnet#1474 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
850e628 to
94470d6
Compare
There was a problem hiding this comment.
Pull request overview
Adds a PyTorch-style items() iterator to nn.Module so consumers can iterate (name, tensor) pairs from state_dict()—enabling workflows like model-parameter merging—while updating existing dict-like modules to hide the new base method and adding coverage in the NN test suite.
Changes:
- Add
virtual items()toModules.Moduleto enumerate(name, Tensor)entries fromstate_dict(). - Add
newkeyword toModuleDict.items()andParameterDict.items()to properly hide the baseModule.items()with a different tuple type. - Add xUnit tests covering
items()for simple modules, nested modules, and a model-merge example.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/TorchSharp/NN/Module.cs |
Introduces Module.items() enumerator backed by state_dict(). |
src/TorchSharp/NN/ModuleDict.cs |
Marks items() as new to hide base Module.items(). |
src/TorchSharp/NN/ParameterDict.cs |
Marks items() as new to hide base Module.items(). |
test/TorchSharpTest/NN.cs |
Adds tests for Module.items() and a merge workflow example. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| foreach (var kv in state_dict()) { | ||
| yield return (kv.Key, kv.Value); | ||
| } |
There was a problem hiding this comment.
Module.items() is implemented by calling state_dict(), which allocates and populates a new Dictionary<string, Tensor> on every enumeration. Since this API is intended for iteration, consider yielding directly from named_parameters(recurse: true) and named_buffers(recurse: true, include_nonpersistent: false) to avoid the intermediate dictionary allocation and duplicate traversal work.
| foreach (var kv in state_dict()) { | |
| yield return (kv.Key, kv.Value); | |
| } | |
| foreach (var p in named_parameters(recurse: true)) { | |
| yield return (p.name, p.Item2); | |
| } | |
| foreach (var b in named_buffers(recurse: true, include_nonpersistent: false)) { | |
| yield return (b.name, b.Item2); | |
| } |
| [Fact] | ||
| public void TestModuleItems() | ||
| { | ||
| var lin = Linear(10, 5, true); | ||
| var sd = lin.state_dict(); | ||
| var items = new List<(string, Tensor)>(); | ||
|
|
||
| using (var enumerator = lin.items()) { | ||
| while (enumerator.MoveNext()) { | ||
| items.Add(enumerator.Current); | ||
| } | ||
| } | ||
|
|
||
| // items() should return the same entries as state_dict() | ||
| Assert.Equal(sd.Count, items.Count); | ||
| foreach (var (name, value) in items) { | ||
| Assert.True(sd.ContainsKey(name)); | ||
| Assert.Equal(sd[name].shape, value.shape); | ||
| } | ||
| } |
There was a problem hiding this comment.
The new items() API is documented to include persistent buffers as well as parameters, but the added tests only exercise parameters (Linear / Sequential). Consider adding an assertion using a module with persistent buffers (e.g., BatchNorm) to verify that items() exposes buffer entries too (e.g., running_mean / running_var) and matches state_dict() for those keys.
| using (var enumerator = model1.items()) { | ||
| while (enumerator.MoveNext()) { | ||
| var (name, _) = enumerator.Current; | ||
| merged[name] = (sd1[name] + sd2[name]) / 2; |
There was a problem hiding this comment.
In TestModelMergeUsingItemsAndStateDict, the averaging operation is performed outside a torch.no_grad() context. Since state_dict() returns parameters that typically have requires_grad=true, the arithmetic will build an autograd graph and retain references unnecessarily. Wrapping the merge computation in torch.no_grad() (or detaching/cloning the source tensors) would better reflect the recommended model-merge pattern and avoid extra graph/memory overhead.
| using (var enumerator = model1.items()) { | |
| while (enumerator.MoveNext()) { | |
| var (name, _) = enumerator.Current; | |
| merged[name] = (sd1[name] + sd2[name]) / 2; | |
| using (var _ = no_grad()) { | |
| using (var enumerator = model1.items()) { | |
| while (enumerator.MoveNext()) { | |
| var (name, _) = enumerator.Current; | |
| merged[name] = (sd1[name] + sd2[name]) / 2; | |
| } |
Fixes #1474
Add items() method to nn.Module that returns an enumerator of (name, tensor) tuples from the module's state_dict. This enables easy iteration over all parameters and persistent buffers, consistent with the existing items() pattern in ModuleDict and ParameterDict.
This provides the items() API needed for model merging workflows (averaging parameters between models using state_dict + load_state_dict).
Changes: