Skip to content

Latest commit

 

History

History
675 lines (519 loc) · 26.6 KB

File metadata and controls

675 lines (519 loc) · 26.6 KB

Cortical

Cortical is a framework for building scalable differentiable memory architectures in PyTorch, CUDA, and Triton.

Cortical supports two ways to build memory systems:

  • Fabrics: streaming recurrent substrates made from many small cells connected by a graph.
  • Stacks: transformer-style stacked architectures built from recurrent, attention, and routed-memory blocks.

Cortical treats memory as structure, not only as a token cache or a hidden vector. A model can keep a persistent state field, move signals through it, and learn local updates over time.

Install

pip install cortical

From this repository:

pip install -e .

CUDA and Triton paths are used when available. PyTorch paths are kept for development, CPU runs, and parity checks.

Fabric

Fabric is a streaming recurrent substrate.

Instead of processing a sequence as a stack of layers, Fabric keeps a graph of small recurrent cells alive across time. Inputs arrive at named regions of the graph, perturb the substrate, and then disappear. The cells keep running, exchange messages with nearby cells, and carry memory forward in their recurrent state.

The mental model is closer to a small patch of neural tissue than to a transformer block:

  • the graph defines where cells live and who can talk to whom
  • each cell keeps private state
  • cells expose public signals that neighboring cells can read
  • inputs and outputs are regions on the graph
  • computation is local, recurrent, and streaming

Fabric is still an ordinary differentiable PyTorch module. It is trained with gradient descent. It differs in structure: the model has a persistent spatial substrate rather than only a stack of per-token activations.

What Is A Cell?

A Fabric cell is a small learned recurrent unit attached to one node in the graph.

At every Fabric step, a cell receives:

  • its previous private state
  • messages from connected sender cells
  • optional external input if it belongs to an input region
  • reset information when a stream boundary occurs

It produces:

  • updated private state
  • public output that other cells can read
  • public output that named output regions can aggregate

Cells describe local transitions. Fabric batches, routes, projects, and executes them.

What Is Message Passing?

Message passing is the differentiable function cells use to communicate across graph edges. For each edge, the cell being updated is the receiver and the connected neighbor it reads from is the sender.

The graph chooses the sender set for each receiver. The message_rules declaration defines how the receiver combines those sender signals into one incoming message. Dot-product attention is one example: the receiver provides the query, senders provide keys and values, and the rule returns a weighted message. Other rules can use different differentiable math while keeping the same graph and cell model.

Message passing is trained with the rest of the model. Fabric owns the routing, batching, layout, and GPU execution.

What Happens In One Step?

A Fabric step updates the whole substrate once:

Stage What happens
Input External data enters named graph regions.
Message passing Cells read public signals from connected neighbors.
Cell update Each cell updates its private recurrent state.
Public emission Cells expose new public signals.
Output Named output regions aggregate public signals.
Carry The new substrate state becomes the starting point for the next step.

For sequence input [B, T, H], Fabric runs this step repeatedly over time. Fabric is streaming by design: callers can keep the returned state and pass it into the next call, so the same substrate continues from one window to the next. Fabric can also be used as a regular sequence-to-sequence module by passing a full [B, T, H] window and reading the window outputs.

Losses attach to the output tensors as usual. During backpropagation, gradients flow backward from the loss through output aggregation, public emission, cell updates, message passing, and input projections. For a sequence window, autograd also sends gradients backward through the repeated Fabric steps in time.

In both cases, the computation is a recurrent substrate, not a stack of independent per-timestep activations.

Quick Start

import torch
import cortical.fabric as fabric

graph = fabric.graphs.lattice2d.Graph(
    width=8,
    height=8,
    populations={
        "core": fabric.Population(
            cell=fabric.cells.SLSTM(hidden_dim=32),
        ),
    },
)

blueprint = fabric.Blueprint(
    interface=fabric.Interface(public_dim=32, message_dim=32),
    graph=graph,
    inputs={"tokens": fabric.Input(dim=128)},
    outputs={"prediction": fabric.Output(dim=128)},
    message_passing=fabric.message_rules.DotProduct(head_dim=32),
    execution=fabric.ExecutionSpec(backend="auto", inner_steps=1),
)

model = fabric.compile(blueprint)

x = torch.randn(4, 32, 128)  # [batch, time, external hidden]
y, state = model(x, state=None)

fabric.compile(blueprint) returns a PyTorch module. External dimensions live on Blueprint.inputs and Blueprint.outputs; Fabric projects those tensors into and out of the graph substrate internally.

Visualization With cscope

cortical.scope is the graph and activity visualizer for Fabric. It opens an interactive WebGL viewer for inspecting cell positions, input/output boundary edges, local and patch connectivity, live sculpted lattice parameters, and recorded message-flow traces.

Launch the sculptable viewer:

cscope

Open a saved .cscope bundle:

cscope open path/to/run.cscope

Record and view a model from Python:

import torch
import cortical.scope as cscope

x = torch.randn(2, 16, 128)
run = cscope.record(model, x)
viewer = cscope.show(run)
run.save("run.cscope")

The viewer is served locally and keeps the visualization stack separate from Fabric execution. Python code builds scenes, recordings, and bundles; the browser viewer owns presentation, camera controls, rendering presets, and sculpt controls.

Declaring A Fabric

The main user object is a Blueprint. A blueprint says:

  • what graph to build
  • which external adapters exist
  • what cell type each region uses
  • which top-level message_passing rule cells use across graph edges
  • what dimensions and backend to use
import torch
import cortical.fabric as fabric

graph = fabric.graphs.lattice2d.Graph(
    width=16,
    height=16,
    wrap=True,
    inputs={"tokens": fabric.graphs.lattice2d.XBand("low", width=1)},
    outputs={
        "prediction": fabric.graphs.lattice2d.Output(
            fabric.graphs.lattice2d.XBand("high", width=1),
            aggregate="mean",
        )
    },
    connectivity=[fabric.graphs.lattice2d.LocalRadius(radius=1.5)],
    populations={
        "core": fabric.Population(cell=fabric.cells.SLSTM(hidden_dim=32)),
    },
)

blueprint = fabric.Blueprint(
    interface=fabric.Interface(public_dim=32, message_dim=32),
    graph=graph,
    inputs={"tokens": fabric.Input(dim=128)},
    outputs={"prediction": fabric.Output(dim=128)},
    message_passing=fabric.message_rules.DotProduct(head_dim=32),
    execution=fabric.ExecutionSpec(backend="auto", inner_steps=1),
)

model = fabric.compile(blueprint)

x = torch.randn(8, 16, 128)
y, state = model(x, state=None)

This example creates a 16 by 16 substrate with one input region, one output region, one sLSTM cell population, local connectivity, and dot-product message passing.

message_passing is part of the Blueprint, not part of the graph. The graph says which cells are connected; the message rule says what computation runs across those connections.

Mixed Cell Families

Use multiple populations when regions need different recurrent dynamics.

import cortical.fabric as fabric

graph = fabric.graphs.lattice2d.Graph(
    width=16,
    height=16,
    populations={
        "fast": fabric.Population(
            cell=fabric.cells.SLSTM(hidden_dim=32),
            nodes=fabric.graphs.lattice2d.Region(x=(0.0, 0.5)),
        ),
        "trace": fabric.Population(
            cell=fabric.cells.AxonCell(hidden_dim=32, trace_dim=32),
            nodes=fabric.graphs.lattice2d.Region(x=(0.5, 1.0)),
        ),
    },
)

blueprint = fabric.Blueprint(
    interface=fabric.Interface(public_dim=32, message_dim=32),
    graph=graph,
    inputs={"tokens": fabric.Input(dim=128)},
    outputs={"prediction": fabric.Output(dim=128)},
    message_passing=fabric.message_rules.DotProduct(head_dim=32),
    execution=fabric.ExecutionSpec(backend="auto", inner_steps=1),
)

Multimodal Boundaries

Graphs can name multiple input and output regions for multimodal boundaries.

import cortical.fabric as fabric

height = 16

