diff --git a/pesummary/core/plots/corner.py b/pesummary/core/plots/corner.py
index 6288fb1caa61bc33fba231964aefc2a7aff40c19..ce971188d8711168e3d5ea9bd99a105d19a3a9c3 100644
--- a/pesummary/core/plots/corner.py
+++ b/pesummary/core/plots/corner.py
@@ -39,7 +39,7 @@ def hist2d(
     plot_contours=True, no_fill_contours=False, fill_contours=False,
     contour_kwargs=None, contourf_kwargs=None, data_kwargs=None,
     pcolor_kwargs=None, new_fig=True, kde=None, kde_kwargs={},
-    density_cmap=None, label=None, **kwargs
+    density_cmap=None, label=None, grid=True, **kwargs
 ):
     """Extension of the corner.hist2d function. Allows the user to specify the
     kde used when estimating the 2d probability density
diff --git a/pesummary/core/plots/publication.py b/pesummary/core/plots/publication.py
index acad4fa2397e632e312737ca217bdaf03beebf18..a4713589241fefc5ff7da57ef28980afd0110fe7 100644
--- a/pesummary/core/plots/publication.py
+++ b/pesummary/core/plots/publication.py
@@ -25,7 +25,7 @@ from pesummary import conf
 
 def pcolormesh(
     x, y, density, ax=None, levels=None, smooth=None, bins=None, label=None,
-    level_kwargs={}, range=None, **kwargs
+    level_kwargs={}, range=None, grid=True, **kwargs
 ):
     """Generate a colormesh plot on a given axis
 
@@ -55,8 +55,11 @@ def pcolormesh(
     _off = False
     if _cmap is not None and isinstance(_cmap, str) and _cmap.lower() == "off":
         _off = True
+    _zorder = 10.
+    if grid:
+        _zorder = -10
     if not _off:
-        ax.pcolormesh(x, y, density, **kwargs)
+        ax.pcolormesh(x, y, density, zorder=_zorder, **kwargs)
     if levels is not None:
         ax.contour(x, y, density, levels=levels, **level_kwargs)
     return ax
@@ -154,7 +157,7 @@ def twod_contour_plot(
 
     _function(
         x, y, *args, ax=ax, levels=levels, bins=bins, smooth=smooth,
-        label=label, **kwargs
+        label=label, grid=grid, **kwargs
     )
     if truth is not None:
         _default_truth_kwargs.update(truth_kwargs)
@@ -435,7 +438,7 @@ def _analytic_triangle_plot(
     """
     ax1, ax3, ax4 = axes
     analytic_twod_contour_plot(
-        x, y, probs_xy, ax=ax3, smooth=smooth, **kwargs
+        x, y, probs_xy, ax=ax3, smooth=smooth, grid=grid, **kwargs
     )
     ax1.plot(x, probs_x)
     ax4.plot(probs_y, y)
@@ -445,8 +448,12 @@ def _analytic_triangle_plot(
     if ylabel is not None:
         ax3.set_ylabel(ylabel, fontsize=fontsize["label"])
     ax1.grid(grid)
-    ax3.grid(grid)
+    ax3.grid(grid, zorder=10)
     ax4.grid(grid)
+    xlims = ax3.get_xlim()
+    ax1.set_xlim(xlims)
+    ylims = ax3.get_ylim()
+    ax4.set_ylim(ylims)
     return fig, ax1, ax3, ax4