-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplottypes.py
More file actions
275 lines (224 loc) · 10.9 KB
/
plottypes.py
File metadata and controls
275 lines (224 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""
Infrastructure for plot.py:
- Helper classes for groupwise plots: data subsets handling
- Parameterized plotting routines
"""
import math
import os
import traceback
import typing as tg
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
class Subset(dict):
def __getattr__(self, attrname):
"""subset dictionary keys become Python pseudo attributes"""
return self[attrname]
class Rows(Subset):
def rows_(self, df: pd.DataFrame) -> pd.Series:
return self['rows'](df)
def check_validity(self):
if not callable(self.get('rows', False)):
raise TypeError(f"Rows({self}) is missing essential callable 'rows'")
class Values(Subset):
def values_(self, df: pd.DataFrame) -> tg.Any:
return self['values'](df)
def check_validity(self):
if not callable(self.get('values', False)):
raise TypeError(f"Values({self}) is missing essential callable 'values'")
class Subsets(list):
"""
Define and handle overlapping subsets of something, e.g. rows in pd.Dataframes.
A list of either Rows or Values objects, e.g.:
ss = Subsets([Rows(rows=lambda df: df.a >= 0, color='red', label="A"),
Rows(rows=lambda df: df.b <= 0, color='blue', label="B")])
If you initialize it all at once, like above, the consistency will be checked.
Then iterate over groups and retrieve their rows and other attributes, e.g.
for subset in ss:
somecall(subset['color'], subset.rows(df).somecolumn.mean())
"""
def __init__(self, descriptor: tg.Union[tg.Iterable[Rows], tg.Iterable[Values]]):
super().__init__(descriptor)
some_subset = self[0]
some_subset.check_validity()
the_attrs = set(some_subset.keys())
# ----- check attributes consistency:
for idx, descr in enumerate(self):
descr_attrs = set(descr.keys())
if descr_attrs > the_attrs:
raise ValueError(f"subset {idx} has extra attributes: %s" %
(descr_attrs - the_attrs))
if the_attrs > descr_attrs:
raise ValueError(f"subset {idx} has attributes missing: %s" %
(the_attrs - descr_attrs))
class PlotContext:
"""Configuration object to supply to generic plotting operations."""
outputdir: str
basename: str
df: pd.DataFrame
subsets: tg.Optional[Subsets] = None
inner_subsets: tg.Optional[Subsets] = None
fig: tg.Optional[mpl.figure.Figure] = None
ax: tg.Optional[mpl.axes.Axes] = None
def __init__(self, outputdir, basename, df, height, width,
subsets=None, inner_subsets=None):
self.outputdir = outputdir
self.basename = basename
self.df = df
self.subsets = subsets
self.inner_subsets = inner_subsets
mpl.rcParams.update({'font.size': 8})
figsize = (width, height)
self.fig = mpl.figure.Figure(figsize=figsize, layout='constrained')
self.ax = self.fig.add_subplot()
def savefig(self):
self.fig.savefig(os.path.join(self.outputdir, self.basename + '.pdf'))
def again_for(self, basename: str):
"""Reuse for another plot: Clear Figure, set a different filename, else the same."""
self.fig.clear()
self.ax = self.fig.add_subplot()
self.basename = basename
AddXletOp = tg.Callable[[PlotContext, float, tg.Any, Subset], None]
def plot_xletgroups(ctx: PlotContext, add_op: AddXletOp, plottype: str, basename: str,
ylabel: str, *, ymax=None):
"""Draw groups of xlets (one per inner_subset) for all subsets."""
ctx.again_for(f"{plottype}_xletgroups_{basename}")
ctx.ax.set_ylim(bottom=0, top=ymax)
ctx.ax.set_ylabel(ylabel)
ctx.ax.grid(axis='y', linewidth=0.1)
inner_x_min = min((sub['x'] for sub in ctx.inner_subsets))
inner_x_max = max((sub['x'] for sub in ctx.inner_subsets))
inner_width = inner_x_max - inner_x_min
xticks = []
for subset in ctx.subsets:
x = (subset['x'] * inner_width * 1.25)
xticks.append(x)
add_xletgroup(ctx, x - inner_width/2, add_op, subset)
ctx.ax.set_xticks(ticks=xticks, labels=[ss['label'] for ss in ctx.subsets])
ctx.savefig()
def add_xletgroup(ctx: PlotContext, x: float, add_xlet_op: AddXletOp, subset: Subset):
for inner_subset in ctx.inner_subsets:
rows = subset if isinstance(subset, Rows) else inner_subset
values = subset if isinstance(subset, Values) else inner_subset
rows_data: pd.Series = rows.rows_(ctx.df)
assert isinstance(rows_data, pd.Series), type(rows_data)
xlet_data = values.values_(ctx.df[rows_data.values])
add_xlet_op(ctx, x, xlet_data, inner_subset)
def add_boxplotlet(ctx: PlotContext, x: float,
xlet_data: tg.Any, inner_subset: Subset):
color = inner_subset.get('color', "mediumblue")
xlet_x = x + inner_subset['x']
ctx.ax.boxplot(
[xlet_data[xlet_data.notna()]],
notch=False, whis=(10, 90),
positions=[xlet_x], labels=[""], # labels are per-group only
widths=0.8, capwidths=0.2,
showfliers=False, showmeans=True,
patch_artist=True, boxprops=dict(facecolor=color),
medianprops=dict(color='grey'),
meanprops=dict(marker="o", markersize=3,
markerfacecolor="orange", markeredgecolor="orange"))
def add_nonzerofractionbarplotlet(ctx: PlotContext, x: float, xlet_data: tg.Any, inner_subset: Subset):
"""One bar that shows what fraction (in percent) of the data is nonzero"""
color = inner_subset.get('color', "mediumblue")
xlet_x = x + inner_subset['x']
y = 100 * ((xlet_data != 0).sum() / len(xlet_data))
ctx.ax.bar(x=xlet_x, height=y, width=0.8, label="", color=color)
def add_nonzerofractionbarplotlet_with_errorbars(ctx: PlotContext, x: float, xlet_data: tg.Any, inner_subset: Subset):
"""One bar that shows what fraction (in percent) of the data is nonzero, with error bars"""
color = inner_subset.get('color', "mediumblue")
xlet_x = x + inner_subset['x']
y = round(100 * ((xlet_data != 0).sum() / len(xlet_data)), 3)
# Error bar calculation
level1 = xlet_data.index.get_level_values(0)
level2 = xlet_data.index.get_level_values(1)
nonzero_counts_level1 = xlet_data.groupby(level1).apply(lambda group: (group != 0).sum())
all_nonzero = (nonzero_counts_level1 == level2.nunique()).sum()
any_nonzero = (nonzero_counts_level1 >= 1).sum()
upper_error = round(100 * any_nonzero / level1.nunique(), 3)
lower_error = round(100 * all_nonzero / level1.nunique(), 3)
yerr = [[y - lower_error], [upper_error - y]]
ctx.ax.bar(x=xlet_x, height=y, yerr=yerr, width=0.8, label="", color=color, error_kw = {"elinewidth": 0.7, "capsize": 1, "capthick": 0.5})
def add_zerofractionbarplotlet(ctx: PlotContext, x: float, xlet_data: tg.Any, inner_subset: Subset):
"""One bar that shows what fraction (in percent) of the data is zero"""
color = inner_subset.get('color', "mediumblue")
xlet_x = x + inner_subset['x']
y = 100 * ((xlet_data == 0).sum() / len(xlet_data))
ctx.ax.bar(x=xlet_x, height=y, width=0.8, label="", color=color)
def add_zerofractionbarplotlet_with_errorbars(ctx: PlotContext, x: float, xlet_data: tg.Any, inner_subset: Subset):
"""One bar that shows what fraction (in percent) of the data is zero, with error bars"""
color = inner_subset.get('color', "mediumblue")
xlet_x = x + inner_subset['x']
y = round(100 * ((xlet_data == 0).sum() / len(xlet_data)), 3)
# Error bar calculation
level1 = xlet_data.index.get_level_values(0)
level2 = xlet_data.index.get_level_values(1)
zero_counts_level1 = xlet_data.groupby(level1).apply(lambda group: (group == 0).sum())
all_zero = (zero_counts_level1 == level2.nunique()).sum()
any_zero = (zero_counts_level1 >= 1).sum()
upper_error = round(100 * any_zero / level1.nunique(), 3)
lower_error = round(100 * all_zero / level1.nunique(), 3)
yerr = [[y - lower_error], [upper_error - y]]
ctx.ax.bar(x=xlet_x, yerr=yerr, height=y, width=0.8, label="", color=color, error_kw = {"elinewidth": 0.7, "capsize": 1, "capthick": 0.5})
def plot_boxplots(ctx: PlotContext, which: str, *, ymax=None):
"""Make a plot with one boxplot for each subset."""
ctx.again_for(f"boxplots_{which}")
ctx.ax.set_ylim(bottom=0, top=ymax)
ctx.ax.set_ylabel(which)
ctx.ax.grid(axis='y', linewidth=0.1)
for descriptor in ctx.subsets:
vals = ctx.df.loc[descriptor.rows_(ctx.df), which]
add_boxplot(ctx, vals, descriptor)
ctx.savefig()
def add_boxplot(ctx, vals, descr):
"""Insert a single boxplot into a larger plot."""
ctx.ax.boxplot(
[vals],
notch=False, whis=(10, 90),
positions=[descr['x']], labels=[descr['label']],
widths=0.7, capwidths=0.2,
showfliers=False, showmeans=True,
patch_artist=True, boxprops=dict(facecolor="yellow"),
medianprops=dict(color='black'),
meanprops=dict(marker="o", markerfacecolor="mediumblue", markeredgecolor="mediumblue"))
# ----- add "n=123" at the bottom:
ctx.ax.text(descr.x, 0, "n=%d" % len(vals),
verticalalignment='bottom', horizontalalignment='center', color="mediumblue",
fontsize=7)
# ----- add error bar for the mean:
mymean = vals.mean()
se = vals.std() / math.sqrt(len(vals)) # standard error of the mean
plt.vlines(descr.x + 0.1, mymean - se, mymean + se,
colors='red', linestyles='solid', linewidth=0.7)
def plot_lowess(x: pd.Series, xlabel: str, y: pd.Series, ylabel: str,
outputdir: str, name_suffix: str, *,
frac=0.67, show=True, xmax=None, ymax=None):
"""Plot a scatter plot plus a local linear regression line."""
# ----- compute lowess line:
import statsmodels.nonparametric.smoothers_lowess as sml
delta = 0.01 * (x.max() - x.min())
line_xy = sml.lowess(y.to_numpy(), x.to_numpy(), frac=frac, delta=delta,
is_sorted=False)
# ----- plot labeling:
plt.figure()
plt.xlim(left=0, right=xmax)
plt.xlabel(xlabel)
plt.ylim(bottom=0, top=ymax)
plt.ylabel(ylabel)
plt.grid(axis='both', linewidth=0.1)
# ----- plot points:
if show:
plt.scatter(x, y, s=2, c="darkred")
# ----- plot lowess line:
# print(line_xy)
plt.plot(line_xy[:, 0], line_xy[:, 1], )
# ----- save:
plt.savefig(plotfilename(outputdir, name_suffix=name_suffix))
def funcname(levels_up: int) -> str:
"""The name of the function levels_up levels further up on the stack"""
return traceback.extract_stack(limit=levels_up+1)[0].name
def plotfilename(outputdir: str, name_suffix="", nesting=0) -> str:
"""Filename derived from function name nesting+2 stackframes up."""
if name_suffix:
name_suffix = "_" + name_suffix
return "%s/%s%s.pdf" % (outputdir, funcname(2+nesting).replace('plot_', ''), name_suffix)