Skip to content

Commit 916eed9

Browse files
committed
optimize violine plots
1 parent 656aafd commit 916eed9

3 files changed

Lines changed: 125 additions & 21 deletions

File tree

tlparser/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def visualize_data(file, latest, selfonly, plot):
226226
plot_methods = {
227227
"hist": viz.plot_histogram,
228228
"viol": viz.plot_violin_engcompl,
229-
"viol_req": viz.plot_violin_reqtext,
229+
"viol_req": viz.plot_violin_reqtext(palette_index=2),
230230
"pair": viz.plot_pairplot,
231231
"chord": viz.plot_chord,
232232
"sankey": viz.plot_sankey,

tlparser/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,45 @@ def get_latest_excel(folder):
249249
file = os.path.join(folder, latest_file)
250250
return file
251251

252+
@staticmethod
253+
def rotate_palette_map(
254+
palette,
255+
types,
256+
*,
257+
index: int = 0,
258+
default_color: str = "#808080",
259+
):
260+
"""Return a dict mapping of types to colors, rotated by index.
261+
262+
- palette: mapping type -> color (preferred). If not a dict, attempts to
263+
treat it as a sequence of colors aligned to the provided types.
264+
- types: ordered sequence of type labels to include.
265+
- index: rotation amount (wraps; negative allowed). 0 keeps original.
266+
- default_color: color used when a type is missing from palette.
267+
"""
268+
types = list(types)
269+
if not types:
270+
return {}
271+
272+
# Build color list aligned to the given types
273+
if isinstance(palette, dict):
274+
colors = [palette.get(t, default_color) for t in types]
275+
else:
276+
# Fallback: assume palette is a sequence of colors
277+
palette_seq = list(palette or [])
278+
if not palette_seq:
279+
palette_seq = [default_color] * len(types)
280+
# Repeat or trim to match types length
281+
times = (len(types) + len(palette_seq) - 1) // len(palette_seq)
282+
colors = (palette_seq * times)[: len(types)]
283+
284+
if len(colors) > 0 and isinstance(index, int):
285+
k = index % len(colors)
286+
if k:
287+
colors = colors[k:] + colors[:k]
288+
289+
return dict(zip(types, colors))
290+
252291
@staticmethod
253292
def lighten_color(hex_color, opacity=0.6):
254293
hex_color = hex_color.lstrip("#")

tlparser/viz.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -124,30 +124,76 @@ def plot_histogram(self):
124124
return out
125125

126126
def _plot_violin(
127-
self, df_long, stats_values, metrics, title_map, out_prefix, include_strip=False
127+
self,
128+
df_long,
129+
stats_values,
130+
metrics,
131+
title_map,
132+
out_prefix,
133+
include_strip=False,
134+
palette_index=0,
128135
):
129136
type_palette = self.config.color_palette
137+
# Build a palette for the present types, optionally rotated by palette_index.
138+
present_types = list(df_long["type"].unique())
139+
ordered_present_types = [
140+
t for t in self.__get_reduced_logic_order() if t in present_types
141+
] or present_types
142+
143+
plot_palette = Utils.rotate_palette_map(
144+
type_palette, ordered_present_types, index=palette_index
145+
)
146+
147+
# Special case: if only one type is present, allow selecting a different
148+
# color by index across the full configured palette, so different
149+
# violin plots can visually alternate even for the same type.
150+
if (
151+
len(ordered_present_types) == 1
152+
and isinstance(type_palette, dict)
153+
and isinstance(palette_index, int)
154+
):
155+
only_type = ordered_present_types[0]
156+
full_order = self.config.logic_order or list(type_palette.keys())
157+
if not full_order:
158+
full_colors = list(type_palette.values())
159+
else:
160+
full_colors = [type_palette.get(t, "#808080") for t in full_order]
161+
if len(full_colors) > 0:
162+
idx = palette_index % len(full_colors)
163+
plot_palette = {only_type: full_colors[idx]}
164+
130165
number_of_types = df_long["type"].unique().size
131166

132-
fig, axes = plt.subplots(
133-
nrows=2 if len(metrics) > 3 else 1,
134-
ncols=3,
135-
figsize=(11, 8) if len(metrics) > 3 else (11, 4),
136-
sharex=False,
137-
sharey=False,
138-
)
139-
axes = axes.flatten()
140-
plt.subplots_adjust(hspace=0.05, wspace=0.05)
167+
if len(metrics) == 3:
168+
# Create a layout with the third plot centered by spanning both columns
169+
fig = plt.figure(figsize=(8, 8))
170+
gs = fig.add_gridspec(nrows=2, ncols=2)
171+
axes_list = [
172+
fig.add_subplot(gs[0, 0]),
173+
fig.add_subplot(gs[0, 1]),
174+
fig.add_subplot(gs[1, :]),
175+
]
176+
fig.subplots_adjust(hspace=0.05, wspace=0.05)
177+
else:
178+
fig, axes = plt.subplots(
179+
nrows=3 if len(metrics) > 3 else 2,
180+
ncols=2,
181+
figsize=(8, 11) if len(metrics) > 3 else (8, 7),
182+
sharex=False,
183+
sharey=False,
184+
)
185+
axes_list = axes.flatten().tolist()
186+
plt.subplots_adjust(hspace=0.05, wspace=0.05)
141187
i = 1
142188

143-
for ax, agg in zip(axes, metrics):
189+
for ax, agg in zip(axes_list, metrics):
144190
y_max = df_long[df_long["aggregation"] == agg]["value"].max() * 1.8
145191
violin = sns.violinplot(
146192
x="type",
147193
y="value",
148194
data=df_long[df_long["aggregation"] == agg],
149195
hue="type",
150-
palette=type_palette,
196+
palette=plot_palette,
151197
bw_method=0.5,
152198
edgecolor="black",
153199
linewidth=1,
@@ -160,7 +206,7 @@ def _plot_violin(
160206
x="type",
161207
y="value",
162208
hue="type",
163-
palette=type_palette,
209+
palette=plot_palette,
164210
data=df_long[df_long["aggregation"] == agg],
165211
width=0.12,
166212
showcaps=True,
@@ -181,7 +227,7 @@ def _plot_violin(
181227
hue="type",
182228
data=df_long[df_long["aggregation"] == agg],
183229
alpha=0.3,
184-
palette=type_palette,
230+
palette=plot_palette,
185231
size=3,
186232
marker="d",
187233
edgecolor="black",
@@ -241,16 +287,24 @@ def _plot_violin(
241287
minor_tick_interval = major_interval / 5
242288
ax.yaxis.set_minor_locator(ticker.MultipleLocator(minor_tick_interval))
243289

244-
for j in range(len(metrics), len(axes)):
245-
axes[j].set_visible(False)
290+
for j in range(len(metrics), len(axes_list)):
291+
axes_list[j].set_visible(False)
246292

247293
fig.tight_layout()
294+
# If exactly 3 metrics, keep the third axis centered without stretching.
295+
if len(metrics) == 3:
296+
ref = axes_list[0].get_position()
297+
bottom = axes_list[2].get_position()
298+
w = ref.width
299+
h = ref.height
300+
x = 0.5 - w / 2
301+
axes_list[2].set_position([x, bottom.y0, w, h])
248302
out = self.__get_file_name(out_prefix)
249303
plt.savefig(out)
250304
plt.close()
251305
return out
252306

253-
def plot_violin_engcompl(self, include_strip=False):
307+
def plot_violin_engcompl(self, include_strip=False, palette_index=0):
254308
df_filtered = self.data[self.data["translation"] == "self"]
255309
metrics = df_filtered.filter(like=".agg.").columns.tolist()
256310
metrics = metrics + ["stats.asth", "stats.entropy.lops_tops"]
@@ -274,11 +328,13 @@ def plot_violin_engcompl(self, include_strip=False):
274328
self.title_map,
275329
"viol_engcompl",
276330
include_strip,
331+
palette_index,
277332
)
278333

279-
def plot_violin_reqtext(self, include_strip=False):
334+
def plot_violin_reqtext(self, include_strip=False, palette_index=0):
280335
df_filtered = self.data[self.data["translation"] == "self"]
281-
metrics = df_filtered.filter(like=".req_").columns.tolist()
336+
# metrics = df_filtered.filter(like=".req_").columns.tolist()
337+
metrics = ["stats.req_word_count", "stats.req_sentence_count"]
282338
df_long = pd.melt(
283339
df_filtered,
284340
id_vars=["id", "type"],
@@ -293,14 +349,23 @@ def plot_violin_reqtext(self, include_strip=False):
293349
.reset_index()
294350
)
295351
return self._plot_violin(
296-
df_long, stats_values, metrics, self.title_map, "viol_req", include_strip
352+
df_long,
353+
stats_values,
354+
metrics,
355+
self.title_map,
356+
"viol_req",
357+
include_strip,
358+
palette_index,
297359
)
298360

299361
def plot_pairplot(self):
300362
type_palette = self.config.color_palette
301363
df = self.data[self.data["translation"] == "self"]
302364
metrics = df.filter(like=".agg.").columns.tolist()
303365
df_pairplot = df[metrics + ["type", "stats.asth", "stats.entropy.lops_tops"]]
366+
req_metrics = df.filter(like=".req_").columns.tolist()
367+
if req_metrics:
368+
df_pairplot = pd.concat([df_pairplot, df[req_metrics]], axis=1)
304369

305370
unique_types = df_pairplot["type"].nunique()
306371
markers = ["o", "s", "D", "^", "v", "P"][:unique_types]

0 commit comments

Comments
 (0)