While working on machine learning projects, we generally start with data cleaning and exploratory analysis and end with feature extraction and model training. Our primary objective of the project is to obtain higher accuracy and precision values. But just training the model with higher accuracy is not enough. Deploying the model for a specific application or use case is the principle intention behind an ML project. This article will focus on training the ML model using fruits’ data/features to classify four different fruits. Followed by deploying it on the Flask server to create a basic web application for predicting the fruits depending on the input data/features.
We assume that you have a python IDE set up on your local machine. You will need to install few packages before starting with this project. The list of the packages are:
- Pandas: For reading CSV files containing the dataset for training.
- Sklearn: For splitting, training, and testing the machine learning model.
- Pickle: For creating a .pkl file for saving the model.
- Flask: For running the basic web application on the flask server.
- Numpy: For doing numeric operations like creating an array.
You can find all the project resources and files in this GitHub repository. Clone this repository on your local machine and run the “python app,py” command on your terminal, provided your environment has all the required dependencies.
Training the ML model
Before deploying the Flask server model to create the web application, we need to train the model using a suitable machine learning algorithm for obtaining sufficient accuracy and precision value. So, the process starts by reading the dataset from a CSV file. The dataset contains seven columns, including fruit label, fruit name, fruit subtype, mass, width, height, and color score as features of the fruit. We will use mass, width, height, and color score features as the input data for training our model, as these are the numeric data features of the dataset. The fruit label column would be the target/output data.
Reading and Processing the Data
Create python file “model.py” for training and testing the model. Now, import the necessary python libraries required for training the data. Import pandas for reading the CSV file and pickle library for storing the trained model in the “.pkl” file format. We will use sklearn’s KNeighborsClassifier for training and testing the model, as it gives the highest accuracy among all algorithms from the sklearn library. For a detailed explanation of the fruits data and classification algorithms, visit Susan Li’s post on the towardsdatascience blog. Now, make the ‘X’ variable as the input data by extracting from the ‘fruits’ dataframe and the ‘y’ variable for storing the ‘fruit_label’ values as the target data.Building the ML model
We have imported all the necessary dependencies from the sklearn library, so it’s time to use all of them. Use the train_test_split() method from the sklearn library to split the dataset for training and testing. The input dataset needs to be scaled to limit the feature values within a specific range. Hence, we will use MiniMaxScaler() from the sklearn library to implement this scaling. Finally, assign the KNeighboursClassifier() object to the ‘knn’ variable. Now fit your training dataset using the knn.fit() method. Obtain the training and testing set’s accuracy to ensure if the KNeighbors Classifier is an efficient algorithm for this classification problem. Dump this model using the pickle package to create a ‘model.pkl’ file for deployment purposes.Building the Web Application on Flask Server
We have successfully created the model, and now it’s time to deploy it on the Flask server. We will make our backend using python script and load the model.pkl file as the server-side processing unit. The clients will interact with the front-end by sending the data processed by the server-side, and a response would be given to the client at the front-end. This is how a basic web application works, and we would be using this same technique for this project.
Creating the Index page for Web Application
Create a python file ‘app.py,’ which will act as the backend script for the web application. Import the flask, numpy, and pickle library for building the backend of the Flask server. Load the ‘model.pkl’ file using the pickle.load() method and store it in the ‘model’ variable. Now create the instance of the Flask application and assign it to the ‘app’ variable. To load the application’s index page, we will need a URL path to render the HTML page. These HTML pages should be stored inside the template folder. You can refer to the below image for understanding the syntax to create a URL and render an HTML template.
We will not get into the in-depth details of creating the index page using HTML. The basic idea is to create a form with a POST method that submits the data to the function assigned to the ‘prediction’ URL path. We will discuss the ‘prediction’ path in the other section. The form should have four input data fields, as the model requires four features for predicting the output. We will create four inputs, including mass, width, height, and color score, under the <input> tag. Creating the Prediction page for Web Application
Create the path for the prediction page and define the prediction() function. The function takes the values from the data submitted via index page form. Further, it typecasts the values to float type data and stores them in variables. These variables are stored in a 2D array using the np.array() method so that the input is compatible with our model. The model loaded through pickle contains the predict() method, which takes this 2D array and outputs the predicted fruit value. This value is rendered via ‘prediction.html’ page.While rendering the prediction page the function sends in the predicted value through the data variable. As there are four classes in the dataset, which are numbered from one to four. Hence, we will use the jinja script to create the if-else logic to display the fruit names according to the fruit labels. We will display the fruit name along with the fruit image. You can find all the fruit images under the static folder once you clone the GitHub repository. Here are the images of the working web application. Also Read: Deep Dive On Regression And Classification In Machine Learning