diff --git a/python/lsst/analysis/tools/actions/plot/histPlot.py b/python/lsst/analysis/tools/actions/plot/histPlot.py index b494a9716..ff1b89e7d 100644 --- a/python/lsst/analysis/tools/actions/plot/histPlot.py +++ b/python/lsst/analysis/tools/actions/plot/histPlot.py @@ -224,6 +224,16 @@ class HistPlot(PlotAction): "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", default="rubin", ) + panelsPerRow = Field[int]( + doc="Maximum number of histogram panels to place in each row. Set to 1 to stack panels vertically.", + default=2, + ) + + def validate(self): + super().validate() + if self.panelsPerRow < 1: + msg = "panelsPerRow must be at least 1." + raise FieldValidationError(self.__class__.panelsPerRow, self, msg) def getInputSchema(self) -> KeyedDataSchema: for panel in self.panels: # type: ignore @@ -292,14 +302,14 @@ def makePlot( nth_panel = len(self.panels) nth_col = ncols nth_row = nrows - 1 - label_font_size = max(6, 10 - nrows) + label_font_size = max(6, 10 - ((len(self.panels) + 1) // 2)) for panel, ax in zip(self.panels, axs): nth_panel -= 1 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: nth_col -= 1 # Set font size for legend based on number of panels being plotted. - legend_font_size = max(4, int(7 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore + legend_font_size = max(4, int(7 - len(self.panels[panel].hists) / 2 - len(self.panels) / 4)) nums, meds, mads, stats_dict = self._makePanel( data, panel, @@ -348,7 +358,7 @@ def _makeAxes(self, fig): if num_panels <= 1: ncols = 1 else: - ncols = 2 + ncols = min(self.panelsPerRow, num_panels) nrows = int(np.ceil(num_panels / ncols)) gs = GridSpec(nrows, ncols, left=0.12, right=0.88, bottom=0.1, top=0.88, wspace=0.41, hspace=0.45) @@ -477,7 +487,7 @@ def _makePanel( nHist += 1 if nHist > 0: - ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) + ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False, borderaxespad=1.1) ax.set_xlim(panel_range) # The following accommodates spacing for ranges with large numbers # but small-ish dynamic range (example use case: RA 300-301). @@ -624,7 +634,9 @@ def _addReferenceLines(self, ax, panel, panel_range, meds, legend_font_size=7): if ax2.get_ylim()[1] < 1.05 * y_max: ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) ax2.set_ylim(ax.get_ylim()) - ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) + ax2.legend( + fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False, borderaxespad=1.1 + ) return ax @@ -652,14 +664,20 @@ def _addStatisticsPanel( # set up new legend handles and labels legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) + def format_stat(x): + if isinstance(x, int): + return str(x) if abs(x) <= 9999 else f"{x:.2e}" + else: + return f"{x:.2f}" if abs(x) < 999 else f"{x:.2e}" + legend_labels = ( ([""] * (len(handles) + 1)) + [stats_dict["statLabels"][0]] - + [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat1"]] + + [format_stat(x) for x in stats_dict["stat1"]] + [stats_dict["statLabels"][1]] - + [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat2"]] + + [format_stat(x) for x in stats_dict["stat2"]] + [stats_dict["statLabels"][2]] - + [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat3"]] + + [format_stat(x) for x in stats_dict["stat3"]] ) # Replace "e+0" with "e" and "e-0" with "e-" to save space. legend_labels = [label.replace("e+0", "e") for label in legend_labels] diff --git a/python/lsst/analysis/tools/atools/skyObject.py b/python/lsst/analysis/tools/atools/skyObject.py index 5a7fc8e4f..b83718940 100644 --- a/python/lsst/analysis/tools/atools/skyObject.py +++ b/python/lsst/analysis/tools/atools/skyObject.py @@ -68,6 +68,7 @@ def setDefaults(self): self.process.buildActions.hist_gaap1p0_sn = CalcSn(fluxType="{band}_gaap1p0Flux") self.produce.plot = HistPlot() + self.produce.plot.panelsPerRow = 1 self.produce.plot.panels["panel_flux"] = HistPanel() self.produce.plot.panels["panel_flux"].label = "Flux (nJy)" diff --git a/tests/test_histPlot.py b/tests/test_histPlot.py new file mode 100644 index 000000000..8c8a2c888 --- /dev/null +++ b/tests/test_histPlot.py @@ -0,0 +1,74 @@ +# This file is part of analysis_tools. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +import unittest + +import matplotlib +import matplotlib.pyplot as plt + +import lsst.utils.tests +from lsst.analysis.tools.actions.plot.histPlot import HistPanel, HistPlot + +matplotlib.use("Agg") + + +class HistPlotLayoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + self.plot = HistPlot() + for panel_name in ("panel_a", "panel_b", "panel_c"): + self.plot.panels[panel_name] = HistPanel() + self.plot.panels[panel_name].hists = {f"{panel_name}_hist": panel_name} + + def test_default_layout_uses_two_columns(self): + fig = plt.figure() + + axes, ncols, nrows = self.plot._makeAxes(fig) + + self.assertEqual(ncols, 2) + self.assertEqual(nrows, 2) + self.assertEqual(len(axes), 3) + + def test_panels_per_row_one_stacks_vertically(self): + self.plot.panelsPerRow = 1 + fig = plt.figure() + + axes, ncols, nrows = self.plot._makeAxes(fig) + + self.assertEqual(ncols, 1) + self.assertEqual(nrows, 3) + self.assertEqual(len(axes), 3) + self.assertTrue(all(ax.get_position().x0 == axes[0].get_position().x0 for ax in axes[1:])) + self.assertGreater(axes[0].get_position().y0, axes[1].get_position().y0) + self.assertGreater(axes[1].get_position().y0, axes[2].get_position().y0) + + +class MemoryTester(lsst.utils.tests.MemoryTestCase): + pass + + +def setup_module(module): + lsst.utils.tests.init() + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()