Skip to content

A hybrid CNN–Transformer framework for precise industrial surface defect detection and segmentation, integrating Vision Transformer (ViT) with convolutional modules to effectively capture both local texture details and global contextual features.

Notifications You must be signed in to change notification settings

rasoulameri/Defect_Segmentation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🔍 Defect_Segmentation

A Hybrid CNN-Transformer Approach for Accurate Surface Defect Detection


📘 Overview

This repository implements a hybrid CNN-Transformer architecture designed for surface defect segmentation on industrial images. It combines the local feature extraction power of CNNs with the global dependency modeling of Transformers, based on a modified TransUNet framework.

Key enhancements include:

  • A Mean Filter Module in the encoder to denoise input images and preserve local details.
  • An Attention Gate Module in the decoder to enhance positional and spatial precision.
  • The architecture has been tested on Crack500 dataset, demonstrating superior performance in segmentation accuracy, F1-score, and IoU.

🧩 Project Structure

Defect_Segmentation/
│
├── configs/
│   └── configs.py                # Configuration and training hyperparameters
│
├── data/
│   ├── traincrop/                # Training images and masks
│   ├── valcrop/                  # Validation images and masks
│   └── testcrop/                 # Test images and masks
│
├── docs/
│   ├── crack500_samples.jpg      # Sample images from Crack500 dataset
│   ├── crack500_results.jpg      # Example segmentation results
│   └── modified_transunet.jpg    # Proposed architecture illustration
│
├── networks/
│   ├── vit_seg_modeling.py       # Vision Transformer backbone for segmentation
│   └── vit_seg_configs.py        # ViT architecture configuration
│
├── utils/
│   ├── dataset.py                # Crack500Dataset class
│   ├── evaluate.py               # Evaluation metrics (IoU, F1, Precision, Recall)
│   ├── losses.py                 # FocalLoss implementation
│   └── utils.py                  # Utility functions (GPU usage, result saving)
│
├── results/                      # Saved metrics, checkpoints, and plots
├── main.py                       # Training & evaluation entry point
└── requirements.txt              # Project dependencies

⚙️ Workflow

Architecture Diagram

Workflow Summary

  1. Data Loading & Normalization using PyTorch Dataset and DataLoader.
  2. Data Augmentation through random flips, rotations, and normalization.
  3. Model Initialization with Vision Transformer backbone and hybrid CNN layers.
  4. Training using Focal Loss and AdamW optimizer with mixed precision (torch.cuda.amp).
  5. Validation & Early Stopping based on the lowest validation loss.
  6. Evaluation on the test set with accuracy, precision, recall, F1-score, and IoU metrics.
  7. Result Visualization — outputs, metrics, and segmentation maps are saved in /results/.

📊 Datasets

🔹 Crack500

  • Pavement crack images (2000×1500 px) with pixel-level annotations.
  • Split: 250 train, 50 validation, 200 test images.
  • After cropping: 1896 train / 348 val / 1124 test samples.

Crack500 Samples

🧠 Model Architecture

Modified TransUNet Framework

  • Encoder: ResNet backbone + Transformer blocks
  • Decoder: CNN layers + Attention Gate Module
  • Mean Filter Module: Reduces input noise via multi-scale averaging
  • Attention Gate: Enhances feature localization for precise segmentation

Segmentation Results

Results on Crack500 dataset: (a) Image, (b) Ground Truth, (c) DeepLabV3, (d) DeepLabV3+, (e) FPN, (f) MANet, (g) PAN, (h) PSPNet, (i) U-Net, (j) U-Net++, (k) UperNet, (l) TransUNet, (m) SegFormer, and (n) Our Proposed Method.

🚀 Training Configuration

Parameter Value
Batch Size 8
Learning Rate 5e-5
Epochs 100 (configurable)
Optimizer AdamW
Early Stopping Patience 10
Loss Function Focal Loss (α = 0.5, γ = 2.0)
Input Size 256×256
Framework PyTorch 2.0+
GPU RTX 3090

📈 Evaluation Metrics

Metric Description
Accuracy Correct predictions ratio
Precision True positives over predicted positives
Recall True positives over actual positives
F1-Score Harmonic mean of Precision & Recall
IoU Intersection-over-Union for defect regions

Results are exported to:

results/
├── training_metrics.xlsx
├── per_image_metrics.xlsx
└── evaluation_metrics.xlsx

🧪 Performance Summary

Dataset Accuracy (%) F1-Score (%) IoU (%)
Crack500 96.72 66.95 52.18

✅ Outperforms classical CNNs (U-Net, PSPNet) and Transformers (SegFormer, TransUNet).
✅ Strong balance between local detail preservation and global feature learning.

💾 Installation and Usage

Clone the Repository

git clone https://github.com/rasoulameri/Defect_Segmentation.git
cd Defect_Segmentation

Install Dependencies

pip install -r requirements.txt

Train and Evaluate

python main.py

🧰 Dependencies

torch >= 2.0
torchvision
numpy
pandas
tqdm
matplotlib
ml_collections
opencv-python
scikit-learn
thop

📚 Citation

If you use this code in your research, please cite:

R. Ameri, C.-C. Hsu, and S. S. Band,
A Hybrid CNN-Transformer Approach for Accurate Surface Defect Detection,
TAAI 2024, Taiwan, 2024.

BibTeX

@article{ameri2024hybrid,
  title={A Hybrid CNN-Transformer Approach for Accurate Surface Defect Detection},
  author={Ameri, R. and Hsu, C.-C. and Band, S. S.},
  Conference={TAAI 2024},
  year={2024},
  address={Taiwan}
}

📫 Contact

Rasoul Ameri
📧 rasoulameri90@gmail.com
🔗 GitHub Profile


About

A hybrid CNN–Transformer framework for precise industrial surface defect detection and segmentation, integrating Vision Transformer (ViT) with convolutional modules to effectively capture both local texture details and global contextual features.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages