Skip to content

Commit 9c1f280

Browse files
author
martinkilbinger
committed
plot_data_1d: more flexible ticks; optional second x-axis
1 parent b44994d commit 9c1f280

2 files changed

Lines changed: 89 additions & 34 deletions

File tree

cs_util/plots.py

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import matplotlib
1212
import matplotlib.pylab as plt
13+
import matplotlib.ticker as ticker
1314
import numpy as np
1415

1516

@@ -203,7 +204,7 @@ def plot_data_1d(
203204
xlabel,
204205
ylabel,
205206
out_path=None,
206-
create_figure=True,
207+
ax=None,
207208
xlog=False,
208209
ylog=False,
209210
log=False,
@@ -217,6 +218,9 @@ def plot_data_1d(
217218
ylim=None,
218219
shift_x=False,
219220
close_fig=True,
221+
second_x_axis=None,
222+
second_x_label=None,
223+
second_x_every=1,
220224
):
221225
"""Plot Data 1D.
222226
@@ -230,8 +234,8 @@ def plot_data_1d(
230234
title and labels
231235
out_path : string, optional
232236
output file path, default is ``None``
233-
create_figure : bool, optional
234-
create figure if ``True`` (default)
237+
ax : matplotlib.axes, optional
238+
use this axis object if given; it not (default) create a new figure
235239
xlog, ylog : bool, optional, default is ``False``
236240
logscale on x, y if True
237241
labels : list, optional, default is ``None``
@@ -254,6 +258,12 @@ def plot_data_1d(
254258
shift datasets by small amount along x if ``True``; default is ``False``
255259
close_fig : bool, optional
256260
closes figure if True (default)
261+
second_x_axis : array of float, optional, default is ``None``
262+
values for second x-axis on top, if not None
263+
second_x_label : string, optional, default is ``None``
264+
label for second x-axis on top
265+
second_x_every: int, optional
266+
plot only one in every `every` point on second x-axis; default is 1
257267
258268
"""
259269
if labels is None:
@@ -273,8 +283,10 @@ def plot_data_1d(
273283
if markers is None:
274284
markers = ["o"] * len(x)
275285

276-
if create_figure:
277-
figure(figsize=(10, 10))
286+
if ax is None:
287+
fig, ax = plt.subplots(figsize=(10, 10))
288+
else:
289+
fig = ax.figure
278290

279291
for idx in range(len(x)):
280292
this_x = x[idx]
@@ -284,15 +296,15 @@ def plot_data_1d(
284296
else:
285297
raise ValueError("shift_x without log not implemented yet")
286298
if np.isnan(yerr[idx]).all():
287-
eb = plt.plot(
299+
eb = ax.plot(
288300
this_x,
289301
y[idx],
290302
label=labels[idx],
291303
color=colors[idx],
292304
linestyle=linestyles[idx],
293305
)
294306
else:
295-
eb = plt.errorbar(
307+
eb = ax.errorbar(
296308
this_x,
297309
y[idx],
298310
yerr=yerr[idx],
@@ -305,40 +317,83 @@ def plot_data_1d(
305317
)
306318
eb[-1][0].set_linestyle(eb_linestyles[idx])
307319

308-
plt.axhline(color="k", linestyle="dashed", linewidth=linewidths[0] / 2)
320+
ax.axhline(color="k", linestyle="dashed", linewidth=linewidths[0] / 2)
309321

310322
if xlog:
311-
plt.xscale("log")
312-
plt.xticks(
313-
[0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500],
314-
labels=[
315-
"0.1",
316-
"0.2",
317-
"0.5",
318-
"1",
319-
"2",
320-
"5",
321-
"10",
322-
"20",
323-
"50",
324-
"100",
325-
"200",
326-
"500",
327-
],
328-
)
323+
ax.set_xscale("log")
324+
#plt.gca().xaxis.set_major_locator(ticker.LogLocator(base=10, subs=(1,2,5), numticks=15)
325+
ax.xaxis.set_major_locator(ticker.LogLocator(base=10, subs=(1,2,5), numticks=15))
326+
ax.xaxis.set_major_formatter(ticker.LogFormatter(labelOnlyBase=False))
327+
329328
if ylog:
330-
plt.yscale("log")
329+
ax.set_yscale("log")
331330

332331
if xlim:
333-
plt.xlim(xlim)
332+
ax.set_xlim(xlim)
334333
if ylim:
335-
plt.ylim(ylim)
334+
ax.set_ylim(ylim)
336335

337-
plt.title(title)
338-
plt.xlabel(xlabel)
339-
plt.ylabel(ylabel)
336+
ax.set_title(title)
337+
ax.set_xlabel(xlabel)
338+
ax.set_ylabel(ylabel)
340339
if do_legend:
341-
plt.legend()
340+
ax.legend()
341+
342+
# Add second x-axis on top if requested
343+
if second_x_axis is not None:
344+
ax2 = ax.twiny()
345+
346+
# Set second x-axis with provided values
347+
ax2.set_xlim(ax.get_xlim())
348+
if xlog:
349+
ax2.set_xscale('log')
350+
351+
if second_x_label is not None:
352+
ax2.set_xlabel(second_x_label)
353+
354+
# Create tick positions that correspond to the main x-axis values
355+
# Map from main x values to second x-axis values
356+
main_x_values = x[0] if len(x) > 0 else []
357+
if len(main_x_values) > 0 and len(second_x_axis) == len(main_x_values):
358+
# Find values within the plot range
359+
xlim_current = ax.get_xlim()
360+
tick_positions = []
361+
tick_labels = []
362+
363+
for i, main_x_val in enumerate(main_x_values):
364+
if (
365+
xlim_current[0] <= main_x_val <= xlim_current[1]
366+
):
367+
tick_positions.append(main_x_val)
368+
if i % second_x_every == 0:
369+
my_label = f'{second_x_axis[i]:.2g}'
370+
else:
371+
my_label = ""
372+
tick_labels.append(my_label)
373+
374+
if tick_positions:
375+
ax2.set_xticks(tick_positions)
376+
ax2.set_xticklabels(tick_labels)
377+
ax2.tick_params(axis='x', labelrotation=45)
342378

343379
if out_path:
344380
savefig(out_path, close_fig=close_fig)
381+
382+
383+
def log_ticks(x):
384+
385+
x = np.asarray(x)
386+
xmin, xmax = np.nanmin(x), np.nanmax(x)
387+
388+
# figure out the exponent range
389+
exp_min = int(np.floor(np.log10(xmin)))
390+
exp_max = int(np.ceil(np.log10(xmax)))
391+
392+
ticks = []
393+
for exp in range(exp_min, exp_max + 1):
394+
for base in [1, 2, 5]:
395+
val = base * 10**exp
396+
if xmin <= val <= xmax:
397+
ticks.append(val)
398+
399+
return np.array(ticks)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "cs_util"
3-
version = "0.1.5"
3+
version = "0.1.7"
44
description = "Utility library for CosmoStat"
55
authors = [
66
{ name = "Martin Kilbinger", email = "martin.kilbinger@cea.fr" },

0 commit comments

Comments
 (0)