Skip to content

Commit 4d8af09

Browse files
committed
color_func is now back-compatible with sympy.plotting
1 parent 57f3753 commit 4d8af09

File tree

6 files changed

+161
-25
lines changed

6 files changed

+161
-25
lines changed

spb/backends/k3d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ def _process_series(self, series):
252252
x, y, z, u, v = s.get_data()
253253
vertices, indices = get_vertices_indices(x, y, z)
254254
vertices = vertices.astype(np.float32)
255-
attribute = s.color_func(vertices[:, 0], vertices[:, 1], vertices[:, 2], u.flatten().astype(np.float32), v.flatten().astype(np.float32))
255+
attribute = s.eval_color_func(vertices[:, 0], vertices[:, 1], vertices[:, 2], u.flatten().astype(np.float32), v.flatten().astype(np.float32))
256256
else:
257257
x, y, z = s.get_data()
258258
x = x.flatten()
259259
y = y.flatten()
260260
z = z.flatten()
261261
vertices = np.vstack([x, y, z]).T.astype(np.float32)
262262
indices = Triangulation(x, y).triangles.astype(np.uint32)
263-
attribute = s.color_func(vertices[:, 0], vertices[:, 1], vertices[:, 2])
263+
attribute = s.eval_color_func(vertices[:, 0], vertices[:, 1], vertices[:, 2])
264264

265265
self._high_aspect_ratio(x, y, z)
266266
a = dict(
@@ -493,11 +493,11 @@ def _update_interactive(self, params):
493493
if s.is_parametric:
494494
x, y, z, u, v = s.get_data()
495495
x, y, z, u, v = [t.flatten().astype(np.float32) for t in [x, y, z, u, v]]
496-
attribute = s.color_func(x, y, z, u, v)
496+
attribute = s.eval_color_func(x, y, z, u, v)
497497
else:
498498
x, y, z = s.get_data()
499499
x, y, z = [t.flatten().astype(np.float32) for t in [x, y, z]]
500-
attribute = s.color_func(x, y, z)
500+
attribute = s.eval_color_func(x, y, z)
501501

502502
vertices = np.vstack([x, y, z]).astype(np.float32)
503503
self._fig.objects[i].vertices = vertices.T

spb/backends/matplotlib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,10 @@ def _process_series(self, series):
413413
elif (s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit)):
414414
if not s.is_parametric:
415415
x, y, z = self.series[i].get_data()
416-
facecolors = s.color_func(x, y, z)
416+
facecolors = s.eval_color_func(x, y, z)
417417
else:
418418
x, y, z, u, v = self.series[i].get_data()
419-
facecolors = s.color_func(x, y, z, u, v)
419+
facecolors = s.eval_color_func(x, y, z, u, v)
420420
skw = dict(rstride=1, cstride=1, linewidth=0.1)
421421
norm, cmap = None, None
422422
if s.use_cm:
@@ -855,10 +855,10 @@ def _update_interactive(self, params):
855855
elif s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit):
856856
if not s.is_parametric:
857857
x, y, z = self.series[i].get_data()
858-
facecolors = s.color_func(x, y, z)
858+
facecolors = s.eval_color_func(x, y, z)
859859
else:
860860
x, y, z, u, v = self.series[i].get_data()
861-
facecolors = s.color_func(x, y, z, u, v)
861+
facecolors = s.eval_color_func(x, y, z, u, v)
862862
# TODO: by setting the keyword arguments, somehow the
863863
# update becomes really really slow.
864864
kw, is_cb_added, cax = self._handles[i][1:]

spb/backends/plotly.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,10 @@ def _process_series(self, series):
349349
elif s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit):
350350
if not s.is_parametric:
351351
xx, yy, zz = s.get_data()
352-
surfacecolor = s.color_func(xx, yy, zz)
352+
surfacecolor = s.eval_color_func(xx, yy, zz)
353353
else:
354354
xx, yy, zz, uu, vv = s.get_data()
355-
surfacecolor = s.color_func(xx, yy, zz, uu, vv)
355+
surfacecolor = s.eval_color_func(xx, yy, zz, uu, vv)
356356