graph = fabric.graphs.lattice2d.Graph(
    width=16,
    height=height,
    inputs={
        "vision": tuple(range(0, height)),
        "language": tuple(range(height, 2 * height)),
    },
    outputs={
        "policy": fabric.graphs.lattice2d.Output(tuple(range(14 * height, 15 * height))),
        "value": fabric.graphs.lattice2d.Output(tuple(range(15 * height, 16 * height))),
    },
    connectivity=[fabric.graphs.lattice2d.LocalRadius(radius=1.5)],
    populations={
        "core": fabric.Population(cell=fabric.cells.SLSTM(hidden_dim=32)),
    },
)

These names are graph facts. The compiled model still returns the runtime output tensor; output names remain available to backend and higher-level output surfaces.

Graphs Are The Body

The graph is the physical substrate of a Fabric model. It defines:

  • how many cells exist
  • where cells are located
  • which nodes receive input
  • which nodes produce named outputs
  • which nodes are recurrent internal substrate
  • which cells can communicate

The high-level graph constructor included today is fabric.graphs.lattice2d.Graph:

import cortical.fabric as fabric

graph = fabric.graphs.lattice2d.Graph(
    width=16,
    height=16,
    wrap=False,
    inputs={"sensory": fabric.graphs.lattice2d.XBand("low", width=2)},
    outputs={"action": fabric.graphs.lattice2d.Output(fabric.graphs.lattice2d.XBand("high", width=2))},
    connectivity=[fabric.graphs.lattice2d.LocalRadius(radius=2.0)],
    populations={
        "core": fabric.Population(cell=fabric.cells.SLSTM(hidden_dim=32)),
    },
)

Graph constructors can expose geometry-specific helpers. XBand, Region, and Corners make sense for a 2D lattice. A different graph family could expose different primitives while still normalizing to the same facts: nodes, coordinates, input groups, output groups, recurrent nodes, and edges.

For example, a future graph family could describe a mixture of cortical columns: repeated local columns with dense within-column communication and longer fibers between columns. That would still normalize to the same backend facts: node ids, coordinates or metadata, input groups, output groups, recurrent nodes, edges, and edge metadata.

For fully explicit connectivity, pass ExplicitEdges as graph connectivity:

import cortical.fabric as fabric

width = 8
height = 8
node_count = width * height

graph = fabric.graphs.lattice2d.Graph(
    width=width,
    height=height,
    inputs={"tokens": (0, 1, 2, 3)},
    outputs={"prediction": fabric.graphs.lattice2d.Output((60, 61, 62, 63))},
    connectivity=[
        fabric.graphs.lattice2d.ExplicitEdges(
            edges=tuple((receiver, (receiver + 1) % node_count) for receiver in range(node_count)),
        )
    ],
    populations={
        "core": fabric.Population(cell=fabric.cells.SLSTM(hidden_dim=32)),
    },
)

Cell Types

Fabric currently exposes two cell families:

Cell Meaning When to use it
cells.SLSTM(...) A stabilized gated recurrent cell. General-purpose recurrent substrate cells.
cells.AxonCell(...) A trace-based recurrent cell inspired by RTRL-style online credit assignment. Streaming settings that need compact traces of recent activity.

The Axon cell is especially relevant for streaming data. It keeps trace state that carries information about recent dynamics and gradient-relevant history, so the cell can support online recurrent behavior without requiring the user to turn the entire past stream into a replayed sequence.

Message Passing

Cells communicate through local message passing. A message rule is evaluated from the receiver's point of view: for each cell update, Fabric gathers the public signals from connected sender cells and computes the incoming message consumed by the receiver cell.

The default rule is dot-product local attention:

import cortical.fabric as fabric

message = fabric.message_rules.DotProduct(
    query=fabric.message_rules.ReceiverSlot(),
    key=fabric.message_rules.SenderPublic(),
    value=fabric.message_rules.SenderPublic(),
    num_heads=1,
    head_dim=32,
    kv_sharing=fabric.message_rules.ShareBySenderTile(tile_shape=(8, 8)),
)

ReceiverSlot means the query comes from the cell being updated. SenderPublic means keys and values come from the public outputs of connected neighbor cells. The graph supplies the neighbor list; the message rule supplies the differentiable communication math.

Training And State

