Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions python/lsst/analysis/tools/actions/plot/histPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ class HistPlot(PlotAction):
"A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
default="rubin",
)
panelsPerRow = Field[int](
doc="Maximum number of histogram panels to place in each row. Set to 1 to stack panels vertically.",
default=2,
)

def validate(self):
super().validate()
if self.panelsPerRow < 1:
msg = "panelsPerRow must be at least 1."
raise FieldValidationError(self.__class__.panelsPerRow, self, msg)

def getInputSchema(self) -> KeyedDataSchema:
for panel in self.panels: # type: ignore
Expand Down Expand Up @@ -292,14 +302,14 @@ def makePlot(
nth_panel = len(self.panels)
nth_col = ncols
nth_row = nrows - 1
label_font_size = max(6, 10 - nrows)
label_font_size = max(6, 10 - ((len(self.panels) + 1) // 2))
for panel, ax in zip(self.panels, axs):
nth_panel -= 1
nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
nth_col -= 1
# Set font size for legend based on number of panels being plotted.
legend_font_size = max(4, int(7 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
legend_font_size = max(4, int(7 - len(self.panels[panel].hists) / 2 - len(self.panels) / 4))
nums, meds, mads, stats_dict = self._makePanel(
data,
panel,
Expand Down Expand Up @@ -348,7 +358,7 @@ def _makeAxes(self, fig):
if num_panels <= 1:
ncols = 1
else:
ncols = 2
ncols = min(self.panelsPerRow, num_panels)
nrows = int(np.ceil(num_panels / ncols))

gs = GridSpec(nrows, ncols, left=0.12, right=0.88, bottom=0.1, top=0.88, wspace=0.41, hspace=0.45)
Expand Down Expand Up @@ -477,7 +487,7 @@ def _makePanel(
nHist += 1

if nHist > 0:
ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False, borderaxespad=1.1)
ax.set_xlim(panel_range)
# The following accommodates spacing for ranges with large numbers
# but small-ish dynamic range (example use case: RA 300-301).
Expand Down Expand Up @@ -624,7 +634,9 @@ def _addReferenceLines(self, ax, panel, panel_range, meds, legend_font_size=7):
if ax2.get_ylim()[1] < 1.05 * y_max:
ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
ax2.set_ylim(ax.get_ylim())
ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
ax2.legend(
fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False, borderaxespad=1.1
)

return ax

Expand Down Expand Up @@ -652,14 +664,20 @@ def _addStatisticsPanel(
# set up new legend handles and labels
legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)

def format_stat(x):
if isinstance(x, int):
return str(x) if abs(x) <= 9999 else f"{x:.2e}"
else:
return f"{x:.2f}" if abs(x) < 999 else f"{x:.2e}"

legend_labels = (
([""] * (len(handles) + 1))
+ [stats_dict["statLabels"][0]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat1"]]
+ [format_stat(x) for x in stats_dict["stat1"]]
+ [stats_dict["statLabels"][1]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat2"]]
+ [format_stat(x) for x in stats_dict["stat2"]]
+ [stats_dict["statLabels"][2]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat3"]]
+ [format_stat(x) for x in stats_dict["stat3"]]
)
# Replace "e+0" with "e" and "e-0" with "e-" to save space.
legend_labels = [label.replace("e+0", "e") for label in legend_labels]
Expand Down
1 change: 1 addition & 0 deletions python/lsst/analysis/tools/atools/skyObject.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def setDefaults(self):
self.process.buildActions.hist_gaap1p0_sn = CalcSn(fluxType="{band}_gaap1p0Flux")

self.produce.plot = HistPlot()
self.produce.plot.panelsPerRow = 1

self.produce.plot.panels["panel_flux"] = HistPanel()
self.produce.plot.panels["panel_flux"].label = "Flux (nJy)"
Expand Down
74 changes: 74 additions & 0 deletions tests/test_histPlot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This file is part of analysis_tools.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.


import unittest

import matplotlib
import matplotlib.pyplot as plt

import lsst.utils.tests
from lsst.analysis.tools.actions.plot.histPlot import HistPanel, HistPlot

matplotlib.use("Agg")


class HistPlotLayoutTestCase(lsst.utils.tests.TestCase):
def setUp(self):
self.plot = HistPlot()
for panel_name in ("panel_a", "panel_b", "panel_c"):
self.plot.panels[panel_name] = HistPanel()
self.plot.panels[panel_name].hists = {f"{panel_name}_hist": panel_name}

def test_default_layout_uses_two_columns(self):
fig = plt.figure()

axes, ncols, nrows = self.plot._makeAxes(fig)

self.assertEqual(ncols, 2)
self.assertEqual(nrows, 2)
self.assertEqual(len(axes), 3)

def test_panels_per_row_one_stacks_vertically(self):
self.plot.panelsPerRow = 1
fig = plt.figure()

axes, ncols, nrows = self.plot._makeAxes(fig)

self.assertEqual(ncols, 1)
self.assertEqual(nrows, 3)
self.assertEqual(len(axes), 3)
self.assertTrue(all(ax.get_position().x0 == axes[0].get_position().x0 for ax in axes[1:]))
self.assertGreater(axes[0].get_position().y0, axes[1].get_position().y0)
self.assertGreater(axes[1].get_position().y0, axes[2].get_position().y0)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()
Loading