357357
# create a solid color to be used when s.use_cm=False
358358
col = next(self._cl)
@@ -658,10 +658,10 @@ def _update_interactive(self, params):
658658
elif s.is_3Dsurface and (not s.is_domain_coloring) and (not s.is_implicit):
659659
if not s.is_parametric:
660660
x, y, z = s.get_data()
661-
surfacecolor = s.color_func(x, y, z)
661+
surfacecolor = s.eval_color_func(x, y, z)
662662
else:
663663
x, y, z, u, v = s.get_data()
664-
surfacecolor = s.color_func(x, y, z, u, v)
664+
surfacecolor = s.eval_color_func(x, y, z, u, v)
665665
self.fig.data[i]["x"] = x
666666
self.fig.data[i]["y"] = y
667667

spb/functions.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,16 @@ def plot_parametric(*args, show=True, **kwargs):
626626
backend : Plot, optional
627627
A subclass of `Plot`, which will perform the rendering.
628628
Default to `MatplotlibBackend`.
629-
629+
630630
color_func : callable, optional
631-
A function of 3 variables, x, y, parameter (the points computed by
632-
the internal algorithm) which defines the line color. Default to None.
631+
A function defining the line color. The arity can be:
632+
633+
* 1 argument: ``f(t)``, where ``t`` is the parameter.
634+
* 2 arguments: ``f(x, y)`` where ``x, y`` are the coordinates of the
635+
points.
636+
* 3 arguments: ``f(x, y, t)``.
637+
638+
Default to None.
633639
634640
label : str or list/tuple, optional
635641
The label to be shown in the legend or in the colorbar. If not
@@ -844,10 +850,16 @@ def plot3d_parametric_line(*args, show=True, **kwargs):
844850
backend : Plot, optional
845851
A subclass of `Plot`, which will perform the rendering.
846852
Default to `MatplotlibBackend`.
847-
853+
848854
color_func : callable, optional
849-
A function of 4 variables, x, y, z, parameter (the points computed by
850-
the internal algorithm) which defines the line color. Default to None.
855+
A function defining the line color. The arity can be:
856+
857+
* 1 argument: ``f(t)``, where ``t`` is the parameter.
858+
* 3 arguments: ``f(x, y, z)`` where ``x, y, z`` are the coordinates of
859+
the points.
860+
* 4 arguments: ``f(x, y, z, t)``.
861+
862+
Default to None.
851863
852864
label : str or list/tuple, optional
853865
The label to be shown in the legend or in the colorbar. If not
@@ -1306,9 +1318,16 @@ def plot3d_parametric_surface(*args, show=True, **kwargs):
13061318
Default to `MatplotlibBackend`.
13071319
13081320
color_func : callable, optional
1309-
A function of 5 variables, x, y, z, u, v (the points computed by the
1310-
internal algorithm and the parameters) which defines the surface color
1311-
when ``use_cm=True``. Default to None.
1321+
A function defining the surface color when ``use_cm=True``. The arity
1322+
can be:
1323+
1324+
* 1 argument: ``f(u)``, where ``u`` is the first parameter.
1325+
* 2 arguments: ``f(u, v)`` where ``u, v`` are the parameters.
1326+
* 3 arguments: ``f(x, y, z)`` where ``x, y, z`` are the coordinates of
1327+
the points.
1328+
* 5 arguments: ``f(x, y, z, u, v)``.
1329+
1330+
Default to None.
13121331
13131332
label : str or list/tuple, optional
13141333
The label to be shown in the colorbar. If not provided, the string
@@ -1401,7 +1420,7 @@ def plot3d_parametric_surface(*args, show=True, **kwargs):
14011420
... r * cos(v)
14021421
... )
14031422
>>> plot3d_parametric_surface(*expr, (u, 0, 2 * pi), (v, 0, pi), "u",
1404-
... n=200, use_cm=True, color_func=lambda x, y, z, u, v: u)
1423+
... n=200, use_cm=True, color_func=lambda u, v: u)
14051424
Plot object containing:
14061425
[0]: parametric cartesian surface: ((sin(7*u + 5*v) + 2)*sin(v)*cos(u), (sin(7*u + 5*v) + 2)*sin(u)*sin(v), (sin(7*u + 5*v) + 2)*cos(v)) for u over (0.0, 6.283185307179586) and v over (0.0, 3.141592653589793)
14071426

