diff --git a/neat/aggregations.py b/neat/aggregations.py index 3f7764be..79da35f2 100644 --- a/neat/aggregations.py +++ b/neat/aggregations.py @@ -3,6 +3,7 @@ and code for adding new user-defined ones. """ +import inspect import types import warnings from functools import reduce @@ -48,15 +49,25 @@ class InvalidAggregationFunction(TypeError): pass -def validate_aggregation(function): # TODO: Recognize when need `reduce` - if not isinstance(function, - (types.BuiltinFunctionType, - types.FunctionType, - types.LambdaType)): - raise InvalidAggregationFunction("A function object is required.") - - if not (function.__code__.co_argcount >= 1): - raise InvalidAggregationFunction("A function taking at least one argument is required") +def validate_aggregation(function): + if not callable(function): + raise InvalidAggregationFunction("A callable object is required.") + + try: + signature = inspect.signature(function) + except (TypeError, ValueError) as exc: + # CPython builtins (e.g. max, sum) often lack introspectable signatures. + # Skip signature validation for these; they are assumed correct. + if isinstance(function, types.BuiltinFunctionType): + return + raise InvalidAggregationFunction("Unable to inspect aggregation callable signature.") from exc + + try: + signature.bind(object()) + except TypeError as exc: + raise InvalidAggregationFunction( + "A callable with exactly one required positional argument is required" + ) from exc class AggregationFunctionSet: diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 6ac2ce82..81a3d0fa 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -72,10 +72,29 @@ def test_add_minabs(): assert config.genome_config.aggregation_function_defs.is_valid('minabs') +def test_add_builtin_max(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + config.genome_config.add_aggregation('builtin_max', max) + assert config.genome_config.aggregation_function_defs.get('builtin_max') is max + + def dud_function(): return 0.0 +def keyword_only_function(*, items): + return sum(items) + + +def two_argument_function(items, scale): + return sum(items) * scale + + def test_function_set(): s = aggregations.AggregationFunctionSet() assert s.get('sum') is not None @@ -135,6 +154,36 @@ def test_bad_add2(): raise Exception("Should have had a TypeError/derived for dud_function") +def test_bad_add3(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + try: + config.genome_config.add_aggregation('keyword_only_function', keyword_only_function) + except TypeError: + pass + else: + raise Exception("Should have had a TypeError/derived for keyword_only_function") + + +def test_bad_add4(): + local_dir = os.path.dirname(__file__) + config_path = os.path.join(local_dir, 'test_configuration') + config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction, + neat.DefaultSpeciesSet, neat.DefaultStagnation, + config_path) + + try: + config.genome_config.add_aggregation('two_argument_function', two_argument_function) + except TypeError: + pass + else: + raise Exception("Should have had a TypeError/derived for two_argument_function") + + if __name__ == '__main__': test_sum() test_product()