A scalable, reasoning-centric neural architecture that unifies Mixture-of-Experts efficiency, dynamic compute allocation, long-context modeling (up to 256K tokens), and built-in self-verification.
⚠️ Research Project: This is an experimental architecture for research purposes. Not production-ready.
- Overview
- Key Features
- Architecture
- Installation
- Quick Start
- Configuration
- Training
- Inference
- Performance
- Advanced Usage
- Project Structure
- Testing
- Contributing
- FAQ
- License
The Adaptive Cognitive Graph (ACG) is a novel neural architecture designed for efficient large-scale language modeling with:
- Sparse Expert Activation: Uses Mixture-of-Experts with dynamic routing to activate only 2-16 experts per token from a pool of 8-64 experts
- Long-Context Processing: Hybrid Transformer + SSM architecture supports up to 256K tokens with linear-time complexity
- Cognitive Phase Routing: Intent-based routing system that determines both cognitive phase and expert selection
- Continual Learning: Per-expert LoRA adapters enable ongoing adaptation without catastrophic forgetting
- Self-Verification: Built-in verifier module assesses output quality and applies corrections when confidence is low
- Scalable Architecture: Designed to scale beyond 1T parameters while maintaining <40B active parameters per token
- Sparse Activation: Only 12-25% of model parameters active per forward pass
- Linear-Time Complexity: SSM blocks enable efficient processing of extreme context lengths
- Expert Parallelism: Distributes experts across GPUs for scalable training
- Dynamic Compute: Adapts computational resources based on input complexity
- Modular Design: Each component can be configured or replaced independently
- Multiple Topologies: Supports phase-grouped, fully-connected, and chain expert graphs
- Self-Verification: Confidence-based quality assessment with automatic correction
- Cross-Expert Communication: DAG topology enables lateral information flow between experts
- Balanced Routing: Entropy-based loss ensures diverse expert utilization
┌─────────────┐
│ Input Tokens│
└──────┬──────┘
│
▼
┌─────────────────┐
│SemanticEncoder │ ◄── Hybrid Transformer + SSM
│(Contextual Rep) │ (Linear-time long context)
└────────┬────────┘
│
├──────────────────────────────────┐
│ │
▼ │ (Residual)
┌─────────────────┐ │
│ IntentRouter │ ◄── Phase + Expert │
│ (Dynamic Route) │ Selection │
└────────┬────────┘ │
│ │
▼ │
┌─────────────────┐ │
│ GraphOfExperts │ ◄── Sparse Activation │
│ (DAG of 8-64) │ (12-25% active) │
└────────┬────────┘ │
│ │
▼ │
┌─────────────────┐ │
│ MemoryAdapter │ ◄── LoRA-based │
│(Continual Learn)│ Adaptation │
└────────┬────────┘ │
│ │
▼ │
┌─────────────────┐ │
│ VerifierExpert │ ◄── Quality Check │
│ (Confidence) │ + Correction │
└────────┬────────┘ │
│ │
▼ │
┌─────────────────┐ │
│OutputFusionLayer│ ◄─────────────────────┘
│ (Final Logits) │
└────────┬────────┘
│
▼
┌─────────────────┐
│ Predictions │
│ + Confidence │
└─────────────────┘
Mathematical Formulation:
For input tokens
# Token Embedding with RoPE
E = Embed(X) + RoPE(positions)
# Hybrid Block (interleaved every ssm_every layers)
for layer in range(n_layers):
if layer % ssm_every == 0:
# State Space Model (Mamba-style)
H = SSM(E, Δ, A, B, C)
else:
# Multi-Head Attention
Q, K, V = W_q·E, W_k·E, W_v·E
Attn = softmax(QK^T / √d_k)V
H = LayerNorm(E + Attn)
# SwiGLU Feed-Forward
FFN = (W_1·H ⊙ σ(W_2·H)) · W_3
E = LayerNorm(H + FFN)
Output: H ∈ ℝ^(B×L×d)
Key Properties:
- RoPE encoding: Rotary position embeddings for better length extrapolation
- SSM blocks: O(L) complexity for long sequences vs O(L²) for attention
- SwiGLU: Gated activation for better gradient flow
Mathematical Formulation:
# Phase Classification
φ = softmax(W_phase · pool(H)) # φ ∈ ℝ^P (P phases)
# Expert Scoring
s = W_expert · H # s ∈ ℝ^(B×L×E) (E experts)
# Phase-Gated Expert Selection
g = φ ⊗ s # Element-wise modulation
top_k_indices, top_k_gates = TopK(g, k=K)
# Normalize gates
G = softmax(top_k_gates) # G ∈ ℝ^(B×L×K)
Output: (expert_ids, gates, phase_probs)
Routing Balance Loss:
# Entropy-based load balancing
p_avg = mean(G, dim=[0,1]) # Average expert usage
H_routing = -Σ p_i · log(p_i)
H_max = log(E)
L_balance = 1 - H_routing / H_max
Mathematical Formulation:
# Expert Processing with DAG
for phase p in range(P):
experts_in_phase = get_experts(phase=p)
for expert e in experts_in_phase:
# Expert computation
h_e = Expert_e(H, routing_gates)
# Cross-edge communication (if enabled)
if enable_cross_edges:
h_e = h_e + Σ α_ij · h_j # α_ij: learned edge weights
outputs[e] = h_e
# Weighted aggregation
H_expert = Σ G_i · h_i # G_i: routing gates
Expert Architecture:
Expert_e(x):
for layer in range(expert_layers):
x = LayerNorm(x + MHA(x))
x = LayerNorm(x + FFN(x))
return x
LoRA Formulation:
# Standard layer: y = Wx
# LoRA-adapted: y = Wx + ΔW·x where ΔW = BA
W_adapted = W_frozen + (B · A) / α
where:
B ∈ ℝ^(d×r) # Down-projection
A ∈ ℝ^(r×d) # Up-projection
r << d # Low rank (4-64)
α: scaling factor (0.1-0.5)
EMA Update:
# After each training step
θ_ema = β · θ_ema + (1-β) · θ_current
where β ∈ [0.85, 0.95] (ema_decay)
Mathematical Formulation:
# Confidence Scoring
z = MLP([H_expert; H_encoder]) # Concatenate representations
c = σ(W_conf · z) # c ∈ [0,1] (confidence)
# Correction Mechanism
if c < threshold:
H_corrected = Corrector(H_expert, H_encoder)
else:
H_corrected = H_expert
Output: (H_corrected, c)
Verification Loss:
L_verify = BCE(c, y_correct)
where y_correct ∈ {0,1} indicates correctness
Mathematical Formulation:
# Per-expert projections
O_i = W_i · H_expert_i # for each active expert
# Confidence-weighted fusion
O_fused = Σ (c_i · G_i · O_i)
# Residual connection
O_final = O_fused + W_residual · H_encoder
# Vocabulary projection
logits = W_vocab · O_final # logits ∈ ℝ^(B×L×V)
Algorithm: ACG Forward Pass
Input: token_ids ∈ ℝ^(B×L)
Output: (logits ∈ ℝ^(B×L×V), confidence ∈ ℝ^B)
1. H_enc = SemanticEncoder(token_ids)
2. (expert_ids, gates, phases) = IntentRouter(H_enc)
3. H_expert = GraphOfExperts(H_enc, expert_ids, gates)
4. H_adapted = MemoryAdapter(H_expert, expert_ids)
5. (H_verified, c) = VerifierExpert(H_adapted, H_enc)
6. logits = OutputFusion(H_verified, c, H_enc)
7. return (logits, c)Total Loss:
L_total = L_lm + α·L_balance + β·L_verify + γ·L_phase
where:
L_lm = CrossEntropy(logits, targets)
L_balance = 1 - H(expert_usage) / log(E)
L_verify = BCE(confidence, correctness)
L_phase = CrossEntropy(phase_logits, phase_labels)
α ∈ [0.001, 0.1] # balance weight
β ∈ [0.01, 0.1] # verify weight
γ ∈ [0.01, 0.1] # phase weight
| Component | Time Complexity | Space Complexity |
|---|---|---|
| Encoder (Attention) | O(L² · d) | O(L² + L·d) |
| Encoder (SSM) | O(L · d²) | O(L · d) |
| Router | O(L · d · E) | O(L · E) |
| Experts (sparse) | O(L · d² · K/E) | O(d² · E) |
| Verifier | O(L · d²) | O(L · d) |
| Fusion | O(L · d · V) | O(d · V) |
| Total (per token) | O(L · d²) | O(L · d + d² · E) |
Active Parameters: ~12-25% of total (K/E ratio)
- Python 3.8+
- PyTorch 2.1+
- CUDA 11.8+ (for GPU training)
- 8GB+ GPU memory (small model) or 32GB+ (medium/large models)
# Clone repository
git clone https://github.com/yourusername/acg.git
cd acg
# Install dependencies
pip install -r requirements.txt
# Install package
pip install -e .torch>=2.1.0
numpy>=1.24.0
tqdm>=4.65.0
Optional dependencies for distributed training:
deepspeed>=0.12.0
transformers>=4.35.0
from acg import ACGConfig, ACGModel
import torch
# Create configuration
config = ACGConfig(
d_model=2048,
n_experts=32,
active_experts=4,
n_layers=8,
max_seq_len=2048
)
# Initialize model
model = ACGModel(config)
# Forward pass
token_ids = torch.randint(0, config.vocab_size, (4, 512))
logits, confidence = model(token_ids)
print(f"Output shape: {logits.shape}") # (4, 512, 50257)
print(f"Confidence: {confidence}") # (4,)from acg import ACGConfig, ACGModel
from acg.training import ACGTrainer
import torch
# Setup
config = ACGConfig()
model = ACGModel(config).cuda()
trainer = ACGTrainer(model, config, use_mixed_precision=True)
# Training step
token_ids = torch.randint(0, config.vocab_size, (8, 512)).cuda()
targets = torch.randint(0, config.vocab_size, (8, 512)).cuda()
metrics = trainer.training_step(token_ids, targets)
print(f"Loss: {metrics['total_loss']:.4f}")Train a model:
python examples/train.py \
--model-size medium \
--batch-size 8 \
--num-epochs 10 \
--checkpoint-dir checkpoints \
--mixed-precisionRun inference:
python examples/inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--mode generate \
--prompt "Your prompt here" \
--max-length 100Three predefined configurations are available in examples/configs/:
- 8 experts, 2 active per token
- 512 model dimension, 4 layers
- Suitable for single GPU training
- Good for experimentation and prototyping
python examples/train.py --model-size small- 32 experts, 4 active per token
- 2048 model dimension, 8 layers
- Requires 4-8 GPUs for training
- Production-ready configuration
python examples/train.py --model-size medium- 64 experts, 8 active per token
- 4096 model dimension, 16 layers
- Requires 32-128 GPUs for training
- Scales to 1T+ parameters with <40B active
python examples/train.py --model-size largeCreate a custom configuration file:
{
"d_model": 2048,
"n_heads": 32,
"n_layers": 8,
"n_experts": 32,
"active_experts": 4,
"n_phases": 4,
"max_seq_len": 32768,
"use_memory_adapter": true,
"lora_rank": 16,
"verifier_threshold": 0.5
}Load in Python:
import json
from acg import ACGConfig
with open('my_config.json') as f:
config_dict = json.load(f)
config = ACGConfig(**config_dict)| Parameter | Range | Default | Description |
|---|---|---|---|
d_model |
512-4096 | 2048 | Model dimension |
n_heads |
divisor of d_model | 32 | Attention heads |
n_layers |
4-16 | 8 | Encoder layers |
n_experts |
8-64 | 64 | Total experts |
active_experts |
2-16 | 4 | Active experts per token |
n_phases |
2-8 | 4 | Cognitive phases |
max_seq_len |
512-256000 | 256000 | Max context length |
ssm_every |
1-4 | 2 | SSM block interval |
ffn_mult |
2-8 | 4 | FFN dimension multiplier |
dropout |
0.0-0.3 | 0.1 | Dropout rate |
lora_rank |
4-64 | 16 | LoRA adapter rank |
verifier_threshold |
0.3-0.7 | 0.5 | Confidence threshold |
See acg/config.py for complete parameter list and validation rules.
The examples/train.py script provides a complete training pipeline:
python examples/train.py \
--config examples/configs/medium.json \
--batch-size 16 \
--num-epochs 10 \
--learning-rate 1e-4 \
--checkpoint-dir checkpoints \
--log-dir logs \
--save-every 1000 \
--log-every 100 \
--mixed-precision- Mixed Precision Training: Automatic fp16/bf16 support with loss scaling
- Gradient Checkpointing: Reduces memory usage for large models
- Checkpoint Management: Automatic saving of best and latest checkpoints
- Metrics Logging: Comprehensive tracking of loss components and diagnostics
- Resume Training: Load from checkpoint to continue training
The training objective combines multiple loss terms:
- Language Modeling Loss: Cross-entropy on next token prediction
- Routing Balance Loss: Entropy-based expert utilization (weight: 0.001-0.1)
- Verification Loss: Binary cross-entropy on correctness (weight: 0.01-0.1)
- Phase Classification Loss: Optional phase supervision (weight: 0.01-0.1)
Total loss: L = L_lm + α*L_balance + β*L_verify + γ*L_phase
For multi-GPU training, use DeepSpeed or PyTorch FSDP:
# DeepSpeed
deepspeed --num_gpus=8 examples/train.py \
--model-size large \
--deepspeed ds_config.json
# PyTorch FSDP
torchrun --nproc_per_node=8 examples/train.py \
--model-size large \
--use-fsdpThe examples/inference.py script supports multiple inference modes:
python examples/inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--mode generate \
--prompt "Once upon a time" \
--max-length 200 \
--temperature 0.8 \
--top-k 50 \
--top-p 0.9python examples/inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--mode evaluate \
--input-file eval_data.txt \
--batch-size 8 \
--output-file results.jsonpython examples/inference.py \
--checkpoint checkpoints/best_checkpoint.pt \
--mode interactive- torch.compile: Use
--compileflag for 2x speedup (PyTorch 2.0+) - Batch Processing: Process multiple prompts efficiently
- Mixed Precision: Automatic fp16/bf16 inference
- KV Cache: (Future) Cache attention keys/values for faster generation
| Model Size | Total Params | Active Params | Memory (bf16) | Throughput* |
|---|---|---|---|---|
| Small | ~50M | ~15M | ~200MB | ~5K tok/s |
| Medium | ~7B | ~1.5B | ~14GB | ~2K tok/s |
| Large | ~100B | ~20B | ~200GB | ~500 tok/s |
*Throughput measured on A100 80GB with batch size optimized for hardware
Thanks to SSM blocks, ACG maintains linear complexity for long contexts:
| Context Length | Memory | Latency |
|---|---|---|
| 2K tokens | 1x | 1x |
| 8K tokens | 1.2x | 1.3x |
| 32K tokens | 1.8x | 2.1x |
| 128K tokens | 3.5x | 4.2x |
| 256K tokens | 6.8x | 8.5x |
Compare to standard Transformer O(n²) complexity which would be 128x slower at 256K tokens.
The training script automatically tracks:
- Loss Components: LM loss, balance loss, verify loss, phase loss
- Routing Metrics: Expert utilization, routing entropy, phase distribution
- Verifier Metrics: Mean confidence, correction rate
- Training Metrics: Gradient norm, learning rate, throughput
- Memory Metrics: Peak memory usage, active parameters
Metrics are saved to JSON files in the log directory:
import json
import matplotlib.pyplot as plt
# Load metrics
with open('logs/metrics.json') as f:
metrics = json.load(f)
# Plot loss curve
plt.plot(metrics['steps'], metrics['total_loss'])
plt.xlabel('Training Steps')
plt.ylabel('Total Loss')
plt.savefig('loss_curve.png')from acg import ACGConfig, ACGModel
config = ACGConfig(
n_experts=32,
active_experts=4,
dag_topology='fully_connected', # or 'phase_grouped', 'chain'
enable_cross_edges=True
)
model = ACGModel(config)config = ACGConfig(
use_memory_adapter=True,
lora_rank=32, # Higher rank = more capacity
lora_alpha=0.3, # Scaling factor
ema_decay=0.92 # EMA smoothing
)config = ACGConfig(
verifier_threshold=0.5, # Confidence threshold
verifier_hidden=1024, # Hidden dimension
correction_type='transformer' # or 'mlp'
)acg/
├── __init__.py # Package initialization
├── config.py # Configuration dataclass with validation
├── model.py # Main ACG model
├── training.py # Training utilities and loss functions
├── models/ # Core model components
│ ├── __init__.py
│ ├── base.py # Base module interface
│ ├── encoder.py # SemanticEncoder
│ ├── router.py # IntentRouter
│ ├── experts.py # GraphOfExperts
│ ├── adapter.py # MemoryAdapter
│ ├── verifier.py # VerifierExpert
│ └── fusion.py # OutputFusionLayer
└── utils/ # Utility functions
├── __init__.py
├── routing.py # RoutingMap data structure
├── validation.py # Input validation utilities
├── metrics.py # Metrics tracking
├── logging.py # Logging setup
├── deployment.py # Deployment utilities
├── deepspeed_integration.py # DeepSpeed integration
└── parallelism.py # Parallelism utilities
examples/
├── train.py # Training script
├── inference.py # Inference script
└── configs/ # Predefined configurations
├── small.json
├── medium.json
└── large.json
If you use ACG in your research, please cite:
@article{acg2025,
title={Adaptive Cognitive Graph: A Scalable Architecture for Reasoning-Centric Language Models},
author={afif amir},
year={2025}
}A: Yes! Use the small config (~50M params) with CPU mode. See the "Quick Start" section.
A: ACG uses sparse expert activation (12-25% of params active) and SSM blocks for linear-time long-context processing, making it more efficient for large-scale models.
A: Yes, under MIT License. However, this is experimental research code - use at your own risk.
A: See the Contributing section. We welcome bug reports, feature requests, and pull requests!
A: Pre-trained weights are not yet available. You'll need to train from scratch using the provided scripts.
- DeepSpeed integration requires additional testing
- Flash Attention support is experimental
- Very long contexts (>128K) may require gradient checkpointing
- Expert load balancing may need tuning for specific datasets
- Complete implementation of all 6 core components
- Training and inference scripts
- Three predefined model sizes
- Comprehensive documentation
- Unit test coverage
- Initial architecture design
- Configuration system
- Basic component interfaces
Apache License 2.0
Copyright 2025 afif amir
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
See the LICENSE file for the full license text.
This architecture builds on ideas from:
- Mixture-of-Experts (Shazeer et al., 2017)
- State Space Models (Gu et al., 2021)
- LoRA (Hu et al., 2021)
- Self-Verification in LLMs (Madaan et al., 2023)
- Documentation: See this README and inline code documentation
- Issues: GitHub Issues - Report bugs or request features
- Discussions: GitHub Discussions - Ask questions and share ideas
- Pull Requests: Contributions welcome! See Contributing section
- Check the FAQ section above
- Search existing GitHub issues
- Review the example scripts in
examples/ - Open a new issue with detailed information
If you find this project useful, please consider giving it a ⭐ on GitHub!