spb/series.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from spb.defaults import cfg
22
from sympy import latex
33
from sympy.core.containers import Tuple
4+
from sympy.core.function import arity
45
from sympy.core.symbol import symbols
56
from sympy.core.sympify import sympify
67
from sympy.solvers.solvers import solve
@@ -376,6 +377,43 @@ def _correct_size(a, b):
376377
a = a.reshape(b.shape)
377378
return a
378379

380+
def eval_color_func(self, *args):
381+
"""Evaluate the color function.
382+
383+
Parameters
384+
==========
385+
386+
args : tuple
387+
Arguments to be passed to the coloring function. Can be coordinates
388+
or parameters or both.
389+
390+
Notes
391+
=====
392+
393+
The backend will request the data series to generate the numerical
394+
data. Depending on the data series, either the data series itself or
395+
the backend will eventually execute this function to generate the
396+
appropriate coloring value.
397+
"""
398+
nargs = arity(self.color_func)
399+
if nargs == 1:
400+
if self.is_2Dline and self.is_parametric:
401+
if len(args) == 2:
402+
# ColoredLineOver1DRangeSeries
403+
return self.color_func(args[0])
404+
# Parametric2DLineSeries
405+
return self.color_func(args[2])
406+
elif self.is_3Dline and self.is_parametric:
407+
return self.color_func(args[3])
408+
elif self.is_3Dsurface and self.is_parametric:
409+
return self.color_func(args[3])
410+
return self.color_func(args[0])
411+
elif nargs == 2:
412+
if self.is_3Dsurface and self.is_parametric:
413+
return self.color_func(*args[3:])
414+
return self.color_func(*args[:2])
415+
return self.color_func(*args[:nargs])
416+
379417
def get_data(self):
380418
"""Compute and returns the numerical data.
381419
@@ -723,7 +761,7 @@ def get_points(self):
723761
Color associated to each point.
724762
"""
725763
x, y = super().get_points()
726-
return x, y, self.color_func(x, y)
764+
return x, y, self.eval_color_func(x, y)
727765

728766

729767
class AbsArgLineSeries(LineOver1DRangeSeries):
@@ -878,7 +916,7 @@ def get_points(self):
878916

879917
if callable(self.color_func):
880918
coords = list(coords)
881-
coords[-1] = self.color_func(*coords)
919+
coords[-1] = self.eval_color_func(*coords)
882920
return coords
883921

884922

