Linear Regression using OLS and SGD Methods in Python
Linear Regression is very widely used in data analysis. Many people run the analysis in Excel, but do you know you can read the data from an Excel file and plot the analysis (as well as calculate many very useful metrics) in Python?
Methods to fit a Linear Model
In Python, there are many ways to fit a Linear Model. Below, we will mainly focus on the OLS (Ordinary Least Square) Method, which will minimize the sum of the squares of the differences. Afterwards, we will be using rhe SGDRegressor (Linear model fitted by minimizing a regularized empirical loss with SGD (Stochastic Gradient Descent Algorithm)). More can be found here.
How to know if your linear regression model is accurate?
- \(R^2\) Score
- It represents the proportion of the variance in the dependent variable that can be explained by the independent variables used in the model. (See below for more theoretical details)
- Scikit-Learn Function: sklearn.metrics.r2_score
- Mean Absolute Error (MAE)
- The average of the absolute errors between the predicted values and the actual values.
- Scikit-Learn Function: sklearn.metrics.mean_absolute_error
- Mean Squared Error (MSE) (Mainly used for SGD)
- The average of the squared errors between the predicted values and the actual values
- Scikit-Learn Function: sklearn.metrics.mean_squared_error
- Root Mean Squared Error (RMSE)
- The square root of the MSE. It provides error in the same unit as the dependent variable.
- Residual Plots
- A good fit is indicated by residuals randomly scattered around zero without any discernible pattern.
We will mainly use \(R^2\), MSE, and MAE in the analysis below.
Running Linear Regression on World Population Data
Given the World Population Data we can try to find and plot a linear trend in the data.
We start by storing the data in an Excel file and then read it into Python using the Pandas library. This allows us to handle and manipulate the data easily. The data consists of years (independent variable) and the world population (dependent variable).
OLS Method
We use \(80 \%\) of the data for training our model, which means this portion of the data is used to “learn” the relationship between the years and the world population (by finding \(m\) and \(c\) in \(y=mx+c\) that minimises the sum of square residues). The remaining \(20 \%\)is used to test and validate the model’s accuracy.
SGD Method
For the SGD method, we would first need to scale the independent variable (the years) to ensure fairness in the optimization process. We then use the SGDRegressor function to minimize the MSE by iteratively moving towards the minimum value, with a maximum iteration of \(1000\). After \(1000\) iterations, the best fit line which minimises the MSE will be plotted.
By following these steps, we can attempt to find and visualize the linear trend in the world population data using either the OLS or SGD regression methods.
Results
OLS Method
- MSE: 503030576553090.75
- \(R^2\): 0.9997738328468143
- MAE: 17666518.47913985
SGD Method
- MSE: 454359657773408.56
- \(R^2\): 0.9997335657388569
- MAE: 16986193.058504555
The difference is very minimal, but the SGD Method is slightly more accurate!
Here is an image of the fit (Red line for OLS Method, Orange Line for SGD Method, but they almost overlap!):