-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathatmo_graph.py
More file actions
182 lines (169 loc) · 6.96 KB
/
atmo_graph.py
File metadata and controls
182 lines (169 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import numpy as np
# Monkey‐patch for NumPy 2.0 compatibility: define np.Inf as np.inf
np.Inf = np.inf
import matplotlib.pyplot as plt
def plot_histogram_heatmap(
data,
dim,
n_bins=100,
cmap="turbo",
vlim=None,
log_transform=False,
log_offset=1e-6,
xlabel=None,
ylabel=None,
colorbar_label="Counts",
title=None,
axhlines=None,
axvlines=None,
fontsize=12,
figsize=(7, 6)
):
"""
Plots a heatmap of the histogram of field values along a specified dimension.
Parameters:
- data (xarray.DataArray): Input data array.
- dim (str): The dimension along which to compute the histogram (e.g., "level").
- n_bins (int): Number of bins for the histogram (default: 100).
- cmap (str): Colormap for the heatmap (default: "turbo").
- vlim (float or None): Upper limit for the heatmap color scale (default: None).
If you need “no limit”, pass vlim=np.inf.
- log_transform (bool): Whether to apply log transformation to histogram data (default: False).
- log_offset (float): Offset added to histogram counts to avoid log(0) (default: 1e-6).
- xlabel (str): Label for the x-axis (default: None, inferred from data if available).
- ylabel (str): Label for the y-axis (default: None, inferred from dimension name).
- colorbar_label (str): Label for the colorbar (default: "Counts").
- title (str): Title of the plot (default: None).
- axhlines (list of tuples): List of horizontal lines to draw, each as (y, color, linestyle, label).
Example: [(28, "white", "--", "z = 6 km")].
- axvlines (list of tuples): List of vertical lines to draw, each as (x, color, linestyle, label).
Example: [(0.5, "red", "--", "Threshold")].
- fontsize (int): Font size for all text in the plot (default: 12).
- figsize (tuple): Figure size as (width, height) in inches (default: (7, 6)).
"""
# Prepare the dimension values and histogram bins
dim_values = data[dim]
bins = np.linspace(data.min().item(), data.max().item(), n_bins + 1)
# Infer unit and data name if available
unit = getattr(data, "units", "")
data_name = getattr(data, "name", "Data")
# Compute histogram data along the specified dimension
histogram_data = np.array([
np.histogram(data.sel({dim: value}).values.flatten(), bins=bins)[0]
for value in dim_values
])
# Apply log transformation if requested
if log_transform:
histogram_data = np.log10(histogram_data + log_offset)
# Plot heatmap
plt.figure(figsize=figsize)
plt.imshow(
histogram_data,
aspect="auto",
extent=[bins[0], bins[-1], dim_values[0].item(), dim_values[-1].item()],
origin="lower",
cmap=cmap,
vmax=vlim # Pass np.inf here if you want no upper limit
)
# Add horizontal lines if specified
if axhlines:
for y, color, linestyle, label in axhlines:
plt.axhline(y=y, color=color, linestyle=linestyle, label=label)
# Add vertical lines if specified
if axvlines:
for x, color, linestyle, label in axvlines:
plt.axvline(x=x, color=color, linestyle=linestyle, label=label)
# Add labels, title, colorbar, and legend
if title:
plt.title(title, fontsize=fontsize)
plt.xlabel(xlabel if xlabel else f"{data_name} [{unit}]", fontsize=fontsize)
plt.ylabel(ylabel if ylabel else dim, fontsize=fontsize)
plt.colorbar(label=colorbar_label)
plt.legend(fontsize=fontsize)
plt.gca().invert_yaxis()
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.show()
def plot_histogram_heatmap_options(option_name):
"""
Prints the description for a single parameter of plot_histogram_heatmap.
Parameters:
- option_name (str): The name of the parameter to describe.
Example:
describe_plot_histogram_heatmap_option("cmap")
"""
options = {
"data": (
"xarray.DataArray\n"
" The input data array to be visualized. It should be an xarray.DataArray "
"with labeled dimensions and coordinates."
),
"dim": (
"str\n"
" The dimension along which the histogram is computed (e.g., 'Altitude'). "
"This dimension will be represented on the y-axis of the heatmap."
),
"n_bins": (
"int, optional, default=100\n"
" The number of bins to use for the histogram computation."
),
"cmap": (
"str, optional, default='turbo'\n"
" The colormap to use for the heatmap visualization."
),
"vlim": (
"float or None, optional, default=None\n"
" The maximum value for the heatmap color scale. If None, the scale is "
"determined automatically."
),
"log_transform": (
"bool, optional, default=False\n"
" Whether to apply a logarithmic transformation to the histogram data "
"for better visibility of small values."
),
"log_offset": (
"float, optional, default=1e-6\n"
" A small offset added to histogram counts to avoid issues with log(0) "
"when log_transform=True."
),
"xlabel": (
"str or None, optional, default=None\n"
" Label for the x-axis. If None, it will be inferred from the data if available."
),
"ylabel": (
"str or None, optional, default=None\n"
" Label for the y-axis. If None, it will be inferred from the data if available."
),
"colorbar_label": (
"str, optional, default='Counts'\n"
" Label for the colorbar, which represents the histogram counts."
),
"title": (
"str or None, optional, default=None\n"
" Title of the plot. If None, no title is displayed."
),
"axhlines": (
"list of tuples or None, optional, default=None\n"
" A list of horizontal lines to add to the plot. Each tuple should be "
"(y_position, color, linestyle, label)."
),
"axvlines": (
"list of tuples or None, optional, default=None\n"
" A list of vertical lines to add to the plot. Each tuple should be "
"(x_position, color, linestyle, label)."
),
"fontsize": (
"int, optional, default=12\n"
" Font size for axis labels, title, and annotations."
),
"figsize": (
"tuple, optional, default=(7, 6)\n"
" Size of the figure in inches, specified as (width, height)."
),
}
if option_name in options:
print(f"{option_name}:\n {options[option_name]}")
else:
valid = ", ".join(sorted(options.keys()))
print(f"Invalid option name: '{option_name}'.\n"
f"Valid options are:\n {valid}")