Skip to content

Add items() method to nn.Module for state_dict iteration#1556

Open
alinpahontu2912 wants to merge 1 commit intodotnet:mainfrom
alinpahontu2912:feature/module-items-and-merge
Open

Add items() method to nn.Module for state_dict iteration#1556
alinpahontu2912 wants to merge 1 commit intodotnet:mainfrom
alinpahontu2912:feature/module-items-and-merge

Conversation

@alinpahontu2912
Copy link
Member

Fixes #1474

Add items() method to nn.Module that returns an enumerator of (name, tensor) tuples from the module's state_dict. This enables easy iteration over all parameters and persistent buffers, consistent with the existing items() pattern in ModuleDict and ParameterDict.

This provides the items() API needed for model merging workflows (averaging parameters between models using state_dict + load_state_dict).

Changes:

  • Add virtual items() method to Module class
  • Add 'new' keyword to ModuleDict.items() and ParameterDict.items() to properly hide the base class method (different return types)
  • Add tests for items() on simple and nested modules
  • Add test demonstrating the model merge pattern from the issue

Add items() method to nn.Module that returns an enumerator of (name, tensor)
tuples from the module's state_dict. This enables easy iteration over all
parameters and persistent buffers, consistent with the existing items()
pattern in ModuleDict and ParameterDict.

This addresses the core request in issue dotnet#1474 by providing the items()
API needed for model merging workflows (averaging parameters between
models using state_dict + load_state_dict).

Changes:
- Add virtual items() method to Module class
- Add 'new' keyword to ModuleDict.items() and ParameterDict.items() to
  properly hide the base class method (different return types)
- Add tests for items() on simple and nested modules
- Add test demonstrating the model merge pattern from the issue

Closes dotnet#1474

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a PyTorch-style items() iterator to nn.Module so consumers can iterate (name, tensor) pairs from state_dict()—enabling workflows like model-parameter merging—while updating existing dict-like modules to hide the new base method and adding coverage in the NN test suite.

Changes:

  • Add virtual items() to Modules.Module to enumerate (name, Tensor) entries from state_dict().
  • Add new keyword to ModuleDict.items() and ParameterDict.items() to properly hide the base Module.items() with a different tuple type.
  • Add xUnit tests covering items() for simple modules, nested modules, and a model-merge example.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
src/TorchSharp/NN/Module.cs Introduces Module.items() enumerator backed by state_dict().
src/TorchSharp/NN/ModuleDict.cs Marks items() as new to hide base Module.items().
src/TorchSharp/NN/ParameterDict.cs Marks items() as new to hide base Module.items().
test/TorchSharpTest/NN.cs Adds tests for Module.items() and a merge workflow example.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +522 to +524
foreach (var kv in state_dict()) {
yield return (kv.Key, kv.Value);
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module.items() is implemented by calling state_dict(), which allocates and populates a new Dictionary<string, Tensor> on every enumeration. Since this API is intended for iteration, consider yielding directly from named_parameters(recurse: true) and named_buffers(recurse: true, include_nonpersistent: false) to avoid the intermediate dictionary allocation and duplicate traversal work.

Suggested change
foreach (var kv in state_dict()) {
yield return (kv.Key, kv.Value);
}
foreach (var p in named_parameters(recurse: true)) {
yield return (p.name, p.Item2);
}
foreach (var b in named_buffers(recurse: true, include_nonpersistent: false)) {
yield return (b.name, b.Item2);
}

Copilot uses AI. Check for mistakes.
Comment on lines +3312 to +3331
[Fact]
public void TestModuleItems()
{
var lin = Linear(10, 5, true);
var sd = lin.state_dict();
var items = new List<(string, Tensor)>();

using (var enumerator = lin.items()) {
while (enumerator.MoveNext()) {
items.Add(enumerator.Current);
}
}

// items() should return the same entries as state_dict()
Assert.Equal(sd.Count, items.Count);
foreach (var (name, value) in items) {
Assert.True(sd.ContainsKey(name));
Assert.Equal(sd[name].shape, value.shape);
}
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new items() API is documented to include persistent buffers as well as parameters, but the added tests only exercise parameters (Linear / Sequential). Consider adding an assertion using a module with persistent buffers (e.g., BatchNorm) to verify that items() exposes buffer entries too (e.g., running_mean / running_var) and matches state_dict() for those keys.

Copilot uses AI. Check for mistakes.
Comment on lines +3365 to +3368
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TestModelMergeUsingItemsAndStateDict, the averaging operation is performed outside a torch.no_grad() context. Since state_dict() returns parameters that typically have requires_grad=true, the arithmetic will build an autograd graph and retain references unnecessarily. Wrapping the merge computation in torch.no_grad() (or detaching/cloning the source tensors) would better reflect the recommended model-merge pattern and avoid extra graph/memory overhead.

Suggested change
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
using (var _ = no_grad()) {
using (var enumerator = model1.items()) {
while (enumerator.MoveNext()) {
var (name, _) = enumerator.Current;
merged[name] = (sd1[name] + sd2[name]) / 2;
}

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Merge items().

2 participants