Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:

Args:
old_state_dict: state dict from the old AutoencoderKL model.
verbose: if True, print diagnostic information about key mismatches.
"""

new_state_dict = self.state_dict()
Expand Down Expand Up @@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")

# old version did not have a projection so set these to the identity
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
)
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
new_state_dict[f"{block}.attn.out_proj.bias"].shape
)
out_w = f"{block}.attn.out_proj.weight"
out_b = f"{block}.attn.out_proj.bias"
proj_w = f"{block}.proj_attn.weight"
proj_b = f"{block}.proj_attn.bias"

if out_w in new_state_dict:
if proj_w in old_state_dict:
new_state_dict[out_w] = old_state_dict.pop(proj_w)
if proj_b in old_state_dict:
new_state_dict[out_b] = old_state_dict.pop(proj_b)
else:
new_state_dict[out_b] = torch.zeros(
new_state_dict[out_b].shape,
dtype=new_state_dict[out_b].dtype,
device=new_state_dict[out_b].device,
)
else:
# No legacy proj_attn - initialize out_proj to identity/zero
new_state_dict[out_w] = torch.eye(
new_state_dict[out_w].shape[0],
dtype=new_state_dict[out_w].dtype,
device=new_state_dict[out_w].device,
)
new_state_dict[out_b] = torch.zeros(
new_state_dict[out_b].shape,
dtype=new_state_dict[out_b].dtype,
device=new_state_dict[out_b].device,
)
elif proj_w in old_state_dict:
# new model has no out_proj at all - discard the legacy keys so they
# don't surface as "unexpected keys" during load_state_dict
old_state_dict.pop(proj_w)
old_state_dict.pop(proj_b, None)

# fix the upsample conv blocks which were renamed postconv
for k in new_state_dict:
Expand Down
101 changes: 101 additions & 0 deletions tests/networks/nets/test_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@


class TestAutoEncoderKL(unittest.TestCase):
_MIGRATION_PARAMS = {
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"channels": (4, 4, 4),
"latent_channels": 4,
"attention_levels": (False, False, False),
"num_res_blocks": 1,
"norm_num_groups": 4,
}

@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
net = AutoencoderKL(**input_param).to(device)
Expand Down Expand Up @@ -327,6 +338,96 @@ def test_compatibility_with_monai_generative(self):

net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)

@staticmethod
def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict:
"""Convert new-style state dict keys to legacy naming conventions.

Args:
new_sd: State dict with current key naming.
include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`.

Returns:
State dict with legacy key names.
"""
old_sd: dict = {}
for k, v in new_sd.items():
if ".attn.to_q." in k:
old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone()
elif ".attn.to_k." in k:
old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone()
elif ".attn.to_v." in k:
old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone()
elif ".attn.out_proj." in k:
if include_proj_attn:
old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone()
elif "postconv" in k:
old_sd[k.replace("postconv", "conv")] = v.clone()
else:
old_sd[k] = v.clone()
return old_sd

@skipUnless(has_einops, "Requires einops")
def test_load_old_state_dict_proj_attn_copied_to_out_proj(self):
params = {**self._MIGRATION_PARAMS, "include_fc": True}
src = AutoencoderKL(**params).to(device)
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True)

# record the tensor values that were stored under proj_attn
expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k}
self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config")

dst = AutoencoderKL(**params).to(device)
dst.load_old_state_dict(old_sd)

for new_key, expected_val in expected.items():
torch.testing.assert_close(
dst.state_dict()[new_key], expected_val.to(device), msg=f"Weight mismatch for {new_key}"
)

@skipUnless(has_einops, "Requires einops")
def test_load_old_state_dict_missing_proj_attn_initialises_identity(self):
params = {**self._MIGRATION_PARAMS, "include_fc": True}
src = AutoencoderKL(**params).to(device)
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)

dst = AutoencoderKL(**params).to(device)
dst.load_old_state_dict(old_sd)
loaded = dst.state_dict()

out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k]
out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k]
self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config")

for k in out_proj_weights:
n = loaded[k].shape[0]
torch.testing.assert_close(
loaded[k], torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix"
)
for k in out_proj_biases:
torch.testing.assert_close(loaded[k], torch.zeros_like(loaded[k]), msg=f"{k} should be all-zeros")

@skipUnless(has_einops, "Requires einops")
def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self):
params = {**self._MIGRATION_PARAMS, "include_fc": False}
src = AutoencoderKL(**params).to(device)
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)

# inject synthetic proj_attn keys (mimic an old checkpoint)
attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")]
self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config")
for block in attn_blocks:
ch = old_sd[f"{block}.to_q.weight"].shape[0]
old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch)
old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch)

dst = AutoencoderKL(**params).to(device)
dst.load_old_state_dict(old_sd)

loaded = dst.state_dict()
self.assertFalse(
any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False"
)


if __name__ == "__main__":
unittest.main()
Loading