Welcome to the Digit Classification Project! This project focuses on training a model to classify handwritten digits and using the trained model to predict digits from new images.
This project consists of two main components:
- digit_classification_model_trainer.py: This script is responsible for training a model to classify digits based on a dataset of handwritten digits.
- digit_classification_detector.py: This script uses the trained model to predict digits from new images.
The following libraries are used in this project:
- tensorflow: TensorFlow is an open-source machine learning library used for training and inference in this project.
- numpy: NumPy is used for numerical operations on image data.
- sklearn: Scikit-learn is used for splitting the dataset into training and testing sets.
- os: The OS module is used for directory and file manipulation.
This script is essential for training the digit classification model. The key components of the script are:
- DigitClassifier Class: This class manages the entire training process, from loading and preprocessing the dataset to building, training, and saving the model. It uses TensorFlow’s Keras API to create a simple neural network for digit classification.
- _load_and_preprocess_data() Method: This method loads the images from the dataset directory, preprocesses them (resizes and normalizes), and splits them into training and testing sets.
- _build_model() Method: This method constructs a simple neural network model with a flatten layer, a dense hidden layer, and an output layer for digit classification.
- train() Method: This method trains the model on the preprocessed dataset.
- evaluate() Method: This method evaluates the model on the test dataset and prints the test accuracy.
- save_model() Method: This method saves the trained model to a file for later use.
This script uses the trained model to predict digits from new images. The key components of the script are:
- DigitPredictor Class: This class loads the trained model and provides methods to preprocess images and make predictions.
- preprocess_image() Method: This method loads and preprocesses a new image, preparing it for prediction.
- predict() Method: This method predicts the digit in the provided image using the trained model and prints the predicted digit and confidence.
-
Model Training:
- The
digit_classification_model_trainer.py
script reads images from the specified dataset directory. - The images are resized, normalized, and fed into a neural network model, which is trained to classify them into one of ten digit classes (0-9).
- The trained model is saved for later use.
- The
-
Digit Prediction:
- The
digit_classification_detector.py
script loads the trained model and processes new images. - Each image is preprocessed and classified by the model, which outputs the predicted digit and confidence.
- The
The dataset used for training the model can be accessed via this Dataset.
To use this project, follow these steps:
-
Clone the repository:
git clone https://github.com/amiriiw/digit_classification cd digit_classification
-
Install the required libraries:
pip install tensorflow numpy scikit-learn
-
Download and prepare your dataset, ensuring it's structured appropriately for training.
-
Run the model training script:
python digit_classification_model_trainer.py
-
Use the trained model to predict digits from new images:
python digit_classification_detector.py
This project is licensed under the MIT License. See the LICENSE file for details.