Skip to content
Open
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.

You should eventually see a training run as:

Expand Down
13 changes: 7 additions & 6 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
CONTEXT = "context"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
Expand Down Expand Up @@ -67,18 +68,18 @@
### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, CONTEXT],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, CONTEXT],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, None],
]
26 changes: 14 additions & 12 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -166,31 +166,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_length', 'context'],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'context'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
25 changes: 14 additions & 11 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -154,30 +154,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],

['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_length', 'context'],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'context'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def run(config, pipeline=None, filename_prefix=""):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
flax.config.update("flax_always_shard_variable", False)
try:
flax.config.update("flax_always_shard_variable", False)
except:
pass
run(pyconfig.config)


Expand Down
35 changes: 24 additions & 11 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True):
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]
if "dcn_context_parallelism" in config.get_keys() and "ici_context_parallelism" in config.get_keys():
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_context_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_context_parallelism,
config.ici_tensor_parallelism,
]
else:
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
Expand Down
Loading