Skip to content

Commit 3bcde1a

Browse files
feat: pure-JS LSTM/GRU + walk-forward + HMM regime [3.8.0]
1 parent 8a9b9a2 commit 3bcde1a

11 files changed

Lines changed: 897 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [3.8.0] - 2026-05-03
9+
10+
### Added
11+
- **Pure-JS RNN cells** (`src/ml/lstm.ts`, `src/ml/gru.ts`):
12+
- `LSTMCell` with `[i, f, g, o]` gate ordering, weights `Wi[4H×N]`, `Wh[4H×H]`, `b[4H]`.
13+
- `GRUCell` with `[r, z, n]` gate ordering, weights `Wi`, `Wh`, `bi`, `bh`.
14+
- Forward pass only — load pre-trained weights from Python/JAX for inference.
15+
- **Walk-forward validation** (`src/ml/walk-forward.ts`):
16+
- `walkForward(X, y, cfg)` with `expanding` and `rolling` modes.
17+
- Per-fold MSE/MAE/R² plus combined out-of-sample predictions.
18+
- **Feature engineering** (`src/ml/feature-engineer.ts`):
19+
- `lagFeatures`, `rollingMean`, `rollingStd`, `logReturns`, `simpleReturns`, `zScore`, `minMaxScale`, `diff`.
20+
- **HMM regime detection** (`src/ml/hmm-regime.ts`):
21+
- `trainHMM` — Gaussian-emission Baum-Welch with scaled forward-backward.
22+
- `viterbi` — log-domain decoding.
23+
- Docs: `docs/ML.md`.
24+
- Tests: `tests/ml.test.ts` (10 tests, including 2-state regime decode >70% accuracy).
25+
826
## [3.7.0] - 2026-05-03
927

1028
### Added

docs/ML.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Machine Learning (Pure-JS)
2+
3+
Zero-dependency ML primitives for time-series forecasting and regime detection.
4+
No native bindings, no tfjs — runs in Node, Deno, Bun, browsers.
5+
6+
## Modules
7+
8+
| Module | Purpose |
9+
|--------|---------|
10+
| `LSTMCell` / `GRUCell` | Forward-pass RNN cells |
11+
| `walkForward` | Time-series cross-validation |
12+
| Feature engineering | Lags, rolling stats, returns, scaling |
13+
| `trainHMM` / `viterbi` | Gaussian HMM regime detection |
14+
15+
## LSTM / GRU Forward Pass
16+
17+
Pre-trained weights only (no backprop). Useful for deploying models trained
18+
elsewhere (Python, JAX) into pure-JS runtimes.
19+
20+
```typescript
21+
import { LSTMCell, randomLSTMWeights } from 'meridianalgo';
22+
23+
const inputSize = 5;
24+
const hiddenSize = 16;
25+
const cell = new LSTMCell(randomLSTMWeights(inputSize, hiddenSize));
26+
27+
const sequence = [/* ... [number[]] ... */];
28+
const { h, c } = cell.forward(sequence);
29+
// h: final hidden state (length hiddenSize)
30+
```
31+
32+
GRU API mirrors LSTM, returns only `h`.
33+
34+
Gate ordering:
35+
- LSTM: `[i, f, g, o]` (input, forget, candidate, output)
36+
- GRU: `[r, z, n]` (reset, update, candidate)
37+
38+
Weight shapes:
39+
- LSTM: `Wi[4H × N]`, `Wh[4H × H]`, `b[4H]`
40+
- GRU: `Wi[3H × N]`, `Wh[3H × H]`, `bi[3H]`, `bh[3H]`
41+
42+
## Walk-Forward Validation
43+
44+
Time-series CV that respects causality. Two modes:
45+
46+
```typescript
47+
import { walkForward } from 'meridianalgo';
48+
49+
const result = walkForward(X, y, {
50+
mode: 'expanding', // or 'rolling'
51+
initialTrainSize: 100,
52+
testSize: 20,
53+
step: 20,
54+
fit: (Xtrain, ytrain) => trainModel(Xtrain, ytrain),
55+
predict: (model, Xtest) => model.predict(Xtest),
56+
});
57+
58+
result.folds; // per-fold {predictions, actual, mse, mae, rSquared}
59+
result.combinedPredictions; // concatenated out-of-sample predictions
60+
result.meanMse;
61+
result.meanMae;
62+
```
63+
64+
- **Expanding**: train window grows each fold (`[0, end)`).
65+
- **Rolling**: train window slides at fixed size.
66+
67+
## Feature Engineering
68+
69+
```typescript
70+
import {
71+
lagFeatures, rollingMean, rollingStd,
72+
logReturns, simpleReturns,
73+
zScore, minMaxScale, diff,
74+
} from 'meridianalgo';
75+
76+
const lags = lagFeatures(prices, [1, 5, 10]); // matrix [n × 3]
77+
const ma = rollingMean(prices, 20);
78+
const sd = rollingStd(prices, 20);
79+
const r = logReturns(prices);
80+
const z = zScore(values); // (x - μ)/σ
81+
const scaled = minMaxScale(values); // [0, 1]
82+
```
83+
84+
NaN-padded arrays preserve index alignment with original series.
85+
86+
## HMM Regime Detection
87+
88+
Gaussian-emission HMM with Baum-Welch training (forward-backward + scaling)
89+
and Viterbi decoding in log-domain.
90+
91+
```typescript
92+
import { trainHMM, viterbi } from 'meridianalgo';
93+
94+
const observations = returns; // 1-D series
95+
const k = 2; // num regimes (e.g. bull/bear)
96+
97+
const { params, logLik, iterations } = trainHMM(observations, k, {
98+
maxIter: 100,
99+
tol: 1e-4,
100+
});
101+
102+
const states = viterbi(observations, params);
103+
// states[t] ∈ {0, 1, ..., k-1}
104+
```
105+
106+
`HMMParams`:
107+
- `pi[k]` — initial state distribution
108+
- `A[k][k]` — transition matrix
109+
- `mu[k]`, `sigma[k]` — emission Gaussian params
110+
111+
## Limitations
112+
113+
- No autograd / training for RNN cells (forward-only).
114+
- HMM Gaussian emissions are univariate scalar.
115+
- For heavy training workloads use Python — load weights here for inference.
116+
117+
See also: `INDICATORS-PATTERNS.md` for streaming indicators that pair well
118+
with online ML inference.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "meridianalgo",
3-
"version": "3.7.0",
3+
"version": "3.8.0",
44
"description": "Professional-grade quantitative finance framework for JavaScript/TypeScript - algorithmic trading, backtesting, risk management, and portfolio optimization",
55
"main": "dist/index.js",
66
"module": "dist/index.mjs",

src/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,8 @@ export * from './indicators/range-vol';
7272
// Microstructure (order book, spread estimators, market impact)
7373
export * from './microstructure';
7474

75+
// Machine learning (LSTM/GRU forward pass, walk-forward, features, HMM regimes)
76+
export * from './ml';
77+
7578
// CLI (not exported as class usually, but internal)
7679
// export * from './cli';

src/ml/feature-engineer.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/**
2+
* Feature engineering helpers — lag features, rolling statistics, returns,
3+
* z-score normalization, and standard time-series transforms.
4+
*/
5+
6+
/** Build lag features: [x_{t-1}, x_{t-2}, ..., x_{t-K}]. */
7+
export function lagFeatures(xs: readonly number[], lags: number): number[][] {
8+
if (lags <= 0) throw new Error('lagFeatures: lags must be > 0');
9+
const out: number[][] = [];
10+
for (let i = lags; i < xs.length; i++) {
11+
const row: number[] = [];
12+
for (let k = 1; k <= lags; k++) row.push(xs[i - k]);
13+
out.push(row);
14+
}
15+
return out;
16+
}
17+
18+
/** Rolling mean. */
19+
export function rollingMean(xs: readonly number[], n: number): number[] {
20+
const out: number[] = new Array(xs.length).fill(NaN);
21+
let sum = 0;
22+
for (let i = 0; i < xs.length; i++) {
23+
sum += xs[i];
24+
if (i >= n) sum -= xs[i - n];
25+
if (i >= n - 1) out[i] = sum / n;
26+
}
27+
return out;
28+
}
29+
30+
/** Rolling standard deviation. */
31+
export function rollingStd(xs: readonly number[], n: number): number[] {
32+
const out: number[] = new Array(xs.length).fill(NaN);
33+
for (let i = n - 1; i < xs.length; i++) {
34+
let mean = 0;
35+
for (let j = i - n + 1; j <= i; j++) mean += xs[j];
36+
mean /= n;
37+
let v = 0;
38+
for (let j = i - n + 1; j <= i; j++) v += (xs[j] - mean) ** 2;
39+
out[i] = Math.sqrt(v / (n - 1));
40+
}
41+
return out;
42+
}
43+
44+
/** Log returns. */
45+
export function logReturns(prices: readonly number[]): number[] {
46+
const out: number[] = new Array(prices.length).fill(NaN);
47+
for (let i = 1; i < prices.length; i++) out[i] = Math.log(prices[i] / prices[i - 1]);
48+
return out;
49+
}
50+
51+
/** Simple returns. */
52+
export function simpleReturns(prices: readonly number[]): number[] {
53+
const out: number[] = new Array(prices.length).fill(NaN);
54+
for (let i = 1; i < prices.length; i++) out[i] = prices[i] / prices[i - 1] - 1;
55+
return out;
56+
}
57+
58+
/** Z-score normalization (population stats). */
59+
export function zScore(xs: readonly number[]): number[] {
60+
const n = xs.length;
61+
if (n === 0) return [];
62+
const mean = xs.reduce((s, v) => s + v, 0) / n;
63+
let v = 0;
64+
for (const x of xs) v += (x - mean) ** 2;
65+
const sd = Math.sqrt(v / n);
66+
if (sd === 0) return new Array(n).fill(0);
67+
return xs.map((x) => (x - mean) / sd);
68+
}
69+
70+
/** Min-max normalization to [0, 1]. */
71+
export function minMaxScale(xs: readonly number[]): number[] {
72+
if (xs.length === 0) return [];
73+
let lo = Infinity;
74+
let hi = -Infinity;
75+
for (const x of xs) {
76+
if (x < lo) lo = x;
77+
if (x > hi) hi = x;
78+
}
79+
const r = hi - lo;
80+
if (r === 0) return xs.map(() => 0);
81+
return xs.map((x) => (x - lo) / r);
82+
}
83+
84+
/** First difference (xs[i] - xs[i-1]). */
85+
export function diff(xs: readonly number[], lag = 1): number[] {
86+
const out: number[] = new Array(xs.length).fill(NaN);
87+
for (let i = lag; i < xs.length; i++) out[i] = xs[i] - xs[i - lag];
88+
return out;
89+
}

