Why use - šâ
log(y) - (1-š)â
log(1-y)
as the loss function for training logistic regression and sigmoid outputs?
The purpose of this article is to explain the reasoning that leads to a choice of loss function. There are four areas you will want to understand:
- Basic Probability: interpreting numbers from 0 to 1 as probability
- Statistical Insight : what is the best understanding of âbestâ?
- Mathematical Tricks to Make Life Easier: things you can do to simplify the calculations
- Limitations of Computing Hardware: what kind of maths will existing computers get wrong?
Understanding âwhy this loss function?â should help you understand why not : when should you not use it, and how you might generate a correct alternative.
Background
When training a neural network, for instance in supervised learning, we typically start with a network of random matrices. We then use sample data to repeatedly adjust them until we have shaped them to give the right answers. This requires three things:
- Sample data. We need hundreds or thousand or millions of examples that are already labelled with the correct answer.
- A cost function. The cost function gives a number to express how wrong our current set of matrices are. âTraining a neural networkâ means âminimise that wrongness by adjusting the matrices to reduce the value of the cost function.â
- An algorithm for how exactly you do this minimising. Gradient descent using back-propagation is currently the one-size-fits-all choice and a consequence is that your cost function must be differentiable.
The example we'll work towards is a cost function you commonly see in neural network introductions:
-1/m * â(for i=1 to m) y * log( š(i)) + (1-y)* log( 1-š(i) )
where
m
is the number of examples in your sample datay
means the correct result for a given input xš
(pronounce y-hat) means your networkâs output for a given input x
â
(pronounce sigma) is the mathematical symbol for sum, or adding up.log
is the mathematical logarithm
function found in log tables.
1. Basic Probability: Interpreting numbers from 0 to 1 as probability
How certain are you that you are reading this article? 100% certain is taken to mean completely certain, and 0% certain to mean not even a little bit certain. In fact, certainly not. Rather than using percentages, we just use the numbers 0
for certainly not; 1
for completely certain; and numbers in between for a scale of probability.
The probability of two independent events is found by multiplying the probability of each individual event. Consider rolling some fair six sided dice. The probability of rolling a 6 is â
, which can also be written as 16.66667%
or as 0.166667
. The probability of rolling two 6s is â
* â
, which is 1/36
, which is about 2.77778%
or 0.277778
.
The probability of something not happening is (1 - probability of it happening)
. The probability of not rolling a six is 1-â
which is â
or 83.3333%
or 0.83333333
. It's same as the probability of rolling one of 1,2,3,4,5. The probability of not rolling 2 sixes is (1 - 1/36)
which is 97.2222%
or 0.97222
.
We abbreviate the expression âthe probability of âŚâ to p(âŚ)
. So if we call the die roll d
, then
is read as âthe probability that die roll d is 6.â
Conditional probability is âThe probability of ⌠given ⌠already happenedâ and is written p(⌠| âŚ)
. So if we call our two dice rolls d1
and d2
then
can be read as âThe probability of d1 + d2 adding to 12 given that we already rolled d2 = 6â
Further reading: https://en.wikipedia.org/wiki/Probability#Mathematical_treatment
2. Statistical Insight : What is the best understanding of âBestâ?
So you started with some data, that is some examples for which you already know the right answer. Letâs say you have m examples and call this set X
.
X = { x(1),...,x(m) }
and the matching set of correct answers, or labels:
Y = {y(1),...,y(m)}
.
You might think that the best setting for your network, given these examples, is:
- âBestâ is the one that most often returns the correct result
y(i)
for a given input x(i)
.
But then you find that your network almost never returns exactly the right answer. For instance in binary classification the right answers are integers 0 and 1; but your network returns answers like 0.997254 or 0.0023.
So next you think about âbeing closest to the correct answerâ and consider:
- âBestâ is the one that gets closest to the correct
y(i)
s
But there is no single definition of âclosestâ for multiple points. If you know about averages and also learnt Pythagorasâ theorem at school you might think of Mean Squared Error as a good definition of closest:
- âBestâ is the one that reduces mean squared error for your example data.
And that would work. This is how linear regressionâfinding the best straight line through points plotted on graph paperâis often done. (And, if the input data is normally distributed, mean squared error will give very similar results as the cost function we are deriving here).
But if you've studied even more statistics, or probability, or information theory, you may know that âclosest to the correct answerâ for a distribution of data which need not be normal is actually really tricky, and you might think of:
- âBestâ is the one that maximises the expected probability of the known correct answers.
The new idea here is âexpected value.â It is worth your while to study enough statistics and probability to understand it.
And indeed, statisticians take this as often being the best meaning of âbestâ. Itâs called the Maximum Likelihood Estimator and reasons for considering it âbestâ include: it is unbiased, consistent, and does not rely on any assumptions about the distribution of the data. So itâs a good default choice. Minimising the mean squared error, surprisingly, can sometimes resulted in a biased estimator.
What the Maximum Likelihood Estimator looks like depends on the detail of what problem you are trying to solve.
Specific Example: Binary Classification with a 0..1 output network
So letâs take the case of
- A binary classification task : the correct answer is one of two possible labels. Think of them as two buckets.
- A network where the output is a single number in the range from 0 to 1
A concrete example would be a cat recogniser : the task is to identify photos of cats. The input data are photographs. The sample data we start with are a set of photographs already correctly labelled as âcatâ and ânot-catâ. The final layer of the network is a single sigmoid unit.
After we have seen some mathematical tricks, we will be able to write down a mathematical formula for âexpected probability of the known correct answersâ for the case of binary classification using a network with a single output in the range 0 to 1.
3. Mathematical Tricks to Make Life Easier: Things you can do to simplify the calculations
1st Trick: Use 0 and 1 as the bucket labels.
You canât easily do maths with words like âcatâ and ânot-catâ. Instead, letâs use 0 and 1 as the names of our buckets. Weâll use 1 to mean âitâs a catâ and 0 to mean âitâs not a catâ.
2nd Trick: Interpret a 0-1 output as a probability
This is a neat trick: since the network output is in the range 0 to 1, we decide to interpret the output as âour estimate of the probability that 1 is the correct bucketâ. To clarify, we do this not because of any deep insight but just because it makes the maths simpler.
With this interpretation we can turn our statistical phrase âexpected probability of the known correct answerâ into a mathematical formula:
p( š(i)= y(i) | x(i) ; đ )
which you should read aloud as âThe probability that our estimate y-hat-i equals the correct value y-i, given that the current input is x-i and given that the matrices of our network are currently set to đ.â
We used the letter đ here as an abbreviation for âall of the current values in the all of the matrices in the network.â đ is in fact a very long list of hundreds or thousands of numbers.
We will have found the Maximum Likelihood Estimator if we can find the values for the network that maximise this probability.
Before we move on: if the output means âthe probability of 1 being correctâ, what about the probability of 0 being correct?
In binary classification, everything must go into either bucket 0 or bucket 1. The chance of you being in one or other of the buckets is 100%. But if youâre in 1, the chances of being in 0 are none at all; and vice versa.
If youâre in between â âI estimate the chance of this being a cat photo is 80%â â then remember that in probability, all possibilities together must add up to 100% (Itâs 100% certain that a photo goes into one of the buckets). So âI estimate the chance of this being a cat photo is 80%â automatically means, âI estimate the chance of this being a not-cat photo is 20%â.
In general:
- if
output
means âthe probability of 1 being correctâ - then
(100% - output)
or ( 1 - output )
must mean âthe probability of 0 being correctâ
2½th Trick: Put tricks 1 and 2 together
Hereâs the clever bit. If you put tricks 1 and 2 together you can re-write our formula from (2):
p( š(i)= y(i) | x(i) ; đ )
as just
š(i)
when the correct answer is 1, and1 - š(i)
when the correct answer is 0.
Which is a lot simpler. I mean, really a lot simpler.
How does that work? Remember that the set of y(i)
s were the correct answers for our sample data. So y(i)
=1 if the correct answer is 1, and y(i)
=0 if the correct answer is 0. Now trying putting these two sentences from earlier together, using the word output to substitute the 2nd line into the first:
-š(i)
means the output given the current input is x(i) and given the matrices of our network are currently set to đ.
-We interpret the output as our estimate of the probability that the correct answer is 1
And you get:
âš(i)
is our estimate of the probability that the correct answer is 1 given the current input is x(i) and given the matrices of our network are currently set to đ.â
But notice that when the correct answer y(i) is 1, this sentence which defines š(i)
is exactly what we meant by
p( š(i)= y(i) | x(i) ; đ )
In the case when the correct answer is y(i) = 0
(remember the rule that âthe probability of the result being 0 is (1.0 - probability of the result being 1)
â that sentence in quotes is what we mean by
p( š(i)) = 1-y(i) | x(i) ; đ )
With some high school algebra and our basic knowledge of probability and again the (1-âŚ) rule, we realise that
p( š(i)=1-y(i) )
is the same as p( 1-š(i)=1 )
which is the same as 1- š(i)
Conclusion: We have turned our definition of maximum likelihood into a pair of very simple formulae. For each example datum x(i)
:
- if the correct answer is 1, then our maximum likelihood estimate is
š(i)
- if the correct answer is 0, then our maximum likelihood estimate is
1-š(i)
We need just one more trick. Remember that to use this with back propagation, we need a single differentiable function. We must combine those two formulae into a single formula, and it has to be differentiable.
3rd mathematical trick: Invent a differentiable if(âŚ,then âŚ, else âŚ) function
Letâs try to invent a differentiable if(âŚ,then âŚ, else âŚ)
function that can combine the two formulae into one:
if( y(i)=1
, then š(i)
, else (1 - š(i))
To keep our maths simple, we want some kind of arithmetic version of this if( y=1, then š(i) , else (1-š(i) ))
function. Similar to when we noticed that âthe probability of not x
is 1-x
â, we will use a 1-x
trick.
There are a couple of simple options. Try these two definitions:
if(y,a,b) = a*y + b*(1-y)
if(y,a,b) = a^y * b^(1-y)
Both work. We could call the first one the âtimes & addâ version of if and the second one the âexponent & timesâ version. You could invent more. The constraints are:
- It must return
a
when y=1
and return b
when y=0
- It should have a (preferably simple!) derivative.
Both of my suggestions meet the need, but we go with the second option and define:
if(y,a,b) = a^y * b^(1-y)
This choice is not some profound insight; again, it is only because it will make the maths easier further down the page .
This new if()
function combines our two separate formulae for âexpected probability of the known correct answerâ into a single, simple formula:
š(i)^y * (1-š(i))^(1-y)
Maximum Likelihood Estimator for all m examples
So far, weâve only consider one sample datum, x(i), at a time. What about all m samples? You recall that the the probability of m independent events is the probability of each of the individual events all multiplied together. We use capital greek letter pi â â â meaning product, to show lots of terms being multiplied together, and write it as:
â (for i=1 to m) š(i)^y * (1-š(i))^(1-y)
Which you read aloud as âThe product, for every i from 1 to m, of y-hat-i to the power y, times one minus y-hat-i to the power 1 - yâ. Or as âmultiply all the m individual maximum likelihood estimates together.â
At this point, we have done enough work to get started on our back-propagation algorithm and train our network. We have worked out a cost function that will guide us to a maximum likelihood estimator for all the data we have.
Why donât we go with it?
4. Limitations of Computing Hardware: What kind of maths will existing computers get wrong?
Well. Suppose you have a small training set of only 1,000 examples. Then this product will be 1,000 small numbers multiplied together. Suppose our typical š(i)
is about 0.5, then the result will usually be much less than 2^-1000.
The smallest number that the IEEE 754-2008 standard for 64 bit floating point arithmetic can represent is 2^-1023. With only a thousand training examples we are already within a hairâs breadth of having cost function calculation underflow, and rounding to zero! Thatâs before weâve thought about rounding errors for doing a 1,000 multiplications. If we used single precision 32 bit arithmeticâwhich we might want to because it's about twice as fastâwe hit the underflow problem with only about a hundred training examples.
If only we could use additions instead of multiplications. That would avoid underflow and dramatically reduce rounding errors. If onlyâŚ
Computing trick: Use log() to avoid underflow and rounding errors
Those of you old enough to have used log tables at school will recall that the log()
function neatly replaces multiplication with addition:
log(a * b * c) = log(a) + log(b) + log(c)
log also replaces exponentiation with multiplication:
log(a^b) = log(a) * b
.
And, the logs of very small numbers are very big (bigly negative, that is). And log() is differentiable. The derivative of log(x)
is 1/x
.
Sounds perfect. What if, instead of using the product-of-probabilities for our cost function we could use the sum-of-logs-of-probabilities instead?
Important to the success of this trick is, that the log()
function is monotonic. That is, when something
goes up or down, log(something)
goes up or down exactly in step. So when you increase or maximise log(something)
then you simultaneously increase or maximisesomething
. And vice versa.
What this means is, we can use logs. If we find the value of đ that improves or maximises the log
of
â (for i=1 to m) š(i)^y * (1-š(i))^(1-y)
then we know for sure that that same value of đ simultaneously improves or maximises this product itself.
Let's do it. The log of the product is the sum of the logs:
â (for i=1 to m) log( š(i)^y * (1-š(i))^(1-y) )
and remembering your high school grasp of log() you can simplify even further to,
The Maximum Likelihood Estimator is the network that maximises
â (for i=1 to m) y*log( š(i) ) + (1-y)*log( 1-š(i) )
You say maximise, I say minimise
For no good reason whatsoever, we often think of optimisation problems as a minimisation challenge, not a maximisation one. For the same reason we prefer to divide by m so that the cost function is sort of averagey.
It's tradition. Whatever. So we stick a minus sign in front and divide by m and, ta-da:
-1/m * â (for i=1 to m) y*log( š(i) ) + (1-y)*log( 1-š(i) )
That is how you derive the familiar formula for the cost function for optimising a sigmoid or logistic output with back propagation and gradient descent.
Recap
The ideas that lead to this formulae are:
- Basic Probability: Interpret numbers from 0 to 1 as probability
- Statistical Insight : The best understanding of âbestâ is usually the Maximum Likelihood Estimator.
- Mathematical Tricks to Make Life Easier: A series of tricks you can do to simplify the calculations
- Limitations of Computing Hardware: What kind of maths will existing computers get wrong?
So my recommendations, in priority order, are:
- Take a course on probability and statistics. Seriously. You canât do good machine learning without a good grasp of statistics.
- Practise your high school maths
- Learn enough about your computer hardware and platform to know the gotchas.