From ba236a87d8f9262f6a41fe06e45fca495896a25e Mon Sep 17 00:00:00 2001 From: Adrian Schneider Date: Mon, 28 Apr 2025 16:49:51 +0200 Subject: [PATCH] Refactored constraints --- .gitignore | 2 + cadquery/__init__.py | 3 +- cadquery/assembly.py | 143 +++-- cadquery/occ_impl/solver.py | 1208 +++++++++++++++++++++-------------- tests/test_assembly.py | 17 +- 5 files changed, 847 insertions(+), 526 deletions(-) diff --git a/.gitignore b/.gitignore index 41ad64e42..eac3b07d7 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,5 @@ out1.3mf out2.3mf out3.3mf orig.dxf +box.brep +sketch.dxf \ No newline at end of file diff --git a/cadquery/__init__.py b/cadquery/__init__.py index 0678377a6..73fe8d738 100644 --- a/cadquery/__init__.py +++ b/cadquery/__init__.py @@ -37,7 +37,7 @@ ) from .sketch import Sketch from .cq import CQ, Workplane -from .assembly import Assembly, Color, Constraint +from .assembly import Assembly, Color from . import selectors from . import plugins @@ -47,7 +47,6 @@ "Workplane", "Assembly", "Color", - "Constraint", "plugins", "selectors", "Plane", diff --git a/cadquery/assembly.py b/cadquery/assembly.py index 1aac2cb3d..6474695ec 100644 --- a/cadquery/assembly.py +++ b/cadquery/assembly.py @@ -1,4 +1,5 @@ from functools import reduce +from itertools import chain from typing import ( Union, Optional, @@ -22,7 +23,7 @@ from .occ_impl.solver import ( ConstraintKind, ConstraintSolver, - ConstraintSpec as Constraint, + BaseConstraint, UnaryConstraintKind, BinaryConstraintKind, ) @@ -92,7 +93,7 @@ class Assembly(object): children: List["Assembly"] objects: Dict[str, "Assembly"] - constraints: List[Constraint] + constraints: List[BaseConstraint] # Allows metadata to be stored for exports _subshape_names: dict[Shape, str] @@ -341,12 +342,21 @@ def _subloc(self, name: str) -> Tuple[Location, str]: @overload def constrain( - self, q1: str, q2: str, kind: ConstraintKind, param: Any = None + self, + q1: str, + q2: str, + kind: Literal["Point", "Axis", "PointInPlane", "PointOnLine", "Plane"], + param: Any = None, ) -> "Assembly": ... @overload - def constrain(self, q1: str, kind: ConstraintKind, param: Any = None) -> "Assembly": + def constrain( + self, + q1: str, + kind: Literal["Fixed", "FixedPoint", "FixedAxis", "FixedRotation"], + param: Any = None, + ) -> "Assembly": ... @overload @@ -356,57 +366,95 @@ def constrain( s1: Shape, id2: str, s2: Shape, - kind: ConstraintKind, + kind: Literal["Point", "Axis", "PointInPlane", "PointOnLine", "Plane"], param: Any = None, ) -> "Assembly": ... @overload def constrain( - self, id1: str, s1: Shape, kind: ConstraintKind, param: Any = None, + self, + id1: str, + s1: Shape, + kind: Literal["Fixed", "FixedPoint", "FixedAxis", "FixedRotation"], + param: Any = None, ) -> "Assembly": ... def constrain(self, *args, param=None): """ Define a new constraint. - """ - # dispatch on arguments - if len(args) == 2: - q1, kind = args - id1, s1 = self._query(q1) - elif len(args) == 3 and instance_of(args[1], UnaryConstraintKind): - q1, kind, param = args - id1, s1 = self._query(q1) - elif len(args) == 3: - q1, q2, kind = args - id1, s1 = self._query(q1) - id2, s2 = self._query(q2) - elif len(args) == 4: - q1, q2, kind, param = args - id1, s1 = self._query(q1) - id2, s2 = self._query(q2) - elif len(args) == 5: - id1, s1, id2, s2, kind = args - elif len(args) == 6: - id1, s1, id2, s2, kind, param = args - else: - raise ValueError(f"Incompatible arguments: {args}") - - # handle unary and binary constraints - if instance_of(kind, UnaryConstraintKind): - loc1, id1_top = self._subloc(id1) - c = Constraint((id1_top,), (s1,), (loc1,), kind, param) - elif instance_of(kind, BinaryConstraintKind): - loc1, id1_top = self._subloc(id1) - loc2, id2_top = self._subloc(id2) - c = Constraint((id1_top, id2_top), (s1, s2), (loc1, loc2), kind, param) + The method accepts several call signatures: + + 1. Unary constraints (Fixed, FixedPoint, FixedAxis, FixedRotation): + - constrain(query_str, kind, param=None) + - constrain(id_str, shape, kind, param=None) + + 2. Binary constraints (Point, Axis, PointInPlane, PointOnLine, Plane): + - constrain(query_str1, query_str2, kind, param=None) + - constrain(id_str1, shape1, id_str2, shape2, kind, param=None) + + 3. Higher order constraints: + - constrain(query_str1, query_str2, ..., query_strN, kind, param=None) + - constrain(id_str1, shape1, id_str2, shape2, ..., id_strN, shapeN, kind, param=None) + """ + + # Collect all arguments into ids, shapes, and kind + ids = [] + shapes = [] + + if len(args) < 2: + raise ValueError("At least two arguments required") + + # Find the kind argument - it should be a string and a valid constraint kind + kind_idx = -1 + for i, arg in enumerate(args): + constraint_kinds = chain.from_iterable( + get_args(x) for x in get_args(ConstraintKind) + ) + if isinstance(arg, str) and arg in constraint_kinds: + kind_idx = i + break + + if kind_idx == -1: + raise ValueError("No valid constraint kind found in arguments") + + kind = args[kind_idx] + + # Handle arguments before the kind + if all(isinstance(arg, str) for arg in args[:kind_idx]): + # Query string pattern + for q in args[:kind_idx]: + id_, shape = self._query(q) + ids.append(id_) + shapes.append(shape) else: - raise ValueError(f"Unknown constraint: {kind}") - + # id/shape pairs pattern + if kind_idx % 2 != 0: # Should be even (pairs) + raise ValueError("Arguments before kind must be id/shape pairs") + for i in range(0, kind_idx, 2): + ids.append(args[i]) + shapes.append(args[i + 1]) + + # Handle param if present after kind + if kind_idx < len(args) - 1: + param = args[kind_idx + 1] + + # Get locations based on whether it's a unary or binary constraint + locs = [] + ids_top = [] + for id_ in ids: + loc, id_top = self._subloc(id_) + locs.append(loc) + ids_top.append(id_top) + + args_tuple = (tuple(ids_top), tuple(shapes), tuple(locs)) + + # Create the appropriate constraint based on kind + constraint_class = BaseConstraint.get_constraint_class(kind) + c = constraint_class(*args_tuple, param) self.constraints.append(c) - return self def solve(self, verbosity: int = 0) -> "Assembly": @@ -453,17 +501,8 @@ def solve(self, verbosity: int = 0) -> "Assembly": locs = [self.objects[n].loc for n in ents] - # construct the constraint mapping - constraints = [] - for c in self.constraints: - ixs = tuple(ents[obj] for obj in c.objects) - pods = c.toPODs() - - for pod in pods: - constraints.append((ixs, pod)) - # check if any constraints were specified - if not constraints: + if not self.constraints: raise ValueError("At least one constraint required") # check if at least two entities are present @@ -472,7 +511,9 @@ def solve(self, verbosity: int = 0) -> "Assembly": # instantiate the solver scale = self.toCompound().BoundingBox().DiagonalLength - solver = ConstraintSolver(locs, constraints, locked=locked, scale=scale) + solver = ConstraintSolver( + locs, self.constraints, object_indices=ents, locked=locked, scale=scale + ) # solve locs_new, self._solve_result = solver.solve(verbosity) diff --git a/cadquery/occ_impl/solver.py b/cadquery/occ_impl/solver.py index 69a1cecb9..f628a65e1 100644 --- a/cadquery/occ_impl/solver.py +++ b/cadquery/occ_impl/solver.py @@ -1,18 +1,13 @@ from typing import ( - List, - Tuple, Union, Any, - Callable, Optional, - Dict, Literal, - cast as tcast, Type, ) - +from dataclasses import dataclass from math import radians, pi -from typish import instance_of, get_type +from abc import ABC, abstractmethod import casadi as ca @@ -32,581 +27,842 @@ from .geom import Location, Vector, Plane from .shapes import Shape, Face, Edge, Wire -from ..types import Real - -# type definitions +# Type definitions NoneType = type(None) -DOF6 = Tuple[Tuple[float, float, float], Tuple[float, float, float]] +DOF6 = tuple[tuple[float, float, float], tuple[float, float, float]] ConstraintMarker = Union[gp_Pln, gp_Dir, gp_Pnt, gp_Lin, None] UnaryConstraintKind = Literal["Fixed", "FixedPoint", "FixedAxis", "FixedRotation"] BinaryConstraintKind = Literal["Plane", "Point", "Axis", "PointInPlane", "PointOnLine"] -ConstraintKind = Literal[ - "Plane", - "Point", - "Axis", - "PointInPlane", - "Fixed", - "FixedPoint", - "FixedAxis", - "PointOnLine", - "FixedRotation", -] - -# (arity, marker types, param type, conversion func) -ConstraintInvariants = { - "Point": (2, (gp_Pnt, gp_Pnt), Real, None), - "Axis": ( - 2, - (gp_Dir, gp_Dir), - Real, - lambda x: radians(x) if x is not None else None, - ), - "PointInPlane": (2, (gp_Pnt, gp_Pln), Real, None), - "PointOnLine": (2, (gp_Pnt, gp_Lin), Real, None), - "Fixed": (1, (None,), Type[None], None), - "FixedPoint": (1, (gp_Pnt,), Tuple[Real, Real, Real], None), - "FixedAxis": (1, (gp_Dir,), Tuple[Real, Real, Real], None), - "FixedRotation": ( - 1, - (None,), - Tuple[Real, Real, Real], - lambda x: tuple(map(radians, x)), - ), -} - -# translation table for compound constraints {name : (name, ...), converter} -CompoundConstraints: Dict[ - ConstraintKind, Tuple[Tuple[ConstraintKind, ...], Callable[[Any], Tuple[Any, ...]]] -] = { - "Plane": (("Axis", "Point"), lambda x: (radians(x) if x is not None else None, 0)), -} - -# constraint POD type -Constraint = Tuple[ - Tuple[ConstraintMarker, ...], ConstraintKind, Optional[Any], -] - -NDOF_V = 3 -NDOF_Q = 3 -NDOF = 6 -DIR_SCALING = 1e2 -DIFF_EPS = 1e-10 -TOL = 1e-12 -MAXITER = 2000 - -# high-level constraint class - to be used by clients - - -class ConstraintSpec(object): - """ - Geometrical constraint specification between two shapes of an assembly. +ConstraintKind = Union[UnaryConstraintKind, BinaryConstraintKind] + +# Constants for solver +NDOF_V = 3 # Number of degrees of freedom for translation +NDOF_Q = 3 # Number of degrees of freedom for rotation +NDOF = 6 # Total degrees of freedom +DIR_SCALING = 1e2 # Scaling factor for directions +DIFF_EPS = 1e-10 # Epsilon for finite differences +TOL = 1e-12 # Tolerance for convergence +MAXITER = 2000 # Maximum number of iterations + +# Helper functions for constraint cost calculations +def Quaternion(R): + m = ca.sumsqr(R) + u = 2 * R / (1 + m) + s = (1 - m) / (1 + m) + return s, u + + +def Rotate(v, R): + s, u = Quaternion(R) + return 2 * ca.dot(u, v) * u + (s ** 2 - ca.dot(u, u)) * v + 2 * s * ca.cross(u, v) + + +def Transform(v, T, R): + return Rotate(v, R) + T + + +def loc_to_dof6(loc: Location) -> DOF6: + """Convert a Location to a 6-DOF representation (translation and rotation).""" + Tr = loc.wrapped.Transformation() + v = Tr.TranslationPart() + q = Tr.GetRotation() + + alpha_2 = (1 - q.W()) / (1 + q.W()) + a = (alpha_2 + 1) * q.X() / 2 + b = (alpha_2 + 1) * q.Y() / 2 + c = (alpha_2 + 1) * q.Z() / 2 + + return (v.X(), v.Y(), v.Z()), (a, b, c) + + +def getDir(arg: Shape) -> gp_Dir: + if isinstance(arg, Face): + rv = arg.normalAt() + elif isinstance(arg, Edge) and arg.geomType() != "CIRCLE": + rv = arg.tangentAt() + elif isinstance(arg, Edge) and arg.geomType() == "CIRCLE": + rv = arg.normal() + else: + raise ValueError(f"Cannot construct Axis for {arg}") + + return rv.toDir() + + +def getPln(arg: Shape) -> gp_Pln: + if isinstance(arg, Face): + rv = gp_Pln(getPnt(arg), arg.normalAt().toDir()) + elif isinstance(arg, (Edge, Wire)): + normal = arg.normal() + origin = arg.Center() + plane = Plane(origin, normal=normal) + rv = plane.toPln() + else: + raise ValueError(f"Cannot construct a plane for {arg}.") + + return rv + + +def getPnt(arg: Shape) -> gp_Pnt: + # check for infinite face + if isinstance(arg, Face) and any( + Precision.IsInfinite_s(x) for x in BRepTools.UVBounds_s(arg.wrapped) + ): + # fall back to gp_Pln center + pln = arg.toPln() + center = Vector(pln.Location()) + else: + center = arg.Center() + + return center.toPnt() + + +def getLin(arg: Shape) -> gp_Lin: + if isinstance(arg, (Edge, Wire)): + center = arg.Center() + tangent = arg.tangentAt() + else: + raise ValueError(f"Cannot construct a plane for {arg}.") + + return gp_Lin(center.toPnt(), tangent.toDir()) + + +@dataclass +class CostParams: + """Parameters passed to constraint cost functions. + + This class standardizes the parameters passed to all constraint cost functions, + making them compatible with any number of objects involved in the constraint. """ - objects: Tuple[str, ...] - args: Tuple[Shape, ...] - sublocs: Tuple[Location, ...] + problem: ca.Opti + markers: list[Union[gp_Pnt, gp_Dir, gp_Pln, gp_Lin, None]] + initial_translations: list[ca.DM] # T0 values for each object + initial_rotations: list[ca.DM] # R0 values for each object + translations: list[ca.MX] # T values for each object + rotations: list[ca.MX] # R values for each object + param: Optional[Any] = None # Optional constraint-specific parameter + scale: float = 1.0 # Scale factor for the optimization + + def __post_init__(self): + """Validate that all lists have the same length.""" + n_objects = len(self.markers) + if not all( + len(lst) == n_objects + for lst in [ + self.initial_translations, + self.initial_rotations, + self.translations, + self.rotations, + ] + ): + raise ValueError("All parameter lists must have the same length") + + +class BaseConstraint(ABC): + """Base class for all constraints.""" + kind: ConstraintKind - param: Any + _registry: dict[ConstraintKind, Type["BaseConstraint"]] = {} + + def __init_subclass__(cls, **kwargs): + """Register constraint classes by their kind.""" + super().__init_subclass__(**kwargs) + if hasattr(cls, "kind"): + BaseConstraint._registry[cls.kind] = cls def __init__( self, - objects: Tuple[str, ...], - args: Tuple[Shape, ...], - sublocs: Tuple[Location, ...], - kind: ConstraintKind, + objects: tuple[str, ...], + args: tuple[Shape, ...], + sublocs: tuple[Location, ...], param: Any = None, ): """ - Construct a constraint. + Initialize a constraint. - :param objects: object names referenced in the constraint - :param args: subshapes (e.g. faces or edges) of the objects - :param sublocs: locations of the objects (only relevant if the objects are nested in a sub-assembly) - :param kind: constraint kind - :param param: optional arbitrary parameter passed to the solver + :param objects: Tuple of object names involved in the constraint + :param args: Tuple of shapes involved in the constraint + :param sublocs: Tuple of locations for each object + :param param: Optional constraint-specific parameter """ - - # validate - if not instance_of(kind, ConstraintKind): - raise ValueError(f"Unknown constraint {kind}.") - - if kind in CompoundConstraints: - kinds, convert_compound = CompoundConstraints[kind] - for k, p in zip(kinds, convert_compound(param)): - self._validate(args, k, p) - else: - self._validate(args, kind, param) - - # convert here for simple constraints - convert = ConstraintInvariants[kind][-1] - param = convert(param) if convert else param - - # store self.objects = objects self.args = args self.sublocs = sublocs - self.kind = kind self.param = param - def _validate(self, args: Tuple[Shape, ...], kind: ConstraintKind, param: Any): + @classmethod + def get_constraint_class(cls, kind: ConstraintKind) -> Type["BaseConstraint"]: + """Get the constraint class for a given kind.""" + return cls._registry[kind] + + +class ConstraintSpec(BaseConstraint, ABC): + """ + Geometrical constraint specification between two shapes of an assembly. + """ + + objects: tuple[str, ...] # Names of objects involved in the constraint + args: tuple[Shape, ...] # Shapes involved in the constraint + sublocs: tuple[Location, ...] # Locations of objects in the constraint + kind: ConstraintKind # Type of constraint + param: Any # Constraint-specific parameter + arity: int = 0 # Number of objects involved in the constraint + marker_types: tuple[ + Type[ConstraintMarker], ... + ] = () # Types of geometric markers needed + param_type: Optional[Any] = None # Type of the constraint parameter - arity, marker_types, param_type, converter = ConstraintInvariants[kind] + def __init__( + self, + objects: tuple[str, ...], + args: tuple[Shape, ...], + sublocs: tuple[Location, ...], + param: Any = None, + ): + super().__init__(objects, args, sublocs, param) + self._validate(args) + self.validate_param(param) + self.param = self.convert_param(param) - # check arity - if arity != len(args): + @staticmethod + @abstractmethod + def cost(params: CostParams) -> float: + """Cost function for the constraint. + + :param params: CostParams object containing the necessary parameters + :return: float value of the cost function + """ + pass + + @abstractmethod + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint. + + :return: tuple of geometric markers + """ + pass + + def get_param(self) -> Optional[Any]: + """Get the parameter for this constraint. + + :return: constraint parameter or None if not applicable + """ + return self.param + + def validate_param(self, param: Any) -> None: + """Validate the constraint parameter. + + :param param: Parameter to validate + :raises ValueError: If parameter is invalid + """ + pass + + def convert_param(self, param: Any) -> Any: + """Convert the parameter to the required type. + + :param param: Parameter to convert + :return: Converted parameter + """ + return param + + def _validate(self, args: tuple[Shape, ...]) -> None: + """Validate arguments for the constraint. + + Args: + args: tuple of shapes to validate + + Raises: + ValueError: If number of arguments doesn't match arity or if arguments are of wrong type + """ + # Validate number of arguments matches constraint arity + if self.arity != len(args): raise ValueError( - f"Invalid number of entities for constraint {kind}. Provided {len(args)}, required {arity}." + f"Invalid number of entities for constraint {self.kind}. " + f"Provided {len(args)}, required {self.arity}." ) - # check arguments - arg_check: Dict[Any, Callable[[Shape], Any]] = { - gp_Pnt: self._getPnt, - gp_Dir: self._getAxis, - gp_Pln: self._getPln, - gp_Lin: self._getLin, - None: lambda x: True, # dummy check for None marker + # Define validation functions for each marker type + MARKER_VALIDATORS = { + gp_Pnt: getPnt, # Point validation + gp_Dir: getDir, # Direction validation + gp_Pln: getPln, # Plane validation + gp_Lin: getLin, # Line validation + NoneType: lambda _: True, # No validation needed for None markers } - for a, t in zip(args, tcast(Tuple[Type[ConstraintMarker], ...], marker_types)): + # Validate each argument against its expected marker type + for arg, marker_type in zip(args, self.marker_types): try: - arg_check[t](a) + MARKER_VALIDATORS[marker_type](arg) except ValueError: - raise ValueError(f"Unsupported entity {a} for constraint {kind}.") + raise ValueError( + f"Unsupported entity {arg} for constraint {self.kind}. " + f"Expected type: {marker_type.__name__}" + ) + - # check parameter - if not instance_of(param, param_type) and param is not None: +class CompoundConstraintSpec(BaseConstraint, ABC): + """Base class for compound constraints that consist of multiple simple constraints.""" + + kind: ConstraintKind + + @abstractmethod + def expand(self) -> list[ConstraintSpec]: + """Expand the compound constraint into its constituent simple constraints.""" + pass + + +class PointConstraint(ConstraintSpec): + """Point constraint between two points.""" + + arity = 2 + marker_types = (gp_Pnt, gp_Pnt) + kind = "Point" + param_type = Optional[float] # Distance between points + + def validate_param(self, param: Optional[float]) -> None: + """Validate that the parameter is a numeric distance. + + :param param: Optional distance between points + :raises ValueError: If parameter is not numeric + """ + if param is not None and not isinstance(param, (int, float)): raise ValueError( - f"Unsupported argument types {get_type(param)}, required {param_type}." + f"Point constraint parameter must be numeric, got {type(param)}" ) - # check parameter conversion - try: - if param is not None and converter: - converter(param) - except Exception as e: - raise ValueError(f"Exception {e} occured in the parameter conversion") - - def _getAxis(self, arg: Shape) -> gp_Dir: - - if isinstance(arg, Face): - rv = arg.normalAt() - elif isinstance(arg, Edge) and arg.geomType() != "CIRCLE": - rv = arg.tangentAt() - elif isinstance(arg, Edge) and arg.geomType() == "CIRCLE": - rv = arg.normal() - else: - raise ValueError(f"Cannot construct Axis for {arg}") - - return rv.toDir() - - def _getPln(self, arg: Shape) -> gp_Pln: - - if isinstance(arg, Face): - rv = gp_Pln(self._getPnt(arg), arg.normalAt().toDir()) - elif isinstance(arg, (Edge, Wire)): - normal = arg.normal() - origin = arg.Center() - plane = Plane(origin, normal=normal) - rv = plane.toPln() - else: - raise ValueError(f"Cannot construct a plane for {arg}.") + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + # apply sublocation + args = tuple( + arg.located(loc * arg.location()) + for arg, loc in zip(self.args, self.sublocs) + ) - return rv + return (getPnt(args[0]), getPnt(args[1])) + + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for Point constraint. + + Minimizes the distance between two points. + If val is provided, enforces that distance to be val. + """ + m1, m2 = params.markers + if not isinstance(m1, gp_Pnt) or not isinstance(m2, gp_Pnt): + raise TypeError("Point constraint requires two points as markers") - def _getPnt(self, arg: Shape) -> gp_Pnt: + T1_0, T2_0 = params.initial_translations + R1_0, R2_0 = params.initial_rotations + T1, T2 = params.translations + R1, R2 = params.rotations + val = 0 if params.param is None else params.param + scale = params.scale - # check for infinite face - if isinstance(arg, Face) and any( - Precision.IsInfinite_s(x) for x in BRepTools.UVBounds_s(arg.wrapped) - ): - # fall back to gp_Pln center - pln = arg.toPln() - center = Vector(pln.Location()) - else: - center = arg.Center() + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + m2_dm = ca.DM((m2.X(), m2.Y(), m2.Z())) - return center.toPnt() + point_error = ( + Transform(m1_dm, T1_0 + T1, R1_0 + R1) + - Transform(m2_dm, T2_0 + T2, R2_0 + R2) + ) / scale - def _getLin(self, arg: Shape) -> gp_Lin: + if val == 0: + return ca.sumsqr(point_error) - if isinstance(arg, (Edge, Wire)): - center = arg.Center() - tangent = arg.tangentAt() - else: - raise ValueError(f"Cannot construct a plane for {arg}.") + return (ca.sumsqr(point_error) - (val / scale) ** 2) ** 2 - return gp_Lin(center.toPnt(), tangent.toDir()) - def toPODs(self) -> Tuple[Constraint, ...]: +class AxisConstraint(ConstraintSpec): + """Axis constraint between two axes.""" + + arity = 2 + marker_types = (gp_Dir, gp_Dir) + kind = "Axis" + param_type = Optional[float] # Angle between axes in degrees + + def validate_param(self, param: Optional[float]) -> None: + """Validate that the parameter is a numeric angle. + + :param param: Optional angle between axes in degrees + :raises ValueError: If parameter is not numeric """ - Convert the constraint to a representation used by the solver. + if param is not None and not isinstance(param, (int, float)): + raise ValueError( + f"Axis constraint parameter must be numeric, got {type(param)}" + ) - NB: Compound constraints are decomposed into simple ones. + def convert_param(self, param: Optional[float]) -> Optional[float]: + """Convert angle from degrees to radians. + + :param param: Angle in degrees + :return: Angle in radians """ + return radians(param) if param is not None else None + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" # apply sublocation args = tuple( arg.located(loc * arg.location()) for arg, loc in zip(self.args, self.sublocs) ) - markers: List[Tuple[ConstraintMarker, ...]] - - # convert to marker objects - if self.kind == "Axis": - markers = [(self._getAxis(args[0]), self._getAxis(args[1]),)] + return (getDir(args[0]), getDir(args[1])) - elif self.kind == "Point": - markers = [(self._getPnt(args[0]), self._getPnt(args[1]))] + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for Axis constraint. + + Minimizes the angle between two axes. + If val is provided, enforces that angle to be val. + """ + m1, m2 = params.markers + if not isinstance(m1, gp_Dir) or not isinstance(m2, gp_Dir): + raise TypeError("Axis constraint requires two directions as markers") - elif self.kind == "Plane": - markers = [ - (self._getAxis(args[0]), self._getAxis(args[1]),), - (self._getPnt(args[0]), self._getPnt(args[1])), - ] + R1_0, R2_0 = params.initial_rotations + R1, R2 = params.rotations + val = pi if params.param is None else params.param - elif self.kind == "PointInPlane": - markers = [(self._getPnt(args[0]), self._getPln(args[1]))] + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + m2_dm = ca.DM((m2.X(), m2.Y(), m2.Z())) - elif self.kind == "PointOnLine": - markers = [(self._getPnt(args[0]), self._getLin(args[1]))] + d1, d2 = (Rotate(m1_dm, R1_0 + R1), Rotate(m2_dm, R2_0 + R2)) - elif self.kind == "Fixed": - markers = [(None,)] + if val == 0: + axis_error = d1 - d2 + return ca.sumsqr(axis_error) - elif self.kind == "FixedPoint": - markers = [(self._getPnt(args[0]),)] + elif val == pi: + axis_error = d1 + d2 + return ca.sumsqr(axis_error) - elif self.kind == "FixedAxis": - markers = [(self._getAxis(args[0]),)] + axis_error = ca.dot(d1, d2) - ca.cos(val) + return axis_error ** 2 - elif self.kind == "FixedRotation": - markers = [(None,), (None,), (None,)] - elif self.kind == "FixedRotationAxis": - markers = [(None,)] +class PointInPlaneConstraint(ConstraintSpec): + """Point in plane constraint.""" - else: - raise ValueError(f"Unknown constraint kind {self.kind}") + arity = 2 + marker_types = (gp_Pnt, gp_Pln) + kind = "PointInPlane" + param_type = Optional[float] # Distance from point to plane - # specify kinds of the simple constraint - if self.kind in CompoundConstraints: - kinds, converter = CompoundConstraints[self.kind] - params = converter(self.param,) - else: - kinds = (self.kind,) - params = (self.param,) + def validate_param(self, param: Optional[float]) -> None: + """Validate that the parameter is a numeric distance. + + :param param: Optional distance from point to plane + :raises ValueError: If parameter is not numeric + """ + if param is not None and not isinstance(param, (int, float)): + raise ValueError( + f"PointInPlane constraint parameter must be numeric, got {type(param)}" + ) - # builds the tuple and return - return tuple(zip(markers, kinds, params)) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + # apply sublocation + args = tuple( + arg.located(loc * arg.location()) + for arg, loc in zip(self.args, self.sublocs) + ) + return (getPnt(args[0]), getPln(args[1])) -# Cost functions of simple constraints -def Quaternion(R): + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for PointInPlane constraint. + + Minimizes the distance between a point and a plane. + If val is provided, enforces that distance to be val. + """ + m1, m2 = params.markers + if not isinstance(m1, gp_Pnt) or not isinstance(m2, gp_Pln): + raise TypeError( + "PointInPlane constraint requires a point and a plane as markers" + ) - m = ca.sumsqr(R) + T1_0, T2_0 = params.initial_translations + R1_0, R2_0 = params.initial_rotations + T1, T2 = params.translations + R1, R2 = params.rotations + val = 0 if params.param is None else params.param + scale = params.scale - u = 2 * R / (1 + m) - s = (1 - m) / (1 + m) + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) - return s, u + m2_dir = m2.Axis().Direction() + m2_pnt = m2.Axis().Location().Translated(val * gp_Vec(m2_dir)) + m2_dir_dm = ca.DM((m2_dir.X(), m2_dir.Y(), m2_dir.Z())) + m2_pnt_dm = ca.DM((m2_pnt.X(), m2_pnt.Y(), m2_pnt.Z())) -def Rotate(v, R): + plane_error = ( + ca.dot( + Rotate(m2_dir_dm, R2_0 + R2), + Transform(m2_pnt_dm, T2_0 + T2, R2_0 + R2) + - Transform(m1_dm, T1_0 + T1, R1_0 + R1), + ) + / scale + ) - s, u = Quaternion(R) + return plane_error ** 2 - return 2 * ca.dot(u, v) * u + (s ** 2 - ca.dot(u, u)) * v + 2 * s * ca.cross(u, v) +class PointOnLineConstraint(ConstraintSpec): + """Point on line constraint.""" -def Transform(v, T, R): + arity = 2 + marker_types = (gp_Pnt, gp_Lin) + kind = "PointOnLine" + param_type = Optional[float] # Distance from point to line - return Rotate(v, R) + T + def validate_param(self, param: Optional[float]) -> None: + """Validate that the parameter is a numeric distance. + + :param param: Optional distance from point to line + :raises ValueError: If parameter is not numeric + """ + if param is not None and not isinstance(param, (int, float)): + raise ValueError( + f"PointOnLine constraint parameter must be numeric, got {type(param)}" + ) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + # apply sublocation + args = tuple( + arg.located(loc * arg.location()) + for arg, loc in zip(self.args, self.sublocs) + ) -def point_cost( - problem, - m1: gp_Pnt, - m2: gp_Pnt, - T1_0, - R1_0, - T2_0, - R2_0, - T1, - R1, - T2, - R2, - val: Optional[float] = None, - scale: float = 1, -) -> float: + return (getPnt(args[0]), getLin(args[1])) - val = 0 if val is None else val + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for PointOnLine constraint. + + Minimizes the distance between a point and a line. + If val is provided, enforces that distance to be val. + """ + m1, m2 = params.markers + if not isinstance(m1, gp_Pnt) or not isinstance(m2, gp_Lin): + raise TypeError( + "PointOnLine constraint requires a point and a line as markers" + ) - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) - m2_dm = ca.DM((m2.X(), m2.Y(), m2.Z())) + T1_0, T2_0 = params.initial_translations + R1_0, R2_0 = params.initial_rotations + T1, T2 = params.translations + R1, R2 = params.rotations + val = 0 if params.param is None else params.param + scale = params.scale - dummy = ( - Transform(m1_dm, T1_0 + T1, R1_0 + R1) - Transform(m2_dm, T2_0 + T2, R2_0 + R2) - ) / scale + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) - if val == 0: - return ca.sumsqr(dummy) + m2_dir = m2.Direction() + m2_pnt = m2.Location() - return (ca.sumsqr(dummy) - (val / scale) ** 2) ** 2 + m2_dir_dm = ca.DM((m2_dir.X(), m2_dir.Y(), m2_dir.Z())) + m2_pnt_dm = ca.DM((m2_pnt.X(), m2_pnt.Y(), m2_pnt.Z())) + d = Transform(m1_dm, T1_0 + T1, R1_0 + R1) - Transform( + m2_pnt_dm, T2_0 + T2, R2_0 + R2 + ) + n = Rotate(m2_dir_dm, R2_0 + R2) -def axis_cost( - problem, - m1: gp_Dir, - m2: gp_Dir, - T1_0, - R1_0, - T2_0, - R2_0, - T1, - R1, - T2, - R2, - val: Optional[float] = None, - scale: float = 1, -) -> float: + line_error = (d - n * ca.dot(d, n)) / scale - val = pi if val is None else val + if val == 0: + return ca.sumsqr(line_error) - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) - m2_dm = ca.DM((m2.X(), m2.Y(), m2.Z())) + return (ca.sumsqr(line_error) - val) ** 2 - d1, d2 = (Rotate(m1_dm, R1_0 + R1), Rotate(m2_dm, R2_0 + R2)) - if val == 0: - dummy = d1 - d2 +class FixedConstraint(ConstraintSpec): + """Fixed constraint.""" - return ca.sumsqr(dummy) + arity = 1 + marker_types = (NoneType,) + kind = "Fixed" + param_type = None # No parameters allowed - elif val == pi: - dummy = d1 + d2 + def validate_param(self, param: None) -> None: + """Validate that no parameter is provided. + + :param param: Must be None + :raises ValueError: If parameter is not None + """ + if param is not None: + raise ValueError("Fixed constraint cannot have parameters") - return ca.sumsqr(dummy) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + return (None,) - dummy = ca.dot(d1, d2) - ca.cos(val) + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for Fixed constraint. + + This is a dummy cost function as fixed constraints are handled at the variable level. + Returns 0.0 to satisfy the type system. + """ + m1 = params.markers[0] + if m1 is not None: + raise TypeError("Fixed constraint should have no markers") + return 0.0 - return dummy ** 2 +class FixedPointConstraint(ConstraintSpec): + """Fixed point constraint.""" -def point_in_plane_cost( - problem, - m1: gp_Pnt, - m2: gp_Pln, - T1_0, - R1_0, - T2_0, - R2_0, - T1, - R1, - T2, - R2, - val: Optional[float] = None, - scale: float = 1, -) -> float: + arity = 1 + marker_types = (gp_Pnt,) + kind = "FixedPoint" + param_type = tuple[float, float, float] # 3D coordinates (x, y, z) - val = 0 if val is None else val + def validate_param(self, param: tuple[float, float, float]) -> None: + """Validate that the parameter is a 3D point. + + :param param: 3D coordinates (x, y, z) + :raises ValueError: If parameter is not a 3D point + """ + if ( + not isinstance(param, (tuple, list)) + or len(param) != 3 + or not all(isinstance(x, (int, float)) for x in param) + ): + raise ValueError( + "FixedPoint constraint parameter must be tuple/list of 3 numbers" + ) - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + # apply sublocation + args = tuple( + arg.located(loc * arg.location()) + for arg, loc in zip(self.args, self.sublocs) + ) - m2_dir = m2.Axis().Direction() - m2_pnt = m2.Axis().Location().Translated(val * gp_Vec(m2_dir)) + return (getPnt(args[0]),) - m2_dir_dm = ca.DM((m2_dir.X(), m2_dir.Y(), m2_dir.Z())) - m2_pnt_dm = ca.DM((m2_pnt.X(), m2_pnt.Y(), m2_pnt.Z())) + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for FixedPoint constraint. + + Fixes a point at a specific location in space. + """ + m1 = params.markers[0] + if not isinstance(m1, gp_Pnt): + raise TypeError("FixedPoint constraint requires a point as marker") - dummy = ( - ca.dot( - Rotate(m2_dir_dm, R2_0 + R2), - Transform(m2_pnt_dm, T2_0 + T2, R2_0 + R2) - - Transform(m1_dm, T1_0 + T1, R1_0 + R1), - ) - / scale - ) + T1_0 = params.initial_translations[0] + R1_0 = params.initial_rotations[0] + T1 = params.translations[0] + R1 = params.rotations[0] + val = params.param + scale = params.scale - return dummy ** 2 + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + point_error = (Transform(m1_dm, T1_0 + T1, R1_0 + R1) - ca.DM(val)) / scale -def point_on_line_cost( - problem, - m1: gp_Pnt, - m2: gp_Lin, - T1_0, - R1_0, - T2_0, - R2_0, - T1, - R1, - T2, - R2, - val: Optional[float] = None, - scale: float = 1, -) -> float: + return ca.sumsqr(point_error) - val = 0 if val is None else val - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) +class FixedAxisConstraint(ConstraintSpec): + """Fixed axis constraint.""" - m2_dir = m2.Direction() - m2_pnt = m2.Location() + arity = 1 + marker_types = (gp_Dir,) + kind = "FixedAxis" + param_type = tuple[float, float, float] # 3D direction vector (x, y, z) - m2_dir_dm = ca.DM((m2_dir.X(), m2_dir.Y(), m2_dir.Z())) - m2_pnt_dm = ca.DM((m2_pnt.X(), m2_pnt.Y(), m2_pnt.Z())) + def validate_param(self, param: tuple[float, float, float]) -> None: + """Validate that the parameter is a 3D direction vector. + + :param param: 3D direction vector (x, y, z) + :raises ValueError: If parameter is not a 3D vector + """ + if ( + not isinstance(param, (tuple, list)) + or len(param) != 3 + or not all(isinstance(x, (int, float)) for x in param) + ): + raise ValueError( + "FixedAxis constraint parameter must be tuple/list of 3 numbers" + ) - d = Transform(m1_dm, T1_0 + T1, R1_0 + R1) - Transform( - m2_pnt_dm, T2_0 + T2, R2_0 + R2 - ) - n = Rotate(m2_dir_dm, R2_0 + R2) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + # apply sublocation + args = tuple( + arg.located(loc * arg.location()) + for arg, loc in zip(self.args, self.sublocs) + ) - dummy = (d - n * ca.dot(d, n)) / scale + return (getDir(args[0]),) - if val == 0: - return ca.sumsqr(dummy) + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for FixedAxis constraint. + + Fixes an axis in a specific direction. + """ + m1 = params.markers[0] + if not isinstance(m1, gp_Dir): + raise TypeError("FixedAxis constraint requires a direction as marker") - return (ca.sumsqr(dummy) - val) ** 2 + R1_0 = params.initial_rotations[0] + R1 = params.rotations[0] + val = params.param + m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + m_val = ca.DM(val) / ca.norm_2(ca.DM(val)) -# dummy cost, fixed constraint is handled on variable level -def fixed_cost( - problem, - m1: Type[None], - T1_0, - R1_0, - T1, - R1, - val: Optional[Type[None]] = None, - scale: float = 1, -): + axis_error = Rotate(m1_dm, R1_0 + R1) - m_val - return None + return ca.sumsqr(axis_error) -def fixed_point_cost( - problem, - m1: gp_Pnt, - T1_0, - R1_0, - T1, - R1, - val: Tuple[float, float, float], - scale: float = 1, -): +class FixedRotationConstraint(ConstraintSpec): + """Fixed rotation constraint.""" - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) + arity = 1 + marker_types = (NoneType,) + kind = "FixedRotation" + param_type = tuple[float, float, float] # 3D rotation angles in degrees (x, y, z) - dummy = (Transform(m1_dm, T1_0 + T1, R1_0 + R1) - ca.DM(val)) / scale + def validate_param(self, param: tuple[float, float, float]) -> None: + """Validate that the parameter is a 3D rotation vector. + + :param param: 3D rotation angles in degrees (x, y, z) + :raises ValueError: If parameter is not a 3D vector + """ + if ( + not isinstance(param, (tuple, list)) + or len(param) != 3 + or not all(isinstance(x, (int, float)) for x in param) + ): + raise ValueError( + "FixedRotation constraint parameter must be tuple/list of 3 numbers" + ) - return ca.sumsqr(dummy) + def convert_param( + self, param: Optional[tuple[float, float, float]] + ) -> Optional[tuple[float, float, float]]: + """Convert rotation angles from degrees to radians. + + :param param: Rotation angles in degrees + :return: Rotation angles in radians + """ + if param is None: + return None + x, y, z = param + return (radians(x), radians(y), radians(z)) + def get_markers(self) -> tuple[ConstraintMarker, ...]: + """Get the geometric markers for this constraint.""" + return (None,) -def fixed_axis_cost( - problem, - m1: gp_Dir, - T1_0, - R1_0, - T1, - R1, - val: Tuple[float, float, float], - scale: float = 1, -): + @staticmethod + def cost(params: CostParams) -> float: + """Cost function for FixedRotation constraint. + + Fixes the rotation of an object using Euler angles. + """ + m1 = params.markers[0] + if m1 is not None: + raise TypeError("FixedRotation constraint should have no markers") - m1_dm = ca.DM((m1.X(), m1.Y(), m1.Z())) - m_val = ca.DM(val) / ca.norm_2(ca.DM(val)) + R1_0 = params.initial_rotations[0] + R1 = params.rotations[0] + val = (0.0, 0.0, 0.0) if params.param is None else tuple(params.param) - dummy = Rotate(m1_dm, R1_0 + R1) - m_val + q = gp_Quaternion() + q.SetEulerAngles(gp_Extrinsic_XYZ, *val) + q_dm = ca.DM((q.W(), q.X(), q.Y(), q.Z())) - return ca.sumsqr(dummy) + rotation_error = 1 - ca.dot(ca.vertcat(*Quaternion(R1_0 + R1)), q_dm) ** 2 + return rotation_error -def fixed_rotation_cost( - problem, - m1: Type[None], - T1_0, - R1_0, - T1, - R1, - val: Tuple[float, float, float], - scale: float = 1, -): - - q = gp_Quaternion() - q.SetEulerAngles(gp_Extrinsic_XYZ, *val) - q_dm = ca.DM((q.W(), q.X(), q.Y(), q.Z())) - - dummy = 1 - ca.dot(ca.vertcat(*Quaternion(R1_0 + R1)), q_dm) ** 2 - return dummy +class PlaneConstraint(CompoundConstraintSpec): + """Plane constraint (compound of Axis and Point constraints).""" + kind: ConstraintKind = "Plane" -# dictionary of individual constraint cost functions -costs: Dict[str, Callable[..., float]] = dict( - Point=point_cost, - Axis=axis_cost, - PointInPlane=point_in_plane_cost, - PointOnLine=point_on_line_cost, - Fixed=fixed_cost, - FixedPoint=fixed_point_cost, - FixedAxis=fixed_axis_cost, - FixedRotation=fixed_rotation_cost, -) + def expand(self) -> list[ConstraintSpec]: + """Expand into Axis and Point constraints.""" + # Create Axis constraint + axis_constraint = AxisConstraint( + objects=self.objects, args=self.args, sublocs=self.sublocs, param=self.param + ) -scaling: Dict[str, bool] = dict( - Point=True, - Axis=False, - PointInPlane=True, - PointOnLine=True, - Fixed=False, - FixedPoint=True, - FixedAxis=False, - FixedRotation=False, -) + # Create Point constraint + point_constraint = PointConstraint( + objects=self.objects, + args=self.args, + sublocs=self.sublocs, + param=0, # Distance between points should be 0 + ) -# Actual solver class + return [axis_constraint, point_constraint] +# Actual solver class class ConstraintSolver(object): - opti: ca.Opti - variables: List[Tuple[ca.MX, ca.MX]] - starting_points: List[Tuple[ca.MX, ca.MX]] - constraints: List[Tuple[Tuple[int, ...], Constraint]] - locked: List[int] + variables: list[tuple[ca.MX, ca.MX]] + initial_points: list[tuple[ca.MX, ca.MX]] + constraints: list[BaseConstraint] + locked: list[int] ne: int nc: int scale: float + object_indices: dict[str, int] def __init__( self, - entities: List[Location], - constraints: List[Tuple[Tuple[int, ...], Constraint]], - locked: List[int] = [], + entities: list[Location], + constraints: list[BaseConstraint], + object_indices: dict[str, int], + locked: list[int] = [], scale: float = 1, ): + """ + Initialize the constraint solver. + :param entities: list of locations for each entity + :param constraints: list of constraint specifications + :param object_indices: dictionary mapping object names to their indices + :param locked: list of indices of locked entities + :param scale: Scale factor for the optimization + """ self.scale = scale self.opti = opti = ca.Opti() + self.object_indices = object_indices self.variables = [ (scale * opti.variable(NDOF_V), opti.variable(NDOF_Q)) if i not in locked else (opti.parameter(NDOF_V), opti.parameter(NDOF_Q)) for i, _ in enumerate(entities) ] - self.start_points = [ + self.initial_points = [ (opti.parameter(NDOF_V), opti.parameter(NDOF_Q)) for _ in entities ] # initialize, add the unit quaternion constraints and handle locked for i, ((T, R), (T0, R0), loc) in enumerate( - zip(self.variables, self.start_points, entities) + zip(self.variables, self.initial_points, entities) ): - T0val, R0val = self._locToDOF6(loc) + T0val, R0val = loc_to_dof6(loc) opti.set_value(T0, T0val) opti.set_value(R0, R0val) @@ -625,22 +881,8 @@ def __init__( self.locked = locked self.nc = len(self.constraints) - @staticmethod - def _locToDOF6(loc: Location) -> DOF6: - - Tr = loc.wrapped.Transformation() - v = Tr.TranslationPart() - q = Tr.GetRotation() - - alpha_2 = (1 - q.W()) / (1 + q.W()) - a = (alpha_2 + 1) * q.X() / 2 - b = (alpha_2 + 1) * q.Y() / 2 - c = (alpha_2 + 1) * q.Z() / 2 - - return (v.X(), v.Y(), v.Z()), (a, b, c) - def _build_transform(self, T: ca.MX, R: ca.MX) -> gp_Trsf: - + """Build a transformation from translation and rotation vectors.""" opti = self.opti rv = gp_Trsf() @@ -657,49 +899,82 @@ def _build_transform(self, T: ca.MX, R: ca.MX) -> gp_Trsf: return rv - def solve(self, verbosity: int = 0) -> Tuple[List[Location], Dict[str, Any]]: + def solve(self, verbosity: int = 0) -> tuple[list[Location], dict[str, Any]]: + """ + Solve the constraints. + :param verbosity: Verbosity level for the solver + :return: tuple of (list of new locations, solver results) + """ suppress_banner = "yes" if verbosity == 0 else "no" opti = self.opti - - constraints = self.constraints variables = self.variables - start_points = self.start_points + initial_points = self.initial_points - # construct a penalty term + # Construct a penalty term to prevent large transformations penalty = 0.0 + for translation, rotation in variables: + penalty += ca.sumsqr(ca.vertcat(translation / self.scale, rotation)) - for T, R in variables: - penalty += ca.sumsqr(ca.vertcat(T / self.scale, R)) - - # construct the objective + # Initialize the objective function objective = 0.0 - for ks, (ms, kind, params) in constraints: - - # select the relevant variables and starting points - s_ks: List[ca.DM] = [] - v_ks: List[ca.MX] = [] - - for k in ks: - s_ks.extend(start_points[k]) - v_ks.extend(variables[k]) - - c = costs[kind]( - opti, - *ms, - *s_ks, - *v_ks, - params, - scale=self.scale if scaling[kind] else 1, + + # Expand all constraints (including compound ones) into simple constraints + expanded_constraints = [] + for constraint in self.constraints: + if isinstance(constraint, CompoundConstraintSpec): + expanded_constraints.extend(constraint.expand()) + elif isinstance(constraint, ConstraintSpec): + expanded_constraints.append(constraint) + else: + raise ValueError(f"Invalid constraint type: {type(constraint)}") + + # Process each constraint and add its cost to the objective + for constraint in expanded_constraints: + # Get indices of objects involved in the constraint + object_indices = [self.object_indices[obj] for obj in constraint.objects] + + # Get the markers and parameters from the constraint + markers = constraint.get_markers() + params = constraint.get_param() + + # Collect the relevant variables and initial points for each object + initial_translations: list[ca.DM] = [] + initial_rotations: list[ca.DM] = [] + current_translations: list[ca.MX] = [] + current_rotations: list[ca.MX] = [] + + for obj_idx in object_indices: + initial_translations.append(initial_points[obj_idx][0]) # Translation + initial_rotations.append(initial_points[obj_idx][1]) # Rotation + current_translations.append(variables[obj_idx][0]) # Translation + current_rotations.append(variables[obj_idx][1]) # Rotation + + # Compute constraint cost + constraint_cost = constraint.cost( + CostParams( + problem=opti, + markers=list(markers), + initial_translations=initial_translations, + initial_rotations=initial_rotations, + translations=current_translations, + rotations=current_rotations, + param=params, + scale=self.scale + if constraint.kind + in ["Point", "PointInPlane", "PointOnLine", "FixedPoint"] + else 1, + ) ) - if c is not None: - objective += c + if constraint_cost is not None: + objective += constraint_cost + # Add the penalty term to the objective and minimize opti.minimize(objective + 1e-16 * penalty) - # solve + # Configure and run the solver opti.solver( "ipopt", {"print_time": False}, @@ -722,9 +997,10 @@ def solve(self, verbosity: int = 0) -> Tuple[List[Location], Dict[str, Any]]: result = sol.stats() result["opti"] = opti # this might be removed in the future + # Convert the solution to locations locs = [ Location(self._build_transform(T + T0, R + R0)) - for (T, R), (T0, R0) in zip(variables, start_points) + for (T, R), (T0, R0) in zip(variables, initial_points) ] return locs, result diff --git a/tests/test_assembly.py b/tests/test_assembly.py index 0a46d414d..d0f53a1c9 100644 --- a/tests/test_assembly.py +++ b/tests/test_assembly.py @@ -37,6 +37,8 @@ from OCP.Quantity import Quantity_ColorRGBA, Quantity_TOC_RGB from OCP.TopAbs import TopAbs_ShapeEnum +from cadquery.occ_impl.solver import getPln + @pytest.fixture(scope="function") def tmpdir(tmp_path_factory): @@ -1529,13 +1531,15 @@ def test_PointInPlane_param(box_and_vertex, param0, param1): def test_constraint_getPln(): """ - Test that _getPln does the right thing with different arguments + Test that getPln does the right thing with different arguments """ ids = (0, 1) sublocs = (cq.Location(), cq.Location()) def make_constraint(shape0): - return cq.Constraint(ids, (shape0, shape0), sublocs, "PointInPlane", 0) + return cq.occ_impl.solver.PointInPlaneConstraint( + ids, (shape0, shape0), sublocs, 0 + ) def fail_this(shape0): with pytest.raises(ValueError): @@ -1543,7 +1547,7 @@ def fail_this(shape0): def resulting_pln(shape0): c0 = make_constraint(shape0) - return c0._getPln(c0.args[0]) + return getPln(c0.args[0]) def resulting_plane(shape0): p0 = resulting_pln(shape0) @@ -1663,13 +1667,12 @@ def test_infinite_face_constraint_PointInPlane(origin, normal): f0 = cq.Face.makePlane(length=None, width=None, basePnt=origin, dir=normal) - c0 = cq.assembly.Constraint( + c0 = cq.occ_impl.solver.PointInPlaneConstraint( ("point", "plane"), (cq.Vertex.makeVertex(10, 10, 10), f0), sublocs=(cq.Location(), cq.Location()), - kind="PointInPlane", ) - p0 = c0._getPln(c0.args[1]) # a gp_Pln + p0 = getPln(c0.args[1]) # a gp_Pln derived_origin = cq.Vector(p0.Location()) assert derived_origin == cq.Vector(origin) @@ -1726,7 +1729,7 @@ def test_constraint_validation(simple_assy2): simple_assy2.constrain("b1", "Fixed?") with pytest.raises(ValueError): - cq.assembly.Constraint((), (), (), "Fixed?") + cq.occ_impl.solver.FixedConstraint((), (), ()) def test_single_unary_constraint(simple_assy2):