-
Notifications
You must be signed in to change notification settings - Fork 3
Description
MNIST-CNN Integration Module — PRD (mnist-int)
Summary
Create a standalone, reusable Gradle/Kotlin Multiplatform module that makes it trivial for apps to load and run the pre-trained MNIST CNN model. The module abstracts model format IO via io-core, uses io-gguf for GGUF parameter loading (initial implementation), embeds a pre-trained GGUF binary as a resource, and exposes a coroutine/Flow-based API for lifecycle and inference. This enables future swaps of model storage formats without changing the consuming app code.
Problem Statement
- Apps currently need to know about model loading, resource handling, and execution details for MNIST CNN.
- Model format/loader dependencies leak into apps, increasing coupling.
- No consistent lifecycle or shared-instance pattern for loading and executing the model across app features.
Goals
- Provide an app-friendly module that:
- Loads a pre-trained MNIST CNN model with progress reporting.
- Executes inference on 1x28x28 grayscale images (or batched inputs) and returns logits/probabilities.
- Offers a Flow-based shared model lifecycle: loading → ready → running → unloaded → error.
- Hides model format and loader (GGUF initially), enabling future replacements via
io-coreabstraction. - Works in multiplatform settings with platform-appropriate resource loading.
Non-Goals
- Training the model or providing training utilities in this module.
- General-purpose CNN abstractions beyond what is already in
skainet-lang-models. - Full-featured image pre/post-processing pipelines (provide minimal helpers only).
Users
- App developers in
skainet-apps(e.g., KGPChat demo or other apps) who need MNIST classification without dealing with model internals.
Current Capabilities and Architecture (as-is)
skainet-lang-modelsdefines MNIST CNN network structures (sk/ainet/lang/model/dnn/cnn/MNIST.kt) and helpers (e.g., grayscale conversion).skainet-io-ggufprovides GGUF parameter loading; the project hints at a progress-capable loader.io-coreabstraction exists (inskainet-io) to unify parameter load sources independently of file format.- Prior RFC/notes in
mnist-gguf.mdoutline a Shared/Flow-based lifecycle and resource packaging approach.
Gaps:
- No dedicated integration module that:
- Bundles the model file as a resource.
- Wires the CNN definition to the GGUF loader via
io-core. - Exposes a stable, simple API surface for apps.
Proposal: New Module skainet-int/skainet-int-mnist-cnn
A KMP module delivering a single responsibility: load and run the pre-trained MNIST CNN with minimal setup, using Flow to expose lifecycle, and abstracting IO.
High-level Architecture
flowchart LR
A[App/UI Layer] -->|observe state, call run| B[MnistCnnShared]
B -->|build network| C[MnistCnnModule]
C -->|load params via| D[io-core Loader API]
D -->|GGUF impl| E[io-gguf]
C -->|ops| F[ExecutionContext]
C -->|model def| G[skainet-lang-models]
C -->|resource bytes| H[(embedded gguf)]
MnistCnnSharedprovides a coroutine-safe shared instance for model loading and inference, exposingStateFlow<MnistCnnSharedState>.MnistCnnModuleencapsulates the assembled network and forward pass.io-coredefines the loader interface (format-agnostic).io-ggufis the initial concrete implementation.- The GGUF file is embedded as a resource; platform-specific resource access is abstracted.
Module Boundaries
- Public API:
MnistCnnShared,MnistCnnSharedState, minimal helpers for input shaping. - Internal:
MnistCnnModuleassembly, resource accessors, loader wiring. - Dependencies:
skainet-lang-modelsfor MNIST CNN topology and tensor types.skainet-io:io-corefor loader abstractions.skainet-io:io-gguffor GGUF implementation (runtime dep; swappable later).- Kotlin coroutines/Flow.
Detailed Design
Public API (commonMain)
package sk.ainet.int.mnist
import kotlinx.coroutines.flow.StateFlow
import sk.ainet.lang.tensor.*
import sk.ainet.lang.dtype.FP32
import sk.ainet.lang.exec.ExecutionContext
sealed interface MnistCnnSharedState {
data object Unloaded : MnistCnnSharedState
data class Loading(val current: Long, val total: Long, val message: String?) : MnistCnnSharedState
data class Ready(val module: MnistCnnModule) : MnistCnnSharedState
data class Running(val current: Long, val total: Long, val message: String?) : MnistCnnSharedState
data class Error(val throwable: Throwable) : MnistCnnSharedState
}
class MnistCnnShared(
private val scope: CoroutineScope,
private val execContext: ExecutionContext,
private val loader: ModelParameterLoader = DefaultModelParameterLoader // from io-core
) {
val state: StateFlow<MnistCnnSharedState>
suspend fun load(resourcePath: String = DEFAULT_GGUF_PATH)
suspend fun run(input: Tensor<FP32, Float>): Tensor<FP32, Float>
suspend fun unload()
companion object {
const val DEFAULT_GGUF_PATH: String = "/models/mnist/mnist-cnn-f32.gguf"
}
}- The
loaderis anio-coreinterface; default implementation delegates toio-gguf. ExecutionContextis passed in to align with the app/backend selection.
Internal Assembly
// internal
class MnistCnnModule(
private val net: MnistCnn, // network from skainet-lang-models
private val execContext: ExecutionContext
) {
suspend fun forward(x: Tensor<FP32, Float>): Tensor<FP32, Float> = net.forward(x, execContext)
}// internal resource access (common expect API)
expect fun readResourceBytes(path: String): ByteArrayPlatform-specific actual implementations are provided in jvmMain, androidMain, iosMain, etc.
Loader Abstraction
// io-core (existing in skainet-io)
interface ModelParameterLoader {
suspend fun load(
target: ParameterTarget,
bytes: ByteArray,
onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> }
)
}The module calls loader.load(target=net, bytes=resourceBytes, onProgress=...). io-gguf implements this for GGUF.
State and Concurrency
- Use a
Mutexto guardload/run/unloadto prevent concurrent mutation. - Emit progress via
StateFlow<MnistCnnSharedState>. - Always restore
Readyafterrun, or setErrorif exceptions occur.
Resource Packaging
- Place model file at
src/commonMain/resources/models/mnist/mnist-cnn-f32.gguf. - File is included in published artifacts on JVM/Android and bundled appropriately for other targets.
- Provide tools script or task to validate presence and checksum of resource.
Example Usage in an App
class MyViewModel(
private val execContext: ExecutionContext,
private val scope: CoroutineScope
) {
private val shared = MnistCnnShared(scope, execContext)
val state: StateFlow<MnistCnnSharedState> = shared.state
suspend fun init() = shared.load()
suspend fun classify(x: Tensor<FP32, Float>) = shared.run(x)
}Module Layout and Gradle
- Module path:
skainet-int/skainet-int-mnist-cnn - Source sets:
commonMain,commonTest,jvmMain,androidMain,iosMain,jsMain(as needed)
Gradle (KMP) sketch:
plugins {
kotlin("multiplatform")
`maven-publish`
}
kotlin {
jvm()
androidTarget() // if Android
ios()
js(IR) // optional
sourceSets {
val commonMain by getting {
dependencies {
api(projects.skainetLang.skainetLangModels)
implementation(projects.skainetIo.ioCore)
implementation(projects.skainetIo.ioGguf) // runtime impl
implementation(libs.coroutines.core)
}
}
val commonTest by getting {
dependencies { implementation(kotlin("test")) }
}
val jvmMain by getting {}
val androidMain by getting {}
// iosMain/jsMain as needed
}
}Update settings.gradle.kts:
include(":skainet-int:skainet-int-mnist-cnn")Sequence: Load → Run → Unload
sequenceDiagram
participant App
participant Shared as MnistCnnShared
participant Loader as ModelParameterLoader (io-gguf)
participant Net as MnistCnnModule
App->>Shared: load()
Shared->>Shared: state=Loading(...)
Shared->>Shared: readResourceBytes(gguf)
Shared->>Loader: load(target=net, bytes, onProgress)
Loader-->>Shared: progress callbacks
Loader-->>Shared: OK
Shared->>Shared: state=Ready(Net)
App->>Shared: run(input)
Shared->>Shared: state=Running
Shared->>Net: forward(input)
Net-->>Shared: output
Shared->>Shared: state=Ready
Shared-->>App: output
App->>Shared: unload()
Shared->>Shared: module=null; state=Unloaded
Data Contracts
- Input:
Tensor<FP32, Float>shaped[1, 1, 28, 28]or batched[N, 1, 28, 28]. - Output:
Tensor<FP32, Float>shaped[N, 10]logits. Helper can apply softmax for probabilities.
Optional helpers:
fun asMnistInput(gray: FloatArray): Tensor<FP32, Float> { /* normalize, shape to [1,1,28,28] */ }
fun argmax(logits: Tensor<FP32, Float>): Int { /* returns predicted digit 0..9 */ }Error Handling & Telemetry
- Map loader or IO exceptions to
MnistCnnSharedState.Error. - Provide simple metrics hooks (time to load, time to run, bytes loaded) via optional callback or logger.
Performance Considerations
- Avoid re-parsing GGUF on every run; keep
MnistCnnModulein memory untilunload(). - Allow apps to supply
ExecutionContextselecting CPU/GPU backends. - Consider lazy variant for cold start vs. fully-eager load.
Security & Licensing
- Verify license for distributing pre-trained MNIST CNN weights in GGUF; include NOTICE if required.
- Ensure no PII; model is public domain friendly.
Testing Strategy
- Unit tests (common):
- State transitions:
Unloaded -> Loading -> Ready -> Running -> Ready -> Unloaded. - Error pathways when resource missing or corrupt.
- State transitions:
- Integration tests (JVM):
- Load real GGUF resource and classify a known digit sample; assert top-1 correct.
- Progress callback emits increasing
currentup tototal.
- Resource test: checksum validation of the embedded GGUF file.
Migration & Rollout
- Add module in repo; enable publication if relevant.
- Update one reference app (e.g.,
skainet-apps/...) to useMnistCnnSharedinstead of ad-hoc loading. - Document minimal usage in README of the new module.
Implementation Plan (Step-by-step)
- Create module
skainet-int/skainet-int-mnist-cnn(KMP skeleton). - Add dependencies on
skainet-lang-models,skainet-io:io-core,skainet-io:io-gguf, and coroutines. - Add
MnistCnnSharedState,MnistCnnShared, and internalMnistCnnModule. - Define
expect fun readResourceBytes(path: String): ByteArrayin common; provideactualimplementations per platform. - Place
mnist-cnn-f32.ggufundersrc/commonMain/resources/models/mnist/with checksum. - Wire loader call:
loader.load(target = net, bytes = readResourceBytes(DEFAULT_GGUF_PATH), onProgress = ...). - Implement concurrency guards (
Mutex) and state Flow. - Add helper functions for input shaping and softmax/argmax.
- Write unit tests for state transitions and error cases.
- Write JVM integration test: load real model and classify sample.
- Update
settings.gradle.ktsand publish tasks as needed. - Update an app to consume the new module and remove duplicated logic.
Open Questions
- Exact
io-coreloader interfaces in this repo: confirm signature to align progress callback and target parameter mapping. - Which platforms must be supported in v1 (JVM-only vs. Android/iOS/JS)?
- Preferred execution backend defaults (CPU vs. auto-detect GPU when present).
- Model file size constraints for mobile store distribution.
Acceptance Criteria
- Module builds and publishes.
- App can load MNIST CNN with a single
load()call, observe progress viaStateFlow. - Inference works with expected accuracy on standard MNIST test samples.
- GGUF is an internal detail; replacing GGUF with another loader requires no app code changes.
- Documentation includes quickstart and API reference with example.