3D Mutil-model segmentation algorithm
- First, clone this repository and create the environment.
conda create -n moe python=3.10.13
conda activate moeNote: Before proceeding to the next step, please ensure that CUDA-11.8 is installed and its path has been added to the environment variable.
# Check CUDA version
nvcc -V # Should return CUDA version 11.8- Install dependencies and the
mambapackage.
# Install PyTorch with CUDA 11.8
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
# Install dependencies
pip install -r requirements.txt
Main directories and files are listed below:
.
├── configs # Training-related configurations
│ ├── datasets # Dataset configurations
│ ├── models # Model configurations
│ └── trainers # Training hyperparameters
├── networks # Comparison models
│ ├── ...
│ └── unetr.py
├── runs # Logs and checkpoints save root
├── scripts # Training/testing/parameter calculation scripts
│ ├── ...
│ ├── calc_params_flops.sh
│ ├── test_synapse.sh
│ └── train_synapse.sh
├── train.py
├── test.py
└── main.py
It's easy to train and test using the provided scripts and configurations.
# Example command to run the script
sh scripts/train_synapse.sh
# or
sh scripts/test_synapse.shsh scripts/train_DLCBL.sh sh scripts/train_auto_lym.sh sh scripts/train_HL.sh
For datasets with all labeled cases listed in dataset.json (training and/or validation), you can run one fold with:
python main.py --data_config Dataset800_HL --model_config maanet --trainer_config sgd_trainer_1_batch --fold 0--fold: fold index in[0, 4]--num_folds: default5--splits_file: defaultsplits_final.json(auto-generated in dataset folder if not found)
Run all folds by looping in shell:
for fold in 0 1 2 3 4; do
python main.py --data_config Dataset800_HL --model_config maanet --trainer_config sgd_trainer_1_batch --fold "${fold}" --logdir hl_maanet_5fold
done
'''仅当前仓库生效:git config pull.rebase true
全局生效:git config --global pull.rebase true'''