src/ml/gru.ts

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/**
2+
* Pure-JS GRU cell — forward-pass inference only.
3+
*/
4+
5+
export interface GRUWeights {
6+
/** input-to-hidden: shape [3*hiddenSize, inputSize]. Order: [r, z, n]. */
7+
Wi: number[][];
8+
/** hidden-to-hidden: shape [3*hiddenSize, hiddenSize]. */
9+
Wh: number[][];
10+
/** input bias: shape [3*hiddenSize]. */
11+
bi: number[];
12+
/** hidden bias: shape [3*hiddenSize]. */
13+
bh: number[];
14+
}
15+
16+
const sigmoid = (x: number) => 1 / (1 + Math.exp(-x));
17+
const tanh = Math.tanh;
18+
19+
function matVec(M: readonly (readonly number[])[], v: readonly number[]): number[] {
20+
const out = new Array(M.length).fill(0);
21+
for (let i = 0; i < M.length; i++) {
22+
let s = 0;
23+
const row = M[i];
24+
for (let j = 0; j < v.length; j++) s += row[j] * v[j];
25+
out[i] = s;
26+
}
27+
return out;
28+
}
29+
30+
export class GRUCell {
31+
readonly hiddenSize: number;
32+
readonly inputSize: number;
33+
34+
constructor(public readonly weights: GRUWeights) {
35+
const threeH = weights.bi.length;
36+
if (threeH % 3 !== 0) throw new Error('GRUCell: bias length must be multiple of 3');
37+
this.hiddenSize = threeH / 3;
38+
this.inputSize = weights.Wi[0].length;
39+
}
40+
41+
step(x: readonly number[], h: readonly number[]): number[] {
42+
const H = this.hiddenSize;
43+
const xi = matVec(this.weights.Wi, x);
44+
const hi = matVec(this.weights.Wh, h);
45+
const r: number[] = new Array(H);
46+
const z: number[] = new Array(H);
47+
const n: number[] = new Array(H);
48+
for (let k = 0; k < H; k++) {
49+
r[k] = sigmoid(xi[k] + this.weights.bi[k] + hi[k] + this.weights.bh[k]);
50+
z[k] = sigmoid(xi[k + H] + this.weights.bi[k + H] + hi[k + H] + this.weights.bh[k + H]);
51+
n[k] = tanh(xi[k + 2 * H] + this.weights.bi[k + 2 * H] + r[k] * (hi[k + 2 * H] + this.weights.bh[k + 2 * H]));
52+
}
53+
const out: number[] = new Array(H);
54+
for (let k = 0; k < H; k++) out[k] = (1 - z[k]) * n[k] + z[k] * h[k];
55+
return out;
56+
}
57+
58+
forward(sequence: readonly (readonly number[])[]): number[] {
59+
let h = new Array(this.hiddenSize).fill(0);
60+
for (const x of sequence) h = this.step(x, h);
61+
return h;
62+
}
63+
}
64+
65+
export function randomGRUWeights(inputSize: number, hiddenSize: number, scale = 0.1, seed = 1): GRUWeights {
66+
let s = seed;
67+
const rng = () => {
68+
s = (s * 1664525 + 1013904223) | 0;
69+
return ((s >>> 0) / 0x100000000 - 0.5) * 2;
70+
};
71+
const threeH = 3 * hiddenSize;
72+
const Wi: number[][] = Array.from({ length: threeH }, () =>
73+
Array.from({ length: inputSize }, () => rng() * scale),
74+
);
75+
const Wh: number[][] = Array.from({ length: threeH }, () =>
76+
Array.from({ length: hiddenSize }, () => rng() * scale),
77+
);
78+
return { Wi, Wh, bi: new Array(threeH).fill(0), bh: new Array(threeH).fill(0) };
79+
}

0 commit comments

Comments
 (0)