@@ -1818,7 +1856,7 @@ def get_points(self):
18181856
Color associated to each point.
18191857
"""
18201858
x, y = super().get_points()
1821-
return x, y, self.color_func(x, y)
1859+
return x, y, self.eval_color_func(x, y)
18221860

18231861

18241862
class AbsArgLineInteractiveSeries(LineInteractiveSeries):

tests/test_series.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,3 +2364,82 @@ def test_is_polar_3d():
23642364
x22, y22 = x1 * np.cos(y1), x1 * np.sin(y1)
23652365
assert np.allclose(x2, x22)
23662366
assert np.allclose(y2, y22)
2367+
2368+
2369+
def test_color_func():
2370+
# verify that eval_color_func produces the expected results in order to
2371+
# maintain back compatibility with the old sympy.plotting module
2372+
2373+
x, y, z, u, v = symbols("x, y, z, u, v")
2374+
2375+
s = LineOver1DRangeSeries(sin(x), (x, -5, 5), adaptive=False, n=10,
2376+
color_func=lambda x: x)
2377+
xx, yy, col = s.get_data()
2378+
assert np.allclose(col, xx)
2379+
s = LineOver1DRangeSeries(sin(x), (x, -5, 5), adaptive=False, n=10,
2380+
color_func=lambda x, y: y)
2381+
xx, yy, col = s.get_data()
2382+
assert np.allclose(col, yy)
2383+
2384+
s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi),
2385+
adaptive=False, n=10, color_func=lambda t: t)
2386+
xx, yy, col = s.get_data()
2387+
assert (not np.allclose(xx, col)) and (not np.allclose(yy, col))
2388+
s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi),
2389+
adaptive=False, n=10, color_func=lambda x, y: x * y)
2390+
xx, yy, col = s.get_data()
2391+
assert np.allclose(col, xx * yy)
2392+
s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi),
2393+
adaptive=False, n=10, color_func=lambda x, y, t: x * y * t)
2394+
xx, yy, col = s.get_data()
2395+
assert np.allclose(col, xx * yy * np.linspace(0, 2*np.pi, 10))
2396+
2397+
s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi),
2398+
adaptive=False, n=10, color_func=lambda t: t)
2399+
xx, yy, zz, col = s.get_data()
2400+
assert (not np.allclose(xx, col)) and (not np.allclose(yy, col))
2401+
s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi),
2402+
adaptive=False, n=10, color_func=lambda x, y, z: x * y * z)
2403+
xx, yy, zz, col = s.get_data()
2404+
assert np.allclose(col, xx * yy * zz)
2405+
s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi),
2406+
adaptive=False, n=10, color_func=lambda x, y, z, t: x * y * z * t)
2407+
xx, yy, zz, col = s.get_data()
2408+
assert np.allclose(col, xx * yy * zz * np.linspace(0, 2*np.pi, 10))
2409+
2410+
s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2),
2411+
adaptive=False, n1=10, n2=10, color_func=lambda x: x)
2412+
xx, yy, zz = s.get_data()
2413+
col = s.eval_color_func(xx, yy, zz)
2414+
assert np.allclose(xx, col)
2415+
s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2),
2416+
adaptive=False, n1=10, n2=10, color_func=lambda x, y: x * y)
2417+
xx, yy, zz = s.get_data()
2418+
col = s.eval_color_func(xx, yy, zz)
2419+
assert np.allclose(xx * yy, col)
2420+
s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2),
2421+
adaptive=False, n1=10, n2=10, color_func=lambda x, y, z: x * y * z)
2422+
xx, yy, zz = s.get_data()
2423+
col = s.eval_color_func(xx, yy, zz)
2424+
assert np.allclose(xx * yy * zz, col)
2425+
2426+
s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False,
2427+
n1=10, n2=10, color_func=lambda u:u)
2428+
xx, yy, zz, uu, vv = s.get_data()
2429+
col = s.eval_color_func(xx, yy, zz, uu, vv)
2430+
assert np.allclose(uu, col)
2431+
s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False,
2432+
n1=10, n2=10, color_func=lambda u, v: u * v)
2433+
xx, yy, zz, uu, vv = s.get_data()
2434+
col = s.eval_color_func(xx, yy, zz, uu, vv)
2435+
assert np.allclose(uu * vv, col)
2436+
s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False,
2437+
n1=10, n2=10, color_func=lambda x, y, z: x * y * z)
2438+
xx, yy, zz, uu, vv = s.get_data()
2439+
col = s.eval_color_func(xx, yy, zz, uu, vv)
2440+
assert np.allclose(xx * yy * zz, col)
2441+
s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False,
2442+
n1=10, n2=10, color_func=lambda x, y, z, u, v: x * y * z * u * v)
2443+
xx, yy, zz, uu, vv = s.get_data()
2444+
col = s.eval_color_func(xx, yy, zz, uu, vv)
2445+
assert np.allclose(xx * yy * zz * uu * vv, col)

0 commit comments

Comments
 (0)