forked from ramp-kits/stroke_lesions
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
69 lines (43 loc) · 1.68 KB
/
test.py
File metadata and controls
69 lines (43 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import scipy.ndimage as nd
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.dummy import DummyClassifier
from keras.layers import Input, MaxPooling3D, UpSampling3D, Conv3D, Reshape, Conv3DTranspose
import problem
import numpy as np
import matplotlib.pylab as plt
import submissions.starting_kit.keras_segmentation_classifier as classifier
module_path = '.'
train_ids = problem.get_train_data()
print(train_ids)
#spl = problem.ImageLoader([1,2,3,4])
simp = problem.SimplifiedSegmentationClassifier()
clf = simp.train_submission(module_path=module_path,patient_ids=train_ids)
# n_classes=[0,1]
# img_loader = problem.ImageLoader(patient_ids=train_ids, n_classes=n_classes)
# clf.fit(img_loader)
#test_ids = problem.get_test_data()
#score = simp.test_submission(module_path = module_path,trained_model = clf, patient_idxs = test_ids)
'''
train = problem.get_train_data()
X, y = train
pred = classifier.Classifier().predict(X)
pred_prob = classifier.Classifier().predict_proba(X)
fitit = classifier.Classifier().fit(X, y)
classifier.Classifier().predict(X)
features = classifier.Classifier()._get_features_strided(X)
y_new = classifier.Classifier()._unpack_y(y)
split_ids = problem._read_ids('.')
path = '.'
subject_id = 31970
X = np.stack([problem._read_brain_image(path, subject_id) for subject_id in split_ids])
Y = np.stack([problem._read_stroke_segmentation(path, subject_id) for subject_id in split_ids])
sys.getsizeof(test)
problem._get_pati
plt.subplot(1,2,1)
plt.title('data')
plt.imshow(train[0,:,:,100])
plt.subplot(1,2,2)
plt.title('mask')