From deece7b51d1e4bf8e546581145a52ab65ad69c1f Mon Sep 17 00:00:00 2001 From: shuofengzhang Date: Sun, 15 Mar 2026 06:51:10 +0800 Subject: [PATCH 1/3] Fix aggregation validation for builtins and callables --- neat/aggregations.py | 30 ++++++++++++++++++++++-------- tests/test_aggregation.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/neat/aggregations.py b/neat/aggregations.py index 3f7764be..989a501c 100644 --- a/neat/aggregations.py +++ b/neat/aggregations.py @@ -3,6 +3,7 @@ and code for adding new user-defined ones. """ +import inspect import types import warnings from functools import reduce @@ -49,14 +50,27 @@ class InvalidAggregationFunction(TypeError): def validate_aggregation(function): # TODO: Recognize when need `reduce` - if not isinstance(function, - (types.BuiltinFunctionType, - types.FunctionType, - types.LambdaType)): - raise InvalidAggregationFunction("A function object is required.") - - if not (function.__code__.co_argcount >= 1): - raise InvalidAggregationFunction("A function taking at least one argument is required") + if not callable(function): + raise InvalidAggregationFunction("A callable object is required.") + + try: + signature = inspect.signature(function) + except (TypeError, ValueError) as exc: + if isinstance(function, types.BuiltinFunctionType): + return + raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc + + accepts_positional = any( + parameter.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ) + for parameter in signature.parameters.values() + ) + + if not accepts_positional: + raise InvalidAggregationFunction("A function taking at least one positional argument is required") class AggregationFunctionSet: diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 6ac2ce82..10e0bf9d 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -72,10 +72,25 @@ def test_add_minabs(): assert config.genome_config.aggregation_function_defs.is_valid('minabs') +def test_add_builtin_max(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + config.genome_config.add_aggregation('builtin_max', max) + assert config.genome_config.aggregation_function_defs.get('builtin_max') is max + + def dud_function(): return 0.0 +def keyword_only_function(*, items): + return sum(items) + + def test_function_set(): s = aggregations.AggregationFunctionSet() assert s.get('sum') is not None @@ -135,6 +150,21 @@ def test_bad_add2(): raise Exception("Should have had a TypeError/derived for dud_function") +def test_bad_add3(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + try: + config.genome_config.add_aggregation('keyword_only_function', keyword_only_function) + except TypeError: + pass + else: + raise Exception("Should have had a TypeError/derived for keyword_only_function") + + if __name__ == '__main__': test_sum() test_product() From 548a8b459668668da12820e39f5bba3c02c83457 Mon Sep 17 00:00:00 2001 From: shuofengzhang Date: Sun, 15 Mar 2026 00:44:42 +0000 Subject: [PATCH 2/3] Validate aggregation callables accept single positional argument --- neat/aggregations.py | 19 +++++++------------ tests/test_aggregation.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/neat/aggregations.py b/neat/aggregations.py index 989a501c..7819172d 100644 --- a/neat/aggregations.py +++ b/neat/aggregations.py @@ -49,7 +49,7 @@ class InvalidAggregationFunction(TypeError): pass -def validate_aggregation(function): # TODO: Recognize when need `reduce` +def validate_aggregation(function): if not callable(function): raise InvalidAggregationFunction("A callable object is required.") @@ -60,17 +60,12 @@ def validate_aggregation(function): # TODO: Recognize when need `reduce` return raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc - accepts_positional = any( - parameter.kind in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.VAR_POSITIONAL, - ) - for parameter in signature.parameters.values() - ) - - if not accepts_positional: - raise InvalidAggregationFunction("A function taking at least one positional argument is required") + try: + signature.bind(object()) + except TypeError as exc: + raise InvalidAggregationFunction( + "A function taking a single positional argument is required" + ) from exc class AggregationFunctionSet: diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 10e0bf9d..81a3d0fa 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -91,6 +91,10 @@ def keyword_only_function(*, items): return sum(items) +def two_argument_function(items, scale): + return sum(items) * scale + + def test_function_set(): s = aggregations.AggregationFunctionSet() assert s.get('sum') is not None @@ -165,6 +169,21 @@ def test_bad_add3(): raise Exception("Should have had a TypeError/derived for keyword_only_function") +def test_bad_add4(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + try: + config.genome_config.add_aggregation('two_argument_function', two_argument_function) + except TypeError: + pass + else: + raise Exception("Should have had a TypeError/derived for two_argument_function") + + if __name__ == '__main__': test_sum() test_product() From 5b53b557ed2b077e595cd9c0e03ad4c02d6a691e Mon Sep 17 00:00:00 2001 From: CodeReclaimers Date: Sun, 15 Mar 2026 15:31:05 -0400 Subject: [PATCH 3/3] Clarify error message and add comment on builtin fallback - Error message now says "one required positional argument" (functions with extra optional args are still valid) - Added comment explaining why builtins skip signature validation Co-Authored-By: Claude Opus 4.6 (1M context) --- neat/aggregations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neat/aggregations.py b/neat/aggregations.py index 7819172d..79da35f2 100644 --- a/neat/aggregations.py +++ b/neat/aggregations.py @@ -56,6 +56,8 @@ def validate_aggregation(function): try: signature = inspect.signature(function) except (TypeError, ValueError) as exc: + # CPython builtins (e.g. max, sum) often lack introspectable signatures. + # Skip signature validation for these; they are assumed correct. if isinstance(function, types.BuiltinFunctionType): return raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc @@ -64,7 +66,7 @@ def validate_aggregation(function): signature.bind(object()) except TypeError as exc: raise InvalidAggregationFunction( - "A function taking a single positional argument is required" + "A callable with exactly one required positional argument is required" ) from exc