Ordinary Least Square: Closed Form Solution & Gradient Descent
The Jupyter Notebook for this article can be found HERE.
Ordinary Least Squares (OLS) is a widely used method for estimating the parameters of a linear regression model. We are given some data D:
Where each Xi is a vector of length k, and yi the value of the dependent variable for Xi. In the simplest scenario Xi has dimension 1 and, thus, a scalar. OLS methods seeks to minimize the following objective function (loss):
OLS Visualized
OLS seeks to minimize the Mean Squared Error (MSE) of the model's predictions. The scatter plot below depicts the relationship between the length of the leg (X) of a species of animals and their running speed (y). The relationship between y and X is clearly linear.
Figure 1. Observed data
How do we find the best fitting line, the one that minimizes the loss? Let's start by guessing a fitting line and calculating the MSE.
Figure 2. Observed data, first guess line and MSE
Well, out first guess is pretty bad: the line is not following the data at all and MSE is high (16).
Next, we can propose a better fitting line based on the result of our first attempt:
Figure 3. Observed data, second guess line and MSE
We improved from the first attempt and the second fitting line much better represents the data: the MSE is significantly lower (9) and the line is closer to the data points. Note that the size of the squared residuals is visibly smaller than before. We could continue with this approach and propose a third fitting line, a forth and so on, until the MSE is not changing or the change is lower than a predefined threshold.
As you may have guessed, we are not going to find the best fitting line by blinding testing all the possible fitting lines. This approach is called brute force, and seldomly is the best and smartest way to go. Instead, we are going to implement two methods: closed form solution and gradient descent optimization.
1- Closed Form Solution
We are trying to minimize the following loss function:
Let's rewrite it in matrix notation (1/n was dropped since it is just a number):
Next we expand the square operation:
We can use the following identity:
To modify the loss function into:
Next we expand the expression:
The figure below depicts how the value of the loss function changes with W. The loss function is convex. Let's call W* the value of W for which L(W) is the minimum (Figure 4).
Figure 4. Shape of the loss function
At W* the first derivative of the loss function is zero:
Note that the last term does not contain W, thus:
The derivative of the first term is a little more complicated. In general:
Where A is a symmetric matrix. Let's update the derivative of the loss function by including the derivative of the first term (9) and removing the fourth term (8).
The final trick is to realize that:
Why is that equality in (11) valid? The first two terms are equal to each other because the result of the first term is a scalar, and if you transpose it, you end up with the same scalar. To go from the second term to the third one we applied again (4).
And because:
we end up with:
We get rid of the number 2 and arrange the formula such that:
Let's multiply both sides by:
We get:
We recognize that:
Hence the closed form formula to calculate W is:
Great! Let's use this formula to estimate the intercept and slope of the data in Figure 1. The only caveat is that we need to add a column of ones to the original X, in order to estimate the intercept B0.
The true values of B0 and B1 are 1 and 1.5 respectively.
Finally, let's overlay the fitted line with the original data and calculate the final error:
Figure 5. Closed form solution fitted line
2- Gradient descent
Gradient descent is an optimization method that estimates the parameters of a model by minimizing the loss function in an iteratively manner. At each iteration, gradient descent updates the parameter values by moving in the direction of the steepest descent, as determined by the negative gradient of the loss function with respect to the current parameter values.
Figure 5. Gradient descent intuition
Let's walk through Figure 5. We start the search for the optimal parameter, W*, with an initial guess, W0. The loss at W0 is L(W0). Next, we propose a better value of W such that the loss is smaller. In the figure 5, that means increasing the value of W0 and moving toward the right (if W0 was higher than W*, we would propose a smaller W, thus moving toward the left). However, we don't have Figure 5, so how do we if we have to increase or decrease W? That's where gradient descent comes to the rescue.
The first step of gradient descent is calculating the derivative, a.k.a. gradient, of the loss function at W0, L'(W0). In figure 5 the red lines represent the gradient of the loss function at different values of W. The gradient at W0 is negative because the line points downward and to the right. That tells us that the optimum W* is located on the right, and increasing W would decrease the loss. Conversely, if the gradient is positive, W is greater than W* and we would need to decrease W, moving to the left.
Great! We developed the intuition on how gradient descent works and figured out when increasing or decreasing W. The next step is formalizing our intuition and deciding how much we should decrease (or increase) W of at each iteration. The rule to update W at each iteration is given by the formula:
The value of W at the next iteration (t+1) is equal to the current value of W minus the gradient of the loss w.r.t. W multiplied by the learning rate, lr. The learning rate is an hyperparameter and control the learning process and specifically dictates how fast we are descending the curve. Higher lr means bigger updates of W and faster descents. Values of the learning rate are in the orders of 0.01 to 0.0001 and is often found through trial and error.
It is time to apply gradient descent to find the parameters of the line that best fits the data in Figure 1. However, we first need to derive the derivative of the loss function with respect to the model parameters. Our propose linear model is:
Thus, we have two parameters to estimate, B0 (intercept) and B1 (slope).
The loss function is:
Where yi_hat is the model prediction and yi is the observed data. The derivation of the loss function can be found HERE.
Let's substitute yi_hat with (19).
The gradient of the loss w.r.t. B1 is:
The gradient of the loss w.r.t. B0 is:
Both (22) and (23) were calculated with the chain rule: f(g(x)) = f'(g(x))⋅g'(x).
It is time to code everything out. Let's start with defining two helper functions, one that returns the loss (given predictions and observed data), and one that returns the gradients, given predictions, observed data and X):
Next we deploy gradient descent to estimate the B0 and B1 of the best fitting line:
Note that at each iteration B0 and B1 were updated once, using the whole dataset. This variant of the gradient descent is called batch gradient descent. In the vanilla gradient descent, parameters are updated one data point at the time.
After 300 iteration the estimated B0 and B1 are:
The real values of B0 and B1 are 1 and 1.5, respectively.
Finally, let's plot the loss during training and the final fitted line with the MSE:
Conclusion
Both methods reached the same SME and proposed the same values of B0 and B1. While the closed-form solution can be achieved in a single step, performing matrix operations on large datasets can be computationally expensive. In such situations, we must deploy iterative optimization algorithms like gradient descent. Gradient descent requires tuning the learning rate hyperparameter, which can be time-consuming and challenging.
What's next
In future articles I will introduce other variations of gradient descent.
Comments