Fabric modules are trained the same way as other PyTorch modules:

y, state = model(x, state)
loss = loss_fn(y, target)
loss.backward()
optimizer.step()

The loss can be placed on one output timestep, every output timestep, or any downstream head that consumes Fabric output. The backward pass trains the cells and the message-passing rule together.

For streaming use, callers keep the returned state and pass it into the next call. Reset masks mark stream boundaries.

y1, state = model(x1, state=None)
y2, state = model(x2, state=state)

This lets Fabric serve as a persistent memory system for data streams, sequence datasets, agents, or long-running contexts.

How Fabric Scales And Streams

Fabric exposes structural scaling axes that are different from transformer-style stacks:

Axis What it changes
Batch B How many independent data streams or sequences run at once.
Graph size N How many cells live in the substrate.
Connectivity How many neighbors each cell can read through message passing.
Cell width How much private state and public signal each cell carries.
Cell populations Which recurrent dynamics are assigned to which graph regions.
Weight sharing How parameters are shared across cells, tiles, or populations.

Transformer-style stacks usually scale by token length, depth, width, heads, and experts. Fabric scales by substrate area, local connectivity, cell width, cell families, and sharing policy. Adding cells expands the memory field. Changing the graph changes who can communicate. Increasing degree or message width changes the bandwidth of local communication.

Time T is different. T is the length of the current streaming window, not a new substrate dimension. Increasing T runs the same graph of cells for more steps with carried state. The backend should stream or checkpoint through that window rather than materializing a full [T, cells, state] recurrent surface. Per-step throughput should not drop just because T is larger. In other words, longer windows should mostly mean more steps of the same substrate, not a larger Fabric model.

That is why graph structure is part of the Fabric API. The user describes the substrate; the backend owns the physical routing, batching, workspace, and streaming execution needed to run it efficiently.

Backend Design

Fabric is built around one separation: declarations describe meaning, and the backend decides how that meaning runs.

The user-facing Python API declares a Blueprint: graph structure, input/output regions, cell populations, message_passing, shared dimensions, and backend preference. Normalization turns that into backend graph facts with flat node ids, input/output/recurrent node sets, neighbor tables, edge metadata, population assignments, and sharing groups. From that point on, the backend plans over graph facts. A 2D lattice, a future cortical-column constructor, and explicit graph primitives all normalize to the same kind of backend input.

Cells and message rules have a second declaration layer inside the backend: fabric.cuda.nn. This is the CUDA-side semantic layer where cells and message rules declare computation in backend-understood terms: sources, parameters, projections, reset policies, reduction boundaries, message aggregation, and emitted boundaries.

The lowering flow is:

Stage What Fabric knows
Blueprint User intent: graph, populations, message rule, dimensions, backend preference.
Normalized graph facts Node ids, boundary nodes, recurrent nodes, neighbor tables, edge metadata, population ids, sharing groups.
Cell and message IR Semantic math: state sources, public sources, parameters, affines, reductions, reset policy, message aggregation.
Step plan Runtime facts: batch size, active receivers, reset mask, state materialization request, streaming window, dtype, device.
Physical plan Execution choices: buckets, layouts, source packing, GEMM family, workspace aliases, tape policy, launch metadata.
Forward/backward PyTorch reference or CUDA execution with the same semantic boundaries.

The backend selects execution from shape and structure: batch size, receiver count, cell width, population buckets, source kind, reset policy, weight sharing, degree buckets, graph layout, dtype, device capability, workspace pressure, and streaming-window needs.

Cell Declarations

Most users choose existing cells:

graph = fabric.graphs.lattice2d.Graph(
    width=8,
    height=8,
    populations={
        "core": fabric.Population(cell=fabric.cells.SLSTM(hidden_dim=32)),
    },
)

Internally, a CUDA cell declares its transition through fabric.cuda.nn. The current sLSTM path declares private state, public output, parameters, two affine sources, and a reduction boundary:

