What does log_prob do?
Solution 1
As your own answer mentions, log_prob
returns the logarithm of the density or probability. Here I will address the remaining points in your question:
- How is that different from
log
? Distributions do not have a methodlog
. If they did, the closest possible interpretation would indeed be something likelog_prob
but it would not be a very precise name since if begs the question "log of what"? A distribution has multiple numeric properties (for example its mean, variance, etc) and the probability or density is just one of them, so the name would be ambiguous.
The same does not apply to the Tensor.log()
method (which may be what you had in mind) because Tensor
is itself a mathematical quantity we can take the log of.
- Why take the log of a probability only to exponentiate it later? You may not need to exponentiate it later. For example, if you have the logs of probabilities
p
andq
, then you can directly computelog(p * q)
aslog(p) + log(q)
, avoiding intermediate exponentiations. This is more numerically stable (avoiding underflow) because probabilities may become very close to zero while their logs do not. Addition is also more efficient than multiplication in general, and its derivative is simpler. There is a good article about those topics at https://en.wikipedia.org/wiki/Log_probability.
Solution 2
Part of the answer is that log_prob
returns the log of the probability density/mass function evaluated at the given sample value.
Solution 3
log_prob
takes the log of the probability (of some actions). Example:
action_logits = torch.rand(5)
action_probs = F.softmax(action_logits, dim=-1)
action_probs
Returns:
tensor([0.1457, 0.2831, 0.1569, 0.2221, 0.1922])
Then:
dist = Categorical(action_probs)
action = dist.sample()
print(dist.log_prob(action), torch.log(action_probs[action]))
Returns:
tensor(-1.8519) tensor(-1.8519)
Solution 4
We can go through an easy example to understand what the log_prob
function has done.
Firstly, generate a probability a
by using a uniform function bouned in [0, 1]
,
import torch.distributions as D
import torch
a = torch.empty(1).uniform_(0, 1)
a # OUTPUT: tensor([0.3291])
then, based on this probability and the python class torch.distributions.Bernoulli
, we can instantiate a Bernoulli distribution b
(which generate 1
with probability a=0.3291
while generating 0
with probability 1-a=0.6709
in each Bernoulli experiment),
b = D.Bernoulli(a)
b # OUTPUT: Bernoulli()
Here, we can take one Bernoulli experiment to get a sample c
(hold that we have 0.3291
probability to get 1
while 0.6709
probability to get 0
),
c = b.sample()
c # OUTPUT: tensor([0.])
With the Bernoulli distribution b
and one specific sample c
, we can get the logarithmic probability of the experiment result (c
) under a specific distribution (Bernoulli distribution b
) as, (or officially, returns the log of the probability density/mass function evaluated at value (c
))
b.log_prob(c)
b # OUTPUT: tensor([-0.3991])
As we already know the probability for each sample to be 0
(for one experiment, the probability can be simply viewed as probability density/mass function of this experiment) is 0.6709
, so we can verify the log_prob
result with,
torch.log(torch.tensor(0.6709)) # OUTPUT: tensor(-0.3991)
So, it means that b.log_prob(c)
is the log of the probability density/mass function evaluated at value (c
)
Hope it works for you.
Related videos on Youtube
cerebrou
Updated on June 04, 2022Comments
-
cerebrou almost 2 years
In some (e.g. machine learning) libraries, we can find
log_prob
function. What does it do and how is it different from taking just regularlog
?For example, what is the purpose of this code:
dist = Normal(mean, std) sample = dist.sample() logprob = dist.log_prob(sample)
And subsequently, why would we first take a log and then exponentiate the resulting value instead of just evaluating it directly:
prob = torch.exp(dist.log_prob(sample))
-
JustinBlaber over 3 yearsYou ever find an answer? I was kind of hoping there was a direct way to compute PDFs in torch. This is close but annoying you have to exp it.
-