A hybrid quantum-classical neural network for classifying handwritten digits from the MNIST dataset. This project demonstrates how quantum computing can be integrated with classical machine learning using Qiskit and PyTorch.
This project implements a hybrid architecture that combines classical neural networks with parameterized quantum circuits for image classification. The quantum component uses variational quantum circuits with trainable parameters, similar to weights in classical neural networks.
- Hybrid quantum-classical neural network architecture
- Binary and multi-class classification support
- Classical baseline model for performance comparison
- Comprehensive training and evaluation pipeline
- Detailed visualizations including training curves, confusion matrices, and ROC curves
- Modular and extensible code structure
- Interactive Jupyter notebook tutorial
- Works on free quantum simulators (no quantum hardware required)
The hybrid model consists of three main components:
- Classical Preprocessing: Reduces the 784-dimensional MNIST images down to 4 features suitable for quantum processing
- Quantum Circuit: A parameterized quantum circuit with 4 qubits that processes the compressed features
- Classical Postprocessing: Maps the quantum circuit output to final class predictions
Input (28×28 image) → Classical Layers → Quantum Circuit → Classical Layers → Output
Quantum_MNIST/
├── src/
│ ├── quantum_circuit.py # Quantum circuit definitions
│ ├── models.py # Quantum and classical model architectures
│ ├── data_utils.py # Data loading and preprocessing
│ ├── train.py # Training utilities
│ └── visualize.py # Evaluation and visualization tools
├── notebooks/
│ └── quantum_mnist_tutorial.ipynb # Interactive tutorial
├── data/ # MNIST dataset (auto-downloaded)
├── models/ # Saved model checkpoints
├── results/
│ ├── plots/ # Training curves, confusion matrices, etc.
│ └── metrics/ # Performance metrics (CSV files)
├── logs/ # Training logs
├── config.py # Configuration and hyperparameters
├── train_model.py # Main training script
├── requirements.txt # Python dependencies
└── README.md # This file
- Python 3.8 or higher
- pip package manager
- 4GB+ RAM recommended
git clone <your-repo-url>
cd Quantum_MNIST# On macOS/Linux
python3 -m venv venv
source venv/bin/activate
# On Windows
python -m venv venv
venv\Scripts\activatepip install -r requirements.txtThis will install:
- qiskit (quantum computing framework)
- qiskit-machine-learning (quantum ML tools)
- torch and torchvision (deep learning)
- numpy, matplotlib, scikit-learn (scientific computing)
- jupyter (for notebooks)
The easiest way to train both models and generate all results:
python train_model.pyThis will:
- Download MNIST dataset (first run only)
- Train both quantum and classical models
- Generate training curves, confusion matrices, and ROC curves
- Save models and metrics to disk
- Print performance comparison
Expected runtime: 15-30 minutes depending on your hardware
For a step-by-step walkthrough with detailed explanations:
jupyter notebook notebooks/quantum_mnist_tutorial.ipynbThe notebook provides:
- Interactive code cells with explanations
- Visualizations of quantum circuits
- Sample data exploration
- Model training and evaluation
- Experiments you can modify
Edit config.py to customize training parameters:
# Binary classification (faster, recommended for testing)
DATASET_TYPE = 'binary'
BINARY_CLASS_A = 0 # First digit to classify
BINARY_CLASS_B = 1 # Second digit to classify
BINARY_TRAIN_SIZE = 500 # Samples per class
# Multi-class classification (slower but more comprehensive)
DATASET_TYPE = 'multiclass'
MULTICLASS_CLASSES = [0, 1, 2] # Which digits to include
MULTICLASS_SAMPLES_PER_CLASS = 200MODEL_TYPE = 'hybrid' # 'hybrid', 'classical', or 'both'
N_QUBITS = 4 # Number of qubits in quantum circuitLEARNING_RATE = 0.01
NUM_EPOCHS = 20
BATCH_SIZE = 32Edit config.py:
MODEL_TYPE = 'hybrid'Then run:
python train_model.pyEdit config.py:
MODEL_TYPE = 'classical'Then run:
python train_model.pyEdit config.py:
MODEL_TYPE = 'both'Then run:
python train_model.pyimport torch
from src.models import SimplifiedHybridQNN
# Load the model
model = SimplifiedHybridQNN(n_qubits=4, n_classes=2)
checkpoint = torch.load('models/quantum_hybrid_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
# Make predictions
model.eval()
predictions = model(test_images)After training, you'll find:
models/quantum_hybrid_best.pth- Best quantum model checkpointmodels/classical_baseline_best.pth- Best classical model checkpoint
results/plots/quantum_hybrid_training_curves.png- Training progressresults/plots/quantum_hybrid_confusion_matrix.png- Classification errorsresults/plots/quantum_hybrid_roc_curve.png- ROC analysisresults/plots/model_comparison.png- Side-by-side comparison
results/metrics/results.csv- Quantitative performance metricslogs/quantum_hybrid_history.json- Detailed training history
For binary classification (0 vs 1) with default settings:
| Model | Accuracy | Training Time |
|---|---|---|
| Quantum Hybrid | 85-95% | 20-30 min |
| Classical Baseline | 90-98% | 10-15 min |
Performance notes:
- Results vary based on random initialization
- Quantum simulations are computationally expensive
- Real quantum hardware would have different characteristics
- The goal is to demonstrate the approach, not achieve state-of-the-art accuracy
The quantum circuits used in this project contain rotation gates with trainable parameters. During training, these parameters are optimized using gradient descent, just like weights in classical neural networks.
Current quantum simulators have practical limitations. Using 4 qubits:
- Keeps training time reasonable on CPU
- Demonstrates the concept effectively
- Scales to free cloud-based simulators
Real quantum computers could potentially use more qubits for larger problems.
We use angle encoding to map classical data to quantum states:
- Classical values are normalized to [0, 2π]
- These angles parameterize rotation gates
- The resulting quantum state encodes the input features
Solution: The project works fine on CPU. Edit config.py:
DEVICE = 'cpu'Solution: Reduce the dataset size in config.py:
BINARY_TRAIN_SIZE = 200 # Instead of 500
NUM_EPOCHS = 10 # Instead of 20Solution: Make sure you activated the virtual environment and installed all dependencies:
source venv/bin/activate # or venv\Scripts\activate on Windows
pip install -r requirements.txtSolution: Ensure Qiskit is properly installed:
pip install --upgrade qiskit qiskit-machine-learning qiskit-aerSolution: Install the kernel in your virtual environment:
python -m ipykernel install --user --name=quantum_mnistThen select this kernel in Jupyter.
Edit src/quantum_circuit.py and config.py to increase the number of qubits. Note that training time increases exponentially.
Modify the ansatz in src/quantum_circuit.py:
from qiskit.circuit.library import EfficientSU2
ansatz = EfficientSU2(num_qubits=4, reps=3)Sign up for IBM Quantum and modify the code to use real quantum processors instead of simulators. See the Qiskit documentation for details.
Change the configuration to classify more digits:
DATASET_TYPE = 'multiclass'
MULTICLASS_CLASSES = [0, 1, 2, 3, 4]- Qiskit: Open-source quantum computing framework
- PyTorch: Deep learning framework
- Qiskit Machine Learning: Quantum ML tools that bridge Qiskit and PyTorch
- torchvision: Provides MNIST dataset
- matplotlib/seaborn: Visualization
- scikit-learn: Metrics and evaluation
Quantum hybrid model (binary classification):
- Classical preprocessing: ~105,000 parameters
- Quantum circuit: 36 trainable parameters
- Classical postprocessing: ~50 parameters
The quantum circuit may have fewer parameters but can potentially represent complex functions due to the exponential state space of qubits.
If you use this code in your research, please cite:
@software{quantum_mnist_classifier,
title = {Quantum MNIST Classification: Hybrid Quantum-Classical Neural Networks},
year = {2024},
url = {https://github.com/yourusername/quantum-mnist}
}
This project is open source and available under the MIT License.
Contributions are welcome! Please feel free to submit a Pull Request. Areas for improvement:
- Add support for more quantum circuit designs
- Implement additional evaluation metrics
- Optimize quantum circuit depth
- Add support for other datasets
- Improve documentation
- Built using Qiskit by IBM Research
- MNIST dataset by Yann LeCun
- Inspired by recent advances in quantum machine learning
For questions, issues, or suggestions, please open an issue on GitHub or contact the maintainer.
- Implement quantum convolutional layers
- Explore quantum attention mechanisms
- Test on real quantum hardware
- Compare with other quantum ML approaches
- Extend to other image classification tasks
Happy Quantum Computing!