1010
1111import matplotlib
1212import matplotlib .pylab as plt
13+ import matplotlib .ticker as ticker
1314import numpy as np
1415
1516
@@ -203,7 +204,7 @@ def plot_data_1d(
203204 xlabel ,
204205 ylabel ,
205206 out_path = None ,
206- create_figure = True ,
207+ ax = None ,
207208 xlog = False ,
208209 ylog = False ,
209210 log = False ,
@@ -217,6 +218,9 @@ def plot_data_1d(
217218 ylim = None ,
218219 shift_x = False ,
219220 close_fig = True ,
221+ second_x_axis = None ,
222+ second_x_label = None ,
223+ second_x_every = 1 ,
220224):
221225 """Plot Data 1D.
222226
@@ -230,8 +234,8 @@ def plot_data_1d(
230234 title and labels
231235 out_path : string, optional
232236 output file path, default is ``None``
233- create_figure : bool , optional
234- create figure if ``True`` (default)
237+ ax : matplotlib.axes , optional
238+ use this axis object if given; it not (default) create a new figure
235239 xlog, ylog : bool, optional, default is ``False``
236240 logscale on x, y if True
237241 labels : list, optional, default is ``None``
@@ -254,6 +258,12 @@ def plot_data_1d(
254258 shift datasets by small amount along x if ``True``; default is ``False``
255259 close_fig : bool, optional
256260 closes figure if True (default)
261+ second_x_axis : array of float, optional, default is ``None``
262+ values for second x-axis on top, if not None
263+ second_x_label : string, optional, default is ``None``
264+ label for second x-axis on top
265+ second_x_every: int, optional
266+ plot only one in every `every` point on second x-axis; default is 1
257267
258268 """
259269 if labels is None :
@@ -273,8 +283,10 @@ def plot_data_1d(
273283 if markers is None :
274284 markers = ["o" ] * len (x )
275285
276- if create_figure :
277- figure (figsize = (10 , 10 ))
286+ if ax is None :
287+ fig , ax = plt .subplots (figsize = (10 , 10 ))
288+ else :
289+ fig = ax .figure
278290
279291 for idx in range (len (x )):
280292 this_x = x [idx ]
@@ -284,15 +296,15 @@ def plot_data_1d(
284296 else :
285297 raise ValueError ("shift_x without log not implemented yet" )
286298 if np .isnan (yerr [idx ]).all ():
287- eb = plt .plot (
299+ eb = ax .plot (
288300 this_x ,
289301 y [idx ],
290302 label = labels [idx ],
291303 color = colors [idx ],
292304 linestyle = linestyles [idx ],
293305 )
294306 else :
295- eb = plt .errorbar (
307+ eb = ax .errorbar (
296308 this_x ,
297309 y [idx ],
298310 yerr = yerr [idx ],
@@ -305,40 +317,83 @@ def plot_data_1d(
305317 )
306318 eb [- 1 ][0 ].set_linestyle (eb_linestyles [idx ])
307319
308- plt .axhline (color = "k" , linestyle = "dashed" , linewidth = linewidths [0 ] / 2 )
320+ ax .axhline (color = "k" , linestyle = "dashed" , linewidth = linewidths [0 ] / 2 )
309321
310322 if xlog :
311- plt .xscale ("log" )
312- plt .xticks (
313- [0.1 , 0.2 , 0.5 , 1 , 2 , 5 , 10 , 20 , 50 , 100 , 200 , 500 ],
314- labels = [
315- "0.1" ,
316- "0.2" ,
317- "0.5" ,
318- "1" ,
319- "2" ,
320- "5" ,
321- "10" ,
322- "20" ,
323- "50" ,
324- "100" ,
325- "200" ,
326- "500" ,
327- ],
328- )
323+ ax .set_xscale ("log" )
324+ #plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10, subs=(1,2,5), numticks=15)
325+ ax .xaxis .set_major_locator (ticker .LogLocator (base = 10 , subs = (1 ,2 ,5 ), numticks = 15 ))
326+ ax .xaxis .set_major_formatter (ticker .LogFormatter (labelOnlyBase = False ))
327+
329328 if ylog :
330- plt . yscale ("log" )
329+ ax . set_yscale ("log" )
331330
332331 if xlim :
333- plt . xlim (xlim )
332+ ax . set_xlim (xlim )
334333 if ylim :
335- plt . ylim (ylim )
334+ ax . set_ylim (ylim )
336335
337- plt . title (title )
338- plt . xlabel (xlabel )
339- plt . ylabel (ylabel )
336+ ax . set_title (title )
337+ ax . set_xlabel (xlabel )
338+ ax . set_ylabel (ylabel )
340339 if do_legend :
341- plt .legend ()
340+ ax .legend ()
341+
342+ # Add second x-axis on top if requested
343+ if second_x_axis is not None :
344+ ax2 = ax .twiny ()
345+
346+ # Set second x-axis with provided values
347+ ax2 .set_xlim (ax .get_xlim ())
348+ if xlog :
349+ ax2 .set_xscale ('log' )
350+
351+ if second_x_label is not None :
352+ ax2 .set_xlabel (second_x_label )
353+
354+ # Create tick positions that correspond to the main x-axis values
355+ # Map from main x values to second x-axis values
356+ main_x_values = x [0 ] if len (x ) > 0 else []
357+ if len (main_x_values ) > 0 and len (second_x_axis ) == len (main_x_values ):
358+ # Find values within the plot range
359+ xlim_current = ax .get_xlim ()
360+ tick_positions = []
361+ tick_labels = []
362+
363+ for i , main_x_val in enumerate (main_x_values ):
364+ if (
365+ xlim_current [0 ] <= main_x_val <= xlim_current [1 ]
366+ ):
367+ tick_positions .append (main_x_val )
368+ if i % second_x_every == 0 :
369+ my_label = f'{ second_x_axis [i ]:.2g} '
370+ else :
371+ my_label = ""
372+ tick_labels .append (my_label )
373+
374+ if tick_positions :
375+ ax2 .set_xticks (tick_positions )
376+ ax2 .set_xticklabels (tick_labels )
377+ ax2 .tick_params (axis = 'x' , labelrotation = 45 )
342378
343379 if out_path :
344380 savefig (out_path , close_fig = close_fig )
381+
382+
383+ def log_ticks (x ):
384+
385+ x = np .asarray (x )
386+ xmin , xmax = np .nanmin (x ), np .nanmax (x )
387+
388+ # figure out the exponent range
389+ exp_min = int (np .floor (np .log10 (xmin )))
390+ exp_max = int (np .ceil (np .log10 (xmax )))
391+
392+ ticks = []
393+ for exp in range (exp_min , exp_max + 1 ):
394+ for base in [1 , 2 , 5 ]:
395+ val = base * 10 ** exp
396+ if xmin <= val <= xmax :
397+ ticks .append (val )
398+
399+ return np .array (ticks )
0 commit comments