A PyTorch implementation of the MiniTransformer, a compact Transformer architecture designed for small-sample clinical and behavioral data. The model balances predictive performance with interpretability by combining architectural simplifications with a built-in framework for statistical testing.
This project implements a custom transformer architecture (MiniTransformer) that learns patterns in sequential data, particularly focusing on:
- Feature Interactions: Understanding how different features at various positions influence predictions
- Statistical Testing: Rigorous evaluation of learned patterns with significance testing
- Context Effects: Analysis of how historical context affects future predictions
The MiniTransformer consists of:
- Multi-Head Attention: Custom attention mechanism with positional encodings
- Distance Matrices: Incorporates both pairwise distances and distance-to-end information
- Custom Masking: Specialized attention masks for prediction tasks
- Statistical Analysis: Built-in tools for evaluating feature importance and interactions
- Simulated Data: Generates synthetic sequential data with controllable patterns
- Real Data: Supports real-world datasets (GHQ health questionnaire data)
- Variable Length Sequences: Handles sequences of different lengths with proper padding
- Multi-head attention with customizable heads, key/value dimensions
- Position-aware attention using distance matrices
- Cumulant-based feature aggregation
- L2 regularization with bias exclusion
- Multiple baseline comparisons (average, informed, repeat baselines)
- Regression baseline using scikit-learn
- Statistical significance testing with p-value computation
- Context-target effect visualization
- Clone the repository:
git clone github.com:kianaf/MiniTransformer.git
cd mini_transformer- Create and activate virtual environment:
python -m venv env
source env/bin/activate # On Windows: env\Scripts\activate- Install dependencies:
pip install -r requirements.txtRun the main training script:
python main.pyKey hyperparameters can be modified in main.py:
# Data configuration
data_str = "simulation" # or "ghq_sum", "ghq_b_sum"
batch_size = 1
n = 200 # Training samples
p = 10 # Number of features
maxlen = 10 # Maximum sequence length
# Model architecture
nheads = 16 # Number of attention heads
ncum = 2 # Number of cumulants
dk = 1 # Key dimension
dv = 1 # Value dimension
# Training
learning_rate = 1e-3
lambda_l2 = 1e-3
EPOCHS = 100Explore the analysis notebooks in the notebooks/ directory:
simulation_experiments.ipynb: Basic simulation experimentssimulation_experiments_statistical_testing.ipynb: Statistical analysis of simulated datareal_data_experiments_D1.ipynb: Real data analysis (Dataset 1)real_data_experiments_D2.ipynb: Real data analysis (Dataset 2)
mini_transformer/
├── main.py # Main training script
├── requirements.txt # Python dependencies
├── src/ # Source code
│ ├── transformers.py # MiniTransformer implementation
│ ├── data_preparation.py # Data loading and preprocessing
│ ├── evaluation.py # Model evaluation metrics
│ └── statistical_testing.py # Statistical analysis tools
├── notebooks/ # Jupyter notebooks for experiments
│ ├── simulation_experiments.ipynb
│ ├── real_data_experiments_D1.ipynb
│ └── ...
└── runs/ # TensorBoard logs and results
The core model implementing:
- Multi-head attention with distance-aware weights
- Custom masking for causal prediction
- Linear prediction layer
SimulatedDataset: Generates synthetic sequential data with controlled dependencies- Variable sequence lengths with probabilistic termination
- Configurable feature interactions
- Permutation-based significance testing
- Context-target effect analysis
- P-value computation with multiple comparisons correction
- Visualization of feature interactions
The model evaluation includes:
- Loss Comparisons: Against multiple baselines (average, informed, regression)
- Statistical Significance: P-values for feature interactions
- Context Effects: Heatmaps showing how context influences predictions
- Parameter Analysis: Distance weights and attention patterns
This implementation is particularly useful for:
- Behavioral Data Analysis: Understanding sequential patterns in questionnaire responses
- Feature Interaction Discovery: Identifying which features influence each other
- Causal Inference: Testing statistical significance of learned patterns
- Time Series Analysis: Modeling dependencies in sequential data
Key dependencies include:
- PyTorch 2.4.1
- NumPy 2.0.2
- Pandas 2.2.2
- Matplotlib & Seaborn (visualization)
- TensorBoard (logging)
- Scikit-learn (baseline models)
[Add your license information here]
[Add citation information if this is for a research paper]
[Add contribution guidelines if applicable]
