@@ -124,30 +124,76 @@ def plot_histogram(self):
124124 return out
125125
126126 def _plot_violin (
127- self , df_long , stats_values , metrics , title_map , out_prefix , include_strip = False
127+ self ,
128+ df_long ,
129+ stats_values ,
130+ metrics ,
131+ title_map ,
132+ out_prefix ,
133+ include_strip = False ,
134+ palette_index = 0 ,
128135 ):
129136 type_palette = self .config .color_palette
137+ # Build a palette for the present types, optionally rotated by palette_index.
138+ present_types = list (df_long ["type" ].unique ())
139+ ordered_present_types = [
140+ t for t in self .__get_reduced_logic_order () if t in present_types
141+ ] or present_types
142+
143+ plot_palette = Utils .rotate_palette_map (
144+ type_palette , ordered_present_types , index = palette_index
145+ )
146+
147+ # Special case: if only one type is present, allow selecting a different
148+ # color by index across the full configured palette, so different
149+ # violin plots can visually alternate even for the same type.
150+ if (
151+ len (ordered_present_types ) == 1
152+ and isinstance (type_palette , dict )
153+ and isinstance (palette_index , int )
154+ ):
155+ only_type = ordered_present_types [0 ]
156+ full_order = self .config .logic_order or list (type_palette .keys ())
157+ if not full_order :
158+ full_colors = list (type_palette .values ())
159+ else :
160+ full_colors = [type_palette .get (t , "#808080" ) for t in full_order ]
161+ if len (full_colors ) > 0 :
162+ idx = palette_index % len (full_colors )
163+ plot_palette = {only_type : full_colors [idx ]}
164+
130165 number_of_types = df_long ["type" ].unique ().size
131166
132- fig , axes = plt .subplots (
133- nrows = 2 if len (metrics ) > 3 else 1 ,
134- ncols = 3 ,
135- figsize = (11 , 8 ) if len (metrics ) > 3 else (11 , 4 ),
136- sharex = False ,
137- sharey = False ,
138- )
139- axes = axes .flatten ()
140- plt .subplots_adjust (hspace = 0.05 , wspace = 0.05 )
167+ if len (metrics ) == 3 :
168+ # Create a layout with the third plot centered by spanning both columns
169+ fig = plt .figure (figsize = (8 , 8 ))
170+ gs = fig .add_gridspec (nrows = 2 , ncols = 2 )
171+ axes_list = [
172+ fig .add_subplot (gs [0 , 0 ]),
173+ fig .add_subplot (gs [0 , 1 ]),
174+ fig .add_subplot (gs [1 , :]),
175+ ]
176+ fig .subplots_adjust (hspace = 0.05 , wspace = 0.05 )
177+ else :
178+ fig , axes = plt .subplots (
179+ nrows = 3 if len (metrics ) > 3 else 2 ,
180+ ncols = 2 ,
181+ figsize = (8 , 11 ) if len (metrics ) > 3 else (8 , 7 ),
182+ sharex = False ,
183+ sharey = False ,
184+ )
185+ axes_list = axes .flatten ().tolist ()
186+ plt .subplots_adjust (hspace = 0.05 , wspace = 0.05 )
141187 i = 1
142188
143- for ax , agg in zip (axes , metrics ):
189+ for ax , agg in zip (axes_list , metrics ):
144190 y_max = df_long [df_long ["aggregation" ] == agg ]["value" ].max () * 1.8
145191 violin = sns .violinplot (
146192 x = "type" ,
147193 y = "value" ,
148194 data = df_long [df_long ["aggregation" ] == agg ],
149195 hue = "type" ,
150- palette = type_palette ,
196+ palette = plot_palette ,
151197 bw_method = 0.5 ,
152198 edgecolor = "black" ,
153199 linewidth = 1 ,
@@ -160,7 +206,7 @@ def _plot_violin(
160206 x = "type" ,
161207 y = "value" ,
162208 hue = "type" ,
163- palette = type_palette ,
209+ palette = plot_palette ,
164210 data = df_long [df_long ["aggregation" ] == agg ],
165211 width = 0.12 ,
166212 showcaps = True ,
@@ -181,7 +227,7 @@ def _plot_violin(
181227 hue = "type" ,
182228 data = df_long [df_long ["aggregation" ] == agg ],
183229 alpha = 0.3 ,
184- palette = type_palette ,
230+ palette = plot_palette ,
185231 size = 3 ,
186232 marker = "d" ,
187233 edgecolor = "black" ,
@@ -241,16 +287,24 @@ def _plot_violin(
241287 minor_tick_interval = major_interval / 5
242288 ax .yaxis .set_minor_locator (ticker .MultipleLocator (minor_tick_interval ))
243289
244- for j in range (len (metrics ), len (axes )):
245- axes [j ].set_visible (False )
290+ for j in range (len (metrics ), len (axes_list )):
291+ axes_list [j ].set_visible (False )
246292
247293 fig .tight_layout ()
294+ # If exactly 3 metrics, keep the third axis centered without stretching.
295+ if len (metrics ) == 3 :
296+ ref = axes_list [0 ].get_position ()
297+ bottom = axes_list [2 ].get_position ()
298+ w = ref .width
299+ h = ref .height
300+ x = 0.5 - w / 2
301+ axes_list [2 ].set_position ([x , bottom .y0 , w , h ])
248302 out = self .__get_file_name (out_prefix )
249303 plt .savefig (out )
250304 plt .close ()
251305 return out
252306
253- def plot_violin_engcompl (self , include_strip = False ):
307+ def plot_violin_engcompl (self , include_strip = False , palette_index = 0 ):
254308 df_filtered = self .data [self .data ["translation" ] == "self" ]
255309 metrics = df_filtered .filter (like = ".agg." ).columns .tolist ()
256310 metrics = metrics + ["stats.asth" , "stats.entropy.lops_tops" ]
@@ -274,11 +328,13 @@ def plot_violin_engcompl(self, include_strip=False):
274328 self .title_map ,
275329 "viol_engcompl" ,
276330 include_strip ,
331+ palette_index ,
277332 )
278333
279- def plot_violin_reqtext (self , include_strip = False ):
334+ def plot_violin_reqtext (self , include_strip = False , palette_index = 0 ):
280335 df_filtered = self .data [self .data ["translation" ] == "self" ]
281- metrics = df_filtered .filter (like = ".req_" ).columns .tolist ()
336+ # metrics = df_filtered.filter(like=".req_").columns.tolist()
337+ metrics = ["stats.req_word_count" , "stats.req_sentence_count" ]
282338 df_long = pd .melt (
283339 df_filtered ,
284340 id_vars = ["id" , "type" ],
@@ -293,14 +349,23 @@ def plot_violin_reqtext(self, include_strip=False):
293349 .reset_index ()
294350 )
295351 return self ._plot_violin (
296- df_long , stats_values , metrics , self .title_map , "viol_req" , include_strip
352+ df_long ,
353+ stats_values ,
354+ metrics ,
355+ self .title_map ,
356+ "viol_req" ,
357+ include_strip ,
358+ palette_index ,
297359 )
298360
299361 def plot_pairplot (self ):
300362 type_palette = self .config .color_palette
301363 df = self .data [self .data ["translation" ] == "self" ]
302364 metrics = df .filter (like = ".agg." ).columns .tolist ()
303365 df_pairplot = df [metrics + ["type" , "stats.asth" , "stats.entropy.lops_tops" ]]
366+ req_metrics = df .filter (like = ".req_" ).columns .tolist ()
367+ if req_metrics :
368+ df_pairplot = pd .concat ([df_pairplot , df [req_metrics ]], axis = 1 )
304369
305370 unique_types = df_pairplot ["type" ].nunique ()
306371 markers = ["o" , "s" , "D" , "^" , "v" , "P" ][:unique_types ]
0 commit comments