fabric::cuda::nn::Builder builder;
builder.private_state("state", 3);
builder.public_tensor("public", 3);
builder.parameter("input_gate_weight", 3);
builder.parameter("recurrent_gate_weight", 3);
builder.parameter("input_gate_bias", 2);
builder.state_affine(
    fabric::cuda::nn::StateAffineSourceKind::ProjectedMessage,
    -1, 0, 6,
    receivers, projected_message_dim, gate_dim, 1,
    fabric::cuda::nn::ResetPolicy::None);
builder.state_affine(
    fabric::cuda::nn::StateAffineSourceKind::StatePrev,
    0, 5, -1,
    receivers, hidden, gate_dim, 1,
    fabric::cuda::nn::ResetPolicy::ZeroSourceRows);
builder.reduction_boundary(kReductionStatsDim);
return builder.build_cell_transition();

This declaration says:

  • the cell has private recurrent state and public output
  • one affine reads the incoming projected_message
  • one affine reads previous state
  • previous-state rows are zeroed on reset
  • a reduction boundary exists for backend-owned state/public emission work

The backend lowers those declarations automatically. For one shape it may pack sources and use a receiver-affine super-op. For another it may bucket populations differently, chunk receivers to fit workspace, or choose a different dense execution family. The cell-local code contains the true recurrent update. Graph routing, message aggregation, public projection, readout, sequence scheduling, and backward execution are planned by the backend.

Add a new cell only when the local recurrent transition itself is new. A real cell addition needs a user-facing declaration, PyTorch reference math, CUDA-side fabric.cuda.nn transition declaration, registration, parity tests, and profiling once correctness is established.

Message Rule Declarations

Message passing is separate from cell recurrence. A graph says which cells are connected. A message rule says how a receiver combines information from its connected senders.

The current public rule is dot-product local attention:

message = fabric.message_rules.DotProduct(
    query=fabric.message_rules.ReceiverSlot(),
    key=fabric.message_rules.SenderPublic(),
    value=fabric.message_rules.SenderPublic(),
    num_heads=1,
    head_dim=32,
    kv_sharing=fabric.message_rules.ShareBySenderTile(tile_shape=(8, 8)),
)

ReceiverSlot means the query comes from the receiver cell. SenderPublic means keys and values come from connected sender cells' public outputs. ShareBySenderTile says parameters are shared over sender groups.

The corresponding CUDA-side message rule also uses fabric.cuda.nn:

const int receiver_slot = builder.receiver_source(fabric::cuda::nn::MessageSourceKind::ReceiverSlot);
const int sender_public = builder.sender_source(
    fabric::cuda::nn::MessageSourceKind::SenderPublicPrev,
    fabric::cuda::nn::ResetPolicy::ZeroSourceRows,
    fabric::cuda::nn::ResetScope::BatchRow);
const int edge_distance = builder.edge_source(fabric::cuda::nn::MessageSourceKind::EdgeDistance);
const int q = builder.linear(receiver_slot, q_weight);
const int k = builder.linear(sender_public, k_weight);
const int v = builder.linear(sender_public, v_weight);
const int logits = builder.dot(q, k);
const int weights = builder.segment_softmax(
    builder.add(logits, edge_distance),
    fabric::cuda::nn::MessageSegmentKind::ReceiverNeighborhood);
const int mixed = builder.segment_weighted_sum(
    weights, v, fabric::cuda::nn::MessageSegmentKind::ReceiverNeighborhood);
builder.emit_projected_message(builder.linear(mixed, out_weight));

This declares receiver data, sender data, edge metadata, projections, segmented softmax, weighted aggregation, and the canonical projected_message boundary. The backend uses the graph and runtime shape to choose dense regular-neighborhood blocks, degree buckets, sparse segmented kernels, grouped leftovers, graph capture, and backward execution.

Add a new message_rules declaration only when the communication math itself is new. The rule should declare the semantic inputs, parameters, sharing policy, aggregation, reset behavior, and output boundary. The backend should then lower it from graph and shape metadata, the same way cells are lowered from CellTransitionIR.

Automatic Lowering

