diff --git a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeEvaluatableExpressionFilterPlugin.cs b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeEvaluatableExpressionFilterPlugin.cs index 6e837baf6..0de760716 100644 --- a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeEvaluatableExpressionFilterPlugin.cs +++ b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeEvaluatableExpressionFilterPlugin.cs @@ -8,12 +8,6 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.NodaTime.Query.Internal; /// public class NpgsqlNodaTimeEvaluatableExpressionFilterPlugin : IEvaluatableExpressionFilterPlugin { - private static readonly MethodInfo GetCurrentInstantMethod = - typeof(SystemClock).GetRuntimeMethod(nameof(SystemClock.GetCurrentInstant), [])!; - - private static readonly MemberInfo SystemClockInstanceMember = - typeof(SystemClock).GetMember(nameof(SystemClock.Instance)).FirstOrDefault()!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -24,11 +18,14 @@ public virtual bool IsEvaluatableExpression(Expression expression) { switch (expression) { - case MethodCallExpression methodCallExpression when methodCallExpression.Method == GetCurrentInstantMethod: + case MethodCallExpression methodCallExpression + when methodCallExpression.Method.DeclaringType == typeof(SystemClock) + && methodCallExpression.Method.Name == nameof(SystemClock.GetCurrentInstant): return false; case MemberExpression memberExpression: - if (memberExpression.Member == SystemClockInstanceMember) + if (memberExpression.Member.DeclaringType == typeof(SystemClock) + && memberExpression.Member.Name == nameof(SystemClock.Instance)) { return false; } diff --git a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMemberTranslatorPlugin.cs b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMemberTranslatorPlugin.cs index cda280630..50f44ff61 100644 --- a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMemberTranslatorPlugin.cs +++ b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMemberTranslatorPlugin.cs @@ -9,31 +9,22 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.NodaTime.Query.Internal; /// /// See: https://www.postgresql.org/docs/current/static/functions-datetime.html /// -public class NpgsqlNodaTimeMemberTranslatorPlugin : IMemberTranslatorPlugin +public class NpgsqlNodaTimeMemberTranslatorPlugin( + IRelationalTypeMappingSource typeMappingSource, + ISqlExpressionFactory sqlExpressionFactory) + : IMemberTranslatorPlugin { + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public NpgsqlNodaTimeMemberTranslatorPlugin( - IRelationalTypeMappingSource typeMappingSource, - ISqlExpressionFactory sqlExpressionFactory) - { - Translators = + public virtual IEnumerable Translators { get; } = [ new NpgsqlNodaTimeMemberTranslator(typeMappingSource, (NpgsqlSqlExpressionFactory)sqlExpressionFactory) ]; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual IEnumerable Translators { get; } } /// @@ -42,64 +33,22 @@ public NpgsqlNodaTimeMemberTranslatorPlugin( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlNodaTimeMemberTranslator : IMemberTranslator +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class NpgsqlNodaTimeMemberTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory) + : IMemberTranslator { - private static readonly MemberInfo SystemClock_Instance = - typeof(SystemClock).GetRuntimeProperty(nameof(SystemClock.Instance))!; - - private static readonly MemberInfo ZonedDateTime_LocalDateTime = - typeof(ZonedDateTime).GetRuntimeProperty(nameof(ZonedDateTime.LocalDateTime))!; - - private static readonly MemberInfo Interval_Start = - typeof(Interval).GetRuntimeProperty(nameof(Interval.Start))!; - - private static readonly MemberInfo Interval_End = - typeof(Interval).GetRuntimeProperty(nameof(Interval.End))!; - - private static readonly MemberInfo Interval_HasStart = - typeof(Interval).GetRuntimeProperty(nameof(Interval.HasStart))!; - - private static readonly MemberInfo Interval_HasEnd = - typeof(Interval).GetRuntimeProperty(nameof(Interval.HasEnd))!; - - private static readonly MemberInfo Interval_Duration = - typeof(Interval).GetRuntimeProperty(nameof(Interval.Duration))!; - - private static readonly MemberInfo DateInterval_Start = - typeof(DateInterval).GetRuntimeProperty(nameof(DateInterval.Start))!; - - private static readonly MemberInfo DateInterval_End = - typeof(DateInterval).GetRuntimeProperty(nameof(DateInterval.End))!; - - private static readonly MemberInfo DateInterval_Length = - typeof(DateInterval).GetRuntimeProperty(nameof(DateInterval.Length))!; - - private static readonly MemberInfo DateTimeZoneProviders_TzDb = - typeof(DateTimeZoneProviders).GetRuntimeProperty(nameof(DateTimeZoneProviders.Tzdb))!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly RelationalTypeMapping _dateTypeMapping; - private readonly RelationalTypeMapping _periodTypeMapping; - private readonly RelationalTypeMapping _localDateTimeTypeMapping; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlNodaTimeMemberTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - _dateTypeMapping = typeMappingSource.FindMapping(typeof(LocalDate))!; - _periodTypeMapping = typeMappingSource.FindMapping(typeof(Period))!; - _localDateTimeTypeMapping = typeMappingSource.FindMapping(typeof(LocalDateTime))!; - } - + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly RelationalTypeMapping _dateTypeMapping = typeMappingSource.FindMapping(typeof(LocalDate))!; + private readonly RelationalTypeMapping _periodTypeMapping = typeMappingSource.FindMapping(typeof(Period))!; + private readonly RelationalTypeMapping _localDateTimeTypeMapping = typeMappingSource.FindMapping(typeof(LocalDateTime))!; private static readonly bool[][] TrueArrays = [[], [true], [true, true]]; /// @@ -110,12 +59,12 @@ public NpgsqlNodaTimeMemberTranslator( IDiagnosticsLogger logger) { // This is necessary to allow translation of methods on SystemClock.Instance - if (member == SystemClock_Instance) + if (member.DeclaringType == typeof(SystemClock) && member.Name == nameof(SystemClock.Instance)) { return _sqlExpressionFactory.Constant(SystemClock.Instance); } - if (member == DateTimeZoneProviders_TzDb) + if (member.DeclaringType == typeof(DateTimeZoneProviders) && member.Name == nameof(DateTimeZoneProviders.Tzdb)) { return PendingDateTimeZoneProviderExpression.Instance; } @@ -181,44 +130,34 @@ SqlExpression TranslateDurationTotalMember(SqlExpression instance, double diviso private SqlExpression? TranslateInterval(SqlExpression instance, MemberInfo member) { - if (member == Interval_Start) - { - return Lower(); - } - - if (member == Interval_End) - { - return Upper(); - } - - if (member == Interval_HasStart) - { - return _sqlExpressionFactory.Not( - _sqlExpressionFactory.Function( - "lower_inf", - [instance], - nullable: true, - argumentsPropagateNullability: TrueArrays[1], - typeof(bool))); - } - - if (member == Interval_HasEnd) - { - return _sqlExpressionFactory.Not( - _sqlExpressionFactory.Function( - "upper_inf", - [instance], - nullable: true, - argumentsPropagateNullability: TrueArrays[1], - typeof(bool))); - } - - if (member == Interval_Duration) + return member.Name switch { - return _sqlExpressionFactory.Subtract(Upper(), Lower(), _typeMappingSource.FindMapping(typeof(Duration))); - } + nameof(Interval.Start) => Lower(), + nameof(Interval.End) => Upper(), + + nameof(Interval.HasStart) + => _sqlExpressionFactory.Not( + _sqlExpressionFactory.Function( + "lower_inf", + [instance], + nullable: true, + argumentsPropagateNullability: TrueArrays[1], + typeof(bool))), + + nameof(Interval.HasEnd) + => _sqlExpressionFactory.Not( + _sqlExpressionFactory.Function( + "upper_inf", + [instance], + nullable: true, + argumentsPropagateNullability: TrueArrays[1], + typeof(bool))), + + nameof(Interval.Duration) + => _sqlExpressionFactory.Subtract(Upper(), Lower(), _typeMappingSource.FindMapping(typeof(Duration))), - return null; + _ => null + }; SqlExpression Lower() => _sqlExpressionFactory.Function( @@ -244,28 +183,23 @@ SqlExpression Upper() // NodaTime DateInterval is inclusive on both ends. // PostgreSQL daterange is a discrete range type; this means it gets normalized to inclusive lower bound, exclusive upper bound. // So we can translate Start as-is, but need to subtract a day for End. - if (member == DateInterval_Start) + return member.Name switch { - return Lower(); - } + nameof(DateInterval.Start) => Lower(), - if (member == DateInterval_End) - { // PostgreSQL creates a result of type 'timestamp without time zone' when subtracting intervals from dates, so add a cast back // to date. - return _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Subtract( - Upper(), - _sqlExpressionFactory.Constant(Period.FromDays(1), _periodTypeMapping)), typeof(LocalDate), - _typeMappingSource.FindMapping(typeof(LocalDate))); - } + nameof(DateInterval.End) + => _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Subtract( + Upper(), + _sqlExpressionFactory.Constant(Period.FromDays(1), _periodTypeMapping)), typeof(LocalDate), + _typeMappingSource.FindMapping(typeof(LocalDate))), - if (member == DateInterval_Length) - { - return _sqlExpressionFactory.Subtract(Upper(), Lower()); - } + nameof(DateInterval.Length) => _sqlExpressionFactory.Subtract(Upper(), Lower()), - return null; + _ => null + }; SqlExpression Lower() => _sqlExpressionFactory.Function( @@ -378,7 +312,7 @@ private SqlExpression GetDatePartExpressionDouble( typeof(LocalDateTime), _localDateTimeTypeMapping); - return member == ZonedDateTime_LocalDateTime + return member.Name == nameof(ZonedDateTime.LocalDateTime) ? instance : TranslateDateTime(instance, member); } @@ -388,7 +322,7 @@ private SqlExpression GetDatePartExpressionDouble( // The same works also for the LocalDateTime member. instance = _sqlExpressionFactory.AtUtc(instance); - return member == ZonedDateTime_LocalDateTime + return member.Name == nameof(ZonedDateTime.LocalDateTime) ? instance : TranslateDateTime(instance, member); } diff --git a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMethodCallTranslatorPlugin.cs b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMethodCallTranslatorPlugin.cs index 4d9273a3d..8f6de4ee7 100644 --- a/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMethodCallTranslatorPlugin.cs +++ b/src/EFCore.PG.NodaTime/Query/Internal/NpgsqlNodaTimeMethodCallTranslatorPlugin.cs @@ -44,98 +44,16 @@ public NpgsqlNodaTimeMethodCallTranslatorPlugin( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlNodaTimeMethodCallTranslator : IMethodCallTranslator +public class NpgsqlNodaTimeMethodCallTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory) + : IMethodCallTranslator { - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - - private static readonly MethodInfo SystemClock_GetCurrentInstant = - typeof(SystemClock).GetRuntimeMethod(nameof(SystemClock.GetCurrentInstant), Type.EmptyTypes)!; - - private static readonly MethodInfo Instant_InUtc = - typeof(Instant).GetRuntimeMethod(nameof(Instant.InUtc), Type.EmptyTypes)!; - - private static readonly MethodInfo Instant_InZone = - typeof(Instant).GetRuntimeMethod(nameof(Instant.InZone), [typeof(DateTimeZone)])!; - - private static readonly MethodInfo Instant_ToDateTimeUtc = - typeof(Instant).GetRuntimeMethod(nameof(Instant.ToDateTimeUtc), Type.EmptyTypes)!; - - private static readonly MethodInfo Instant_Distance = - typeof(NpgsqlNodaTimeDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance), [typeof(DbFunctions), typeof(Instant), typeof(Instant)])!; - - private static readonly MethodInfo ZonedDateTime_ToInstant = - typeof(ZonedDateTime).GetRuntimeMethod(nameof(ZonedDateTime.ToInstant), Type.EmptyTypes)!; - - private static readonly MethodInfo ZonedDateTime_Distance = - typeof(NpgsqlNodaTimeDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance), - [typeof(DbFunctions), typeof(ZonedDateTime), typeof(ZonedDateTime)])!; - - private static readonly MethodInfo LocalDateTime_InZoneLeniently = - typeof(LocalDateTime).GetRuntimeMethod(nameof(LocalDateTime.InZoneLeniently), [typeof(DateTimeZone)])!; - - private static readonly MethodInfo LocalDateTime_Distance = - typeof(NpgsqlNodaTimeDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance), - [typeof(DbFunctions), typeof(LocalDateTime), typeof(LocalDateTime)])!; - - private static readonly MethodInfo LocalDate_Distance = - typeof(NpgsqlNodaTimeDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance), [typeof(DbFunctions), typeof(LocalDate), typeof(LocalDate)])!; - - private static readonly MethodInfo Period_FromYears = typeof(Period).GetRuntimeMethod(nameof(Period.FromYears), [typeof(int)])!; - - private static readonly MethodInfo Period_FromMonths = - typeof(Period).GetRuntimeMethod(nameof(Period.FromMonths), [typeof(int)])!; - - private static readonly MethodInfo Period_FromWeeks = typeof(Period).GetRuntimeMethod(nameof(Period.FromWeeks), [typeof(int)])!; - private static readonly MethodInfo Period_FromDays = typeof(Period).GetRuntimeMethod(nameof(Period.FromDays), [typeof(int)])!; - - private static readonly MethodInfo Period_FromHours = typeof(Period).GetRuntimeMethod( - nameof(Period.FromHours), [typeof(long)])!; - - private static readonly MethodInfo Period_FromMinutes = - typeof(Period).GetRuntimeMethod(nameof(Period.FromMinutes), [typeof(long)])!; - - private static readonly MethodInfo Period_FromSeconds = - typeof(Period).GetRuntimeMethod(nameof(Period.FromSeconds), [typeof(long)])!; - - private static readonly MethodInfo Interval_Contains - = typeof(Interval).GetRuntimeMethod(nameof(Interval.Contains), [typeof(Instant)])!; - - private static readonly MethodInfo DateInterval_Contains_LocalDate - = typeof(DateInterval).GetRuntimeMethod(nameof(DateInterval.Contains), [typeof(LocalDate)])!; - - private static readonly MethodInfo DateInterval_Contains_DateInterval - = typeof(DateInterval).GetRuntimeMethod(nameof(DateInterval.Contains), [typeof(DateInterval)])!; - - private static readonly MethodInfo DateInterval_Intersection - = typeof(DateInterval).GetRuntimeMethod(nameof(DateInterval.Intersection), [typeof(DateInterval)])!; - - private static readonly MethodInfo DateInterval_Union - = typeof(DateInterval).GetRuntimeMethod(nameof(DateInterval.Union), [typeof(DateInterval)])!; - - private static readonly MethodInfo IDateTimeZoneProvider_get_Item - = typeof(IDateTimeZoneProvider).GetRuntimeMethod("get_Item", [typeof(string)])!; + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; private static readonly bool[][] TrueArrays = [[], [true], [true, true]]; - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlNodaTimeMethodCallTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - } - #pragma warning disable EF1001 /// public virtual SqlExpression? Translate( @@ -158,7 +76,7 @@ public NpgsqlNodaTimeMethodCallTranslator( return translated; } - if (method == IDateTimeZoneProvider_get_Item && instance is PendingDateTimeZoneProviderExpression) + if (method.DeclaringType == typeof(IDateTimeZoneProvider) && method.Name == "get_Item" && instance is PendingDateTimeZoneProviderExpression) { // We're translating an expression such as 'DateTimeZoneProviders.Tzdb["Europe/Berlin"]'. // Note that the .NET type of that expression is DateTimeZone, but we just return the string ID for the time zone. @@ -172,111 +90,96 @@ public NpgsqlNodaTimeMethodCallTranslator( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) - { - if (method == SystemClock_GetCurrentInstant) + => method.Name switch { - return NpgsqlNodaTimeTypeMappingSourcePlugin.LegacyTimestampBehavior - ? _sqlExpressionFactory.AtTimeZone( - _sqlExpressionFactory.Function( + nameof(SystemClock.GetCurrentInstant) when method.DeclaringType == typeof(SystemClock) + => NpgsqlNodaTimeTypeMappingSourcePlugin.LegacyTimestampBehavior + ? _sqlExpressionFactory.AtTimeZone( + _sqlExpressionFactory.Function( + "NOW", + [], + nullable: false, + argumentsPropagateNullability: [], + method.ReturnType), + _sqlExpressionFactory.Constant("UTC"), + method.ReturnType) + : _sqlExpressionFactory.Function( "NOW", [], nullable: false, argumentsPropagateNullability: [], - method.ReturnType), - _sqlExpressionFactory.Constant("UTC"), - method.ReturnType) - : _sqlExpressionFactory.Function( - "NOW", - [], - nullable: false, - argumentsPropagateNullability: [], - method.ReturnType, - _typeMappingSource.FindMapping(typeof(Instant), "timestamp with time zone")); - } + method.ReturnType, + _typeMappingSource.FindMapping(typeof(Instant), "timestamp with time zone")), - if (method == Instant_InUtc) - { // Instant -> ZonedDateTime is a no-op (different types in .NET but both mapped to timestamptz in PG) - return instance; - } + nameof(Instant.InUtc) when method.DeclaringType == typeof(Instant) + => instance, - if (method == Instant_InZone) - { // When InZone is called, we have a mismatch: on the .NET NodaTime side, we have a ZonedDateTime; but on the PostgreSQL side, // the AT TIME ZONE expression returns a 'timestamp without time zone' (when applied to a 'timestamp with time zone', which is // what ZonedDateTime is mapped to). - return new PendingZonedDateTimeExpression(instance!, arguments[0]); - } + nameof(Instant.InZone) when method.DeclaringType == typeof(Instant) + => new PendingZonedDateTimeExpression(instance!, arguments[0]), - if (method == Instant_ToDateTimeUtc) - { - return _sqlExpressionFactory.Convert( - instance!, - typeof(DateTime), - _typeMappingSource.FindMapping(typeof(DateTime), "timestamp with time zone")); - } + nameof(Instant.ToDateTimeUtc) when method.DeclaringType == typeof(Instant) + => _sqlExpressionFactory.Convert( + instance!, + typeof(DateTime), + _typeMappingSource.FindMapping(typeof(DateTime), "timestamp with time zone")), - if (method == Instant_Distance) - { - return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); - } + nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance) + when method.DeclaringType == typeof(NpgsqlNodaTimeDbFunctionsExtensions) + => _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]), - return null; - } + _ => null + }; private SqlExpression? TranslateZonedDateTime( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) - { - if (method == ZonedDateTime_ToInstant) + => method.Name switch { // We get here with the expression localDateTime.InZoneLeniently(DateTimeZoneProviders.Tzdb["Europe/Berlin"]).ToInstant() - if (instance is PendingZonedDateTimeExpression pendingZonedDateTime) - { - return _sqlExpressionFactory.AtTimeZone( - pendingZonedDateTime.Operand, - pendingZonedDateTime.TimeZoneId, - typeof(Instant), - _typeMappingSource.FindMapping(typeof(Instant))); - } - - // Otherwise, ZonedDateTime -> ToInstant is a no-op (different types in .NET but both mapped to timestamptz in PG) - return instance; - } - - if (method == ZonedDateTime_Distance) - { - return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); - } - - return null; - } + nameof(ZonedDateTime.ToInstant) when method.DeclaringType == typeof(ZonedDateTime) + => instance is PendingZonedDateTimeExpression pendingZonedDateTime + ? _sqlExpressionFactory.AtTimeZone( + pendingZonedDateTime.Operand, + pendingZonedDateTime.TimeZoneId, + typeof(Instant), + _typeMappingSource.FindMapping(typeof(Instant))) + // Otherwise, ZonedDateTime -> ToInstant is a no-op (different types in .NET but both mapped to timestamptz in PG) + : instance, + + nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance) + when method.DeclaringType == typeof(NpgsqlNodaTimeDbFunctionsExtensions) + => _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]), + + _ => null + }; private SqlExpression? TranslateLocalDateTime( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) - { - if (method == LocalDateTime_InZoneLeniently) + => method.Name switch { - return new PendingZonedDateTimeExpression(instance!, arguments[0]); - } + nameof(LocalDateTime.InZoneLeniently) when method.DeclaringType == typeof(LocalDateTime) + => new PendingZonedDateTimeExpression(instance!, arguments[0]), - if (method == LocalDateTime_Distance) - { - return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); - } + nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance) + when method.DeclaringType == typeof(NpgsqlNodaTimeDbFunctionsExtensions) + => _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]), - return null; - } + _ => null + }; private SqlExpression? TranslateLocalDate( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) { - if (method == LocalDate_Distance) + if (method.DeclaringType == typeof(NpgsqlNodaTimeDbFunctionsExtensions) && method.Name == nameof(NpgsqlNodaTimeDbFunctionsExtensions.Distance)) { return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); } @@ -316,43 +219,18 @@ public NpgsqlNodaTimeMethodCallTranslator( return null; } - if (method == Period_FromYears) + return method.Name switch { - return IntervalPart("years", arguments[0]); - } - - if (method == Period_FromMonths) - { - return IntervalPart("months", arguments[0]); - } - - if (method == Period_FromWeeks) - { - return IntervalPart("weeks", arguments[0]); - } - - if (method == Period_FromDays) - { - return IntervalPart("days", arguments[0]); - } - - if (method == Period_FromHours) - { - return IntervalPartOverBigInt("hours", arguments[0]); - } - - if (method == Period_FromMinutes) - { - return IntervalPartOverBigInt("mins", arguments[0]); - } - - if (method == Period_FromSeconds) - { - return IntervalPart( - "secs", _sqlExpressionFactory.Convert(arguments[0], typeof(double), _typeMappingSource.FindMapping(typeof(double)))); - } - - return null; + nameof(Period.FromYears) => IntervalPart("years", arguments[0]), + nameof(Period.FromMonths) => IntervalPart("months", arguments[0]), + nameof(Period.FromWeeks) => IntervalPart("weeks", arguments[0]), + nameof(Period.FromDays) => IntervalPart("days", arguments[0]), + nameof(Period.FromHours) => IntervalPartOverBigInt("hours", arguments[0]), + nameof(Period.FromMinutes) => IntervalPartOverBigInt("mins", arguments[0]), + nameof(Period.FromSeconds) => IntervalPart( + "secs", _sqlExpressionFactory.Convert(arguments[0], typeof(double), _typeMappingSource.FindMapping(typeof(double)))), + _ => null + }; static PgFunctionExpression IntervalPart(string datePart, SqlExpression parameter) => PgFunctionExpression.CreateWithNamedArguments( @@ -390,7 +268,7 @@ PgFunctionExpression IntervalPartOverBigInt(string datePart, SqlExpression param IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method == Interval_Contains) + if (method.DeclaringType == typeof(Interval) && method.Name == nameof(Interval.Contains)) { return _sqlExpressionFactory.Contains(instance!, arguments[0]); } @@ -404,18 +282,17 @@ PgFunctionExpression IntervalPartOverBigInt(string datePart, SqlExpression param IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method == DateInterval_Contains_LocalDate - || method == DateInterval_Contains_DateInterval) + if (method.DeclaringType == typeof(DateInterval) && method.Name == nameof(DateInterval.Contains)) { return _sqlExpressionFactory.Contains(instance!, arguments[0]); } - if (method == DateInterval_Intersection) + if (method.DeclaringType == typeof(DateInterval) && method.Name == nameof(DateInterval.Intersection)) { return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.RangeIntersect, instance!, arguments[0]); } - if (method == DateInterval_Union) + if (method.DeclaringType == typeof(DateInterval) && method.Name == nameof(DateInterval.Union)) { return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.RangeUnion, instance!, arguments[0]); } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayMethodTranslator.cs index f4c56ba38..a957ae601 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayMethodTranslator.cs @@ -11,58 +11,11 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// /// https://www.postgresql.org/docs/current/static/functions-array.html /// -public class NpgsqlArrayMethodTranslator : IMethodCallTranslator +public class NpgsqlArrayMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory, NpgsqlJsonPocoTranslator jsonPocoTranslator) + : IMethodCallTranslator { - #region Methods - - // ReSharper disable InconsistentNaming - private static readonly MethodInfo Array_IndexOf1 = - typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(m => m is { Name: nameof(Array.IndexOf), IsGenericMethod: true } && m.GetParameters().Length == 2); - - private static readonly MethodInfo Array_IndexOf2 = - typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(m => m is { Name: nameof(Array.IndexOf), IsGenericMethod: true } && m.GetParameters().Length == 3); - - private static readonly MethodInfo Enumerable_ElementAt = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single( - m => m.Name == nameof(Enumerable.ElementAt) - && m.GetParameters().Length == 2 - && m.GetParameters()[1].ParameterType == typeof(int)); - - private static readonly MethodInfo Enumerable_SequenceEqual = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2); - - // TODO: Enumerable Append and Concat are only here because primitive collections aren't handled in ExecuteUpdate, - // https://github.com/dotnet/efcore/issues/32494 - private static readonly MethodInfo Enumerable_Append = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(m => m.Name == nameof(Enumerable.Append) && m.GetParameters().Length == 2); - - private static readonly MethodInfo Enumerable_Concat = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(m => m.Name == nameof(Enumerable.Concat) && m.GetParameters().Length == 2); - - // ReSharper restore InconsistentNaming - - #endregion Methods - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlArrayMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory, NpgsqlJsonPocoTranslator jsonPocoTranslator) - { - _sqlExpressionFactory = sqlExpressionFactory; - _jsonPocoTranslator = jsonPocoTranslator; - } + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator = jsonPocoTranslator; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -77,39 +30,44 @@ public NpgsqlArrayMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFacto IDiagnosticsLogger logger) { // During preprocessing, ArrayIndex and List[] get normalized to ElementAt; so we handle indexing into array/list here - if (method.IsClosedFormOf(Enumerable_ElementAt)) + if (method is { IsGenericMethod: true, Name: nameof(Enumerable.ElementAt) } + && method.DeclaringType == typeof(Enumerable) + && arguments is [var source, var index] + && index.Type == typeof(int)) { - return arguments[0].TypeMapping switch + return source.TypeMapping switch { // Indexing over bytea is special, we have to use function rather than subscript NpgsqlByteArrayTypeMapping => _sqlExpressionFactory.Function( "get_byte", - [arguments[0], arguments[1]], + [source, index], nullable: true, argumentsPropagateNullability: TrueArrays[2], typeof(byte)), NpgsqlArrayTypeMapping typeMapping => _sqlExpressionFactory.ArrayIndex( - arguments[0], - _sqlExpressionFactory.GenerateOneBasedIndexExpression(arguments[1]), + source, + _sqlExpressionFactory.GenerateOneBasedIndexExpression(index), nullable: true), // Try translating indexing inside JSON column // Note that Length over PG arrays (not within JSON) gets translated by QueryableMethodTranslatingEV, since arrays are primitive // collections - _ => _jsonPocoTranslator.TranslateMemberAccess(arguments[0], arguments[1], method.ReturnType) + _ => _jsonPocoTranslator.TranslateMemberAccess(source, index, method.ReturnType) }; } - if (method.IsClosedFormOf(Enumerable_SequenceEqual) - && arguments[0].Type.IsArrayOrGenericList() - && !IsMappedToNonArray(arguments[0]) - && arguments[1].Type.IsArrayOrGenericList() - && !IsMappedToNonArray(arguments[1])) + if (method is { IsGenericMethod: true, Name: nameof(Enumerable.SequenceEqual) } + && method.DeclaringType == typeof(Enumerable) + && arguments is [var first, var second] + && first.Type.IsArrayOrGenericList() + && !IsMappedToNonArray(first) + && second.Type.IsArrayOrGenericList() + && !IsMappedToNonArray(second)) { - return _sqlExpressionFactory.Equal(arguments[0], arguments[1]); + return _sqlExpressionFactory.Equal(first, second); } // Translate instance methods on List @@ -135,12 +93,11 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList) SqlExpression? TranslateCommon(SqlExpression arrayOrList, IReadOnlyList arguments) #pragma warning restore CS8321 { - if (method.IsClosedFormOf(Array_IndexOf1) - || method.Name == nameof(List.IndexOf) - && method.DeclaringType.IsGenericList() - && method.GetParameters().Length == 1) + if (arguments is [var searchItem] + && (method is { IsGenericMethod: true, Name: nameof(Array.IndexOf) } && method.DeclaringType == typeof(Array) + || method.Name == nameof(List.IndexOf) && method.DeclaringType.IsGenericList())) { - var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList); + var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(searchItem, arrayOrList); return _sqlExpressionFactory.Coalesce( _sqlExpressionFactory.Subtract( @@ -157,13 +114,12 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList) _sqlExpressionFactory.Constant(-1)); } - if (method.IsClosedFormOf(Array_IndexOf2) - || method.Name == nameof(List.IndexOf) - && method.DeclaringType.IsGenericList() - && method.GetParameters().Length == 2) + if (arguments is [var searchItem2, var startIndexArg] + && (method is { IsGenericMethod: true, Name: nameof(Array.IndexOf) } && method.DeclaringType == typeof(Array) + || method.Name == nameof(List<>.IndexOf) && method.DeclaringType.IsGenericList())) { - var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList); - var startIndex = _sqlExpressionFactory.GenerateOneBasedIndexExpression(arguments[1]); + var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(searchItem2, arrayOrList); + var startIndex = _sqlExpressionFactory.GenerateOneBasedIndexExpression(startIndexArg); return _sqlExpressionFactory.Coalesce( _sqlExpressionFactory.Subtract( @@ -182,9 +138,11 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList) // TODO: Enumerable Append and Concat are only here because primitive collections aren't handled in ExecuteUpdate, // https://github.com/dotnet/efcore/issues/32494 - if (method.IsClosedFormOf(Enumerable_Append)) + if (method is { IsGenericMethod: true, Name: nameof(Enumerable.Append) } + && method.DeclaringType == typeof(Enumerable) + && arguments is [var element]) { - var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList); + var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(element, arrayOrList); return _sqlExpressionFactory.Function( "array_append", @@ -195,15 +153,17 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList) arrayOrList.TypeMapping); } - if (method.IsClosedFormOf(Enumerable_Concat)) + if (method is { IsGenericMethod: true, Name: nameof(Enumerable.Concat) } + && method.DeclaringType == typeof(Enumerable) + && arguments is [var otherArray]) { - var inferredMapping = ExpressionExtensions.InferTypeMapping(arrayOrList, arguments[0]); + var inferredMapping = ExpressionExtensions.InferTypeMapping(arrayOrList, otherArray); return _sqlExpressionFactory.Function( "array_cat", [ _sqlExpressionFactory.ApplyTypeMapping(arrayOrList, inferredMapping), - _sqlExpressionFactory.ApplyTypeMapping(arguments[0], inferredMapping) + _sqlExpressionFactory.ApplyTypeMapping(otherArray, inferredMapping) ], nullable: true, TrueArrays[2], diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlBigIntegerMemberTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlBigIntegerMemberTranslator.cs index 6e9a3a5c6..b56ad89ab 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlBigIntegerMemberTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlBigIntegerMemberTranslator.cs @@ -8,25 +8,8 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlBigIntegerMemberTranslator : IMemberTranslator +public class NpgsqlBigIntegerMemberTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) : IMemberTranslator { - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - - private static readonly MemberInfo IsZero = typeof(BigInteger).GetProperty(nameof(BigInteger.IsZero))!; - private static readonly MemberInfo IsOne = typeof(BigInteger).GetProperty(nameof(BigInteger.IsOne))!; - private static readonly MemberInfo IsEven = typeof(BigInteger).GetProperty(nameof(BigInteger.IsEven))!; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlBigIntegerMemberTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// public virtual SqlExpression? Translate( SqlExpression? instance, @@ -34,26 +17,22 @@ public NpgsqlBigIntegerMemberTranslator(NpgsqlSqlExpressionFactory sqlExpression Type returnType, IDiagnosticsLogger logger) { - if (member.DeclaringType == typeof(BigInteger)) + if (member.DeclaringType != typeof(BigInteger)) { - if (member == IsZero) - { - return _sqlExpressionFactory.Equal(instance!, _sqlExpressionFactory.Constant(BigInteger.Zero)); - } - - if (member == IsOne) - { - return _sqlExpressionFactory.Equal(instance!, _sqlExpressionFactory.Constant(BigInteger.One)); - } - - if (member == IsEven) - { - return _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Modulo(instance!, _sqlExpressionFactory.Constant(new BigInteger(2))), - _sqlExpressionFactory.Constant(BigInteger.Zero)); - } + return null; } - return null; + return member.Name switch + { + nameof(BigInteger.IsZero) + => sqlExpressionFactory.Equal(instance!, sqlExpressionFactory.Constant(BigInteger.Zero)), + nameof(BigInteger.IsOne) + => sqlExpressionFactory.Equal(instance!, sqlExpressionFactory.Constant(BigInteger.One)), + nameof(BigInteger.IsEven) + => sqlExpressionFactory.Equal( + sqlExpressionFactory.Modulo(instance!, sqlExpressionFactory.Constant(new BigInteger(2))), + sqlExpressionFactory.Constant(BigInteger.Zero)), + _ => null + }; } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlConvertTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlConvertTranslator.cs index 84c27eb54..3c7d22e6d 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlConvertTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlConvertTranslator.cs @@ -3,56 +3,8 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// /// Translates methods defined on into PostgreSQL CAST expressions. /// -public class NpgsqlConvertTranslator : IMethodCallTranslator +public class NpgsqlConvertTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary TypeMapping = new() - { - [nameof(Convert.ToBoolean)] = "bool", - [nameof(Convert.ToByte)] = "smallint", - [nameof(Convert.ToDecimal)] = "numeric", - [nameof(Convert.ToDouble)] = "double precision", - [nameof(Convert.ToInt16)] = "smallint", - [nameof(Convert.ToInt32)] = "int", - [nameof(Convert.ToInt64)] = "bigint", - [nameof(Convert.ToString)] = "text" - }; - - private static readonly List SupportedTypes = - [ - typeof(bool), - typeof(byte), - typeof(decimal), - typeof(double), - typeof(float), - typeof(int), - typeof(long), - typeof(short), - typeof(string), - typeof(object) - ]; - - private static readonly List SupportedMethods - = TypeMapping.Keys - .SelectMany( - t => typeof(Convert).GetTypeInfo().GetDeclaredMethods(t) - .Where( - m => m.GetParameters().Length == 1 - && SupportedTypes.Contains(m.GetParameters().First().ParameterType))) - .ToList(); - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlConvertTranslator(ISqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -64,7 +16,35 @@ public NpgsqlConvertTranslator(ISqlExpressionFactory sqlExpressionFactory) MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => SupportedMethods.Contains(method) - ? _sqlExpressionFactory.Convert(arguments[0], method.ReturnType) - : null; + { + if (method.DeclaringType != typeof(Convert)) + { + return null; + } + + var isSupported = method.Name is nameof(Convert.ToBoolean) or nameof(Convert.ToByte) or nameof(Convert.ToDecimal) + or nameof(Convert.ToDouble) or nameof(Convert.ToInt16) or nameof(Convert.ToInt32) or nameof(Convert.ToInt64) + or nameof(Convert.ToString); + + if (!isSupported + || arguments is not [var convertArg] + || !IsSupportedType(convertArg.Type)) + { + return null; + } + + return sqlExpressionFactory.Convert(convertArg, method.ReturnType); + } + + private static bool IsSupportedType(Type type) + => type == typeof(bool) + || type == typeof(byte) + || type == typeof(decimal) + || type == typeof(double) + || type == typeof(float) + || type == typeof(int) + || type == typeof(long) + || type == typeof(short) + || type == typeof(string) + || type == typeof(object); } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDateTimeMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDateTimeMethodTranslator.cs index 2cb2a5d22..b67bc124a 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDateTimeMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlDateTimeMethodTranslator.cs @@ -9,119 +9,17 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlDateTimeMethodTranslator : IMethodCallTranslator +public class NpgsqlDateTimeMethodTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory) + : IMethodCallTranslator { - private static readonly Dictionary MethodInfoDatePartMapping = new() - { - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, "years" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, "months" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, "days" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, "hours" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, "mins" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, "secs" }, - //{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMilliseconds), new[] { typeof(double) })!, "milliseconds" }, - - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddYears), [typeof(int)])!, "years" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMonths), [typeof(int)])!, "months" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddDays), [typeof(double)])!, "days" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddHours), [typeof(double)])!, "hours" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMinutes), [typeof(double)])!, "mins" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), [typeof(double)])!, "secs" }, - //{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), new[] { typeof(double) })!, "milliseconds" } - - // DateOnly.AddDays, AddMonths and AddYears have a specialized translation, see below - { typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddHours), [typeof(int)])!, "hours" }, - { typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.AddMinutes), [typeof(int)])!, "mins" }, - }; - - // ReSharper disable InconsistentNaming - private static readonly MethodInfo DateTime_ToUniversalTime - = typeof(DateTime).GetRuntimeMethod(nameof(DateTime.ToUniversalTime), [])!; - - private static readonly MethodInfo DateTime_ToLocalTime - = typeof(DateTime).GetRuntimeMethod(nameof(DateTime.ToLocalTime), [])!; - - private static readonly MethodInfo DateTime_SpecifyKind - = typeof(DateTime).GetRuntimeMethod(nameof(DateTime.SpecifyKind), [typeof(DateTime), typeof(DateTimeKind)])!; - - private static readonly MethodInfo DateTime_Distance - = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.Distance), [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!; - - private static readonly MethodInfo DateOnly_FromDateTime - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.FromDateTime), [typeof(DateTime)])!; - - private static readonly MethodInfo DateOnly_ToDateTime - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.ToDateTime), [typeof(TimeOnly)])!; - - private static readonly MethodInfo DateOnly_Distance - = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.Distance), [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!; - - private static readonly MethodInfo DateOnly_AddDays - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), [typeof(int)])!; - - private static readonly MethodInfo DateOnly_AddMonths - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), [typeof(int)])!; - - private static readonly MethodInfo DateOnly_AddYears - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), [typeof(int)])!; - - private static readonly MethodInfo DateOnly_FromDayNumber - = typeof(DateOnly).GetRuntimeMethod( - nameof(DateOnly.FromDayNumber), [typeof(int)])!; - - private static readonly MethodInfo TimeOnly_FromDateTime - = typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.FromDateTime), [typeof(DateTime)])!; - - private static readonly MethodInfo TimeOnly_FromTimeSpan - = typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.FromTimeSpan), [typeof(TimeSpan)])!; - - private static readonly MethodInfo TimeOnly_ToTimeSpan - = typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.ToTimeSpan), Type.EmptyTypes)!; - - private static readonly MethodInfo TimeOnly_IsBetween - = typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.IsBetween), [typeof(TimeOnly), typeof(TimeOnly)])!; - - private static readonly MethodInfo TimeOnly_Add_TimeSpan - = typeof(TimeOnly).GetRuntimeMethod(nameof(TimeOnly.Add), [typeof(TimeSpan)])!; - - private static readonly MethodInfo TimeZoneInfo_ConvertTimeBySystemTimeZoneId_DateTime - = typeof(TimeZoneInfo).GetRuntimeMethod( - nameof(TimeZoneInfo.ConvertTimeBySystemTimeZoneId), [typeof(DateTime), typeof(string)])!; - - private static readonly MethodInfo TimeZoneInfo_ConvertTimeBySystemTimeZoneId_DateTimeOffset - = typeof(TimeZoneInfo).GetRuntimeMethod( - nameof(TimeZoneInfo.ConvertTimeBySystemTimeZoneId), [typeof(DateTimeOffset), typeof(string)])!; - - private static readonly MethodInfo TimeZoneInfo_ConvertTimeToUtc - = typeof(TimeZoneInfo).GetRuntimeMethod(nameof(TimeZoneInfo.ConvertTimeToUtc), [typeof(DateTime)])!; - // ReSharper restore InconsistentNaming - - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly RelationalTypeMapping _timestampMapping; - private readonly RelationalTypeMapping _timestampTzMapping; - private readonly RelationalTypeMapping _intervalMapping; - private readonly RelationalTypeMapping _textMapping; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlDateTimeMethodTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - _timestampMapping = typeMappingSource.FindMapping(typeof(DateTime), "timestamp without time zone")!; - _timestampTzMapping = typeMappingSource.FindMapping(typeof(DateTime), "timestamp with time zone")!; - _intervalMapping = typeMappingSource.FindMapping(typeof(TimeSpan), "interval")!; - _textMapping = typeMappingSource.FindMapping("text")!; - } + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly RelationalTypeMapping _timestampMapping = typeMappingSource.FindMapping(typeof(DateTime), "timestamp without time zone")!; + private readonly RelationalTypeMapping _timestampTzMapping = typeMappingSource.FindMapping(typeof(DateTime), "timestamp with time zone")!; + private readonly RelationalTypeMapping _intervalMapping = typeMappingSource.FindMapping(typeof(TimeSpan), "interval")!; + private readonly RelationalTypeMapping _textMapping = typeMappingSource.FindMapping("text")!; /// public virtual SqlExpression? Translate( @@ -139,11 +37,28 @@ public NpgsqlDateTimeMethodTranslator( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) - => instance is not null - && MethodInfoDatePartMapping.TryGetValue(method, out var datePart) - && CreateIntervalExpression(arguments[0], datePart) is SqlExpression interval - ? _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping) - : null; + { + if (instance is null + || (method.DeclaringType != typeof(DateTime) && method.DeclaringType != typeof(DateTimeOffset) && method.DeclaringType != typeof(TimeOnly))) + { + return null; + } + + var datePart = method.Name switch + { + nameof(DateTime.AddYears) => "years", + nameof(DateTime.AddMonths) => "months", + nameof(DateTime.AddDays) => "days", + nameof(DateTime.AddHours) => "hours", + nameof(DateTime.AddMinutes) => "mins", + nameof(DateTime.AddSeconds) => "secs", + _ => null + }; + + return datePart is not null && CreateIntervalExpression(arguments[0], datePart) is SqlExpression interval + ? _sqlExpressionFactory.Add(instance, interval, instance.TypeMapping) + : null; + } private SqlExpression? TranslateDateTime( SqlExpression? instance, @@ -152,7 +67,7 @@ public NpgsqlDateTimeMethodTranslator( { if (instance is null) { - if (method == DateTime_SpecifyKind) + if (method.DeclaringType == typeof(DateTime) && method.Name == nameof(DateTime.SpecifyKind)) { if (arguments[1] is not SqlConstantExpression { Value: DateTimeKind kind }) { @@ -181,19 +96,22 @@ public NpgsqlDateTimeMethodTranslator( } } - if (method == DateTime_Distance) + if (method.DeclaringType == typeof(NpgsqlDbFunctionsExtensions) + && method.Name == nameof(NpgsqlDbFunctionsExtensions.Distance) + && arguments is [_, var dateTime1, var dateTime2] + && dateTime1.Type == typeof(DateTime)) { - return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); + return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, dateTime1, dateTime2); } } else { - if (method == DateTime_ToUniversalTime) + if (method.DeclaringType == typeof(DateTime) && method.Name == nameof(DateTime.ToUniversalTime)) { return _sqlExpressionFactory.Convert(instance, method.ReturnType, _timestampTzMapping); } - if (method == DateTime_ToLocalTime) + if (method.DeclaringType == typeof(DateTime) && method.Name == nameof(DateTime.ToLocalTime)) { return _sqlExpressionFactory.Convert(instance, method.ReturnType, _timestampMapping); } @@ -209,7 +127,7 @@ public NpgsqlDateTimeMethodTranslator( { if (instance is null) { - if (method == DateOnly_FromDateTime) + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.FromDateTime)) { // Note: converting timestamptz to date performs a timezone conversion, which is not what .NET DateOnly.FromDateTime does. // So if our operand is a timestamptz, we first change the type to timestamp with AT TIME ZONE 'UTC' (returns the same value @@ -226,12 +144,15 @@ public NpgsqlDateTimeMethodTranslator( return _sqlExpressionFactory.Convert(dateTime, typeof(DateOnly), _typeMappingSource.FindMapping(typeof(DateOnly))); } - if (method == DateOnly_Distance) + if (method.DeclaringType == typeof(NpgsqlDbFunctionsExtensions) + && method.Name == nameof(NpgsqlDbFunctionsExtensions.Distance) + && arguments is [_, var dateOnly1, var dateOnly2] + && dateOnly1.Type == typeof(DateOnly)) { - return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, arguments[1], arguments[2]); + return _sqlExpressionFactory.MakePostgresBinary(PgExpressionType.Distance, dateOnly1, dateOnly2); } - if (method == DateOnly_FromDayNumber) + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.FromDayNumber)) { // We use fragment rather than a DateOnly constant, since 0001-01-01 gets rendered as -infinity by default. // TODO: Set the right type/type mapping after https://github.com/dotnet/efcore/pull/34995 is merged @@ -245,7 +166,7 @@ public NpgsqlDateTimeMethodTranslator( } else { - if (method == DateOnly_ToDateTime) + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.ToDateTime)) { return new SqlBinaryExpression( ExpressionType.Add, @@ -256,21 +177,21 @@ public NpgsqlDateTimeMethodTranslator( } // In PG, date + int = date (int interpreted as days) - if (method == DateOnly_AddDays) + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.AddDays)) { return _sqlExpressionFactory.Add(instance, arguments[0]); } // For months and years, date + interval yields a timestamp (since interval could have a time component), so we need to cast // the results back to date - if (method == DateOnly_AddMonths + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.AddMonths) && CreateIntervalExpression(arguments[0], "months") is SqlExpression interval1) { return _sqlExpressionFactory.Convert( _sqlExpressionFactory.Add(instance, interval1, instance.TypeMapping), typeof(DateOnly)); } - if (method == DateOnly_AddYears + if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.AddYears) && CreateIntervalExpression(arguments[0], "years") is SqlExpression interval2) { return _sqlExpressionFactory.Convert( @@ -286,7 +207,7 @@ public NpgsqlDateTimeMethodTranslator( MethodInfo method, IReadOnlyList arguments) { - if (method == TimeOnly_FromDateTime) + if (method.DeclaringType == typeof(TimeOnly) && method.Name == nameof(TimeOnly.FromDateTime)) { // Note: converting timestamptz to time performs a timezone conversion, which is not what .NET TimeOnly.FromDateTime does. // So if our operand is a timestamptz, we first change the type to timestamp with AT TIME ZONE 'UTC' (returns the same value @@ -306,28 +227,30 @@ public NpgsqlDateTimeMethodTranslator( _typeMappingSource.FindMapping(typeof(TimeOnly))); } - if (method == TimeOnly_FromTimeSpan) + if (method.DeclaringType == typeof(TimeOnly) && method.Name == nameof(TimeOnly.FromTimeSpan)) { return _sqlExpressionFactory.Convert(arguments[0], typeof(TimeOnly), _typeMappingSource.FindMapping(typeof(TimeOnly))); } if (instance is not null) { - if (method == TimeOnly_ToTimeSpan) + if (method.DeclaringType == typeof(TimeOnly) && method.Name == nameof(TimeOnly.ToTimeSpan)) { return _sqlExpressionFactory.Convert(instance, typeof(TimeSpan), _typeMappingSource.FindMapping(typeof(TimeSpan))); } - if (method == TimeOnly_IsBetween) + if (method.DeclaringType == typeof(TimeOnly) && method.Name == nameof(TimeOnly.IsBetween)) { return _sqlExpressionFactory.And( _sqlExpressionFactory.GreaterThanOrEqual(instance, arguments[0]), _sqlExpressionFactory.LessThan(instance, arguments[1])); } - if (method == TimeOnly_Add_TimeSpan) + if (method.DeclaringType == typeof(TimeOnly) && method.Name == nameof(TimeOnly.Add) + && arguments is [var timeSpan] + && timeSpan.Type == typeof(TimeSpan)) { - return _sqlExpressionFactory.Add(instance, arguments[0]); + return _sqlExpressionFactory.Add(instance, timeSpan); } } @@ -338,9 +261,11 @@ public NpgsqlDateTimeMethodTranslator( MethodInfo method, IReadOnlyList arguments) { - if (method == TimeZoneInfo_ConvertTimeBySystemTimeZoneId_DateTime) + if (method.DeclaringType == typeof(TimeZoneInfo) && method.Name == nameof(TimeZoneInfo.ConvertTimeBySystemTimeZoneId) + && arguments is [var convertDateTime, var timeZoneId] + && convertDateTime.Type == typeof(DateTime)) { - var typeMapping = arguments[0].TypeMapping; + var typeMapping = convertDateTime.TypeMapping; if (typeMapping is null || (typeMapping.StoreType != "timestamp with time zone" && typeMapping.StoreType != "timestamptz")) { @@ -348,12 +273,14 @@ public NpgsqlDateTimeMethodTranslator( "TimeZoneInfo.ConvertTimeBySystemTimeZoneId is only supported on columns with type 'timestamp with time zone'"); } - return _sqlExpressionFactory.AtTimeZone(arguments[0], arguments[1], typeof(DateTime), _timestampMapping); + return _sqlExpressionFactory.AtTimeZone(convertDateTime, timeZoneId, typeof(DateTime), _timestampMapping); } - if (method == TimeZoneInfo_ConvertTimeToUtc) + if (method.DeclaringType == typeof(TimeZoneInfo) && method.Name == nameof(TimeZoneInfo.ConvertTimeToUtc) + && arguments is [var utcDateTime] + && utcDateTime.Type == typeof(DateTime)) { - var typeMapping = arguments[0].TypeMapping; + var typeMapping = utcDateTime.TypeMapping; if (typeMapping is null || (typeMapping.StoreType != "timestamp without time zone" && typeMapping.StoreType != "timestamp")) { @@ -361,7 +288,7 @@ public NpgsqlDateTimeMethodTranslator( "TimeZoneInfo.ConvertTimeToUtc) is only supported on columns with type 'timestamp without time zone'"); } - return _sqlExpressionFactory.Convert(arguments[0], arguments[0].Type, _timestampTzMapping); + return _sqlExpressionFactory.Convert(utcDateTime, utcDateTime.Type, _timestampTzMapping); } return null; diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFullTextSearchMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFullTextSearchMethodTranslator.cs index e4db8fba8..13005876b 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFullTextSearchMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFullTextSearchMethodTranslator.cs @@ -6,41 +6,19 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// /// Provides translations for PostgreSQL full-text search methods. /// -public class NpgsqlFullTextSearchMethodTranslator : IMethodCallTranslator +public class NpgsqlFullTextSearchMethodTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IModel model) + : IMethodCallTranslator { - private static readonly MethodInfo TsQueryParse = - typeof(NpgsqlTsQuery).GetMethod(nameof(NpgsqlTsQuery.Parse), BindingFlags.Public | BindingFlags.Static)!; - - private static readonly MethodInfo TsVectorParse = - typeof(NpgsqlTsVector).GetMethod(nameof(NpgsqlTsVector.Parse), BindingFlags.Public | BindingFlags.Static)!; - - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly IModel _model; - private readonly RelationalTypeMapping _tsQueryMapping; - private readonly RelationalTypeMapping _tsVectorMapping; - private readonly RelationalTypeMapping _regconfigMapping; - private readonly RelationalTypeMapping _regdictionaryMapping; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlFullTextSearchMethodTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - _model = model; - _tsQueryMapping = typeMappingSource.FindMapping("tsquery")!; - _tsVectorMapping = typeMappingSource.FindMapping("tsvector")!; - _regconfigMapping = typeMappingSource.FindMapping("regconfig")!; - _regdictionaryMapping = typeMappingSource.FindMapping("regdictionary")!; - } + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly IModel _model = model; + private readonly RelationalTypeMapping _tsQueryMapping = typeMappingSource.FindMapping("tsquery")!; + private readonly RelationalTypeMapping _tsVectorMapping = typeMappingSource.FindMapping("tsvector")!; + private readonly RelationalTypeMapping _regconfigMapping = typeMappingSource.FindMapping("regconfig")!; + private readonly RelationalTypeMapping _regdictionaryMapping = typeMappingSource.FindMapping("regdictionary")!; /// public virtual SqlExpression? Translate( @@ -49,7 +27,8 @@ public NpgsqlFullTextSearchMethodTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method == TsQueryParse || method == TsVectorParse) + if ((method.DeclaringType == typeof(NpgsqlTsQuery) || method.DeclaringType == typeof(NpgsqlTsVector)) + && method.Name == nameof(NpgsqlTsQuery.Parse)) { return _sqlExpressionFactory.Convert(arguments[0], method.ReturnType); } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFuzzyStringMatchMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFuzzyStringMatchMethodTranslator.cs index b0c6aa651..d66d596c3 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFuzzyStringMatchMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlFuzzyStringMatchMethodTranslator.cs @@ -6,35 +6,8 @@ /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlFuzzyStringMatchMethodTranslator : IMethodCallTranslator +public class NpgsqlFuzzyStringMatchMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary Functions = new() - { - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchSoundex), typeof(DbFunctions), typeof(string))] - = "soundex", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDifference), typeof(DbFunctions), typeof(string), typeof(string))] - = "difference", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshtein), typeof(DbFunctions), typeof(string), typeof(string))] - = "levenshtein", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshtein), typeof(DbFunctions), typeof(string), typeof(string), typeof(int), typeof(int), typeof(int))] - = "levenshtein", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshteinLessEqual), typeof(DbFunctions), typeof(string), typeof(string), typeof(int))] - = "levenshtein_less_equal", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshteinLessEqual), typeof(DbFunctions), typeof(string), typeof(string), typeof(int), typeof(int), typeof(int), typeof(int))] - = "levenshtein_less_equal", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchMetaphone), typeof(DbFunctions), typeof(string), typeof(int))] - = "metaphone", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDoubleMetaphone), typeof(DbFunctions), typeof(string))] - = "dmetaphone", - [GetRuntimeMethod(nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDoubleMetaphoneAlt), typeof(DbFunctions), typeof(string))] - = "dmetaphone_alt" - }; - - private static MethodInfo GetRuntimeMethod(string name, params Type[] parameters) - => typeof(NpgsqlFuzzyStringMatchDbFunctionsExtensions).GetRuntimeMethod(name, parameters)!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private static readonly bool[][] TrueArrays = [ [], @@ -46,29 +19,37 @@ private static MethodInfo GetRuntimeMethod(string name, params Type[] parameters [true, true, true, true, true, true] ]; - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlFuzzyStringMatchMethodTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// public virtual SqlExpression? Translate( SqlExpression? instance, MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => Functions.TryGetValue(method, out var function) - ? _sqlExpressionFactory.Function( + { + if (method.DeclaringType != typeof(NpgsqlFuzzyStringMatchDbFunctionsExtensions)) + { + return null; + } + + var function = method.Name switch + { + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchSoundex) => "soundex", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDifference) => "difference", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshtein) => "levenshtein", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchLevenshteinLessEqual) => "levenshtein_less_equal", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchMetaphone) => "metaphone", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDoubleMetaphone) => "dmetaphone", + nameof(NpgsqlFuzzyStringMatchDbFunctionsExtensions.FuzzyStringMatchDoubleMetaphoneAlt) => "dmetaphone_alt", + _ => null + }; + + return function is null + ? null + : sqlExpressionFactory.Function( function, arguments.Skip(1), nullable: true, argumentsPropagateNullability: TrueArrays[arguments.Count - 1], - method.ReturnType) - : null; + method.ReturnType); + } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonDomTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonDomTranslator.cs index ec838d19d..9812745d9 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonDomTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonDomTranslator.cs @@ -11,56 +11,16 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlJsonDomTranslator : IMemberTranslator, IMethodCallTranslator +public class NpgsqlJsonDomTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IModel model) + : IMemberTranslator, IMethodCallTranslator { - private static readonly MemberInfo RootElement = typeof(JsonDocument).GetProperty(nameof(JsonDocument.RootElement))!; - - private static readonly MethodInfo GetProperty = typeof(JsonElement).GetRuntimeMethod( - nameof(JsonElement.GetProperty), [typeof(string)])!; - - private static readonly MethodInfo GetArrayLength = typeof(JsonElement).GetRuntimeMethod( - nameof(JsonElement.GetArrayLength), Type.EmptyTypes)!; - - private static readonly MethodInfo ArrayIndexer = typeof(JsonElement).GetProperties() - .Single(p => p.GetIndexParameters().Length == 1 && p.GetIndexParameters()[0].ParameterType == typeof(int)) - .GetMethod!; - - private static readonly string[] GetMethods = - [ - nameof(JsonElement.GetBoolean), - nameof(JsonElement.GetDateTime), - nameof(JsonElement.GetDateTimeOffset), - nameof(JsonElement.GetDecimal), - nameof(JsonElement.GetDouble), - nameof(JsonElement.GetGuid), - nameof(JsonElement.GetInt16), - nameof(JsonElement.GetInt32), - nameof(JsonElement.GetInt64), - nameof(JsonElement.GetSingle), - nameof(JsonElement.GetString) - ]; - - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly RelationalTypeMapping _stringTypeMapping; - private readonly IModel _model; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlJsonDomTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - _model = model; - _stringTypeMapping = typeMappingSource.FindMapping(typeof(string), model)!; - } + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly RelationalTypeMapping _stringTypeMapping = typeMappingSource.FindMapping(typeof(string), model)!; + private readonly IModel _model = model; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -79,7 +39,8 @@ public NpgsqlJsonDomTranslator( return null; } - if (member == RootElement && instance is ColumnExpression { TypeMapping: NpgsqlJsonTypeMapping } column) + if (member.Name == nameof(JsonDocument.RootElement) + && instance is ColumnExpression { TypeMapping: NpgsqlJsonTypeMapping } column) { // Simply get rid of the RootElement member access return column; @@ -114,14 +75,25 @@ public NpgsqlJsonDomTranslator( columnExpression, returnsText: false, typeof(string), mapping) : instance; - if (method == GetProperty || method == ArrayIndexer) + if (method.Name is nameof(JsonElement.GetProperty) or "get_Item" && arguments is [_]) { return instance is PgJsonTraversalExpression prevPathTraversal ? prevPathTraversal.Append(_sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[0])) : null; } - if (GetMethods.Contains(method.Name) && arguments.Count == 0 && instance is PgJsonTraversalExpression traversal) + if (method.Name is nameof(JsonElement.GetBoolean) + or nameof(JsonElement.GetDateTime) + or nameof(JsonElement.GetDateTimeOffset) + or nameof(JsonElement.GetDecimal) + or nameof(JsonElement.GetDouble) + or nameof(JsonElement.GetGuid) + or nameof(JsonElement.GetInt16) + or nameof(JsonElement.GetInt32) + or nameof(JsonElement.GetInt64) + or nameof(JsonElement.GetSingle) + or nameof(JsonElement.GetString) + && arguments.Count == 0 && instance is PgJsonTraversalExpression traversal) { var traversalToText = new PgJsonTraversalExpression( traversal.Expression, @@ -137,7 +109,7 @@ public NpgsqlJsonDomTranslator( traversalToText, method.ReturnType, _typeMappingSource.FindMapping(method.ReturnType, _model)); } - if (method == GetArrayLength) + if (method.Name == nameof(JsonElement.GetArrayLength) && arguments is []) { return _sqlExpressionFactory.Function( mapping.IsJsonb ? "jsonb_array_length" : "json_array_length", diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonPocoTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonPocoTranslator.cs index fbaf19142..d89396d83 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonPocoTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlJsonPocoTranslator.cs @@ -11,33 +11,16 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlJsonPocoTranslator : IMemberTranslator, IMethodCallTranslator +public class NpgsqlJsonPocoTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IModel model) + : IMemberTranslator, IMethodCallTranslator { - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly RelationalTypeMapping _stringTypeMapping; - private readonly IModel _model; - - private static readonly MethodInfo Enumerable_AnyWithoutPredicate = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1); - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlJsonPocoTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _typeMappingSource = typeMappingSource; - _sqlExpressionFactory = sqlExpressionFactory; - _model = model; - _stringTypeMapping = typeMappingSource.FindMapping(typeof(string), model)!; - } + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly RelationalTypeMapping _stringTypeMapping = typeMappingSource.FindMapping(typeof(string), model)!; + private readonly IModel _model = model; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -50,16 +33,13 @@ public NpgsqlJsonPocoTranslator( MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - { // Predicate-less Any - translate to a simple length check. - if (method.IsClosedFormOf(Enumerable_AnyWithoutPredicate) - && TranslateArrayLength(arguments[0]) is SqlExpression arrayLengthTranslation) - { - return _sqlExpressionFactory.GreaterThan(arrayLengthTranslation, _sqlExpressionFactory.Constant(0)); - } - - return null; - } + => method is { IsGenericMethod: true, Name: nameof(Enumerable.Any) } + && method.DeclaringType == typeof(Enumerable) + && arguments is [var source] + && TranslateArrayLength(source) is SqlExpression arrayLengthTranslation + ? _sqlExpressionFactory.GreaterThan(arrayLengthTranslation, _sqlExpressionFactory.Constant(0)) + : null; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlLikeTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlLikeTranslator.cs index 6aba0252d..26512ba50 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlLikeTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlLikeTranslator.cs @@ -3,41 +3,8 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// /// Translates methods into PostgreSQL LIKE expressions. /// -public class NpgsqlLikeTranslator : IMethodCallTranslator +public class NpgsqlLikeTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo Like = - typeof(DbFunctionsExtensions).GetRuntimeMethod( - nameof(DbFunctionsExtensions.Like), - [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private static readonly MethodInfo LikeWithEscape = - typeof(DbFunctionsExtensions).GetRuntimeMethod( - nameof(DbFunctionsExtensions.Like), - [typeof(DbFunctions), typeof(string), typeof(string), typeof(string)])!; - - // ReSharper disable once InconsistentNaming - private static readonly MethodInfo ILike = - typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.ILike), - [typeof(DbFunctions), typeof(string), typeof(string)])!; - - // ReSharper disable once InconsistentNaming - private static readonly MethodInfo ILikeWithEscape = - typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.ILike), - [typeof(DbFunctions), typeof(string), typeof(string), typeof(string)])!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - - /// - /// Initializes a new instance of the class. - /// - /// The SQL expression factory to use when generating expressions.. - public NpgsqlLikeTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// public virtual SqlExpression? Translate( SqlExpression? instance, @@ -45,22 +12,14 @@ public NpgsqlLikeTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method == LikeWithEscape) - { - return _sqlExpressionFactory.Like(arguments[1], arguments[2], arguments[3]); - } - - if (method == ILikeWithEscape) - { - return _sqlExpressionFactory.ILike(arguments[1], arguments[2], arguments[3]); - } - bool sensitive; - if (method == Like) + if (method.DeclaringType == typeof(DbFunctionsExtensions) + && method.Name == nameof(DbFunctionsExtensions.Like)) { sensitive = true; } - else if (method == ILike) + else if (method.DeclaringType == typeof(NpgsqlDbFunctionsExtensions) + && method.Name == nameof(NpgsqlDbFunctionsExtensions.ILike)) { sensitive = false; } @@ -69,6 +28,14 @@ public NpgsqlLikeTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) return null; } + // The 4-argument overloads have an escape char parameter + if (arguments is [_, _, _, var escapeChar]) + { + return sensitive + ? sqlExpressionFactory.Like(arguments[1], arguments[2], escapeChar) + : sqlExpressionFactory.ILike(arguments[1], arguments[2], escapeChar); + } + // PostgreSQL has backslash as the default LIKE escape character, but EF Core expects // no escape character unless explicitly requested (https://github.com/aspnet/EntityFramework/issues/8696). @@ -82,12 +49,12 @@ public NpgsqlLikeTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) && !patternValue.Contains('\\')) { return sensitive - ? _sqlExpressionFactory.Like(match, pattern) - : _sqlExpressionFactory.ILike(match, pattern); + ? sqlExpressionFactory.Like(match, pattern) + : sqlExpressionFactory.ILike(match, pattern); } return sensitive - ? _sqlExpressionFactory.Like(match, pattern, _sqlExpressionFactory.Constant(string.Empty)) - : _sqlExpressionFactory.ILike(match, pattern, _sqlExpressionFactory.Constant(string.Empty)); + ? sqlExpressionFactory.Like(match, pattern, sqlExpressionFactory.Constant(string.Empty)) + : sqlExpressionFactory.ILike(match, pattern, sqlExpressionFactory.Constant(string.Empty)); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMathTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMathTranslator.cs index 07919734b..c8bab43fe 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMathTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMathTranslator.cs @@ -12,143 +12,13 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// - https://www.postgresql.org/docs/current/static/functions-math.html /// - https://www.postgresql.org/docs/current/static/functions-conditional.html#FUNCTIONS-GREATEST-LEAST /// -public class NpgsqlMathTranslator : IMethodCallTranslator +public class NpgsqlMathTranslator( + IRelationalTypeMappingSource typeMappingSource, + ISqlExpressionFactory sqlExpressionFactory, + IModel model) : IMethodCallTranslator { - private static readonly Dictionary SupportedMethodTranslations = new() - { - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(decimal)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(double)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(float)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(int)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(long)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(short)])!, "abs" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), [typeof(float)])!, "abs" }, - { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Abs), [typeof(BigInteger)])!, "abs" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(decimal)])!, "ceiling" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(double)])!, "ceiling" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), [typeof(float)])!, "ceiling" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(decimal)])!, "floor" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(double)])!, "floor" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), [typeof(float)])!, "floor" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "power" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "power" }, - { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Pow), [typeof(BigInteger), typeof(int)])!, "power" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Exp), [typeof(double)])!, "exp" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), [typeof(float)])!, "exp" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log10), [typeof(double)])!, "log" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), [typeof(float)])!, "log" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double)])!, "ln" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float)])!, "ln" }, - // Note: PostgreSQL has log(x,y) but only for decimal, whereas .NET has it only for double/float - - { typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), [typeof(double)])!, "sqrt" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), [typeof(float)])!, "sqrt" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Acos), [typeof(double)])!, "acos" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), [typeof(float)])!, "acos" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Asin), [typeof(double)])!, "asin" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), [typeof(float)])!, "asin" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan), [typeof(double)])!, "atan" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), [typeof(float)])!, "atan" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "atan2" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "atan2" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Cos), [typeof(double)])!, "cos" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), [typeof(float)])!, "cos" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sin), [typeof(double)])!, "sin" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), [typeof(float)])!, "sin" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Tan), [typeof(double)])!, "tan" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), [typeof(float)])!, "tan" }, - { typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "radians" }, - { typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "radians" }, - { typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "degrees" }, - { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "degrees" }, - - // https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(decimal), typeof(decimal)])!, "GREATEST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(double), typeof(double)])!, "GREATEST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(float), typeof(float)])!, "GREATEST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(int), typeof(int)])!, "GREATEST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(long), typeof(long)])!, "GREATEST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Max), [typeof(short), typeof(short)])!, "GREATEST" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Max), [typeof(float), typeof(float)])!, "GREATEST" }, - { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Max), [typeof(BigInteger), typeof(BigInteger)])!, "GREATEST" }, - - // https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(decimal), typeof(decimal)])!, "LEAST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(double), typeof(double)])!, "LEAST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(float), typeof(float)])!, "LEAST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(int), typeof(int)])!, "LEAST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(long), typeof(long)])!, "LEAST" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Min), [typeof(short), typeof(short)])!, "LEAST" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Min), [typeof(float), typeof(float)])!, "LEAST" }, - { typeof(BigInteger).GetRuntimeMethod(nameof(BigInteger.Min), [typeof(BigInteger), typeof(BigInteger)])!, "LEAST" }, - }; - - private static readonly IEnumerable TruncateMethodInfos = - [ - typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(decimal)), - typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Truncate), typeof(double)), - typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Truncate), typeof(float)) - ]; - - private static readonly IEnumerable RoundMethodInfos = - [ - typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(decimal)), - typeof(Math).GetRequiredRuntimeMethod(nameof(Math.Round), typeof(double)), - typeof(MathF).GetRequiredRuntimeMethod(nameof(MathF.Round), typeof(float)) - ]; - - private static readonly IEnumerable SignMethodInfos = - [ - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(decimal)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(double)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(float)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(int)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(long)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(sbyte)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(short)])!, - typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), [typeof(float)])! - ]; - - private static readonly MethodInfo RoundDecimalTwoParams - = typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal), typeof(int)])!; - - private static readonly MethodInfo DoubleIsNanMethodInfo - = typeof(double).GetRuntimeMethod(nameof(double.IsNaN), [typeof(double)])!; - - private static readonly MethodInfo DoubleIsPositiveInfinityMethodInfo - = typeof(double).GetRuntimeMethod(nameof(double.IsPositiveInfinity), [typeof(double)])!; - - private static readonly MethodInfo DoubleIsNegativeInfinityMethodInfo - = typeof(double).GetRuntimeMethod(nameof(double.IsNegativeInfinity), [typeof(double)])!; - - private static readonly MethodInfo FloatIsNanMethodInfo - = typeof(float).GetRuntimeMethod(nameof(float.IsNaN), [typeof(float)])!; - - private static readonly MethodInfo FloatIsPositiveInfinityMethodInfo - = typeof(float).GetRuntimeMethod(nameof(float.IsPositiveInfinity), [typeof(float)])!; - - private static readonly MethodInfo FloatIsNegativeInfinityMethodInfo - = typeof(float).GetRuntimeMethod(nameof(float.IsNegativeInfinity), [typeof(float)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - private readonly RelationalTypeMapping _intTypeMapping; - private readonly RelationalTypeMapping _decimalTypeMapping; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlMathTranslator( - IRelationalTypeMappingSource typeMappingSource, - ISqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _sqlExpressionFactory = sqlExpressionFactory; - _intTypeMapping = typeMappingSource.FindMapping(typeof(int), model)!; - _decimalTypeMapping = typeMappingSource.FindMapping(typeof(decimal), model)!; - } + private readonly RelationalTypeMapping _intTypeMapping = typeMappingSource.FindMapping(typeof(int), model)!; + private readonly RelationalTypeMapping _decimalTypeMapping = typeMappingSource.FindMapping(typeof(decimal), model)!; /// public virtual SqlExpression? Translate( @@ -157,38 +27,151 @@ public NpgsqlMathTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName)) - { - var typeMapping = arguments.Count == 1 - ? ExpressionExtensions.InferTypeMapping(arguments[0]) - : ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]); + var declaringType = method.DeclaringType; - var newArguments = new SqlExpression[arguments.Count]; - newArguments[0] = _sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping); - - if (arguments.Count == 2) - { - newArguments[1] = _sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping); - } + if (declaringType != typeof(Math) + && declaringType != typeof(MathF) + && declaringType != typeof(BigInteger) + && declaringType != typeof(double) + && declaringType != typeof(float)) + { + return null; + } - // Note: GREATER/LEAST only return NULL if *all* arguments are null, but we currently can't - // convey this. - return _sqlExpressionFactory.Function( + return method.Name switch + { + nameof(Math.Abs) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) + || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(short) + || arg.Type == typeof(BigInteger)) + => TranslateFunction("abs", arg), + nameof(Math.Ceiling) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("ceiling", arg), + nameof(Math.Floor) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("floor", arg), + nameof(Math.Pow) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(double) || arg1.Type == typeof(float) || arg1.Type == typeof(BigInteger)) + => TranslateBinaryFunction("power", arg1, arg2), + nameof(Math.Exp) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("exp", arg), + nameof(Math.Log10) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("log", arg), + nameof(Math.Log) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("ln", arg), + // Note: PostgreSQL has log(x,y) but only for decimal, whereas .NET has it only for double/float + nameof(Math.Sqrt) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("sqrt", arg), + nameof(Math.Acos) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("acos", arg), + nameof(Math.Asin) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("asin", arg), + nameof(Math.Atan) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("atan", arg), + nameof(Math.Atan2) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(double) || arg1.Type == typeof(float)) + => TranslateBinaryFunction("atan2", arg1, arg2), + nameof(Math.Cos) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("cos", arg), + nameof(Math.Sin) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("sin", arg), + nameof(Math.Tan) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("tan", arg), + nameof(double.DegreesToRadians) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("radians", arg), + nameof(double.RadiansToDegrees) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("degrees", arg), + // https://www.postgresql.org/docs/current/functions-conditional.html#FUNCTIONS-GREATEST-LEAST + nameof(Math.Max) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(decimal) || arg1.Type == typeof(double) || arg1.Type == typeof(float) + || arg1.Type == typeof(int) || arg1.Type == typeof(long) || arg1.Type == typeof(short) + || arg1.Type == typeof(BigInteger)) + => TranslateBinaryFunction("GREATEST", arg1, arg2), + nameof(Math.Min) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(decimal) || arg1.Type == typeof(double) || arg1.Type == typeof(float) + || arg1.Type == typeof(int) || arg1.Type == typeof(long) || arg1.Type == typeof(short) + || arg1.Type == typeof(BigInteger)) + => TranslateBinaryFunction("LEAST", arg1, arg2), + + nameof(Math.Truncate) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateTruncate(arg), + nameof(Math.Round) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateRound(arg), + nameof(Math.Round) when arguments is [var arg, var digits] + && arg.Type == typeof(decimal) && digits.Type == typeof(int) + => TranslateRoundWithDigits(arg, digits), + nameof(Math.Sign) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) + || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(sbyte) || arg.Type == typeof(short)) + => TranslateSign(arg), + + // PostgreSQL treats NaN values as equal, against IEEE754 + nameof(double.IsNaN) when arguments is [var arg] + => sqlExpressionFactory.Equal( + arg, + sqlExpressionFactory.Constant(declaringType == typeof(double) ? double.NaN : (object)float.NaN)), + nameof(double.IsPositiveInfinity) when arguments is [var arg] + => sqlExpressionFactory.Equal( + arg, + sqlExpressionFactory.Constant( + declaringType == typeof(double) ? double.PositiveInfinity : (object)float.PositiveInfinity)), + nameof(double.IsNegativeInfinity) when arguments is [var arg] + => sqlExpressionFactory.Equal( + arg, + sqlExpressionFactory.Constant( + declaringType == typeof(double) ? double.NegativeInfinity : (object)float.NegativeInfinity)), + + _ => null + }; + + SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(arg); + return sqlExpressionFactory.Function( sqlFunctionName, - newArguments, + [sqlExpressionFactory.ApplyTypeMapping(arg, typeMapping)], nullable: true, - argumentsPropagateNullability: TrueArrays[newArguments.Length], + argumentsPropagateNullability: TrueArrays[1], method.ReturnType, typeMapping); } - if (TruncateMethodInfos.Contains(method)) + SqlExpression TranslateBinaryFunction(string sqlFunctionName, SqlExpression arg1, SqlExpression arg2) { - var argument = arguments[0]; + var typeMapping = ExpressionExtensions.InferTypeMapping(arg1, arg2); + // Note: GREATEST/LEAST only return NULL if *all* arguments are null, but we currently can't convey this. + return sqlExpressionFactory.Function( + sqlFunctionName, + [ + sqlExpressionFactory.ApplyTypeMapping(arg1, typeMapping), + sqlExpressionFactory.ApplyTypeMapping(arg2, typeMapping) + ], + nullable: true, + argumentsPropagateNullability: TrueArrays[2], + method.ReturnType, + typeMapping); + } - // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node) + SqlExpression TranslateTruncate(SqlExpression argument) + { + // C# has Truncate over decimal/double/float only so our argument will be one of those types (compiler puts convert node) // In database result will be same type except for float which returns double which we need to cast back to float. - var result = _sqlExpressionFactory.Function( + var result = sqlExpressionFactory.Function( "trunc", [argument], nullable: true, @@ -197,19 +180,17 @@ public NpgsqlMathTranslator( if (argument.Type == typeof(float)) { - result = _sqlExpressionFactory.Convert(result, typeof(float)); + result = sqlExpressionFactory.Convert(result, typeof(float)); } - return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); + return sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); } - if (RoundMethodInfos.Contains(method)) + SqlExpression TranslateRound(SqlExpression argument) { - var argument = arguments[0]; - // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node) // In database result will be same type except for float which returns double which we need to cast back to float. - var result = _sqlExpressionFactory.Function( + var result = sqlExpressionFactory.Function( "round", [argument], nullable: true, @@ -218,73 +199,35 @@ public NpgsqlMathTranslator( if (argument.Type == typeof(float)) { - result = _sqlExpressionFactory.Convert(result, typeof(float)); + result = sqlExpressionFactory.Convert(result, typeof(float)); } - return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); - } - - // PostgreSQL sign() returns 1, 0, -1, but in the same type as the argument, so we need to convert - // the return type to int. - if (SignMethodInfos.Contains(method)) - { - return - _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "sign", - arguments, - nullable: true, - argumentsPropagateNullability: TrueArrays[1], - method.ReturnType), - typeof(int), - _intTypeMapping); + return sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); } - if (method == RoundDecimalTwoParams) - { - return _sqlExpressionFactory.Function( + SqlExpression TranslateRoundWithDigits(SqlExpression argument, SqlExpression digits) + => sqlExpressionFactory.Function( "round", [ - _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[0]), - _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1]) + sqlExpressionFactory.ApplyDefaultTypeMapping(argument), + sqlExpressionFactory.ApplyDefaultTypeMapping(digits) ], nullable: true, argumentsPropagateNullability: TrueArrays[2], method.ReturnType, _decimalTypeMapping); - } - - // PostgreSQL treats NaN values as equal, against IEEE754 - if (method == DoubleIsNanMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NaN)); - } - - if (method == FloatIsNanMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NaN)); - } - - if (method == DoubleIsPositiveInfinityMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.PositiveInfinity)); - } - - if (method == FloatIsPositiveInfinityMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.PositiveInfinity)); - } - if (method == DoubleIsNegativeInfinityMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(double.NegativeInfinity)); - } - - if (method == FloatIsNegativeInfinityMethodInfo) - { - return _sqlExpressionFactory.Equal(arguments[0], _sqlExpressionFactory.Constant(float.NegativeInfinity)); - } - - return null; + // PostgreSQL sign() returns 1, 0, -1, but in the same type as the argument, so we need to convert + // the return type to int. + SqlExpression TranslateSign(SqlExpression argument) + => sqlExpressionFactory.Convert( + sqlExpressionFactory.Function( + "sign", + [argument], + nullable: true, + argumentsPropagateNullability: TrueArrays[1], + method.ReturnType), + typeof(int), + _intTypeMapping); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs index b938ecb96..920455001 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlMiscAggregateMethodTranslator.cs @@ -9,33 +9,15 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlMiscAggregateMethodTranslator : IAggregateMethodCallTranslator +public class NpgsqlMiscAggregateMethodTranslator( + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource, + IModel model) + : IAggregateMethodCallTranslator { - private static readonly MethodInfo StringJoin - = typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(IEnumerable)])!; - - private static readonly MethodInfo StringConcat - = typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(IEnumerable)])!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly IModel _model; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlMiscAggregateMethodTranslator( - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IRelationalTypeMappingSource typeMappingSource, - IModel model) - { - _sqlExpressionFactory = sqlExpressionFactory; - _typeMappingSource = typeMappingSource; - _model = model; - } + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; + private readonly IModel _model = model; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -56,8 +38,21 @@ public NpgsqlMiscAggregateMethodTranslator( return null; } - if (method == StringJoin || method == StringConcat) + if (method.DeclaringType == typeof(string)) { + SqlExpression separator; + switch (method.Name) + { + case nameof(string.Concat) when arguments is []: + separator = _sqlExpressionFactory.Constant(string.Empty, typeof(string)); + break; + case nameof(string.Join) when arguments is [var sep]: + separator = sep; + break; + default: + return null; + } + // string_agg filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we're aggregating over // a non-nullable column. if (sqlExpression is not ColumnExpression { IsNullable: false }) @@ -73,7 +68,7 @@ public NpgsqlMiscAggregateMethodTranslator( "string_agg", [ sqlExpression, - method == StringJoin ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)) + separator ], source, nullable: true, diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlNetworkTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlNetworkTranslator.cs index 5dedf94da..0ab7543f3 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlNetworkTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlNetworkTranslator.cs @@ -11,38 +11,18 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// /// See: https://www.postgresql.org/docs/current/static/functions-net.html /// -public class NpgsqlNetworkTranslator : IMethodCallTranslator +public class NpgsqlNetworkTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IModel model) + : IMethodCallTranslator { - private static readonly MethodInfo IPAddressParse = - typeof(IPAddress).GetRuntimeMethod(nameof(IPAddress.Parse), [typeof(string)])!; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; - private static readonly MethodInfo PhysicalAddressParse = - typeof(PhysicalAddress).GetRuntimeMethod(nameof(PhysicalAddress.Parse), [typeof(string)])!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - - private readonly RelationalTypeMapping _inetMapping; - private readonly RelationalTypeMapping _cidrMapping; - private readonly RelationalTypeMapping _macaddr8Mapping; - private readonly RelationalTypeMapping _longAddressMapping; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlNetworkTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _sqlExpressionFactory = sqlExpressionFactory; - _inetMapping = typeMappingSource.FindMapping("inet")!; - _cidrMapping = typeMappingSource.FindMapping("cidr")!; - _macaddr8Mapping = typeMappingSource.FindMapping("macaddr8")!; - _longAddressMapping = typeMappingSource.FindMapping(typeof(long), model)!; - } + private readonly RelationalTypeMapping _inetMapping = typeMappingSource.FindMapping("inet")!; + private readonly RelationalTypeMapping _cidrMapping = typeMappingSource.FindMapping("cidr")!; + private readonly RelationalTypeMapping _macaddr8Mapping = typeMappingSource.FindMapping("macaddr8")!; + private readonly RelationalTypeMapping _longAddressMapping = typeMappingSource.FindMapping(typeof(long), model)!; /// public virtual SqlExpression? Translate( @@ -51,17 +31,24 @@ public NpgsqlNetworkTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method == IPAddressParse) + if (method.DeclaringType == typeof(IPAddress) + && method.Name == nameof(IPAddress.Parse) + && arguments is [var ipAddressArg] + && ipAddressArg.Type == typeof(string)) { - return _sqlExpressionFactory.Convert(arguments[0], typeof(IPAddress)); + return _sqlExpressionFactory.Convert(ipAddressArg, typeof(IPAddress)); } - if (method == PhysicalAddressParse) + if (method.DeclaringType == typeof(PhysicalAddress) + && method.Name == nameof(PhysicalAddress.Parse) + && arguments is [var physicalAddressArg] + && physicalAddressArg.Type == typeof(string)) { - return _sqlExpressionFactory.Convert(arguments[0], typeof(PhysicalAddress)); + return _sqlExpressionFactory.Convert(physicalAddressArg, typeof(PhysicalAddress)); } - if (method.DeclaringType == typeof(NpgsqlNetworkDbFunctionsExtensions)) + if (method.DeclaringType == typeof(NpgsqlNetworkDbFunctionsExtensions) + && arguments is [_, var networkArg, ..]) { var paramType = method.GetParameters()[1].ParameterType; diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRandomTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRandomTranslator.cs index 3368c9a65..2a046ff40 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRandomTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRandomTranslator.cs @@ -6,24 +6,8 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlRandomTranslator : IMethodCallTranslator +public class NpgsqlRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo _methodInfo - = typeof(DbFunctionsExtensions).GetRuntimeMethod(nameof(DbFunctionsExtensions.Random), [typeof(DbFunctions)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -35,18 +19,13 @@ public NpgsqlRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - { - Check.NotNull(method, nameof(method)); - Check.NotNull(arguments, nameof(arguments)); - Check.NotNull(logger, nameof(logger)); - - return _methodInfo.Equals(method) - ? _sqlExpressionFactory.Function( + => method.DeclaringType == typeof(DbFunctionsExtensions) + && method.Name == nameof(DbFunctionsExtensions.Random) + ? sqlExpressionFactory.Function( "random", [], nullable: false, argumentsPropagateNullability: [], method.ReturnType) : null; - } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRangeTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRangeTranslator.cs index 2b6747fdf..3a722b1d7 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRangeTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRangeTranslator.cs @@ -18,10 +18,6 @@ public class NpgsqlRangeTranslator : IMethodCallTranslator, IMemberTranslator private readonly IModel _model; private readonly bool _supportsMultiranges; - private static readonly MethodInfo EnumerableAnyWithoutPredicate = - typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1); - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -50,13 +46,14 @@ public NpgsqlRangeTranslator( // Any() over multirange -> NOT isempty(). NpgsqlRange has IsEmpty which is translated below. if (_supportsMultiranges && method.IsGenericMethod - && method.GetGenericMethodDefinition() == EnumerableAnyWithoutPredicate - && arguments[0].IsMultirange()) + && method.DeclaringType == typeof(Enumerable) && method.Name == nameof(Enumerable.Any) + && arguments is [var multirange] + && multirange.IsMultirange()) { return _sqlExpressionFactory.Not( _sqlExpressionFactory.Function( "isempty", - [arguments[0]], + [multirange], nullable: true, argumentsPropagateNullability: TrueArrays[1], typeof(bool))); diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueTranslator.cs index f1e6f3aeb..a16c79ee0 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlRowValueTranslator.cs @@ -11,43 +11,8 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlRowValueTranslator : IMethodCallTranslator +public class NpgsqlRowValueTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - - private static readonly MethodInfo GreaterThan = - typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.GreaterThan), - [typeof(DbFunctions), typeof(ITuple), typeof(ITuple)])!; - - private static readonly MethodInfo LessThan = - typeof(NpgsqlDbFunctionsExtensions).GetMethods() - .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThan)); - - private static readonly MethodInfo GreaterThanOrEqual = - typeof(NpgsqlDbFunctionsExtensions).GetMethods() - .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.GreaterThanOrEqual)); - - private static readonly MethodInfo LessThanOrEqual = - typeof(NpgsqlDbFunctionsExtensions).GetMethods() - .Single(m => m.Name == nameof(NpgsqlDbFunctionsExtensions.LessThanOrEqual)); - - private static readonly Dictionary ComparisonMethods = new() - { - { GreaterThan, ExpressionType.GreaterThan }, - { LessThan, ExpressionType.LessThan }, - { GreaterThanOrEqual, ExpressionType.GreaterThanOrEqual }, - { LessThanOrEqual, ExpressionType.LessThanOrEqual } - }; - - /// - /// Initializes a new instance of the class. - /// - public NpgsqlRowValueTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - /// [DynamicDependency(DynamicallyAccessedMemberTypes.PublicMethods, typeof(ValueType))] // For ValueTuple.Create public virtual SqlExpression? Translate( @@ -63,7 +28,21 @@ public NpgsqlRowValueTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) } // Translate EF.Functions.GreaterThan and other comparisons - if (method.DeclaringType != typeof(NpgsqlDbFunctionsExtensions) || !ComparisonMethods.TryGetValue(method, out var expressionType)) + if (method.DeclaringType != typeof(NpgsqlDbFunctionsExtensions)) + { + return null; + } + + var expressionType = method.Name switch + { + nameof(NpgsqlDbFunctionsExtensions.GreaterThan) => ExpressionType.GreaterThan, + nameof(NpgsqlDbFunctionsExtensions.LessThan) => ExpressionType.LessThan, + nameof(NpgsqlDbFunctionsExtensions.GreaterThanOrEqual) => ExpressionType.GreaterThanOrEqual, + nameof(NpgsqlDbFunctionsExtensions.LessThanOrEqual) => ExpressionType.LessThanOrEqual, + _ => (ExpressionType?)null + }; + + if (expressionType is null) { return null; } @@ -90,6 +69,6 @@ public NpgsqlRowValueTranslator(NpgsqlSqlExpressionFactory sqlExpressionFactory) throw new ArgumentException(NpgsqlStrings.RowValueComparisonRequiresTuplesOfSameLength); } - return _sqlExpressionFactory.MakeBinary(expressionType, arguments[1], arguments[2], typeMapping: null); + return sqlExpressionFactory.MakeBinary(expressionType.Value, arguments[1], arguments[2], typeMapping: null); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs index d56f63380..0e0395dee 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs @@ -18,115 +18,6 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator private readonly IRelationalTypeMappingSource _typeMappingSource; private readonly SqlExpression _whitespace; - #region MethodInfo - - private static readonly MethodInfo IndexOfChar = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!; - private static readonly MethodInfo IndexOfString = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!; - - private static readonly MethodInfo IsNullOrWhiteSpace = - typeof(string).GetRuntimeMethod(nameof(string.IsNullOrWhiteSpace), [typeof(string)])!; - - private static readonly MethodInfo PadLeft = typeof(string).GetRuntimeMethod(nameof(string.PadLeft), [typeof(int)])!; - - private static readonly MethodInfo PadLeftWithChar = typeof(string).GetRuntimeMethod( - nameof(string.PadLeft), [typeof(int), typeof(char)])!; - - private static readonly MethodInfo PadRight = typeof(string).GetRuntimeMethod(nameof(string.PadRight), [typeof(int)])!; - - private static readonly MethodInfo PadRightWithChar = typeof(string).GetRuntimeMethod( - nameof(string.PadRight), [typeof(int), typeof(char)])!; - - private static readonly MethodInfo Replace = typeof(string).GetRuntimeMethod( - nameof(string.Replace), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo Substring = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring)) - .Single(m => m.GetParameters().Length == 1); - - private static readonly MethodInfo SubstringWithLength = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring)) - .Single(m => m.GetParameters().Length == 2); - - private static readonly MethodInfo ToLower = typeof(string).GetRuntimeMethod(nameof(string.ToLower), [])!; - private static readonly MethodInfo ToUpper = typeof(string).GetRuntimeMethod(nameof(string.ToUpper), [])!; - private static readonly MethodInfo TrimBothWithNoParam = typeof(string).GetRuntimeMethod(nameof(string.Trim), Type.EmptyTypes)!; - private static readonly MethodInfo TrimBothWithChars = typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char[])])!; - - private static readonly MethodInfo TrimBothWithSingleChar = - typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char)])!; - - private static readonly MethodInfo TrimEndWithNoParam = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimEndWithChars = typeof(string).GetRuntimeMethod( - nameof(string.TrimEnd), [typeof(char[])])!; - - private static readonly MethodInfo TrimEndWithSingleChar = - typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char)])!; - - private static readonly MethodInfo TrimStartWithNoParam = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimStartWithChars = - typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char[])])!; - - private static readonly MethodInfo TrimStartWithSingleChar = - typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char)])!; - - private static readonly MethodInfo Reverse = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.Reverse), [typeof(DbFunctions), typeof(string)])!; - - private static readonly MethodInfo StringToArray = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.StringToArray), [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private static readonly MethodInfo StringToArrayNullString = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.StringToArray), [typeof(DbFunctions), typeof(string), typeof(string), typeof(string)])!; - - private static readonly MethodInfo ToDate = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.ToDate), [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private static readonly MethodInfo ToTimestamp = typeof(NpgsqlDbFunctionsExtensions).GetRuntimeMethod( - nameof(NpgsqlDbFunctionsExtensions.ToTimestamp), [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private static readonly MethodInfo FirstOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single( - m => m.Name == nameof(Enumerable.FirstOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo LastOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single( - m => m.Name == nameof(Enumerable.LastOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - // ReSharper disable InconsistentNaming - private static readonly MethodInfo String_Join1 = - typeof(string).GetMethod(nameof(string.Join), [typeof(string), typeof(object[])])!; - - private static readonly MethodInfo String_Join2 = - typeof(string).GetMethod(nameof(string.Join), [typeof(string), typeof(string[])])!; - - private static readonly MethodInfo String_Join3 = - typeof(string).GetMethod(nameof(string.Join), [typeof(char), typeof(object[])])!; - - private static readonly MethodInfo String_Join4 = - typeof(string).GetMethod(nameof(string.Join), [typeof(char), typeof(string[])])!; - - private static readonly MethodInfo String_Join5 = - typeof(string).GetMethod(nameof(string.Join), [typeof(string), typeof(IEnumerable)])!; - - private static readonly MethodInfo String_Join_generic1 = - typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single( - m => m is { Name: nameof(string.Join), IsGenericMethod: true } - && m.GetParameters().Length == 2 - && m.GetParameters()[0].ParameterType == typeof(string)); - - private static readonly MethodInfo String_Join_generic2 = - typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) - .Single( - m => m is { Name: nameof(string.Join), IsGenericMethod: true } - && m.GetParameters().Length == 2 - && m.GetParameters()[0].ParameterType == typeof(char)); - // ReSharper restore InconsistentNaming - - #endregion - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -164,27 +55,28 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I return TranslateDbFunctionsMethod(instance, method, arguments); } - if (method == FirstOrDefaultMethodInfoWithoutArgs) + if (method.DeclaringType == typeof(Enumerable) + && method is { IsGenericMethod: true, Name: nameof(Enumerable.FirstOrDefault) or nameof(Enumerable.LastOrDefault) } + && arguments is [var stringArg] + && method.ReturnType == typeof(char)) { - var argument = arguments[0]; - return _sqlExpressionFactory.Function( - "substr", - [argument, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], - nullable: true, - argumentsPropagateNullability: TrueArrays[3], - method.ReturnType); - } + if (method.Name == nameof(Enumerable.FirstOrDefault)) + { + return _sqlExpressionFactory.Function( + "substr", + [stringArg, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], + nullable: true, + argumentsPropagateNullability: TrueArrays[3], + method.ReturnType); + } - if (method == LastOrDefaultMethodInfoWithoutArgs) - { - var argument = arguments[0]; return _sqlExpressionFactory.Function( "substr", [ - argument, + stringArg, _sqlExpressionFactory.Function( "length", - [argument], + [stringArg], nullable: true, argumentsPropagateNullability: [true], typeof(int)), @@ -200,259 +92,232 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I private SqlExpression? TranslateStringMethod(SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) { - if (method == IndexOfString || method == IndexOfChar) + switch (method.Name) { - var argument = arguments[0]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument); + case nameof(string.IndexOf) when arguments is [_]: + { + var argument = arguments[0]; + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument); - return _sqlExpressionFactory.Subtract( - _sqlExpressionFactory.Function( - "strpos", + return _sqlExpressionFactory.Subtract( + _sqlExpressionFactory.Function( + "strpos", + [ + _sqlExpressionFactory.ApplyTypeMapping(instance!, stringTypeMapping), + _sqlExpressionFactory.ApplyTypeMapping(argument, stringTypeMapping) + ], + nullable: true, + argumentsPropagateNullability: TrueArrays[2], + method.ReturnType), + _sqlExpressionFactory.Constant(1)); + } + + case nameof(string.Replace) when arguments is [var oldValue, var newValue] && oldValue.Type == typeof(string): + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue); + + return _sqlExpressionFactory.Function( + "replace", [ _sqlExpressionFactory.ApplyTypeMapping(instance!, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping(argument, stringTypeMapping) + _sqlExpressionFactory.ApplyTypeMapping(oldValue, stringTypeMapping), + _sqlExpressionFactory.ApplyTypeMapping(newValue, stringTypeMapping) ], nullable: true, - argumentsPropagateNullability: TrueArrays[2], - method.ReturnType), - _sqlExpressionFactory.Constant(1)); - } - - if (method == Replace) - { - var oldValue = arguments[0]; - var newValue = arguments[1]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue); - - return _sqlExpressionFactory.Function( - "replace", - [ - _sqlExpressionFactory.ApplyTypeMapping(instance!, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping(oldValue, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping(newValue, stringTypeMapping) - ], - nullable: true, - argumentsPropagateNullability: TrueArrays[3], - method.ReturnType, - stringTypeMapping); - } - - if (method == ToLower || method == ToUpper) - { - return _sqlExpressionFactory.Function( - method == ToLower ? "lower" : "upper", - [instance!], - nullable: true, - argumentsPropagateNullability: TrueArrays[1], - method.ReturnType, - instance!.TypeMapping); - } + argumentsPropagateNullability: TrueArrays[3], + method.ReturnType, + stringTypeMapping); + } - if (method == Substring || method == SubstringWithLength) - { - var args = - method == Substring - ? [instance!, GenerateOneBasedIndexExpression(arguments[0])] - : new[] { instance!, GenerateOneBasedIndexExpression(arguments[0]), arguments[1] }; - return _sqlExpressionFactory.Function( - "substring", - args, - nullable: true, - argumentsPropagateNullability: TrueArrays[args.Length], - method.ReturnType, - instance!.TypeMapping); - } + case nameof(string.ToLower) or nameof(string.ToUpper) when arguments is []: + return _sqlExpressionFactory.Function( + method.Name == nameof(string.ToLower) ? "lower" : "upper", + [instance!], + nullable: true, + argumentsPropagateNullability: TrueArrays[1], + method.ReturnType, + instance!.TypeMapping); - if (method == IsNullOrWhiteSpace) - { - var argument = arguments[0]; + case nameof(string.Substring): + { + var args = + arguments is [var startIndex] + ? new SqlExpression[] { instance!, GenerateOneBasedIndexExpression(startIndex) } + : [instance!, GenerateOneBasedIndexExpression(arguments[0]), arguments[1]]; + return _sqlExpressionFactory.Function( + "substring", + args, + nullable: true, + argumentsPropagateNullability: TrueArrays[args.Length], + method.ReturnType, + instance!.TypeMapping); + } - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.IsNull(argument), - _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Function( - "btrim", - [argument, _whitespace], - nullable: true, - argumentsPropagateNullability: TrueArrays[2], - argument.Type, - argument.TypeMapping), - _sqlExpressionFactory.Constant(string.Empty, argument.TypeMapping))); - } + case nameof(string.IsNullOrWhiteSpace): + { + var argument = arguments[0]; + + return _sqlExpressionFactory.OrElse( + _sqlExpressionFactory.IsNull(argument), + _sqlExpressionFactory.Equal( + _sqlExpressionFactory.Function( + "btrim", + [argument, _whitespace], + nullable: true, + argumentsPropagateNullability: TrueArrays[2], + argument.Type, + argument.TypeMapping), + _sqlExpressionFactory.Constant(string.Empty, argument.TypeMapping))); + } - var isTrimStart = method == TrimStartWithNoParam || method == TrimStartWithChars || method == TrimStartWithSingleChar; - var isTrimEnd = method == TrimEndWithNoParam || method == TrimEndWithChars || method == TrimEndWithSingleChar; - var isTrimBoth = method == TrimBothWithNoParam || method == TrimBothWithChars || method == TrimBothWithSingleChar; - if (isTrimStart || isTrimEnd || isTrimBoth) - { - char[]? trimChars = null; - - if (method == TrimStartWithChars - || method == TrimStartWithSingleChar - || method == TrimEndWithChars - || method == TrimEndWithSingleChar - || method == TrimBothWithChars - || method == TrimBothWithSingleChar) + case nameof(string.TrimStart) or nameof(string.TrimEnd) or nameof(string.Trim): { - var constantTrimChars = arguments[0] as SqlConstantExpression; - if (constantTrimChars is null) + char[]? trimChars = null; + + if (arguments is [_]) { - return null; // Don't translate if trim chars isn't a constant + var constantTrimChars = arguments[0] as SqlConstantExpression; + if (constantTrimChars is null) + { + return null; // Don't translate if trim chars isn't a constant + } + + trimChars = constantTrimChars.Value is char c + ? [c] + : (char[]?)constantTrimChars.Value; } - trimChars = constantTrimChars.Value is char c - ? [c] - : (char[]?)constantTrimChars.Value; - } + var isTrimStart = method.Name is nameof(string.TrimStart); + var isTrimEnd = method.Name is nameof(string.TrimEnd); - return _sqlExpressionFactory.Function( - isTrimStart ? "ltrim" : isTrimEnd ? "rtrim" : "btrim", - [ - instance!, - trimChars is null || trimChars.Length == 0 - ? _whitespace - : _sqlExpressionFactory.Constant(new string(trimChars)) - ], - nullable: true, - argumentsPropagateNullability: TrueArrays[2], - instance!.Type, - instance.TypeMapping); - } + return _sqlExpressionFactory.Function( + isTrimStart ? "ltrim" : isTrimEnd ? "rtrim" : "btrim", + [ + instance!, + trimChars is null || trimChars.Length == 0 + ? _whitespace + : _sqlExpressionFactory.Constant(new string(trimChars)) + ], + nullable: true, + argumentsPropagateNullability: TrueArrays[2], + instance!.Type, + instance.TypeMapping); + } - if (method == PadLeft || method == PadLeftWithChar || method == PadRight || method == PadRightWithChar) - { - var args = - method == PadLeft || method == PadRight - ? [instance!, arguments[0]] - : new[] { instance!, arguments[0], arguments[1] }; + case nameof(string.PadLeft) or nameof(string.PadRight): + { + var args = + arguments is [var padCount] + ? new SqlExpression[] { instance!, padCount } + : new[] { instance!, arguments[0], arguments[1] }; - return _sqlExpressionFactory.Function( - method == PadLeft || method == PadLeftWithChar ? "lpad" : "rpad", - args, - nullable: true, - argumentsPropagateNullability: TrueArrays[args.Length], - instance!.Type, - instance.TypeMapping); - } + return _sqlExpressionFactory.Function( + method.Name is nameof(string.PadLeft) ? "lpad" : "rpad", + args, + nullable: true, + argumentsPropagateNullability: TrueArrays[args.Length], + instance!.Type, + instance.TypeMapping); + } - if (method.DeclaringType == typeof(string) - && (method == String_Join1 - || method == String_Join2 - || method == String_Join3 - || method == String_Join4 - || method == String_Join5 - || method.IsClosedFormOf(String_Join_generic1) - || method.IsClosedFormOf(String_Join_generic2)) - && arguments[1].TypeMapping is NpgsqlArrayTypeMapping) - { - // If the array of strings to be joined is a constant (NewArrayExpression), we translate to concat_ws. - // Otherwise we translate to array_to_string, which also supports array columns and parameters. - if (arguments[1] is PgNewArrayExpression newArrayExpression) + case nameof(string.Join) + when arguments is [_, var joinArray] && joinArray.TypeMapping is NpgsqlArrayTypeMapping: { - var rewrittenArguments = new SqlExpression[newArrayExpression.Expressions.Count + 1]; - rewrittenArguments[0] = arguments[0]; - - for (var i = 0; i < newArrayExpression.Expressions.Count; i++) + // If the array of strings to be joined is a constant (NewArrayExpression), we translate to concat_ws. + // Otherwise we translate to array_to_string, which also supports array columns and parameters. + if (arguments[1] is PgNewArrayExpression newArrayExpression) { - var argument = newArrayExpression.Expressions[i]; + var rewrittenArguments = new SqlExpression[newArrayExpression.Expressions.Count + 1]; + rewrittenArguments[0] = arguments[0]; - rewrittenArguments[i + 1] = argument switch + for (var i = 0; i < newArrayExpression.Expressions.Count; i++) { - ColumnExpression { IsNullable: false } => argument, - SqlConstantExpression constantExpression => constantExpression.Value is null - ? _sqlExpressionFactory.Constant(string.Empty, typeof(string)) - : constantExpression, - _ => _sqlExpressionFactory.Coalesce(argument, _sqlExpressionFactory.Constant(string.Empty, typeof(string))) - }; + var argument = newArrayExpression.Expressions[i]; + + rewrittenArguments[i + 1] = argument switch + { + ColumnExpression { IsNullable: false } => argument, + SqlConstantExpression constantExpression => constantExpression.Value is null + ? _sqlExpressionFactory.Constant(string.Empty, typeof(string)) + : constantExpression, + _ => _sqlExpressionFactory.Coalesce(argument, _sqlExpressionFactory.Constant(string.Empty, typeof(string))) + }; + } + + // Only the delimiter (first arg) propagates nullability - all others are non-nullable, since we wrap the others in coalesce + // (where needed). + var argumentsPropagateNullability = new bool[rewrittenArguments.Length]; + argumentsPropagateNullability[0] = true; + + return _sqlExpressionFactory.Function( + "concat_ws", + rewrittenArguments, + nullable: true, + argumentsPropagateNullability, + typeof(string)); } - // Only the delimiter (first arg) propagates nullability - all others are non-nullable, since we wrap the others in coalesce - // (where needed). - var argumentsPropagateNullability = new bool[rewrittenArguments.Length]; - argumentsPropagateNullability[0] = true; - return _sqlExpressionFactory.Function( - "concat_ws", - rewrittenArguments, + "array_to_string", + [arguments[1], arguments[0], _sqlExpressionFactory.Constant("")], nullable: true, - argumentsPropagateNullability, + argumentsPropagateNullability: TrueArrays[3], typeof(string)); } - return _sqlExpressionFactory.Function( - "array_to_string", - [arguments[1], arguments[0], _sqlExpressionFactory.Constant("")], - nullable: true, - argumentsPropagateNullability: TrueArrays[3], - typeof(string)); + default: + return null; } - - return null; } private SqlExpression? TranslateDbFunctionsMethod(SqlExpression? instance, MethodInfo method, IReadOnlyList arguments) - { - if (method == Reverse) + => method.Name switch { - return _sqlExpressionFactory.Function( + nameof(NpgsqlDbFunctionsExtensions.Reverse) => _sqlExpressionFactory.Function( "reverse", [arguments[1]], nullable: true, argumentsPropagateNullability: TrueArrays[1], typeof(string), - arguments[1].TypeMapping); - } + arguments[1].TypeMapping), - if (method == StringToArray) - { // Note that string_to_array always returns text[], regardless of the input type - return _sqlExpressionFactory.Function( - "string_to_array", - [arguments[1], arguments[2]], - nullable: true, - argumentsPropagateNullability: [true, false], - typeof(string[]), - _typeMappingSource.FindMapping(typeof(string[]))); - } + nameof(NpgsqlDbFunctionsExtensions.StringToArray) when arguments is [_, var strArg, var delimArg] + => _sqlExpressionFactory.Function( + "string_to_array", + [strArg, delimArg], + nullable: true, + argumentsPropagateNullability: [true, false], + typeof(string[]), + _typeMappingSource.FindMapping(typeof(string[]))), - if (method == StringToArrayNullString) - { // Note that string_to_array always returns text[], regardless of the input type - return _sqlExpressionFactory.Function( - "string_to_array", - [arguments[1], arguments[2], arguments[3]], - nullable: true, - argumentsPropagateNullability: [true, false, false], - typeof(string[]), - _typeMappingSource.FindMapping(typeof(string[]))); - } + nameof(NpgsqlDbFunctionsExtensions.StringToArray) when arguments is [_, _, _, _] + => _sqlExpressionFactory.Function( + "string_to_array", + [arguments[1], arguments[2], arguments[3]], + nullable: true, + argumentsPropagateNullability: [true, false, false], + typeof(string[]), + _typeMappingSource.FindMapping(typeof(string[]))), - if (method == ToDate) - { - return _sqlExpressionFactory.Function( + nameof(NpgsqlDbFunctionsExtensions.ToDate) => _sqlExpressionFactory.Function( "to_date", [arguments[1], arguments[2]], nullable: true, argumentsPropagateNullability: [true, true], typeof(DateOnly), - _typeMappingSource.FindMapping(typeof(DateOnly)) - ); - } + _typeMappingSource.FindMapping(typeof(DateOnly))), - if (method == ToTimestamp) - { - return _sqlExpressionFactory.Function( + nameof(NpgsqlDbFunctionsExtensions.ToTimestamp) => _sqlExpressionFactory.Function( "to_timestamp", [arguments[1], arguments[2]], nullable: true, argumentsPropagateNullability: [true, true], typeof(DateTime), - _typeMappingSource.FindMapping(typeof(DateTime)) - ); - } - - return null; - } + _typeMappingSource.FindMapping(typeof(DateTime))), + _ => null, + }; private SqlExpression GenerateOneBasedIndexExpression(SqlExpression expression) => expression is SqlConstantExpression constant diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlTrigramsMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlTrigramsMethodTranslator.cs index 291dcd3d4..aaa40ace7 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlTrigramsMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlTrigramsMethodTranslator.cs @@ -8,73 +8,18 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class NpgsqlTrigramsMethodTranslator : IMethodCallTranslator +public class NpgsqlTrigramsMethodTranslator( + IRelationalTypeMappingSource typeMappingSource, + NpgsqlSqlExpressionFactory sqlExpressionFactory, + IModel model) + : IMethodCallTranslator { - private static readonly Dictionary Functions = new() - { - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsShow), typeof(DbFunctions), typeof(string))] - = "show_trgm", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsSimilarity), typeof(DbFunctions), typeof(string), typeof(string))] - = "similarity", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarity), typeof(DbFunctions), typeof(string), typeof(string))] - = "word_similarity", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarity), typeof(DbFunctions), typeof(string), typeof(string))] - = "strict_word_similarity" - }; - - private static readonly Dictionary BoolReturningOperators = new() - { - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreSimilar), typeof(DbFunctions), typeof(string), typeof(string))] - = "%", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreWordSimilar), typeof(DbFunctions), typeof(string), typeof(string))] - = "<%", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreNotWordSimilar), typeof(DbFunctions), typeof(string), typeof(string))] - = "%>", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreStrictWordSimilar), typeof(DbFunctions), typeof(string), typeof(string))] - = "<<%", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreNotStrictWordSimilar), typeof(DbFunctions), typeof(string), typeof(string))] - = "%>>" - }; - - private static readonly Dictionary FloatReturningOperators = new() - { - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsSimilarityDistance), typeof(DbFunctions), typeof(string), typeof(string))] - = "<->", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarityDistance), typeof(DbFunctions), typeof(string), typeof(string))] - = "<<->", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarityDistanceInverted), typeof(DbFunctions), typeof(string), typeof(string))] - = "<->>", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarityDistance), typeof(DbFunctions), typeof(string), typeof(string))] - = "<<<->", - [GetRuntimeMethod(nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarityDistanceInverted), typeof(DbFunctions), typeof(string), typeof(string))] - = "<->>>" - }; - - private static MethodInfo GetRuntimeMethod(string name, params Type[] parameters) - => typeof(NpgsqlTrigramsDbFunctionsExtensions).GetRuntimeMethod(name, parameters)!; - - private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; - private readonly RelationalTypeMapping _boolMapping; - private readonly RelationalTypeMapping _floatMapping; + private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly RelationalTypeMapping _boolMapping = typeMappingSource.FindMapping(typeof(bool), model)!; + private readonly RelationalTypeMapping _floatMapping = typeMappingSource.FindMapping(typeof(float), model)!; private static readonly bool[][] TrueArrays = [[], [true], [true, true]]; - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public NpgsqlTrigramsMethodTranslator( - IRelationalTypeMappingSource typeMappingSource, - NpgsqlSqlExpressionFactory sqlExpressionFactory, - IModel model) - { - _sqlExpressionFactory = sqlExpressionFactory; - _boolMapping = typeMappingSource.FindMapping(typeof(bool), model)!; - _floatMapping = typeMappingSource.FindMapping(typeof(float), model)!; - } - #pragma warning disable EF1001 /// public virtual SqlExpression? Translate( @@ -83,37 +28,56 @@ public NpgsqlTrigramsMethodTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (Functions.TryGetValue(method, out var function)) + if (method.DeclaringType != typeof(NpgsqlTrigramsDbFunctionsExtensions)) { - return _sqlExpressionFactory.Function( - function, + return null; + } + + return method.Name switch + { + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsShow) => Function("show_trgm"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsSimilarity) => Function("similarity"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarity) => Function("word_similarity"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarity) => Function("strict_word_similarity"), + + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreSimilar) => BoolOperator("%"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreWordSimilar) => BoolOperator("<%"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreNotWordSimilar) => BoolOperator("%>"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreStrictWordSimilar) => BoolOperator("<<%"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsAreNotStrictWordSimilar) => BoolOperator("%>>"), + + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsSimilarityDistance) => FloatOperator("<->"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarityDistance) => FloatOperator("<<->"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsWordSimilarityDistanceInverted) => FloatOperator("<->>"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarityDistance) => FloatOperator("<<<->"), + nameof(NpgsqlTrigramsDbFunctionsExtensions.TrigramsStrictWordSimilarityDistanceInverted) => FloatOperator("<->>>"), + + _ => null + }; + + SqlExpression Function(string name) + => _sqlExpressionFactory.Function( + name, arguments.Skip(1), nullable: true, argumentsPropagateNullability: TrueArrays[arguments.Count - 1], method.ReturnType); - } - if (BoolReturningOperators.TryGetValue(method, out var boolOperator)) - { - return new PgUnknownBinaryExpression( + PgUnknownBinaryExpression BoolOperator(string op) + => new( _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1]), _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[2]), - boolOperator, + op, _boolMapping.ClrType, _boolMapping); - } - if (FloatReturningOperators.TryGetValue(method, out var floatOperator)) - { - return new PgUnknownBinaryExpression( + PgUnknownBinaryExpression FloatOperator(string op) + => new( _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1]), _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[2]), - floatOperator, + op, _floatMapping.ClrType, _floatMapping); - } - - return null; } #pragma warning restore EF1001 } diff --git a/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs b/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs index ffb952826..3326215fc 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlEvaluatableExpressionFilter.cs @@ -13,12 +13,6 @@ public class NpgsqlEvaluatableExpressionFilter : RelationalEvaluatableExpression { private readonly Version _postgresVersion; - private static readonly MethodInfo TsQueryParse = - typeof(NpgsqlTsQuery).GetRuntimeMethod(nameof(NpgsqlTsQuery.Parse), [typeof(string)])!; - - private static readonly MethodInfo TsVectorParse = - typeof(NpgsqlTsVector).GetRuntimeMethod(nameof(NpgsqlTsVector.Parse), [typeof(string)])!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -48,8 +42,8 @@ public override bool IsEvaluatableExpression(Expression expression, IModel model var declaringType = methodCallExpression.Method.DeclaringType; var method = methodCallExpression.Method; - if (method == TsQueryParse - || method == TsVectorParse + if ((method.Name == nameof(NpgsqlTsQuery.Parse) + && (method.DeclaringType == typeof(NpgsqlTsQuery) || method.DeclaringType == typeof(NpgsqlTsVector))) || declaringType == typeof(NpgsqlDbFunctionsExtensions) || declaringType == typeof(NpgsqlFullTextSearchDbFunctionsExtensions) || declaringType == typeof(NpgsqlFullTextSearchLinqExtensions)