Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions src/Pure.DI.Core/Core/Code/LocalFunctions.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,58 @@
namespace Pure.DI.Core.Code;
using Pure.DI.Core.Models;

namespace Pure.DI.Core.Code;

class LocalFunctions(INodeTools nodeTools): ILocalFunctions
{
public bool UseFor(CodeContext ctx)
{
if (ctx.IsFactory && HasOverridesInDependencies(ctx))
{
return false;
}

var var = ctx.VarInjection.Var;
return ctx is { HasOverrides: false, Accumulators.Length: 0 }
&& nodeTools.IsBlock(var.AbstractNode)
&& ctx.RootContext.Graph.Graph.TryGetOutEdges(var.Declaration.Node.Node, out var targets)
&& targets.Count > 1;
}
}

private static bool HasOverridesInDependencies(CodeContext ctx)
{
var graph = ctx.RootContext.Graph.Graph;
var visited = new HashSet<int>();
var stack = new Stack<DependencyNode>();
stack.Push(ctx.VarInjection.Var.AbstractNode.Node);
while (stack.Count > 0)
{
var node = stack.Pop();
if (!visited.Add(node.Binding.Id))
{
continue;
}

if (node.Factory?.HasOverrides == true)
{
return true;
}

if (!graph.TryGetInEdges(node, out var dependencies))
{
continue;
}

foreach (var dependency in dependencies)
{
if (dependency.Injection.Kind is InjectionKind.FactoryInjection or InjectionKind.Override)
{
return true;
}

stack.Push(dependency.Source);
}
}

return false;
}
}
9 changes: 6 additions & 3 deletions src/Pure.DI.Core/Core/Code/VarsMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ private VarDeclaration CreateDeclaration(IDependencyNode node) =>
/// </summary>
private IReadOnlyDictionary<int, VarState> CreateState(Var var) =>
_map
.Where(i => i.Key != var.Declaration.Node.BindingId)
.Where(i => i.Key != var.Declaration.Node.BindingId
&& i.Value.Declaration.Node.Construct is not { Source.Kind: MdConstructKind.Override })
.ToDictionary(i => i.Key, i => new VarState(i.Value));

/// <summary>
Expand All @@ -189,12 +190,14 @@ private void RemoveNewNonPersistentVars(Var var, IReadOnlyDictionary<int, VarSta
}

var node = i.Value.Declaration.Node;
var isPersistent = node.ActualLifetime is Lifetime.Singleton or Lifetime.Scoped or Lifetime.PerResolve
|| node.Arg is { Source.Kind: ArgKind.Composition };
if (node.BindingId == var.Declaration.Node.BindingId)
{
return false;
return !isPersistent;
}

return !(node.ActualLifetime is Lifetime.Singleton or Lifetime.Scoped or Lifetime.PerResolve || node.Arg is not null);
return !isPersistent;
}).ToList();

foreach (var item in newItems)
Expand Down
29 changes: 24 additions & 5 deletions src/Pure.DI.Core/Core/GraphOverrider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ private DependencyNode Override(
return targetNode;
}

if (processed.TryGetValue(targetNode.Binding.Id, out var node))
// Rewritten nodes are context-dependent when any override scope is active.
// In such cases we isolate memoization to the current branch to avoid
// leaking context-dependent rewrites into sibling branches.
var isContextFree = !consumeLocalOverrides
&& nodes.Count == 0
&& localOverrides.Count == 0
&& overrides.Count == 0;
var branchProcessed = isContextFree ? processed : new Dictionary<int, DependencyNode>(processed);
if (branchProcessed.TryGetValue(targetNode.Binding.Id, out var node))
{
return node;
}
Expand Down Expand Up @@ -98,7 +106,7 @@ private DependencyNode Override(
overridesEnumerable = [];
}

processed.Add(targetNode.Binding.Id, targetNode);
branchProcessed[targetNode.Binding.Id] = targetNode;
var newDependencies = new List<Dependency>(dependencies.Count);
var lastDependencyPosition = 0;
using var overridesEnumerator = overridesEnumerable.GetEnumerator();
Expand Down Expand Up @@ -179,12 +187,13 @@ private DependencyNode Override(
var currentDependency = dependency with { Target = targetNode };
if (!localNodesMap.TryGetValue(currentDependency.Injection, out var overridingSourceNode))
{
var sourceOverrides = overridesMap.ToDictionary();
var source = Override(
processed,
branchProcessed,
nodesMap,
nextLocalOverrides,
nextLocalOverrides.Count > 0,
overridesMap,
sourceOverrides,
setup,
graph,
rootNode,
Expand All @@ -206,7 +215,17 @@ private DependencyNode Override(
newDependencies.Add(currentDependency);
}

entries.Add(new GraphEntry<DependencyNode, Dependency>(targetNode, newDependencies));
var entry = new GraphEntry<DependencyNode, Dependency>(targetNode, newDependencies);
var entryIndex = entries.FindIndex(i => Equals(i.Target, targetNode));
if (entryIndex >= 0)
{
entries[entryIndex] = entry;
}
else
{
entries.Add(entry);
}

return targetNode;
}
}
19 changes: 15 additions & 4 deletions src/Pure.DI.Core/Core/OverrideIdProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@ public override bool Equals(object? obj)
if (obj.GetType() != GetType()) return false;
var other = (Key)obj;
return SymbolEqualityComparer.Default.Equals(_type, other._type)
&& (_tags.Count == 0 && other._tags.Count == 0 || _tags.Intersect(other._tags).Any());
&& _tags.SetEquals(other._tags);
}

public override int GetHashCode() =>
SymbolEqualityComparer.Default.GetHashCode(_type);
public override int GetHashCode()
{
var hashCode = SymbolEqualityComparer.Default.GetHashCode(_type);
foreach (var tagHashCode in _tags.Select(GetTagHashCode).OrderBy(i => i))
{
hashCode = (hashCode * 397) ^ tagHashCode;
}

return hashCode;
}

private static int GetTagHashCode(object tag) =>
tag.GetHashCode();
}
}
}
Loading
Loading