diff --git a/CHANGELOG.md b/CHANGELOG.md index e71671961..05eca701f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Attention: The newest changes should be on top --> ### Added +- ENH: Added Crop and Clip Methods to Function Class [#817](https://github.com/RocketPy-Team/RocketPy/pull/817) - ENH: Discretized and No-Pickle Encoding Options [#827] (https://github.com/RocketPy-Team/RocketPy/pull/827) - ENH: Add the Coriolis Force to the Flight class [#799](https://github.com/RocketPy-Team/RocketPy/pull/799) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 681aa68f5..648cc16c9 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -24,10 +24,9 @@ RBFInterpolator, ) +from rocketpy.plots.plot_helpers import show_or_save_plot from rocketpy.tools import deprecated, from_hex_decode, to_hex_encode -from ..plots.plot_helpers import show_or_save_plot - # Numpy 1.x compatibility, # TODO: remove these lines when all dependencies support numpy>=2.0.0 if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": @@ -153,6 +152,7 @@ def __init__( self.__extrapolation__ = extrapolation self.title = title self.__img_dim__ = 1 # always 1, here for backwards compatibility + self.__cropped_domain__ = (None, None) # the x interval if cropped # args must be passed from self. self.set_source(self.source) @@ -625,10 +625,121 @@ def __get_value_opt_nd(self, *args): return result + def __determine_1d_domain_bounds(self, lower, upper): + """Determine domain bounds for 1-D function discretization. + + Parameters + ---------- + lower : scalar, optional + Lower bound. If None, will use cropped domain or default. + upper : scalar, optional + Upper bound. If None, will use cropped domain or default. + + Returns + ------- + tuple + (lower_bound, upper_bound) for the domain. + """ + domain = [0, 10] # default boundaries + cropped = self.__cropped_domain__ + + if cropped[0] is not None and cropped[0] > domain[0]: + domain[0] = cropped[0] + + if cropped[1] is not None and cropped[1] < domain[1]: + domain[1] = cropped[1] + + # Input bounds have preference + domain[0] = lower if lower is not None else domain[0] + domain[1] = upper if upper is not None else domain[1] + + return domain + + def __determine_2d_domain_bounds(self, lower, upper, samples): + """Determine domain bounds for 2-D function discretization. + + Parameters + ---------- + lower : scalar or list, optional + Lower bounds. If None, will use cropped domain or default. + upper : scalar or list, optional + Upper bounds. If None, will use cropped domain or default. + samples : int or list + Number of samples for each dimension. + + Returns + ------- + tuple + (lower_bounds, upper_bounds, sample_counts) for the 2D domain. + """ + default_bounds = [[0, 10], [0, 10]] + + # Apply cropped domain constraints if they exist + final_bounds = deepcopy(default_bounds) + if self.__cropped_domain__ is not None: + for dim in range(2): + cropped_limits = self.__cropped_domain__[dim] + if cropped_limits is not None: + # Use the more restrictive bounds (cropped domain takes precedence) + final_bounds[dim][0] = max( + default_bounds[dim][0], cropped_limits[0] + ) + final_bounds[dim][1] = min( + default_bounds[dim][1], cropped_limits[1] + ) + + # Convert parameters to consistent list format + lower_bounds = self.__normalize_2d_parameter( + lower, [final_bounds[0][0], final_bounds[1][0]] + ) + upper_bounds = self.__normalize_2d_parameter( + upper, [final_bounds[0][1], final_bounds[1][1]] + ) + sample_counts = self.__normalize_2d_parameter(samples, samples) + + return lower_bounds, upper_bounds, sample_counts + + def __normalize_2d_parameter(self, param, default_values): + if param is None: + return ( + default_values + if isinstance(default_values, list) + else [default_values, default_values] + ) + + if isinstance(param, NUMERICAL_TYPES): + return [param, param] + + return param + + def __discretize_1d_function( + self, func, lower, upper, samples, interpolation, extrapolation, one_by_one + ): + lower, upper = self.__determine_1d_domain_bounds(lower, upper) + xs = np.linspace(lower, upper, samples) + ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs) + func.__interpolation__ = interpolation + func.__extrapolation__ = extrapolation + func.set_source(np.column_stack((xs, ys))) + + def __discretize_2d_function(self, func, lower, upper, samples): + lower, upper, sam = self.__determine_2d_domain_bounds(lower, upper, samples) + + # Create nodes to evaluate function + xs = np.linspace(lower[0], upper[0], sam[0]) + ys = np.linspace(lower[1], upper[1], sam[1]) + xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size) + + # Evaluate function at all mesh nodes and convert it to matrix + zs = np.array(func.get_value(xs, ys)) + func.set_source(np.concatenate(([xs], [ys], [zs])).transpose()) + func.__interpolation__ = "shepard" + func.__extrapolation__ = "natural" + def set_discrete( self, - lower=0, - upper=10, + lower=None, + upper=None, samples=200, interpolation="spline", extrapolation="constant", @@ -647,9 +758,9 @@ def set_discrete( Parameters ---------- lower : scalar, optional - Value where sampling range will start. Default is 0. + Value where sampling range will start. Default is None. upper : scalar, optional - Value where sampling range will end. Default is 10. + Value where sampling range will end. Default is None. samples : int, optional Number of samples to be taken from inside range. Default is 200. interpolation : string @@ -689,24 +800,11 @@ def set_discrete( func = deepcopy(self) if not mutate_self else self if func.__dom_dim__ == 1: - xs = np.linspace(lower, upper, samples) - ys = func.get_value(xs.tolist()) if one_by_one else func.get_value(xs) - func.__interpolation__ = interpolation - func.__extrapolation__ = extrapolation - func.set_source(np.column_stack((xs, ys))) + self.__discretize_1d_function( + func, lower, upper, samples, interpolation, extrapolation, one_by_one + ) elif func.__dom_dim__ == 2: - lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower - upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper - sam = 2 * [samples] if isinstance(samples, NUMERICAL_TYPES) else samples - # Create nodes to evaluate function - xs = np.linspace(lower[0], upper[0], sam[0]) - ys = np.linspace(lower[1], upper[1], sam[1]) - xs, ys = np.array(np.meshgrid(xs, ys)).reshape(2, xs.size * ys.size) - # Evaluate function at all mesh nodes and convert it to matrix - zs = np.array(func.get_value(xs, ys)) - func.set_source(np.concatenate(([xs], [ys], [zs])).transpose()) - func.__interpolation__ = "shepard" - func.__extrapolation__ = "natural" + self.__discretize_2d_function(func, lower, upper, samples) else: raise ValueError( "Discretization is only supported for 1-D and 2-D Functions." @@ -897,6 +995,244 @@ def reset( return self + def __crop_array_source(self, cropped_func, x_lim): + """Crop the array source of a Function based on domain limits. + + Parameters + ---------- + cropped_func : Function + The Function instance to be cropped. + x_lim : list[tuple] + Range of values with lower and upper limits for cropping. + """ + if cropped_func.__dom_dim__ == 1: + cropped_func.source = cropped_func.source[ + (cropped_func.source[:, 0] >= x_lim[0][0]) + & (cropped_func.source[:, 0] <= x_lim[0][1]) + ] + elif cropped_func.__dom_dim__ == 2: + cropped_func.source = cropped_func.source[ + (cropped_func.source[:, 0] >= x_lim[0][0]) + & (cropped_func.source[:, 0] <= x_lim[0][1]) + & (cropped_func.source[:, 1] >= x_lim[1][0]) + & (cropped_func.source[:, 1] <= x_lim[1][1]) + ] + + def __set_cropped_domain_1d(self, cropped_func, x_lim): + """Set the cropped domain for 1-D functions. + + Parameters + ---------- + cropped_func : Function + The Function instance to set the cropped domain for. + x_lim : list[tuple] + Range of values with lower and upper limits. + """ + if x_lim[0][0] < x_lim[0][1]: + cropped_func.__cropped_domain__ = x_lim[0] + + def __set_cropped_domain_2d(self, cropped_func, x_lim): + """Set the cropped domain for 2-D functions. + + Parameters + ---------- + cropped_func : Function + The Function instance to set the cropped domain for. + x_lim : list[tuple] + Range of values with lower and upper limits. + """ + if len(x_lim) < 2: + raise IndexError("x_lim must have a length of 2 for 2-D function") + + if x_lim[0] is not None and x_lim[0][0] < x_lim[0][1]: + cropped_func.__cropped_domain__ = [x_lim[0]] + else: + cropped_func.__cropped_domain__ = [None] + + if len(x_lim) >= 2 and x_lim[1] is not None and x_lim[1][0] < x_lim[1][1]: + cropped_func.__cropped_domain__.append(x_lim[1]) + else: + cropped_func.__cropped_domain__.append(None) + + def crop(self, x_lim): + """Restrict the **input** domain of the Function to specified ranges. + + This method limits the input values of the Function to the intervals + defined in `x_lim`, effectively trimming the data so that only values + within the specified ranges are retained. For multi-dimensional + functions, each dimension can be cropped independently by providing a + tuple with lower and upper bounds for each input variable. If a + dimension is set to `None`, it will not be cropped. + + Parameters + ---------- + x_lim : list[tuple] + Range of values with lower and upper limits for input values to be + cropped within. + + Returns + ------- + Function + A new Function instance with the cropped domain. + + See also + -------- + Function.clip + + Examples + -------- + >>> from rocketpy import Function + >>> import numpy as np + + Create two 2D functions: + >>> f1 = Function( + ... lambda x1, x2: np.sin(x1)*np.cos(x2), + ... inputs=['x1', 'x2'], + ... outputs='y' + ... ) + >>> f2 = Function( + ... lambda x1, x2: np.cos(x1)*np.sin(x2), + ... inputs=['x1', 'x2'], + ... outputs='y' + ... ) + + Crop their domains: + >>> f1_cropped = f1.crop([(-1, 1), (-2, 2)]) + >>> f2_cropped = f2.crop([None, (-2, 2)]) + + Compare the cropped functions using Function.compare_plots: + >>> # Function.compare_plots([ + >>> # (f1_cropped, 'sin(x1)*cos(x2), cropped'), + >>> # (f2_cropped, 'cos(x1)*sin(x2), cropped') + >>> # ]) + """ + if not isinstance(x_lim, list): + raise TypeError("x_lim must be a list of tuples.") + + if len(x_lim) > self.__dom_dim__: + raise ValueError( + "x_lim must not exceed the length of the domain dimension." + ) + + cropped_func = deepcopy(self) + + if isinstance(cropped_func.source, np.ndarray): + self.__crop_array_source(cropped_func, x_lim) + + if cropped_func.__dom_dim__ == 1: + self.__set_cropped_domain_1d(cropped_func, x_lim) + elif cropped_func.__dom_dim__ == 2: + self.__set_cropped_domain_2d(cropped_func, x_lim) + + cropped_func.set_source(cropped_func.source) + return cropped_func + + def __validate_clip_parameters(self, y_lim): + if not isinstance(y_lim, list): + raise TypeError("y_lim must be a list of tuples.") + + if len(y_lim) != len(self.__outputs__): + raise ValueError( + "y_lim must have the same length as the output dimensions." + ) + + def __clip_array_source(self, clipped_func, y_lim: list[tuple]): + clipped_func.source = clipped_func.source[ + (clipped_func.source[:, clipped_func.__dom_dim__] >= y_lim[0][0]) + & (clipped_func.source[:, clipped_func.__dom_dim__] <= y_lim[0][1]) + ] + + def __clip_numerical_source(self, clipped_func, y_lim: list[tuple]): + try: + if clipped_func.source < y_lim[0][0]: + raise ArithmeticError("Constant function outside range") + if clipped_func.source > y_lim[0][1]: + raise ArithmeticError("Constant function outside range") + except TypeError as e: + raise TypeError("y_lim must be the same type as the function source") from e + + def __clip_callable_source(self, clipped_func, y_lim: list[tuple]): + original_function = clipped_func.source + + def clipped_function(*args): + results = original_function(*args) + clipped_results = [] + + if isinstance(results, (tuple, list)): + # Multi-dimensional output + for i, (lower, upper) in enumerate(y_lim): + clipped_results.append(max(lower, min(upper, results[i]))) + else: + # Single value output + for lower, upper in y_lim: + clipped_results.append(max(lower, min(upper, results))) + + return ( + tuple(clipped_results) + if len(clipped_results) > 1 + else clipped_results[0] + ) + + clipped_func.source = clipped_function + + def clip(self, y_lim): + """Restrict the **output** values of the Function to specified ranges. + + This method limits the output values of the Function to the intervals + defined in `y_lim`, effectively removing all input-output pairs where + the output values fall outside the specified ranges. This operation + filters the data based on output constraints rather than input domain + restrictions. + + Parameters + ---------- + y_lim : list[tuple] + Range of values with lower and upper limits for output values to be + clipped within. + + Returns + ------- + Function + A new Function instance with the clipped output values. + + See also + -------- + Function.crop + + Examples + -------- + >>> from rocketpy import Function + >>> + >>> f = Function(lambda x: x**2, inputs='x', outputs='y') + >>> print(f) + Function from R1 to R1 : (x) → (y) + >>> f_clipped = f.clip([(-5, 5)]) + >>> print(f_clipped) + Function from R1 to R1 : (x) → (y) + """ + self.__validate_clip_parameters(y_lim) + + clipped_func = deepcopy(self) + + if isinstance(clipped_func.source, np.ndarray): + self.__clip_array_source(clipped_func, y_lim) + elif isinstance(clipped_func.source, NUMERICAL_TYPES): + self.__clip_numerical_source(clipped_func, y_lim) + elif callable(clipped_func.source): + self.__clip_callable_source(clipped_func, y_lim) + + try: + clipped_func.set_source(clipped_func.source) + except ValueError as e: + raise ValueError( + "Cannot clip function as function reduces to " + f"{len(clipped_func.source) if isinstance(clipped_func.source, (list, np.ndarray)) else 'unknown'} points (too few data points to define" + " a domain). Ensure that the source is array-like and has " + "sufficient data points after applying the clipping function." + ) from e + + return clipped_func + # Define all get methods def get_inputs(self): "Return tuple of inputs of the function." @@ -1524,8 +1860,13 @@ def plot_1d( # pylint: disable=too-many-statements ax = fig.axes if self._source_type is SourceType.CALLABLE: # Determine boundaries - lower = 0 if lower is None else lower - upper = 10 if upper is None else upper + domain = [0, 10] + if self.__cropped_domain__[0] and self.__cropped_domain__[0] > domain[0]: + domain[0] = self.__cropped_domain__[0] + if self.__cropped_domain__[1] and self.__cropped_domain__[1] < domain[1]: + domain[1] = self.__cropped_domain__[1] + lower = domain[0] if lower is None else lower + upper = domain[1] if upper is None else upper else: # Determine boundaries x_data = self.x_array @@ -1635,9 +1976,17 @@ def plot_2d( # pylint: disable=too-many-statements # Define a mesh and f values at mesh nodes for plotting if self._source_type is SourceType.CALLABLE: # Determine boundaries - lower = [0, 0] if lower is None else lower + domain = [[0, 10], [0, 10]] + if self.__cropped_domain__ is not None: + for i in range(0, 2): + if self.__cropped_domain__[i] is not None: + if self.__cropped_domain__[i][0] > domain[i][0]: + domain[i][0] = self.__cropped_domain__[i][0] + if self.__cropped_domain__[i][1] < domain[i][1]: + domain[i][1] = self.__cropped_domain__[i][1] + lower = [domain[0][0], domain[1][0]] if lower is None else lower lower = 2 * [lower] if isinstance(lower, NUMERICAL_TYPES) else lower - upper = [10, 10] if upper is None else upper + upper = [domain[0][1], domain[1][1]] if upper is None else upper upper = 2 * [upper] if isinstance(upper, NUMERICAL_TYPES) else upper else: # Determine boundaries diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index 77f5916f4..96acf45f5 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -250,6 +250,113 @@ def test_set_discrete_based_on_model_non_mutator(linear_func): assert callable(func.source) +source_array = np.array( + [ + [-2, -4, -6], + [-0.75, -1.5, -2.25], + [0, 0, 0], + [0, 1, 1], + [0.5, 1, 1.5], + [1.5, 1, 2.5], + [2, 4, 6], + ] +) +cropped_array = np.array([[-0.75, -1.5, -2.25], [0, 0, 0], [0, 1, 1], [0.5, 1, 1.5]]) +clipped_array = np.array([[0, 0, 0], [0, 1, 1], [0.5, 1, 1.5]]) + + +@pytest.mark.parametrize( + "array3dsource, array3dcropped", + [ + (source_array, cropped_array), + ], +) +def test_crop_ndarray(array3dsource, array3dcropped): # pylint: disable=unused-argument + """Tests the functionality of crop method of the Function class. + The source is initialized as a ndarray before cropping. + """ + func = Function(array3dsource, inputs=["x1", "x2"], outputs="y") + cropped_func = func.crop([(-1, 1), (-2, 2)]) + + assert isinstance(func, Function) + assert isinstance(cropped_func, Function) + assert np.array_equal(cropped_func.source, array3dcropped) + assert isinstance(cropped_func.source, type(func.source)) + + +def test_crop_function(): + """Tests the functionality of crop method of the Function class. + The source is initialized as a function before cropping. + """ + func = Function( + lambda x1, x2: np.sin(x1) * np.cos(x2), inputs=["x1", "x2"], outputs="y" + ) + cropped_func = func.crop([(-1, 1), (-2, 2)]) + + assert isinstance(func, Function) + assert isinstance(cropped_func, Function) + assert callable(func.source) + assert callable(cropped_func.source) + + +def test_crop_constant(): + """Tests the functionality of crop method of the Function class. + The source is initialized as a single integer constant before cropping. + """ + func = Function(13) + cropped_func = func.crop([(-1, 1)]) + + assert isinstance(func, Function) + assert isinstance(cropped_func, Function) + assert callable(func.source) + assert callable(cropped_func.source) + + +@pytest.mark.parametrize( + "array3dsource, array3dclipped", + [ + (source_array, clipped_array), + ], +) +def test_clip_ndarray(array3dsource, array3dclipped): # pylint: disable=unused-argument + """Tests the functionality of clip method of the Function class. + The source is initialized as a ndarray before clipping. + """ + func = Function(array3dsource, inputs=["x1", "x2"], outputs="y") + clipped_func = func.clip([(-2, 2)]) + + assert isinstance(func, Function) + assert isinstance(clipped_func, Function) + assert np.array_equal(clipped_func.source, array3dclipped) + assert isinstance(clipped_func.source, type(func.source)) + + +def test_clip_function(): + """Tests the functionality of clip method of the Function class. + The source is initialized as a function before clipping. + """ + func = Function(lambda x: x**2, inputs="x", outputs="y") + clipped_func = func.clip([(-1, 1)]) + + assert isinstance(func, Function) + assert isinstance(clipped_func, Function) + assert callable(func.source) + assert callable(clipped_func.source) + + +def test_clip_constant(): + """Tests the functionality of clip method of the Function class. + The source is initialized as a single integer constant before clipping. + """ + func = Function(1) + clipped_func = func.clip([(-2, 2)]) + + assert isinstance(func, Function) + assert isinstance(clipped_func, Function) + assert callable(func.source) + assert callable(clipped_func.source) + + @pytest.mark.parametrize( "x, y, expected_x, expected_y", [