From 6cce575dcb59424e5d87ed81264b298b9af8f0d1 Mon Sep 17 00:00:00 2001 From: Alex Soffronow-Pagonidis Date: Wed, 11 Mar 2026 14:07:12 +0100 Subject: [PATCH] add aggregate method translator --- .../ClickHouseServiceCollectionExtensions.cs | 1 + ...seAggregateMethodCallTranslatorProvider.cs | 18 ++ ...HouseQueryableAggregateMethodTranslator.cs | 183 ++++++++++++++++++ .../GroupByAggregateTests.cs | 168 ++++++++++++++++ 4 files changed, 370 insertions(+) create mode 100644 src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseAggregateMethodCallTranslatorProvider.cs create mode 100644 src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseQueryableAggregateMethodTranslator.cs create mode 100644 test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs diff --git a/src/EFCore.ClickHouse/Extensions/ClickHouseServiceCollectionExtensions.cs b/src/EFCore.ClickHouse/Extensions/ClickHouseServiceCollectionExtensions.cs index 29dd231..4d9528c 100644 --- a/src/EFCore.ClickHouse/Extensions/ClickHouseServiceCollectionExtensions.cs +++ b/src/EFCore.ClickHouse/Extensions/ClickHouseServiceCollectionExtensions.cs @@ -37,6 +37,7 @@ public static IServiceCollection AddEntityFrameworkClickHouse(this IServiceColle .TryAdd() .TryAdd() .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseAggregateMethodCallTranslatorProvider.cs b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 0000000..39e99f6 --- /dev/null +++ b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,18 @@ +using Microsoft.EntityFrameworkCore.Query; + +namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal; + +public class ClickHouseAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider +{ + public ClickHouseAggregateMethodCallTranslatorProvider( + RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + : base(dependencies) + { + var sqlExpressionFactory = dependencies.SqlExpressionFactory; + + AddTranslators( + [ + new ClickHouseQueryableAggregateMethodTranslator(sqlExpressionFactory), + ]); + } +} diff --git a/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseQueryableAggregateMethodTranslator.cs b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseQueryableAggregateMethodTranslator.cs new file mode 100644 index 0000000..1c39e47 --- /dev/null +++ b/src/EFCore.ClickHouse/Query/ExpressionTranslators/Internal/ClickHouseQueryableAggregateMethodTranslator.cs @@ -0,0 +1,183 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using System.Reflection; + +namespace ClickHouse.EntityFrameworkCore.Query.ExpressionTranslators.Internal; + +/// +/// Translates grouped LINQ aggregate methods (Count, Sum, Average, Min, Max) +/// into ClickHouse SQL aggregate function calls. +/// +/// Scalar aggregates (without GROUP BY) are handled by the base EF Core classes; +/// this translator is needed for grouped aggregates produced by +/// GroupBy().Select(g => g.Count()) and similar patterns. +/// +public class ClickHouseQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + public ClickHouseQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + public SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType != typeof(Queryable)) + return null; + + var methodInfo = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + + switch (methodInfo.Name) + { + case nameof(Queryable.Average) + when (QueryableMethods.IsAverageWithoutSelector(methodInfo) + || QueryableMethods.IsAverageWithSelector(methodInfo)) + && source.Selector is SqlExpression averageSqlExpression: + { + // ClickHouse avg() on integer columns returns 0 for empty groups; + // avgOrNull() returns NULL instead, matching LINQ/SQL Server semantics. + // Cast int/long to double first so avg doesn't do integer division. + var averageInputType = averageSqlExpression.Type; + if (averageInputType == typeof(int) || averageInputType == typeof(long)) + { + averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(averageSqlExpression, typeof(double))); + } + + averageSqlExpression = CombineTerms(source, averageSqlExpression); + + return _sqlExpressionFactory.Function( + "avgOrNull", + [averageSqlExpression], + nullable: true, + argumentsPropagateNullability: [false], + typeof(double)); + } + + case nameof(Queryable.Count) + when methodInfo == QueryableMethods.CountWithoutPredicate + || methodInfo == QueryableMethods.CountWithPredicate: + { + var countSqlExpression = (source.Selector as SqlExpression) + ?? _sqlExpressionFactory.Fragment("*"); + countSqlExpression = CombineTerms(source, countSqlExpression); + + return _sqlExpressionFactory.Function( + "COUNT", + [countSqlExpression], + nullable: false, + argumentsPropagateNullability: [false], + typeof(int)); + } + + case nameof(Queryable.LongCount) + when methodInfo == QueryableMethods.LongCountWithoutPredicate + || methodInfo == QueryableMethods.LongCountWithPredicate: + { + var longCountSqlExpression = (source.Selector as SqlExpression) + ?? _sqlExpressionFactory.Fragment("*"); + longCountSqlExpression = CombineTerms(source, longCountSqlExpression); + + return _sqlExpressionFactory.Function( + "COUNT", + [longCountSqlExpression], + nullable: false, + argumentsPropagateNullability: [false], + typeof(long)); + } + + case nameof(Queryable.Max) + when (methodInfo == QueryableMethods.MaxWithoutSelector + || methodInfo == QueryableMethods.MaxWithSelector) + && source.Selector is SqlExpression maxSqlExpression: + { + maxSqlExpression = CombineTerms(source, maxSqlExpression); + + return _sqlExpressionFactory.Function( + "MAX", + [maxSqlExpression], + nullable: true, + argumentsPropagateNullability: [false], + maxSqlExpression.Type, + maxSqlExpression.TypeMapping); + } + + case nameof(Queryable.Min) + when (methodInfo == QueryableMethods.MinWithoutSelector + || methodInfo == QueryableMethods.MinWithSelector) + && source.Selector is SqlExpression minSqlExpression: + { + minSqlExpression = CombineTerms(source, minSqlExpression); + + return _sqlExpressionFactory.Function( + "MIN", + [minSqlExpression], + nullable: true, + argumentsPropagateNullability: [false], + minSqlExpression.Type, + minSqlExpression.TypeMapping); + } + + case nameof(Queryable.Sum) + when (QueryableMethods.IsSumWithoutSelector(methodInfo) + || QueryableMethods.IsSumWithSelector(methodInfo)) + && source.Selector is SqlExpression sumSqlExpression: + { + sumSqlExpression = CombineTerms(source, sumSqlExpression); + + return _sqlExpressionFactory.Function( + "SUM", + [sumSqlExpression], + nullable: true, + argumentsPropagateNullability: [false], + sumSqlExpression.Type, + sumSqlExpression.TypeMapping); + } + } + + return null; + } + + /// + /// Wraps the aggregate operand to handle predicate filtering and DISTINCT. + /// + /// When a predicate is present (e.g. g.Count(x => x.IsActive)), the operand + /// is wrapped in CASE WHEN predicate THEN expr ELSE NULL END so that only + /// matching rows contribute to the aggregate. If the operand is * (a fragment), + /// it's replaced with the constant 1 since CASE WHEN ... THEN * END + /// isn't valid SQL. + /// + /// When DISTINCT is requested, the operand is wrapped in a + /// so the SQL generator emits COUNT(DISTINCT expr) etc. + /// + private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression) + { + if (enumerableExpression.Predicate != null) + { + if (sqlExpression is SqlFragmentExpression) + { + sqlExpression = _sqlExpressionFactory.Constant(1); + } + + sqlExpression = _sqlExpressionFactory.Case( + [new CaseWhenClause(enumerableExpression.Predicate, sqlExpression)], + elseResult: null); + } + + if (enumerableExpression.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return sqlExpression; + } +} diff --git a/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs b/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs new file mode 100644 index 0000000..c163ff1 --- /dev/null +++ b/test/EFCore.ClickHouse.Tests/GroupByAggregateTests.cs @@ -0,0 +1,168 @@ +using Microsoft.EntityFrameworkCore; +using Xunit; + +namespace EFCore.ClickHouse.Tests; + +public class GroupByAggregateTests : IClassFixture +{ + private readonly ClickHouseFixture _fixture; + + public GroupByAggregateTests(ClickHouseFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task GroupBy_Count_ReturnsCorrectCounts() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, Count = g.Count() }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // false group: Charlie, Eve, Hank, Jack = 4 + Assert.False(results[0].IsActive); + Assert.Equal(4, results[0].Count); + // true group: Alice, Bob, Diana, Frank, Grace, Ivy = 6 + Assert.True(results[1].IsActive); + Assert.Equal(6, results[1].Count); + } + + [Fact] + public async Task GroupBy_Sum_ReturnsCorrectSums() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, TotalAge = g.Sum(e => e.Age) }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // false: 35 + 22 + 27 + 29 = 113 + Assert.Equal(113, results[0].TotalAge); + // true: 30 + 25 + 28 + 40 + 33 + 31 = 187 + Assert.Equal(187, results[1].TotalAge); + } + + [Fact] + public async Task GroupBy_Average_ReturnsCorrectAverages() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, AvgAge = g.Average(e => e.Age) }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // false: 113 / 4 = 28.25 + Assert.Equal(28.25, results[0].AvgAge); + // true: 187 / 6 ≈ 31.1667 + Assert.Equal(31.1667, results[1].AvgAge, 3); + } + + [Fact] + public async Task GroupBy_MinMax_ReturnsCorrectValues() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, MinAge = g.Min(e => e.Age), MaxAge = g.Max(e => e.Age) }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // false: min=22(Eve), max=35(Charlie) + Assert.Equal(22, results[0].MinAge); + Assert.Equal(35, results[0].MaxAge); + // true: min=25(Bob), max=40(Frank) + Assert.Equal(25, results[1].MinAge); + Assert.Equal(40, results[1].MaxAge); + } + + [Fact] + public async Task GroupBy_Having_FiltersGroups() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + // Only groups where count > 4 + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Where(g => g.Count() > 4) + .Select(g => new { IsActive = g.Key, Count = g.Count() }) + .AsNoTracking() + .ToListAsync(); + + // Only active group has 6, inactive has 4 + Assert.Single(results); + Assert.True(results[0].IsActive); + Assert.Equal(6, results[0].Count); + } + + [Fact] + public async Task GroupBy_MultipleAggregates_ReturnsAll() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new + { + IsActive = g.Key, + Count = g.Count(), + TotalAge = g.Sum(e => e.Age), + MinAge = g.Min(e => e.Age), + MaxAge = g.Max(e => e.Age), + }) + .OrderBy(x => x.IsActive) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + + // Inactive group + Assert.Equal(4, results[0].Count); + Assert.Equal(113, results[0].TotalAge); + Assert.Equal(22, results[0].MinAge); + Assert.Equal(35, results[0].MaxAge); + + // Active group + Assert.Equal(6, results[1].Count); + Assert.Equal(187, results[1].TotalAge); + Assert.Equal(25, results[1].MinAge); + Assert.Equal(40, results[1].MaxAge); + } + + [Fact] + public async Task GroupBy_OrderByAggregate_Sorts() + { + await using var context = new TestDbContext(_fixture.ConnectionString); + + var results = await context.TestEntities + .GroupBy(e => e.IsActive) + .Select(g => new { IsActive = g.Key, Count = g.Count() }) + .OrderByDescending(x => x.Count) + .AsNoTracking() + .ToListAsync(); + + Assert.Equal(2, results.Count); + // Active group (6) should come first + Assert.True(results[0].IsActive); + Assert.Equal(6, results[0].Count); + // Inactive group (4) second + Assert.False(results[1].IsActive); + Assert.Equal(4, results[1].Count); + } +}