The same Fabric declaration can produce different physical programs as shapes change:

  • A single population can become one transition bucket; multiple populations become multiple buckets over the same graph.
  • Uniform affine shapes can use batched dense execution; ragged leftovers can use grouped execution.
  • Reset-aware state sources can be packed once and reused by all compatible affine buckets.
  • Regular local neighborhoods can use dense message blocks; irregular neighborhoods can use degree buckets and segmented reductions.
  • Large receiver sets can be chunked to fit workspace without changing cell semantics.
  • T windows can stream through rolling tape or checkpoint/recompute policy without changing the graph.
  • Backward reuses the same semantic boundaries: projected_message, state-affine outputs, raw public values, public projection, and output aggregation.

This is why Fabric can stay generic as it grows. A new cell usually means "new local recurrence." A new message rule usually means "new communication math." Changes to bucketing, tiling, layout, workspace, graph capture, message routing, readout, or backward execution belong in the backend.

Stacks

Stacks are Cortical's transformer-style architecture path. Use them for layered sequence backbones rather than spatial recurrent substrates.

The stack API is organized around:

  • Core: the stateful computation unit, such as LSTM, sLSTM, mLSTM, XL, AGaLiTe, or Axon.
  • Scaffold: the wrapper that owns projection, gating, normalization, residual structure, and routing.
  • Cell config: the public preset that pairs a scaffold with a core.
  • Column: one or more expert scaffolds; more than one expert uses routed mixing.
  • Stack: a layered composition of columns or scaffolds.

Stack Quick Start

This creates a two-layer stack. Each layer is a Column. Each Column below has four experts: Axon, XL, mLSTM, and sLSTM. Because each layer lists more than one cell config, the Column combines the experts as a mixture of experts.

import torch
from cortical.stacks import build_cortical_auto_stack
from cortical.stacks.cells import AxonCellConfig, XLCellConfig, mLSTMCellConfig, sLSTMCellConfig

stack = build_cortical_auto_stack(
    d_hidden=256,
    num_layers=2,
    layers=[
        [AxonCellConfig(), XLCellConfig(), mLSTMCellConfig(), sLSTMCellConfig()],
        [AxonCellConfig(), XLCellConfig(), mLSTMCellConfig(), sLSTMCellConfig()],
    ],
)

x = torch.randn(4, 32, 256)  # [batch, time, hidden]
y, state = stack(x, state=None)

For a plain stack without expert mixing, give each layer one cell config:

from cortical.stacks import build_cortical_auto_stack
from cortical.stacks.cells import XLCellConfig

stack = build_cortical_auto_stack(
    d_hidden=256,
    num_layers=2,
    layers=[
        [XLCellConfig()],
        [XLCellConfig()],
    ],
)

This creates two XL layers. Each Column has one expert, so the router and mixer paths are skipped.

Step mode uses [B, H]:

x_t = torch.randn(4, 256)
state = stack.init_state(batch=4, device=x_t.device, dtype=x_t.dtype)
y_t, state = stack.step(x_t, state)

Built-In Stack Cells

Cell config Default core Use case
AxonCellConfig() AxonCoreConfig Streaming diagonal recurrent memory
XLCellConfig() XLCoreConfig Transformer-XL style rolling attention memory
mLSTMCellConfig() mLSTMCoreConfig Matrix-LSTM sequence memory
sLSTMCellConfig() sLSTMCoreConfig Structured LSTM with stabilized accumulators
LSTMCellConfig() LSTMCoreConfig Plain recurrent baseline
CausalConv1dCellConfig() CausalConv1dCoreConfig Causal convolutional memory
AGaLiTeCellConfig() AGaLiTeCoreConfig Attention-style recurrent discounted state

Stacks share the same runtime shape as other stateful modules:

def forward(x, state, *, resets=None):
    return y, next_state

Inputs are [B, T, H] in sequence mode and [B, H] in step mode. State is a TensorDict nested by stack, column, scaffold, and core.

Package Layout

  • cortical.fabric: Fabric declarations, graphs, cells, message_rules, normalization, and runtime construction
  • cortical.fabric.backend.cuda: Fabric CUDA semantic declarations, kernels, backend operators, and physical lowering
  • cortical.stacks: stack builders, configs, cells, cores, scaffolds, columns, and adapters
  • cortical.scope: Fabric graph visualization, inference activity recording, and architecture sculpting
  • cortical.evaluations: lightweight synthetic tasks for checking memory behavior