Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,25 @@
LOAD_PRETRAINED_MODEL = False
SAVE_MODEL = False

import os

MODEL_PATH = r'C:\Users\Saranga\Desktop\ML4SCI\All MODELS\Classification\Model_I\model1_EfficientNetB2Backbone.pth'
TRAIN_DATA_PATH = r'C:\Users\Saranga\Desktop\ML4SCI\Work\Model_I_subset\*\*'
TEST_DATA_PATH = r'C:\Users\Saranga\Desktop\ML4SCI\Work\Model_I_test\*\*'
# Base directory of Model_I
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# Expected dataset structure:
# data/
# Model_I/
# axion/
# cdm/
# no_sub/
# Model_I_test/
# axion/
# cdm/
# no_sub/

DATA_DIR = os.path.join(BASE_DIR, "data")

TRAIN_DATA_PATH = os.path.join(DATA_DIR, "Model_I")
TEST_DATA_PATH = os.path.join(DATA_DIR, "Model_I_test")

MODEL_PATH = os.path.join(BASE_DIR, "checkpoints", "model.pth")
Original file line number Diff line number Diff line change
@@ -1,109 +1,142 @@
import os
import glob
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import utils
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import utils
import matplotlib.pyplot as plt

from config import BATCH_SIZE, TRAIN_DATA_PATH, TEST_DATA_PATH
from utils import transforms


class CustomDataset(Dataset):
def __init__(self, root_dir, transform = None):
root_list = glob.glob(root_dir)
self.class_map = {}
self.class_distribution = {}
"""
Custom dataset for loading .npy lensing image files.
Expected structure:

root_dir/
class_name_1/
sample1.npy
sample2.npy
class_name_2/
sample1.npy
...
"""

def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform

for img_path in root_list:
class_name = img_path.split(os.sep)[-2]
if class_name not in self.class_distribution:
self.class_distribution[class_name] = 1
else:
self.class_distribution[class_name] +=1

for index, entity in enumerate(self.class_distribution):
self.class_map[entity] = index
# Collect all .npy files inside class folders
root_list = glob.glob(os.path.join(root_dir, "*", "*.npy"))

# print("\nDataset Distribution:")
# print(self.class_distribution)
# print("\nClass indices:")
# print(self.class_map)
# Guard against empty dataset
if len(root_list) == 0:
raise ValueError(
f"No .npy files found in {root_dir}. "
"Expected structure: root/class_name/*.npy"
)

self.data = []
# Build class distribution
self.class_distribution = {}
for img_path in root_list:
class_name = img_path.split(os.sep)[-2]
self.data.append([img_path, class_name])

class_name = os.path.basename(os.path.dirname(img_path))
self.class_distribution[class_name] = \
self.class_distribution.get(class_name, 0) + 1

# Deterministic class mapping (important for reproducibility)
self.class_map = {
cls_name: idx
for idx, cls_name in enumerate(sorted(self.class_distribution.keys()))
}

# Store dataset entries
self.data = [
(img_path, os.path.basename(os.path.dirname(img_path)))
for img_path in root_list
]

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
img_path, class_name = self.data[idx]
img = np.load(img_path, allow_pickle = True)
if class_name == 'axion':

# Load numpy array
img = np.load(img_path, allow_pickle=True)

# Special handling for axion class
if class_name == "axion":
img = img[0]


# Apply transforms if provided
if self.transform:
aug = self.transform(image = img)
img = aug['image']

img = img.to(torch.float)
class_id = self.class_map[class_name]
class_id = torch.tensor(class_id)
aug = self.transform(image=img)
img = aug["image"]

img = img.to(torch.float32)
class_id = torch.tensor(self.class_map[class_name], dtype=torch.long)

return img, class_id


def create_data_loaders(train_data_path, test_data_path, val_split = 0.1, batch_size = 128, transforms = None, class_map = False):
def create_data_loaders(
train_data_path,
test_data_path,
val_split=0.1,
batch_size=128,
transforms=None,
class_map=False,
):
"""
Creates train, validation, and test dataloaders.
"""

dataset = CustomDataset(train_data_path, transform=transforms)
total_size = len(dataset)

print(f"\nTotal training samples: {total_size}")

# Split train/validation
val_size = int(total_size * val_split)
train_size = total_size - val_size

dataset = CustomDataset(train_data_path, transform = transforms)
m = len(dataset)
print("\nTotal training data: " + str(m))
try:
train_set,val_set=random_split(dataset,[int(m-m*val_split),int(m*val_split)])
except:
train_set,val_set=random_split(dataset,[int(m-m*val_split),int(m*val_split+1)])

test_set = CustomDataset(test_data_path, transform = transforms)
train_set, val_set = random_split(dataset, [train_size, val_size])

print(f"\n Number of training set examples: {len(train_set)} \n\
Number of validation set examples: {len(val_set)} \n\
Number of test set examples: {len(test_set)}")
test_set = CustomDataset(test_data_path, transform=transforms)

train_loader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle = False)
test_loader = DataLoader(test_set, batch_size = batch_size, shuffle = False)
print(
f"\nNumber of training samples: {len(train_set)}"
f"\nNumber of validation samples: {len(val_set)}"
f"\nNumber of test samples: {len(test_set)}"
)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

if class_map:
return train_loader, val_loader, test_loader, dataset.class_map

return train_loader, val_loader, test_loader



if __name__ == "__main__":

train_loader, val_loader, test_loader = create_data_loaders(TRAIN_DATA_PATH, TEST_DATA_PATH,
val_split = 0.2, batch_size = BATCH_SIZE,
transforms = transforms)
# Quick sanity check
train_loader, val_loader, test_loader = create_data_loaders(
TRAIN_DATA_PATH,
TEST_DATA_PATH,
val_split=0.2,
batch_size=BATCH_SIZE,
transforms=transforms,
)

single_batch = next(iter(train_loader))
print(f"\nShape of one batch of training data: {single_batch[0].shape}")
single_batch_grid = utils.make_grid(single_batch[0], nrow=8)
plt.figure(figsize = (20,700))
plt.imshow(single_batch_grid.permute(1, 2, 0))
plt.show()





print(f"\nShape of one batch: {single_batch[0].shape}")




grid = utils.make_grid(single_batch[0], nrow=8)
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0))
plt.title("Sample Training Batch")
plt.show()
Loading