diff --git a/src/Pure.DI.Core/Core/Code/LocalFunctions.cs b/src/Pure.DI.Core/Core/Code/LocalFunctions.cs index 30cdc4ffd..a91bec755 100644 --- a/src/Pure.DI.Core/Core/Code/LocalFunctions.cs +++ b/src/Pure.DI.Core/Core/Code/LocalFunctions.cs @@ -1,13 +1,58 @@ -namespace Pure.DI.Core.Code; +using Pure.DI.Core.Models; + +namespace Pure.DI.Core.Code; class LocalFunctions(INodeTools nodeTools): ILocalFunctions { public bool UseFor(CodeContext ctx) { + if (ctx.IsFactory && HasOverridesInDependencies(ctx)) + { + return false; + } + var var = ctx.VarInjection.Var; return ctx is { HasOverrides: false, Accumulators.Length: 0 } && nodeTools.IsBlock(var.AbstractNode) && ctx.RootContext.Graph.Graph.TryGetOutEdges(var.Declaration.Node.Node, out var targets) && targets.Count > 1; } -} \ No newline at end of file + + private static bool HasOverridesInDependencies(CodeContext ctx) + { + var graph = ctx.RootContext.Graph.Graph; + var visited = new HashSet(); + var stack = new Stack(); + stack.Push(ctx.VarInjection.Var.AbstractNode.Node); + while (stack.Count > 0) + { + var node = stack.Pop(); + if (!visited.Add(node.Binding.Id)) + { + continue; + } + + if (node.Factory?.HasOverrides == true) + { + return true; + } + + if (!graph.TryGetInEdges(node, out var dependencies)) + { + continue; + } + + foreach (var dependency in dependencies) + { + if (dependency.Injection.Kind is InjectionKind.FactoryInjection or InjectionKind.Override) + { + return true; + } + + stack.Push(dependency.Source); + } + } + + return false; + } +} diff --git a/src/Pure.DI.Core/Core/Code/VarsMap.cs b/src/Pure.DI.Core/Core/Code/VarsMap.cs index b1016b10d..6b7a07068 100644 --- a/src/Pure.DI.Core/Core/Code/VarsMap.cs +++ b/src/Pure.DI.Core/Core/Code/VarsMap.cs @@ -170,7 +170,8 @@ private VarDeclaration CreateDeclaration(IDependencyNode node) => /// private IReadOnlyDictionary CreateState(Var var) => _map - .Where(i => i.Key != var.Declaration.Node.BindingId) + .Where(i => i.Key != var.Declaration.Node.BindingId + && i.Value.Declaration.Node.Construct is not { Source.Kind: MdConstructKind.Override }) .ToDictionary(i => i.Key, i => new VarState(i.Value)); /// @@ -189,12 +190,14 @@ private void RemoveNewNonPersistentVars(Var var, IReadOnlyDictionary(processed); + if (branchProcessed.TryGetValue(targetNode.Binding.Id, out var node)) { return node; } @@ -98,7 +106,7 @@ private DependencyNode Override( overridesEnumerable = []; } - processed.Add(targetNode.Binding.Id, targetNode); + branchProcessed[targetNode.Binding.Id] = targetNode; var newDependencies = new List(dependencies.Count); var lastDependencyPosition = 0; using var overridesEnumerator = overridesEnumerable.GetEnumerator(); @@ -179,12 +187,13 @@ private DependencyNode Override( var currentDependency = dependency with { Target = targetNode }; if (!localNodesMap.TryGetValue(currentDependency.Injection, out var overridingSourceNode)) { + var sourceOverrides = overridesMap.ToDictionary(); var source = Override( - processed, + branchProcessed, nodesMap, nextLocalOverrides, nextLocalOverrides.Count > 0, - overridesMap, + sourceOverrides, setup, graph, rootNode, @@ -206,7 +215,17 @@ private DependencyNode Override( newDependencies.Add(currentDependency); } - entries.Add(new GraphEntry(targetNode, newDependencies)); + var entry = new GraphEntry(targetNode, newDependencies); + var entryIndex = entries.FindIndex(i => Equals(i.Target, targetNode)); + if (entryIndex >= 0) + { + entries[entryIndex] = entry; + } + else + { + entries.Add(entry); + } + return targetNode; } } diff --git a/src/Pure.DI.Core/Core/OverrideIdProvider.cs b/src/Pure.DI.Core/Core/OverrideIdProvider.cs index 18d148cc4..75e178433 100644 --- a/src/Pure.DI.Core/Core/OverrideIdProvider.cs +++ b/src/Pure.DI.Core/Core/OverrideIdProvider.cs @@ -23,10 +23,21 @@ public override bool Equals(object? obj) if (obj.GetType() != GetType()) return false; var other = (Key)obj; return SymbolEqualityComparer.Default.Equals(_type, other._type) - && (_tags.Count == 0 && other._tags.Count == 0 || _tags.Intersect(other._tags).Any()); + && _tags.SetEquals(other._tags); } - public override int GetHashCode() => - SymbolEqualityComparer.Default.GetHashCode(_type); + public override int GetHashCode() + { + var hashCode = SymbolEqualityComparer.Default.GetHashCode(_type); + foreach (var tagHashCode in _tags.Select(GetTagHashCode).OrderBy(i => i)) + { + hashCode = (hashCode * 397) ^ tagHashCode; + } + + return hashCode; + } + + private static int GetTagHashCode(object tag) => + tag.GetHashCode(); } -} \ No newline at end of file +} diff --git a/tests/Pure.DI.IntegrationTests/OverrideTests.cs b/tests/Pure.DI.IntegrationTests/OverrideTests.cs index 7605015c2..a2cf2c6a7 100644 --- a/tests/Pure.DI.IntegrationTests/OverrideTests.cs +++ b/tests/Pure.DI.IntegrationTests/OverrideTests.cs @@ -2424,6 +2424,223 @@ public static void Main() result.StdOut.ShouldBe(["3", "0", "1", "2"], result); } + [Fact] + public async Task ShouldSupportOverrideInFactoryWithLocalFunctionAndFuncArgs() + { + // Given + + // When + var result = await """ + using System; + using System.Collections.Generic; + using System.Threading.Tasks; + using Pure.DI; + + namespace Sample + { + interface IStorage {} + + class Storage : IStorage + { + } + + class Command + { + public Command(Func canExecute, Func execute) + { + CanExecute = canExecute; + Execute = execute; + } + + public Func CanExecute { get; } + + public Func Execute { get; } + + public IDispatcher? Dispatcher { get; set; } + } + + interface IDispatcher {} + + class Dispatcher : IDispatcher + { + } + + class NodeName + { + public NodeName(string value) + { + Value = value; + } + + public string Value { get; } + } + + interface ITreeNodeViewModel + { + string Kind { get; } + } + + class DirectoryNodeViewModel : ITreeNodeViewModel + { + public DirectoryNodeViewModel( + Func, Func, Command> commandBuilder, + IStorage storage, + NodeName name, + List children) + { + Kind = $"Dir:{name.Value}:{children.Count}"; + } + + public string Kind { get; } + } + + class FileTreeNodeViewModel : ITreeNodeViewModel + { + public FileTreeNodeViewModel( + Func, Func, Command> commandBuilder, + IStorage storage, + NodeName name) + { + Kind = $"File:{name.Value}"; + } + + public string Kind { get; } + } + + class NodeFactory + { + private readonly IStorage _storage; + private readonly Func, Func, Command> _commandBuilder; + private readonly Func, Func, Command>, IStorage, NodeName, List, ITreeNodeViewModel> _dirFactory; + private readonly Func, Func, Command>, IStorage, NodeName, ITreeNodeViewModel> _fileFactory; + + public NodeFactory( + IStorage storage, + Func, Func, Command>, IStorage, NodeName, List, ITreeNodeViewModel> dirFactory, + Func, Func, Command>, IStorage, NodeName, ITreeNodeViewModel> fileFactory, + Func, Func, Command> commandBuilder) + { + _storage = storage; + _commandBuilder = commandBuilder; + _dirFactory = dirFactory; + _fileFactory = fileFactory; + } + + public ITreeNodeViewModel CreateDir(NodeName name, List children) => + _dirFactory(_commandBuilder, _storage, name, children); + + public ITreeNodeViewModel CreateFile(NodeName name) => + _fileFactory(_commandBuilder, _storage, name); + } + + class TreeCompressor + { + public TreeCompressor( + Func, ITreeNodeViewModel> dirFactory, + Func fileFactory) + { + DirFactory = dirFactory; + FileFactory = fileFactory; + } + + public Func, ITreeNodeViewModel> DirFactory { get; } + + public Func FileFactory { get; } + } + + interface IFileDuplicates + { + int Id { get; } + + string Name { get; } + } + + class FileDuplicates : IFileDuplicates + { + public FileDuplicates(int id, string name) + { + Id = id; + Name = name; + } + + public int Id { get; } + + public string Name { get; } + } + + interface IFileDuplicatesViewModel + { + int Id { get; } + + string Name { get; } + } + + class FileDuplicatesViewModel : IFileDuplicatesViewModel + { + public FileDuplicatesViewModel( + IFileDuplicates duplicates, + NodeFactory nodeFactory, + TreeCompressor treeCompressor) + { + Id = duplicates.Id; + Name = duplicates.Name; + } + + public int Id { get; } + + public string Name { get; } + } + + class DuplicatesViewModel + { + public DuplicatesViewModel(Func factory) + { + Factory = factory; + } + + public Func Factory { get; } + } + + static class Setup + { + private static void SetupComposition() + { + DI.Setup(nameof(Composition)) + .Hint(Hint.Resolve, "Off") + .Bind().As(Lifetime.Singleton).To() + .Bind().As(Lifetime.Singleton).To() + .Bind().As(Lifetime.Singleton).To, Func, Command>>(ctx => + { + ctx.Inject(out IDispatcher dispatcher); + return (canExecute, execute) => + new Command(canExecute, execute) { Dispatcher = dispatcher }; + }) + .Singleton() + .Transient() + .Bind().To() + .Bind().As(Lifetime.Singleton).To() + .Root("Root"); + } + } + + public class Program + { + public static void Main() + { + var composition = new Composition(); + var root = composition.Root; + var vm = root.Factory(new FileDuplicates(42, "custom")); + Console.WriteLine($"{vm.Id} {vm.Name}"); + } + } + } + """.RunAsync(); + + // Then + result.Success.ShouldBeTrue(result); + result.StdOut.ShouldBe(["42 custom"], result); + } + [Fact] public async Task ShouldSupportStdFuncWithArg() { @@ -2615,6 +2832,110 @@ public static void Main() result.StdOut.ShouldBe(["3", "a", "0", "b", "1", "c", "2"], result); } + [Fact] + public async Task ShouldSupportOverrideWhenFuncWith2ArgsOfSameTypeAndTag() + { + // Given + + // When + var result = await """ + using System; + using System.Collections.Generic; + using Pure.DI; + + namespace Sample + { + interface IClock + { + DateTimeOffset Now { get; } + } + + class Clock : IClock + { + public DateTimeOffset Now => DateTimeOffset.Now; + } + + interface IDependency + { + int Id { get; } + int SubId { get; } + } + + class Dependency : IDependency + { + private readonly int _id; + private readonly int _subId; + + public Dependency(IClock clock, int id, [Tag("sub")] int subId) + { + _id = id; + _subId = subId; + } + + public int Id => _id; + + public int SubId => _subId; + } + + interface IService + { + List Dependencies { get; } + } + + class Service : IService + { + public Service(Func dependencyFactory) + { + Dependencies = new List + { + dependencyFactory(10, 100), + dependencyFactory(11, 101), + dependencyFactory(12, 102) + }; + } + + public List Dependencies { get; } + } + + static class Setup + { + private static void SetupComposition() + { + DI.Setup(nameof(Composition)) + .Bind().As(Lifetime.Singleton).To() + .Bind().To>(ctx => + (id, subId) => + { + ctx.Override(id); + ctx.Override(subId, "sub"); + ctx.Inject(out var dependency); + return dependency; + }) + .Bind().To() + .Root("Root"); + } + } + + public class Program + { + public static void Main() + { + var composition = new Composition(); + var service = composition.Root; + Console.WriteLine(service.Dependencies.Count); + Console.WriteLine($"{service.Dependencies[0].Id}:{service.Dependencies[0].SubId}"); + Console.WriteLine($"{service.Dependencies[1].Id}:{service.Dependencies[1].SubId}"); + Console.WriteLine($"{service.Dependencies[2].Id}:{service.Dependencies[2].SubId}"); + } + } + } + """.RunAsync(); + + // Then + result.Success.ShouldBeTrue(result); + result.StdOut.ShouldBe(["3", "10:100", "11:101", "12:102"], result); + } + [Fact] public async Task ShouldSupportOverrideWhenCtor() {