From f85f109b1aca04059148e64082864fe2429abbd8 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 00:19:28 +0800 Subject: [PATCH 1/6] =?UTF-8?q?fix(autoencoderkl):=20handle=20proj=5Fattn?= =?UTF-8?q?=E2=86=92out=5Fproj=20key=20mapping=20in=20load=5Fold=5Fstate?= =?UTF-8?q?=5Fdict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ytl0623 --- monai/networks/nets/autoencoderkl.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index b5a282a340..3d60370401 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -715,13 +715,25 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") - # old version did not have a projection so set these to the identity - new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( - new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] - ) - new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros( - new_state_dict[f"{block}.attn.out_proj.bias"].shape - ) + out_w = f"{block}.attn.out_proj.weight" + out_b = f"{block}.attn.out_proj.bias" + proj_w = f"{block}.proj_attn.weight" + proj_b = f"{block}.proj_attn.bias" + + if out_w in new_state_dict: + if proj_w in old_state_dict: + new_state_dict[out_w] = old_state_dict.pop(proj_w) + if proj_b in old_state_dict: + new_state_dict[out_b] = old_state_dict.pop(proj_b) + else: + # weights pre-date proj_attn: initialise to identity / zero + new_state_dict[out_w] = torch.eye(new_state_dict[out_w].shape[0]) + new_state_dict[out_b] = torch.zeros(new_state_dict[out_b].shape) + elif proj_w in old_state_dict: + # new model has no out_proj at all – discard the legacy keys so they + # don't surface as "unexpected keys" during load_state_dict + old_state_dict.pop(proj_w) + old_state_dict.pop(proj_b) # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: From 5db2a70a0d91670a70f447ac7a16bbdf6e5e3709 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 09:25:04 +0800 Subject: [PATCH 2/6] add unit test Signed-off-by: ytl0623 --- monai/networks/nets/autoencoderkl.py | 15 ++-- tests/networks/nets/test_autoencoderkl.py | 90 +++++++++++++++++++++++ 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 3d60370401..372493063c 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: Args: old_state_dict: state dict from the old AutoencoderKL model. + verbose: if True, print diagnostic information about key mismatches. """ new_state_dict = self.state_dict() @@ -725,15 +726,17 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[out_w] = old_state_dict.pop(proj_w) if proj_b in old_state_dict: new_state_dict[out_b] = old_state_dict.pop(proj_b) - else: - # weights pre-date proj_attn: initialise to identity / zero - new_state_dict[out_w] = torch.eye(new_state_dict[out_w].shape[0]) - new_state_dict[out_b] = torch.zeros(new_state_dict[out_b].shape) + else: + new_state_dict[out_b] = torch.zeros( + new_state_dict[out_b].shape, + dtype=new_state_dict[out_b].dtype, + device=new_state_dict[out_b].device, + ) elif proj_w in old_state_dict: - # new model has no out_proj at all – discard the legacy keys so they + # new model has no out_proj at all - discard the legacy keys so they # don't surface as "unexpected keys" during load_state_dict old_state_dict.pop(proj_w) - old_state_dict.pop(proj_b) + old_state_dict.pop(proj_b, None) # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index bbe2840164..c0cbfcf021 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -327,6 +327,96 @@ def test_compatibility_with_monai_generative(self): net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) + @staticmethod + def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict: + old_sd: dict = {} + for k, v in new_sd.items(): + if ".attn.to_q." in k: + old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone() + elif ".attn.to_k." in k: + old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone() + elif ".attn.to_v." in k: + old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone() + elif ".attn.out_proj." in k: + if include_proj_attn: + old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone() + elif "postconv" in k: + old_sd[k.replace("postconv", "conv")] = v.clone() + else: + old_sd[k] = v.clone() + return old_sd + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_proj_attn_copied_to_out_proj(self): + params = {**self._MIGRATION_PARAMS, "include_fc": True} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True) + + # record the tensor values that were stored under proj_attn + expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k} + self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict – check model config") + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + + for new_key, expected_val in expected.items(): + torch.testing.assert_close( + dst.state_dict()[new_key], + expected_val.to(device), + msg=f"Weight mismatch for {new_key}", + ) + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): + params = {**self._MIGRATION_PARAMS, "include_fc": True} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + loaded = dst.state_dict() + + out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k] + out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k] + self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found – check model config") + + for k in out_proj_weights: + n = loaded[k].shape[0] + torch.testing.assert_close( + loaded[k], + torch.eye(n, device=device), + msg=f"{k} should be an identity matrix", + ) + for k in out_proj_biases: + torch.testing.assert_close( + loaded[k], + torch.zeros_like(loaded[k]), + msg=f"{k} should be all-zeros", + ) + + @skipUnless(has_einops, "Requires einops") + def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): + params = {**self._MIGRATION_PARAMS, "include_fc": False} + src = AutoencoderKL(**params).to(device) + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) + + # inject synthetic proj_attn keys (mimic an old checkpoint) + attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")] + self.assertGreater(len(attn_blocks), 0, "No attention blocks found – check model config") + for block in attn_blocks: + ch = old_sd[f"{block}.to_q.weight"].shape[0] + old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch) + old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch) + + dst = AutoencoderKL(**params).to(device) + dst.load_old_state_dict(old_sd) + + loaded = dst.state_dict() + self.assertFalse( + any("out_proj" in k for k in loaded), + "out_proj should not exist in a model built with include_fc=False", + ) + if __name__ == "__main__": unittest.main() From 703436fa60c3e7eba2a39bea4dfc2d15e27e96e7 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 09:48:51 +0800 Subject: [PATCH 3/6] add missing class attribute Signed-off-by: ytl0623 --- tests/networks/nets/test_autoencoderkl.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index c0cbfcf021..8fd69077eb 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -169,6 +169,17 @@ class TestAutoEncoderKL(unittest.TestCase): + _MIGRATION_PARAMS = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": 1, + "norm_num_groups": 4, + } + @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): net = AutoencoderKL(**input_param).to(device) @@ -329,6 +340,15 @@ def test_compatibility_with_monai_generative(self): @staticmethod def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict: + """Convert new-style state dict keys to legacy naming conventions. + + Args: + new_sd: State dict with current key naming. + include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`. + + Returns: + State dict with legacy key names. + """ old_sd: dict = {} for k, v in new_sd.items(): if ".attn.to_q." in k: @@ -354,7 +374,7 @@ def test_load_old_state_dict_proj_attn_copied_to_out_proj(self): # record the tensor values that were stored under proj_attn expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k} - self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict – check model config") + self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config") dst = AutoencoderKL(**params).to(device) dst.load_old_state_dict(old_sd) From 7ecd0198d2f8d320d2adf8c73badfba095b3fa77 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 09:57:56 +0800 Subject: [PATCH 4/6] add identity initialization Signed-off-by: ytl0623 --- monai/networks/nets/autoencoderkl.py | 12 ++++++++++++ tests/networks/nets/test_autoencoderkl.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 372493063c..fd7d1b854a 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -732,6 +732,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: dtype=new_state_dict[out_b].dtype, device=new_state_dict[out_b].device, ) + else: + # No legacy proj_attn – initialize out_proj to identity/zero + new_state_dict[out_w] = torch.eye( + new_state_dict[out_w].shape[0], + dtype=new_state_dict[out_w].dtype, + device=new_state_dict[out_w].device, + ) + new_state_dict[out_b] = torch.zeros( + new_state_dict[out_b].shape, + dtype=new_state_dict[out_b].dtype, + device=new_state_dict[out_b].device, + ) elif proj_w in old_state_dict: # new model has no out_proj at all - discard the legacy keys so they # don't surface as "unexpected keys" during load_state_dict diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index 8fd69077eb..d955efe88d 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -398,7 +398,7 @@ def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k] out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k] - self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found – check model config") + self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config") for k in out_proj_weights: n = loaded[k].shape[0] @@ -422,7 +422,7 @@ def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): # inject synthetic proj_attn keys (mimic an old checkpoint) attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")] - self.assertGreater(len(attn_blocks), 0, "No attention blocks found – check model config") + self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config") for block in attn_blocks: ch = old_sd[f"{block}.to_q.weight"].shape[0] old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch) From 89abbf1a9d0214e4cbe2cac0d14f4d6d8591d694 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 14:31:18 +0800 Subject: [PATCH 5/6] minor fixes Signed-off-by: ytl0623 --- monai/networks/nets/autoencoderkl.py | 2 +- tests/networks/nets/test_autoencoderkl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index fd7d1b854a..11b4fcfc9e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -733,7 +733,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: device=new_state_dict[out_b].device, ) else: - # No legacy proj_attn – initialize out_proj to identity/zero + # No legacy proj_attn - initialize out_proj to identity/zero new_state_dict[out_w] = torch.eye( new_state_dict[out_w].shape[0], dtype=new_state_dict[out_w].dtype, diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index d955efe88d..8940b670ec 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -404,7 +404,7 @@ def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): n = loaded[k].shape[0] torch.testing.assert_close( loaded[k], - torch.eye(n, device=device), + torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix", ) for k in out_proj_biases: From 1cf6ff4b512d6873f91ffef7f471ae453f0f9e0d Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 20 Mar 2026 14:57:23 +0800 Subject: [PATCH 6/6] reformatted Signed-off-by: ytl0623 --- tests/networks/nets/test_autoencoderkl.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index 8940b670ec..af0c55d6ec 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -381,9 +381,7 @@ def test_load_old_state_dict_proj_attn_copied_to_out_proj(self): for new_key, expected_val in expected.items(): torch.testing.assert_close( - dst.state_dict()[new_key], - expected_val.to(device), - msg=f"Weight mismatch for {new_key}", + dst.state_dict()[new_key], expected_val.to(device), msg=f"Weight mismatch for {new_key}" ) @skipUnless(has_einops, "Requires einops") @@ -403,16 +401,10 @@ def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): for k in out_proj_weights: n = loaded[k].shape[0] torch.testing.assert_close( - loaded[k], - torch.eye(n, dtype=loaded[k].dtype, device=device), - msg=f"{k} should be an identity matrix", + loaded[k], torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix" ) for k in out_proj_biases: - torch.testing.assert_close( - loaded[k], - torch.zeros_like(loaded[k]), - msg=f"{k} should be all-zeros", - ) + torch.testing.assert_close(loaded[k], torch.zeros_like(loaded[k]), msg=f"{k} should be all-zeros") @skipUnless(has_einops, "Requires einops") def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): @@ -433,8 +425,7 @@ def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): loaded = dst.state_dict() self.assertFalse( - any("out_proj" in k for k in loaded), - "out_proj should not exist in a model built with include_fc=False", + any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False" )