diff --git a/Postgrest/Linq/WhereExpressionVisitor.cs b/Postgrest/Linq/WhereExpressionVisitor.cs index 8b6b63c..57c5da9 100644 --- a/Postgrest/Linq/WhereExpressionVisitor.cs +++ b/Postgrest/Linq/WhereExpressionVisitor.cs @@ -26,74 +26,213 @@ internal class WhereExpressionVisitor : ExpressionVisitor /// /// An entry point that will be used to populate . + /// This method handles comparisons, logical operations, and simple arithmetic expressions in a Where clause. /// - /// Invoked like: - /// `Table<Movies>().Where(x => x.Name == "Top Gun").Get();` + /// Examples: + /// Table<Movies>().Where(x => x.Name == "Top Gun").Get(); + /// Table<Movies>().Where(x => x.Rating > 5 && x.Year == 1986).Get(); + /// Table<Movies>().Where(x => x.Rating >= maxRating - 1).Get(); /// - /// - /// - /// + /// The binary expression to process, such as a comparison (e.g., x.Name == "Top Gun") or logical operation (e.g., x.Rating > 5 && x.Year == 1986). + /// The processed expression, typically the input . + /// Thrown if the left side of the expression does not correspond to a property with a or . + /// Thrown if the right side of the expression cannot be evaluated to a constant value. + /// Thrown if the is not set after processing the expression. protected override Expression VisitBinary(BinaryExpression node) { var op = GetMappedOperator(node); - // In the event this is a nested expression (n.Name == "Example" || n.Id = 3) - switch (node.NodeType) + // Handle logical operations (e.g., x.Rating > 5 && x.Year == 1986) + if (IsLogicalOperation(node.NodeType)) { - case ExpressionType.And: - case ExpressionType.Or: - case ExpressionType.AndAlso: - case ExpressionType.OrElse: - var leftVisitor = new WhereExpressionVisitor(); - leftVisitor.Visit(node.Left); + var conditions = FlattenLogicalConditions(node, op); + Filter = new QueryFilter(op, conditions); + return node; + } + + // Handle simple comparisons (e.g., x.Name == "Top Gun" or x.Rating >= maxRating - 1) + var column = ExtractColumnName(node.Left); + var rightValue = EvaluateRightExpression(node.Right); + + // Define the filter for a simple comparison + Filter = new QueryFilter(column, op, rightValue); + return node; + } - var rightVisitor = new WhereExpressionVisitor(); - rightVisitor.Visit(node.Right); + /// + /// Flattens a tree of logical conditions (e.g., AND, OR) into a single list of conditions at the same level. + /// + /// The binary expression node representing a logical operation. + /// The operator (e.g., AND, OR) for the logical operation. + /// A list of filters representing all conditions at the same level. + private List FlattenLogicalConditions(BinaryExpression node, Operator op) + { + var conditions = new List(); + + // Recursively flatten the left and right sides + FlattenLogicalConditionsRecursive(node, op, conditions); - Filter = new QueryFilter(op, - new List { leftVisitor.Filter!, rightVisitor.Filter! }); + return conditions; + } - return node; + /// + /// Recursively flattens a tree of logical conditions into a list of filters. + /// + /// The current binary expression node. + /// The operator (e.g., AND, OR) for the logical operation. + /// The list to accumulate the flattened conditions. + private void FlattenLogicalConditionsRecursive(BinaryExpression node, Operator op, List conditions) + { + // If the node is a logical operation with the same operator, recurse into its children + if (IsLogicalOperation(node.NodeType) && GetMappedOperator(node) == op) + { + if (node.Left is BinaryExpression leftBinary) + { + FlattenLogicalConditionsRecursive(leftBinary, op, conditions); + } + else + { + conditions.Add(ProcessSubExpression(node.Left)); + } + + if (node.Right is BinaryExpression rightBinary) + { + FlattenLogicalConditionsRecursive(rightBinary, op, conditions); + } + else + { + conditions.Add(ProcessSubExpression(node.Right)); + } } + else + { + // If the node is not a logical operation (or has a different operator), process it as a single condition + conditions.Add(ProcessSubExpression(node)); + } + } - // Otherwise, the base case. + /// + /// Determines if the node type represents a logical operation (AND, OR). + /// + /// The type of the expression node. + /// True if the node type is a logical operation; otherwise, false. + private static bool IsLogicalOperation(ExpressionType nodeType) + { + return nodeType == ExpressionType.And || + nodeType == ExpressionType.Or || + nodeType == ExpressionType.AndAlso || + nodeType == ExpressionType.OrElse; + } - var left = Visit(node.Left); - var right = Visit(node.Right); + /// + /// Processes a subexpression and returns the resulting filter. + /// + /// The subexpression to process. + /// The filter generated by the subexpression. + /// Thrown if the subexpression does not produce a valid filter. + private IPostgrestQueryFilter ProcessSubExpression(Expression expression) + { + var visitor = new WhereExpressionVisitor(); + visitor.Visit(expression); + return visitor.Filter ?? throw new InvalidOperationException($"Subexpression '{expression}' did not produce a valid filter."); + } - string? column = null; + /// + /// Extracts the column name from the left side of a binary expression. + /// + /// The left side expression, expected to be a property access. + /// The column name corresponding to the property. + /// Thrown if the left side does not correspond to a property with a or . + private string ExtractColumnName(Expression left) + { if (left is MemberExpression leftMember) { - column = GetColumnFromMemberExpression(leftMember); - } //To handle properly if it's a Convert ExpressionType generally with nullable properties - else if (left is UnaryExpression leftUnary && leftUnary.NodeType == ExpressionType.Convert && - leftUnary.Operand is MemberExpression leftOperandMember) + return GetColumnFromMemberExpression(leftMember); + } + if (left is UnaryExpression leftUnary && leftUnary.NodeType == ExpressionType.Convert && + leftUnary.Operand is MemberExpression leftOperandMember) { - column = GetColumnFromMemberExpression(leftOperandMember); + return GetColumnFromMemberExpression(leftOperandMember); } - if (column == null) - throw new ArgumentException( - $"Left side of expression: '{node}' is expected to be property with a ColumnAttribute or PrimaryKeyAttribute"); + throw new ArgumentException( + $"Left side of expression: '{left}' is expected to be a property with a ColumnAttribute or PrimaryKeyAttribute"); + } + + /// + /// Evaluates the right side of a binary expression to produce a constant value, applying special handling for certain types. + /// + /// The right side expression to evaluate. + /// The evaluated value of the expression, formatted appropriately for use in a PostgREST query. + /// Thrown if the right side cannot be evaluated to a constant value. + private object EvaluateRightExpression(Expression right) + { + right = Visit(right); // Process the right expression + + object value = right switch + { + ConstantExpression constant => constant.Value, + MemberExpression member => EvaluateExpression(member), + NewExpression newExpr => EvaluateExpression(newExpr), + UnaryExpression unary => EvaluateExpression(unary), + BinaryExpression binary => EvaluateBinaryExpression(binary) ?? throw new NotSupportedException( + $"Binary expression '{binary}' on the right side is not supported. Only constant values or simple expressions are allowed."), + _ => throw new NotSupportedException( + $"Right side of expression: '{right}' is not supported. Expected a constant, member, new, unary, or simple binary expression.") + }; - if (right is ConstantExpression rightConstant) + return value switch { - HandleConstantExpression(column, op, rightConstant); + DateTime dateTime => dateTime, + DateTimeOffset dateTimeOffset => dateTimeOffset, + Guid guid => guid.ToString(), + Enum enumValue => enumValue, + _ => value + }; + } + + /// + /// Evaluates an expression to produce a constant value. + /// + /// The type of the expression to evaluate (e.g., MemberExpression, NewExpression, UnaryExpression). + /// The expression to evaluate. + /// The evaluated value of the expression. + /// Thrown if the expression cannot be evaluated. + private object EvaluateExpression(TExpression expression) where TExpression : Expression + { + try + { + var lambda = Expression.Lambda(expression); + var compiled = lambda.Compile(); + return compiled.DynamicInvoke(); } - else if (right is MemberExpression memberExpression) + catch (Exception ex) { - HandleMemberExpression(column, op, memberExpression); + throw new InvalidOperationException($"Failed to evaluate {typeof(TExpression).Name.ToLower()}: '{expression}'.", ex); } - else if (right is NewExpression newExpression) + } + + /// + /// Evaluates a binary expression to compute its constant value, if possible. + /// + /// The binary expression to evaluate (e.g., 'x - 5'). + /// The computed value of the expression as an object, or null if the expression cannot be evaluated. + /// + /// Returns null if the expression cannot be evaluated due to unresolved variables or invalid operations. + /// The calling code should handle the null return value appropriately. + /// + private object? EvaluateBinaryExpression(BinaryExpression binaryExpression) + { + try { - HandleNewExpression(column, op, newExpression); + var lambda = Expression.Lambda(binaryExpression); + var compiled = lambda.Compile(); + return compiled.DynamicInvoke(); } - else if (right is UnaryExpression unaryExpression) + catch (Exception) { - HandleUnaryExpression(column, op, unaryExpression); + return null; } - - return node; } /// @@ -135,100 +274,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node) return node; } - /// - /// A constant expression parser (i.e. x => x.Id == 5 <- where '5' is the constant) - /// - /// - /// - /// - private void HandleConstantExpression(string column, Operator op, ConstantExpression constantExpression) - { - if (constantExpression.Type.IsEnum) - { - var enumValue = constantExpression.Value; - Filter = new QueryFilter(column, op, enumValue); - } - else - { - Filter = new QueryFilter(column, op, constantExpression.Value); - } - } - - /// - /// A member expression parser (i.e. => x.Id == Example.Id <- where both `x.Id` and `Example.Id` are parsed as 'members') - /// - /// - /// - /// - private void HandleMemberExpression(string column, Operator op, MemberExpression memberExpression) - { - Filter = new QueryFilter(column, op, GetMemberExpressionValue(memberExpression)); - } - - /// - /// A unary expression parser (i.e. => x.Id == 1 <- where both `1` is considered unary) - /// - /// - /// - /// - private void HandleUnaryExpression(string column, Operator op, UnaryExpression unaryExpression) - { - if (unaryExpression.Operand is ConstantExpression constantExpression) - { - HandleConstantExpression(column, op, constantExpression); - } - else if (unaryExpression.Operand is MemberExpression memberExpression) - { - HandleMemberExpression(column, op, memberExpression); - } - else if (unaryExpression.Operand is NewExpression newExpression) - { - HandleNewExpression(column, op, newExpression); - } - } - - /// - /// An instantiated class parser (i.e. x => x.CreatedAt <= new DateTime(2022, 08, 20) <- where `new DateTime(...)` is an instantiated expression. - /// - /// - /// - /// - private void HandleNewExpression(string column, Operator op, NewExpression newExpression) - { - var argumentValues = new List(); - foreach (var argument in newExpression.Arguments) - { - var lambda = Expression.Lambda(argument); - var func = lambda.Compile(); - argumentValues.Add(func.DynamicInvoke()); - } - - var constructor = newExpression.Constructor; - var instance = constructor.Invoke(argumentValues.ToArray()); - - switch (instance) - { - case DateTime dateTime: - Filter = new QueryFilter(column, op, dateTime); - break; - case DateTimeOffset dateTimeOffset: - Filter = new QueryFilter(column, op, dateTimeOffset); - break; - case Guid guid: - Filter = new QueryFilter(column, op, guid.ToString()); - break; - default: - { - if (instance.GetType().IsEnum) - { - Filter = new QueryFilter(column, op, instance); - } - - break; - } - } - } - /// /// Gets a column name (postgrest) from a Member Expression (used on BaseModel) /// @@ -238,19 +283,21 @@ private string GetColumnFromMemberExpression(MemberExpression node) { var type = node.Member.ReflectedType; var prop = type?.GetProperty(node.Member.Name); - var attrs = prop?.GetCustomAttributes(true); + if (prop == null) + { + return node.Member.Name; + } - if (attrs == null) return node.Member.Name; + var columnAttr = prop.GetCustomAttribute(true); + if (columnAttr != null) + { + return columnAttr.ColumnName; + } - foreach (var attr in attrs) + var primaryKeyAttr = prop.GetCustomAttribute(true); + if (primaryKeyAttr != null) { - switch (attr) - { - case ColumnAttribute columnAttr: - return columnAttr.ColumnName; - case PrimaryKeyAttribute primaryKeyAttr: - return primaryKeyAttr.ColumnName; - } + return primaryKeyAttr.ColumnName; } return node.Member.Name;