diff --git a/pesummary/core/plots/corner.py b/pesummary/core/plots/corner.py index 6288fb1caa61bc33fba231964aefc2a7aff40c19..ce971188d8711168e3d5ea9bd99a105d19a3a9c3 100644 --- a/pesummary/core/plots/corner.py +++ b/pesummary/core/plots/corner.py @@ -39,7 +39,7 @@ def hist2d( plot_contours=True, no_fill_contours=False, fill_contours=False, contour_kwargs=None, contourf_kwargs=None, data_kwargs=None, pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={}, - density_cmap=None, label=None, **kwargs + density_cmap=None, label=None, grid=True, **kwargs ): """Extension of the corner.hist2d function. Allows the user to specify the kde used when estimating the 2d probability density diff --git a/pesummary/core/plots/publication.py b/pesummary/core/plots/publication.py index acad4fa2397e632e312737ca217bdaf03beebf18..a4713589241fefc5ff7da57ef28980afd0110fe7 100644 --- a/pesummary/core/plots/publication.py +++ b/pesummary/core/plots/publication.py @@ -25,7 +25,7 @@ from pesummary import conf def pcolormesh( x, y, density, ax=None, levels=None, smooth=None, bins=None, label=None, - level_kwargs={}, range=None, **kwargs + level_kwargs={}, range=None, grid=True, **kwargs ): """Generate a colormesh plot on a given axis @@ -55,8 +55,11 @@ def pcolormesh( _off = False if _cmap is not None and isinstance(_cmap, str) and _cmap.lower() == "off": _off = True + _zorder = 10. + if grid: + _zorder = -10 if not _off: - ax.pcolormesh(x, y, density, **kwargs) + ax.pcolormesh(x, y, density, zorder=_zorder, **kwargs) if levels is not None: ax.contour(x, y, density, levels=levels, **level_kwargs) return ax @@ -154,7 +157,7 @@ def twod_contour_plot( _function( x, y, *args, ax=ax, levels=levels, bins=bins, smooth=smooth, - label=label, **kwargs + label=label, grid=grid, **kwargs ) if truth is not None: _default_truth_kwargs.update(truth_kwargs) @@ -435,7 +438,7 @@ def _analytic_triangle_plot( """ ax1, ax3, ax4 = axes analytic_twod_contour_plot( - x, y, probs_xy, ax=ax3, smooth=smooth, **kwargs + x, y, probs_xy, ax=ax3, smooth=smooth, grid=grid, **kwargs ) ax1.plot(x, probs_x) ax4.plot(probs_y, y) @@ -445,8 +448,12 @@ def _analytic_triangle_plot( if ylabel is not None: ax3.set_ylabel(ylabel, fontsize=fontsize["label"]) ax1.grid(grid) - ax3.grid(grid) + ax3.grid(grid, zorder=10) ax4.grid(grid) + xlims = ax3.get_xlim() + ax1.set_xlim(xlims) + ylims = ax3.get_ylim() + ax4.set_ylim(ylims) return fig, ax1, ax3, ax4