diff --git a/spatialmath/base/graphics.py b/spatialmath/base/graphics.py index 2ce18dc8..8c6d1ef2 100644 --- a/spatialmath/base/graphics.py +++ b/spatialmath/base/graphics.py @@ -1462,10 +1462,11 @@ def _axes_dimensions(ax: plt.Axes) -> int: else: # handle the case of Animate objects pretending to be Axes classname = ax.__class__.__name__ - if classname == "Animate": - ret = 3 - elif classname == "Animate2": - ret = 2 + base_classes = ax.__class__.__bases__ + if classname in ("Axes3DSubplot", "Animate") or any(base_class.__name__ in ("Axes3DSubplot", "Animate") for base_class in base_classes): + return 3 + elif classname in ("AxesSubplot", "Animate2"): + return 2 # print("_axes_dimensions ", ax, ret) return ret