On this page
Probability distributions - torch.distributions
The distributions package contains parameterizable probability distributions and sampling functions. This allows the construction of stochastic computation graphs and stochastic gradient estimators for optimization. This package generally follows the design of the TensorFlow Distributions package.
It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through. These are the score function estimator/likelihood ratio estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly seen as the basis for policy gradient methods in reinforcement learning, and the pathwise derivative estimator is commonly seen in the reparameterization trick in variational autoencoders. Whilst the score function only requires the value of samples , the pathwise derivative requires the derivative . The next sections discuss these two in a reinforcement learning example. For more details see Gradient Estimation Using Stochastic Computation Graphs .
Score function
When the probability density function is differentiable with respect to its parameters, we only need sample() and log_prob() to implement REINFORCE:
where are the parameters, is the learning rate, is the reward and is the probability of taking action in state given policy .
In practice we would sample an action from the output of a network, apply this action in an environment, and then use log_prob to construct an equivalent loss function. Note that we use a negative because optimizers use gradient descent, whilst the rule above assumes gradient ascent. With a categorical policy, the code for implementing REINFORCE would be as follows:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
Pathwise derivative
The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the rsample() method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable. The code for implementing the pathwise derivative would be as follows:
params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action) # Assuming that reward is differentiable
loss = -reward
loss.backward()
Distribution
class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]-
Bases:
objectDistribution is the abstract base class for probability distributions.
property arg_constraints: Dict[str, Constraint]-
Returns a dictionary from argument names to
Constraintobjects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict.
property batch_shape: Size-
Returns the shape over which parameters are batched.
cdf(value)[source]-
Returns the cumulative density/mass function evaluated at
value.
entropy()[source]-
Returns entropy of distribution, batched over batch_shape.
- Returns
-
Tensor of shape batch_shape.
- Return type
enumerate_support(expand=True)[source]-
Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be
(cardinality,) + batch_shape + event_shape(whereevent_shape = ()for univariate distributions).Note that this enumerates over all batched tensors in lock-step
[[0, 0], [1, 1], …]. Withexpand=False, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions,[[0], [1], ...To iterate over the full Cartesian product use
itertools.product(m.enumerate_support()).
property event_shape: Size-
Returns the shape of a single sample (without batching).
expand(batch_shape, _instance=None)[source]-
Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to
batch_shape. This method callsexpandon the distribution’s parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in__init__.py, when an instance is first created.- Parameters
-
- batch_shape (torch.Size) – the desired expanded size.
- _instance – new instance provided by subclasses that need to override
.expand.
- Returns
-
New distribution instance with batch dimensions expanded to
batch_size.
icdf(value)[source]-
Returns the inverse cumulative density/mass function evaluated at
value.
log_prob(value)[source]-
Returns the log of the probability density/mass function evaluated at
value.
property mean: Tensor-
Returns the mean of the distribution.
property mode: Tensor-
Returns the mode of the distribution.
perplexity()[source]-
Returns perplexity of distribution, batched over batch_shape.
- Returns
-
Tensor of shape batch_shape.
- Return type
rsample(sample_shape=torch.Size([]))[source]-
Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.
- Return type
sample(sample_shape=torch.Size([]))[source]-
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.
- Return type
sample_n(n)[source]-
Generates n samples or n batches of samples if the distribution parameters are batched.
- Return type
static set_default_validate_args(value)[source]-
Sets whether validation is enabled or disabled.
The default behavior mimics Python’s
assertstatement: validation is on by default, but is disabled if Python is run in optimized mode (viapython -O). Validation may be expensive, so you may want to disable it once a model is working.- Parameters
-
value (bool) – Whether to enable validation.
property stddev: Tensor-
Returns the standard deviation of the distribution.
property support: Optional[Any]-
Returns a
Constraintobject representing this distribution’s support.
property variance: Tensor-
Returns the variance of the distribution.
ExponentialFamily
class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]-
Bases:
DistributionExponentialFamily is the abstract base class for probability distributions belonging to an exponential family, whose probability mass/density function has the form is defined below
where denotes the natural parameters, denotes the sufficient statistic, is the log normalizer function for a given family and is the carrier measure.
Note
This class is an intermediary between the
Distributionclass and distributions which belong to an exponential family mainly to check the correctness of the.entropy()and analytic KL divergence methods. We use this class to compute the entropy and KL divergence using the AD framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families).entropy()[source]-
Method to compute the entropy using Bregman divergence of the log normalizer.
Bernoulli
class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Bernoulli distribution parameterized by
probsorlogits(but not both).Samples are binary (0 or 1). They take the value
1with probabilitypand0with probability1 - p.Example:
>>> m = Bernoulli(torch.tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 tensor([ 0.])- Parameters
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
property logits
property mean
property mode
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
support = Boolean()
property variance
Beta
class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)[source]-
Bases:
ExponentialFamilyBeta distribution parameterized by
concentration1andconcentration0.Example:
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 tensor([ 0.1046])- Parameters
arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
property concentration0
property concentration1
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=())[source]
support = Interval(lower_bound=0.0, upper_bound=1.0)
property variance
Binomial
class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a Binomial distribution parameterized by
total_countand eitherprobsorlogits(but not both).total_countmust be broadcastable withprobs/logits.Example:
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) >>> x = m.sample() tensor([ 0., 22., 71., 100.]) >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) >>> x = m.sample() tensor([[ 4., 5.], [ 7., 6.]])- Parameters
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
property logits
property mean
property mode
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
property support
property variance
Categorical
class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a categorical distribution parameterized by either
probsorlogits(but not both).Note
It is equivalent to the distribution that
torch.multinomial()samples from.Samples are integers from where
Kisprobs.size(-1).If
probsis 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.If
probsis N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.Note
The
probsargument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.probswill return this normalized value. Thelogitsargument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.See also:
torch.multinomial()Example:
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor(3)- Parameters
arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
property logits
property mean
property mode
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
property support
property variance
Cauchy
class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)[source]-
Bases:
DistributionSamples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means
0follows a Cauchy distribution.Example:
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 tensor([ 2.3214])- Parameters
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
support = Real()
property variance
Chi2
class torch.distributions.chi2.Chi2(df, validate_args=None)[source]-
Bases:
GammaCreates a Chi-squared distribution parameterized by shape parameter
df. This is exactly equivalent toGamma(alpha=0.5*df, beta=0.5)Example:
>>> m = Chi2(torch.tensor([1.0])) >>> m.sample() # Chi2 distributed with shape df=1 tensor([ 0.1046])arg_constraints = {'df': GreaterThan(lower_bound=0.0)}
property df
expand(batch_shape, _instance=None)[source]
ContinuousBernoulli
class torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a continuous Bernoulli distribution parameterized by
probsorlogits(but not both).The distribution is supported in [0, 1] and parameterized by ‘probs’ (in (0,1)) or ‘logits’ (real-valued). Note that, unlike the Bernoulli, ‘probs’ does not correspond to a probability and ‘logits’ does not correspond to log-odds, but the same names are used due to the similarity with the Bernoulli. See [1] for more details.
Example:
>>> m = ContinuousBernoulli(torch.tensor([0.3])) >>> m.sample() tensor([ 0.2538])- Parameters
[1] The continuous Bernoulli: fixing a pervasive error in variational autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property logits
property mean
property param_shape
property probs
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
property stddev
support = Interval(lower_bound=0.0, upper_bound=1.0)
property variance
Dirichlet
class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Dirichlet distribution parameterized by concentration
concentration.Example:
>>> m = Dirichlet(torch.tensor([0.5, 0.5])) >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] tensor([ 0.1046, 0.8954])- Parameters
-
concentration (Tensor) – concentration parameter of the distribution (often referred to as alpha)
arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=())[source]
support = Simplex()
property variance
Exponential
class torch.distributions.exponential.Exponential(rate, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Exponential distribution parameterized by
rate.Example:
>>> m = Exponential(torch.tensor([1.0])) >>> m.sample() # Exponential distributed with rate=1 tensor([ 0.1046])arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
property stddev
support = GreaterThanEq(lower_bound=0.0)
property variance
FisherSnedecor
class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)[source]-
Bases:
DistributionCreates a Fisher-Snedecor distribution parameterized by
df1anddf2.Example:
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 tensor([ 0.2453])- Parameters
arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
support = GreaterThan(lower_bound=0.0)
property variance
Gamma
class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Gamma distribution parameterized by shape
concentrationandrate.Example:
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # Gamma distributed with concentration=1 and rate=1 tensor([ 0.1046])- Parameters
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
support = GreaterThanEq(lower_bound=0.0)
property variance
Geometric
class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a Geometric distribution parameterized by
probs, whereprobsis the probability of success of Bernoulli trials. It represents the probability that in Bernoulli trials, the first trials failed, before seeing a success.Samples are non-negative integers [0, ).
Example:
>>> m = Geometric(torch.tensor([0.3])) >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 tensor([ 2.])- Parameters
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property logits
property mean
property mode
property probs
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
property variance
Gumbel
class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)[source]-
Bases:
TransformedDistributionSamples from a Gumbel Distribution.
Examples:
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 tensor([ 1.0124])- Parameters
arg_constraints: Dict[str, Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property mean
property mode
property stddev
support = Real()
property variance
HalfCauchy
class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)[source]-
Bases:
TransformedDistributionCreates a half-Cauchy distribution parameterized by
scalewhere:X ~ Cauchy(0, scale) Y = |X| ~ HalfCauchy(scale)Example:
>>> m = HalfCauchy(torch.tensor([1.0])) >>> m.sample() # half-cauchy distributed with scale=1 tensor([ 2.3214])arg_constraints: Dict[str, Constraint] = {'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(prob)[source]
log_prob(value)[source]
property mean
property mode
property scale
support = GreaterThanEq(lower_bound=0.0)
property variance
HalfNormal
class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)[source]-
Bases:
TransformedDistributionCreates a half-normal distribution parameterized by
scalewhere:X ~ Normal(0, scale) Y = |X| ~ HalfNormal(scale)Example:
>>> m = HalfNormal(torch.tensor([1.0])) >>> m.sample() # half-normal distributed with scale=1 tensor([ 0.1046])arg_constraints: Dict[str, Constraint] = {'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(prob)[source]
log_prob(value)[source]
property mean
property mode
property scale
support = GreaterThanEq(lower_bound=0.0)
property variance
Independent
class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)[source]-
Bases:
DistributionReinterprets some of the batch dims of a distribution as event dims.
This is mainly useful for changing the shape of the result of
log_prob(). For example to create a diagonal Normal distribution with the same shape as a Multivariate Normal distribution (so they are interchangeable), you can:>>> from torch.distributions.multivariate_normal import MultivariateNormal >>> from torch.distributions.normal import Normal >>> loc = torch.zeros(3) >>> scale = torch.ones(3) >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) >>> [mvn.batch_shape, mvn.event_shape] [torch.Size([]), torch.Size([3])] >>> normal = Normal(loc, scale) >>> [normal.batch_shape, normal.event_shape] [torch.Size([3]), torch.Size([])] >>> diagn = Independent(normal, 1) >>> [diagn.batch_shape, diagn.event_shape] [torch.Size([]), torch.Size([3])]- Parameters
-
- base_distribution (torch.distributions.distribution.Distribution) – a base distribution
- reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims
arg_constraints: Dict[str, Constraint] = {}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
property has_enumerate_support
property has_rsample
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
property support
property variance
Kumaraswamy
class torch.distributions.kumaraswamy.Kumaraswamy(concentration1, concentration0, validate_args=None)[source]-
Bases:
TransformedDistributionSamples from a Kumaraswamy distribution.
Example:
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 tensor([ 0.1729])- Parameters
arg_constraints: Dict[str, Constraint] = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
property mean
property mode
support = Interval(lower_bound=0.0, upper_bound=1.0)
property variance
LKJCholesky
class torch.distributions.lkj_cholesky.LKJCholesky(dim, concentration=1.0, validate_args=None)[source]-
Bases:
DistributionLKJ distribution for lower Cholesky factor of correlation matrices. The distribution is controlled by
concentrationparameter to make the probability of the correlation matrix generated from a Cholesky factor proportional to . Because of that, whenconcentration == 1, we have a uniform distribution over Cholesky factors of correlation matrices:L ~ LKJCholesky(dim, concentration) X = L @ L' ~ LKJCorr(dim, concentration)Note that this distribution samples the Cholesky factor of correlation matrices and not the correlation matrices themselves and thereby differs slightly from the derivations in [1] for the
LKJCorrdistribution. For sampling, this uses the Onion method from [1] Section 3.Example:
>>> l = LKJCholesky(3, 0.5) >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix tensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])- Parameters
References
[1]
Generating random correlation matrices based on vines and extended onion method(2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
sample(sample_shape=torch.Size([]))[source]
support = CorrCholesky()
Laplace
class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)[source]-
Bases:
DistributionCreates a Laplace distribution parameterized by
locandscale.Example:
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # Laplace distributed with loc=0, scale=1 tensor([ 0.1046])- Parameters
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
property stddev
support = Real()
property variance
LogNormal
class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)[source]-
Bases:
TransformedDistributionCreates a log-normal distribution parameterized by
locandscalewhere:X ~ Normal(loc, scale) Y = exp(X) ~ LogNormal(loc, scale)Example:
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # log-normal distributed with mean=0 and stddev=1 tensor([ 0.1046])- Parameters
arg_constraints: Dict[str, Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
property loc
property mean
property mode
property scale
support = GreaterThan(lower_bound=0.0)
property variance
LowRankMultivariateNormal
class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]-
Bases:
DistributionCreates a multivariate normal distribution with covariance matrix having a low-rank form parameterized by
cov_factorandcov_diag:covariance_matrix = cov_factor @ cov_factor.T + cov_diagExample
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429])- Parameters
-
- loc (Tensor) – mean of the distribution with shape
batch_shape + event_shape - cov_factor (Tensor) – factor part of low-rank form of covariance matrix with shape
batch_shape + event_shape + (rank,) - cov_diag (Tensor) – diagonal part of low-rank form of covariance matrix with shape
batch_shape + event_shape
- loc (Tensor) – mean of the distribution with shape
Note
The computation for determinant and inverse of covariance matrix is avoided when
cov_factor.shape[1] << cov_factor.shape[0]thanks to Woodbury matrix identity and matrix determinant lemma. Thanks to these formulas, we just need to compute the determinant and inverse of the small size “capacitance” matrix:capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factorarg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}
property covariance_matrix
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
property precision_matrix
rsample(sample_shape=torch.Size([]))[source]
property scale_tril
support = IndependentConstraint(Real(), 1)
property variance
MixtureSameFamily
class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)[source]-
Bases:
DistributionThe
MixtureSameFamilydistribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type. It is parameterized by aCategorical“selecting distribution” (overkcomponent) and a component distribution, i.e., aDistributionwith a rightmost batch shape (equal to[k]) which indexes each (batch of) component.Examples:
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally >>> # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally >>> # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( ... torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each >>> # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp)- Parameters
-
- mixture_distribution –
torch.distributions.Categorical-like instance. Manages the probability of selecting component. The number of categories must match the rightmost batch dimension of thecomponent_distribution. Must have either scalarbatch_shapeorbatch_shapematchingcomponent_distribution.batch_shape[:-1] - component_distribution –
torch.distributions.Distribution-like instance. Right-most batch dimension indexes component.
- mixture_distribution –
arg_constraints: Dict[str, Constraint] = {}
cdf(x)[source]
property component_distribution
expand(batch_shape, _instance=None)[source]
has_rsample = False
log_prob(x)[source]
property mean
property mixture_distribution
sample(sample_shape=torch.Size([]))[source]
property support
property variance
Multinomial
class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a Multinomial distribution parameterized by
total_countand eitherprobsorlogits(but not both). The innermost dimension ofprobsindexes over categories. All other dimensions index over batches.Note that
total_countneed not be specified if onlylog_prob()is called (see example below)Note
The
probsargument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.probswill return this normalized value. Thelogitsargument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.sample()requires a single sharedtotal_countfor all parameters and samples.log_prob()allows differenttotal_countfor each parameter and sample.
Example:
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) >>> x = m.sample() # equal probability of 0, 1, 2, 3 tensor([ 21., 24., 30., 25.]) >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) tensor([-4.1338])- Parameters
arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}
entropy()[source]
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property logits
property mean
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
property support
total_count: int
property variance
MultivariateNormal
class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]-
Bases:
DistributionCreates a multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix.
The multivariate normal distribution can be parameterized either in terms of a positive definite covariance matrix or a positive definite precision matrix or a lower-triangular matrix with positive-valued diagonal entries, such that . This triangular matrix can be obtained via e.g. Cholesky decomposition of the covariance.
Example
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` tensor([-0.2102, -0.5429])- Parameters
Note
Only one of
covariance_matrixorprecision_matrixorscale_trilcan be specified.Using
scale_trilwill be more efficient: all computations internally are based onscale_tril. Ifcovariance_matrixorprecision_matrixis passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition.arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
property covariance_matrix
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
property precision_matrix
rsample(sample_shape=torch.Size([]))[source]
property scale_tril
support = IndependentConstraint(Real(), 1)
property variance
NegativeBinomial
class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a Negative Binomial distribution, i.e. distribution of the number of successful independent and identical Bernoulli trials before
total_countfailures are achieved. The probability of success of each Bernoulli trial isprobs.- Parameters
arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property logits
property mean
property mode
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
property variance
Normal
class torch.distributions.normal.Normal(loc, scale, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a normal (also called Gaussian) distribution parameterized by
locandscale.Example:
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # normally distributed with loc=0 and scale=1 tensor([ 0.1046])- Parameters
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
property stddev
support = Real()
property variance
OneHotCategorical
class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a one-hot categorical distribution parameterized by
probsorlogits.Samples are one-hot coded vectors of size
probs.size(-1).Note
The
probsargument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension.probswill return this normalized value. Thelogitsargument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension.logitswill return this normalized value.See also:
torch.distributions.Categorical()for specifications ofprobsandlogits.Example:
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.])- Parameters
arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
property logits
property mean
property mode
property param_shape
property probs
sample(sample_shape=torch.Size([]))[source]
support = OneHot()
property variance
Pareto
class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)[source]-
Bases:
TransformedDistributionSamples from a Pareto Type 1 distribution.
Example:
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 tensor([ 1.5623])- Parameters
arg_constraints: Dict[str, Constraint] = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
property mean
property mode
property support
property variance
Poisson
class torch.distributions.poisson.Poisson(rate, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Poisson distribution parameterized by
rate, the rate parameter.Samples are nonnegative integers, with a pmf given by
Example:
>>> m = Poisson(torch.tensor([4])) >>> m.sample() tensor([ 3.])- Parameters
-
rate (Number, Tensor) – the rate parameter
arg_constraints = {'rate': GreaterThanEq(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property mean
property mode
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
property variance
RelaxedBernoulli
class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]-
Bases:
TransformedDistributionCreates a RelaxedBernoulli distribution, parametrized by
temperature, and eitherprobsorlogits(but not both). This is a relaxed version of theBernoullidistribution, so the values are in (0, 1), and has reparametrizable samples.Example:
>>> m = RelaxedBernoulli(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.99])) >>> m.sample() tensor([ 0.2951, 0.3442, 0.8918, 0.9021])- Parameters
arg_constraints: Dict[str, Constraint] = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
expand(batch_shape, _instance=None)[source]
has_rsample = True
property logits
property probs
support = Interval(lower_bound=0.0, upper_bound=1.0)
property temperature
LogitRelaxedBernoulli
class torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]-
Bases:
DistributionCreates a LogitRelaxedBernoulli distribution parameterized by
probsorlogits(but not both), which is the logit of a RelaxedBernoulli distribution.Samples are logits of values in (0, 1). See [1] for more details.
- Parameters
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
property logits
property param_shape
property probs
rsample(sample_shape=torch.Size([]))[source]
support = Real()
RelaxedOneHotCategorical
class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)[source]-
Bases:
TransformedDistributionCreates a RelaxedOneHotCategorical distribution parametrized by
temperature, and eitherprobsorlogits. This is a relaxed version of theOneHotCategoricaldistribution, so its samples are on simplex, and are reparametrizable.Example:
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() tensor([ 0.1294, 0.2324, 0.3859, 0.2523])- Parameters
arg_constraints: Dict[str, Constraint] = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}
expand(batch_shape, _instance=None)[source]
has_rsample = True
property logits
property probs
support = Simplex()
property temperature
StudentT
class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]-
Bases:
DistributionCreates a Student’s t-distribution parameterized by degree of freedom
df, meanlocand scalescale.Example:
>>> m = StudentT(torch.tensor([2.0])) >>> m.sample() # Student's t-distributed with degrees of freedom=2 tensor([ 0.1046])- Parameters
arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
support = Real()
property variance
TransformedDistribution
class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)[source]-
Bases:
DistributionExtension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:
X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)|Note that the
.event_shapeof aTransformedDistributionis the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.An example for the usage of
TransformedDistributionwould be:# Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms)For more examples, please look at the implementations of
Gumbel,HalfCauchy,HalfNormal,LogNormal,Pareto,Weibull,RelaxedBernoulliandRelaxedOneHotCategoricalarg_constraints: Dict[str, Constraint] = {}
cdf(value)[source]-
Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.
expand(batch_shape, _instance=None)[source]
property has_rsample
icdf(value)[source]-
Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.
log_prob(value)[source]-
Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.
rsample(sample_shape=torch.Size([]))[source]-
Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies
transform()for every transform in the list.
sample(sample_shape=torch.Size([]))[source]-
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies
transform()for every transform in the list.
property support
Uniform
class torch.distributions.uniform.Uniform(low, high, validate_args=None)[source]-
Bases:
DistributionGenerates uniformly distributed random samples from the half-open interval
[low, high).Example:
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) >>> m.sample() # uniformly distributed in the range [0.0, 5.0) tensor([ 2.3418])- Parameters
arg_constraints = {'high': Dependent(), 'low': Dependent()}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
property mean
property mode
rsample(sample_shape=torch.Size([]))[source]
property stddev
property support
property variance
VonMises
class torch.distributions.von_mises.VonMises(loc, concentration, validate_args=None)[source]-
Bases:
DistributionA circular von Mises distribution.
This implementation uses polar coordinates. The
locandvalueargs can be any real number (to facilitate unconstrained optimization), but are interpreted as angles modulo 2 pi.- Example::
-
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # von Mises distributed with loc=1 and concentration=1 tensor([1.9777])
- Parameters
-
- loc (torch.Tensor) – an angle in radians.
- concentration (torch.Tensor) – concentration parameter
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}
expand(batch_shape)[source]
has_rsample = False
log_prob(value)[source]
property mean-
The provided mean is the circular one.
property mode
sample(sample_shape=torch.Size([]))[source]-
The sampling algorithm for the von Mises distribution is based on the following paper: Best, D. J., and Nicholas I. Fisher. “Efficient simulation of the von Mises distribution.” Applied Statistics (1979): 152-157.
support = Real()
property variance-
The provided variance is the circular one.
Weibull
class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)[source]-
Bases:
TransformedDistributionSamples from a two-parameter Weibull distribution.
Example
>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 tensor([ 0.4784])- Parameters
arg_constraints: Dict[str, Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
property mean
property mode
support = GreaterThan(lower_bound=0.0)
property variance
Wishart
class torch.distributions.wishart.Wishart(df, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]-
Bases:
ExponentialFamilyCreates a Wishart distribution parameterized by a symmetric positive definite matrix , or its Cholesky decomposition
Example
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) >>> m.sample() # Wishart distributed with mean=`df * I` and >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j- Parameters
-
- df (float or Tensor) – real-valued parameter larger than the (dimension of Square matrix) - 1
- covariance_matrix (Tensor) – positive-definite covariance matrix
- precision_matrix (Tensor) – positive-definite precision matrix
- scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal
Note
Only one of
covariance_matrixorprecision_matrixorscale_trilcan be specified. Usingscale_trilwill be more efficient: all computations internally are based onscale_tril. Ifcovariance_matrixorprecision_matrixis passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition. ‘torch.distributions.LKJCholesky’ is a restricted Wishart distribution.[1]References
[1] Wang, Z., Wu, Y. and Chu, H., 2018.
On equivalence of the LKJ distribution and the restricted Wishart distribution. [2] Sawyer, S., 2007.Wishart Distributions and Inverse-Wishart Sampling. [3] Anderson, T. W., 2003.An Introduction to Multivariate Statistical Analysis (3rd ed.). [4] Odell, P. L. & Feiveson, A. H., 1966.A Numerical Procedure to Generate a SampleCovariance Matrix. JASA, 61(313):199-203. [5] Ku, Y.-C. & Bloomfield, P., 2010.Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX.arg_constraints = {'covariance_matrix': PositiveDefinite(), 'df': GreaterThan(lower_bound=0), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
property covariance_matrix
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
property mean
property mode
property precision_matrix
rsample(sample_shape=torch.Size([]), max_try_correction=None)[source]-
Warning
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. Several tries to correct singular samples are performed by default, but it may end up returning singular matrix samples. Singular samples may return
-infvalues in.log_prob(). In those cases, the user should validate the samples and either fix the value ofdfor adjustmax_try_correctionvalue for argument in.rsampleaccordingly.
property scale_tril
support = PositiveDefinite()
property variance
KL Divergence
torch.distributions.kl.kl_divergence(p, q)[source]-
Compute Kullback-Leibler divergence between two distributions.
- Parameters
-
- p (Distribution) – A
Distributionobject. - q (Distribution) – A
Distributionobject.
- p (Distribution) – A
- Returns
-
A batch of KL divergences of shape
batch_shape. - Return type
- Raises
-
NotImplementedError – If the distribution types have not been registered via
register_kl().
- KL divergence is currently implemented for the following distribution pairs:
-
BernoulliandBernoulliBernoulliandPoissonBetaandBetaBetaandContinuousBernoulliBetaandExponentialBetaandGammaBetaandNormalBetaandParetoBetaandUniformBinomialandBinomialCategoricalandCategoricalCauchyandCauchyContinuousBernoulliandContinuousBernoulliContinuousBernoulliandExponentialContinuousBernoulliandNormalContinuousBernoulliandParetoContinuousBernoulliandUniformDirichletandDirichletExponentialandBetaExponentialandContinuousBernoulliExponentialandExponentialExponentialandGammaExponentialandGumbelExponentialandNormalExponentialandParetoExponentialandUniformExponentialFamilyandExponentialFamilyGammaandBetaGammaandContinuousBernoulliGammaandExponentialGammaandGammaGammaandGumbelGammaandNormalGammaandParetoGammaandUniformGeometricandGeometricGumbelandBetaGumbelandContinuousBernoulliGumbelandExponentialGumbelandGammaGumbelandGumbelGumbelandNormalGumbelandParetoGumbelandUniformHalfNormalandHalfNormalIndependentandIndependentLaplaceandBetaLaplaceandContinuousBernoulliLaplaceandExponentialLaplaceandGammaLaplaceandLaplaceLaplaceandNormalLaplaceandParetoLaplaceandUniformLowRankMultivariateNormalandLowRankMultivariateNormalLowRankMultivariateNormalandMultivariateNormalMultivariateNormalandLowRankMultivariateNormalMultivariateNormalandMultivariateNormalNormalandBetaNormalandContinuousBernoulliNormalandExponentialNormalandGammaNormalandGumbelNormalandLaplaceNormalandNormalNormalandParetoNormalandUniformOneHotCategoricalandOneHotCategoricalParetoandBetaParetoandContinuousBernoulliParetoandExponentialParetoandGammaParetoandNormalParetoandParetoParetoandUniformPoissonandBernoulliPoissonandBinomialPoissonandPoissonTransformedDistributionandTransformedDistributionUniformandBetaUniformandContinuousBernoulliUniformandExponentialUniformandGammaUniformandGumbelUniformandNormalUniformandParetoUniformandUniform
torch.distributions.kl.register_kl(type_p, type_q)[source]-
Decorator to register a pairwise function with
kl_divergence(). Usage:@register_kl(Normal, Normal) def kl_normal_normal(p, q): # insert implementation hereLookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a
RuntimeWarningis raised. For example to resolve the ambiguous situation:@register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ...you should register a third most-specific implementation, e.g.:
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
Transforms
class torch.distributions.transforms.AbsTransform(cache_size=0)[source]-
Transform via the mapping .
class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)[source]-
Transform via the pointwise affine mapping .
class torch.distributions.transforms.CatTransform(tseq, dim=0, lengths=None, cache_size=0)[source]-
Transform functor that applies a sequence of transforms
tseqcomponent-wise to each submatrix atdim, of lengthlengths[dim], in a way compatible withtorch.cat().Example:
x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) x = torch.cat([x0, x0], dim=0) t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) y = t(x)
class torch.distributions.transforms.ComposeTransform(parts, cache_size=0)[source]-
Composes multiple transforms in a chain. The transforms being composed are responsible for caching.
class torch.distributions.transforms.CorrCholeskyTransform(cache_size=0)[source]-
Transforms an uncontrained real vector with length into the Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean norm for each row. The transform is processed as follows:
- First we convert x into a lower triangular matrix in row order.
- For each row
of the lower triangular part, we apply a signed version of class
StickBreakingTransformto transform into a unit Euclidean length vector using the following steps: - Scales into the interval domain: . - Transforms into an unsigned domain: . - Applies . - Transforms back into signed domain: .
class torch.distributions.transforms.CumulativeDistributionTransform(distribution, cache_size=0)[source]-
Transform via the cumulative distribution function of a probability distribution.
- Parameters
-
distribution (Distribution) – Distribution whose cumulative distribution function to use for the transformation.
Example:
# Construct a Gaussian copula from a multivariate normal. base_dist = MultivariateNormal( loc=torch.zeros(2), scale_tril=LKJCholesky(2).sample(), ) transform = CumulativeDistributionTransform(Normal(0, 1)) copula = TransformedDistribution(base_dist, [transform])
class torch.distributions.transforms.ExpTransform(cache_size=0)[source]-
Transform via the mapping .
class torch.distributions.transforms.IndependentTransform(base_transform, reinterpreted_batch_ndims, cache_size=0)[source]-
Wrapper around another transform to treat
reinterpreted_batch_ndims-many extra of the right most dimensions as dependent. This has no effect on the forward or backward transforms, but does sum outreinterpreted_batch_ndims-many of the rightmost dimensions inlog_abs_det_jacobian().
class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)[source]-
Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries.
This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.
class torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)[source]-
Transform from unconstrained matrices to positive-definite matrices.
class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)[source]-
Transform via the mapping .
class torch.distributions.transforms.ReshapeTransform(in_shape, out_shape, cache_size=0)[source]-
Unit Jacobian transform to reshape the rightmost part of a tensor.
Note that
in_shapeandout_shapemust have the same number of elements, just as fortorch.Tensor.reshape().- Parameters
-
- in_shape (torch.Size) – The input event shape.
- out_shape (torch.Size) – The output event shape.
class torch.distributions.transforms.SigmoidTransform(cache_size=0)[source]-
Transform via the mapping and .
class torch.distributions.transforms.SoftplusTransform(cache_size=0)[source]-
Transform via the mapping . The implementation reverts to the linear function when .
class torch.distributions.transforms.TanhTransform(cache_size=0)[source]-
Transform via the mapping .
It is equivalent to
` ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) `However this might not be numerically stable, thus it is recommended to useTanhTransforminstead.Note that one should use
cache_size=1when it comes toNaN/Infvalues.
class torch.distributions.transforms.SoftmaxTransform(cache_size=0)[source]-
Transform from unconstrained space to the simplex via then normalizing.
This is not bijective and cannot be used for HMC. However this acts mostly coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.
class torch.distributions.transforms.StackTransform(tseq, dim=0, cache_size=0)[source]-
Transform functor that applies a sequence of transforms
tseqcomponent-wise to each submatrix atdimin a way compatible withtorch.stack().Example:
x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) t = StackTransform([ExpTransform(), identity_transform], dim=1) y = t(x)
class torch.distributions.transforms.StickBreakingTransform(cache_size=0)[source]-
Transform from unconstrained space to the simplex of one additional dimension via a stick-breaking process.
This transform arises as an iterated sigmoid transform in a stick-breaking construction of the
Dirichletdistribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization.
class torch.distributions.transforms.Transform(cache_size=0)[source]-
Abstract class for invertable transformations with computable log det jacobians. They are primarily used in
torch.distributions.TransformedDistribution.Caching is useful for transforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values since the autograd graph may be reversed. For example while the following works with or without caching:
y = t(x) t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.However the following will error when caching due to dependency reversal:
y = t(x) z = t.inv(y) grad(z.sum(), [y]) # error because z is xDerived classes should implement one or both of
_call()or_inverse(). Derived classes that setbijective=Trueshould also implementlog_abs_det_jacobian().- Parameters
-
cache_size (int) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.
- Variables
-
- domain (
Constraint) – The constraint representing valid inputs to this transform. - codomain (
Constraint) – The constraint representing valid outputs to this transform which are inputs to the inverse transform. - bijective (bool) – Whether this transform is bijective. A transform
tis bijective ifft.inv(t(x)) == xandt(t.inv(y)) == yfor everyxin the domain andyin the codomain. Transforms that are not bijective should at least maintain the weaker pseudoinverse propertiest(t.inv(t(x)) == t(x)andt.inv(t(t.inv(y))) == t.inv(y). - sign (int or Tensor) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing.
- domain (
property inv-
Returns the inverse
Transformof this transform. This should satisfyt.inv.inv is t.
property sign-
Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.
log_abs_det_jacobian(x, y)[source]-
Computes the log det jacobian
log |dy/dx|given input and output.
forward_shape(shape)[source]-
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)[source]-
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
Constraints
The following constraints are implemented:
constraints.booleanconstraints.catconstraints.corr_choleskyconstraints.dependentconstraints.greater_than(lower_bound)constraints.greater_than_eq(lower_bound)constraints.independent(constraint, reinterpreted_batch_ndims)constraints.integer_interval(lower_bound, upper_bound)constraints.interval(lower_bound, upper_bound)constraints.less_than(upper_bound)constraints.lower_choleskyconstraints.lower_triangularconstraints.multinomialconstraints.nonnegativeconstraints.nonnegative_integerconstraints.one_hotconstraints.positive_integerconstraints.positiveconstraints.positive_semidefiniteconstraints.positive_definiteconstraints.real_vectorconstraints.realconstraints.simplexconstraints.symmetricconstraints.stackconstraints.squareconstraints.symmetricconstraints.unit_interval
class torch.distributions.constraints.Constraint[source]-
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.
- Variables
check(value)[source]-
Returns a byte tensor of
sample_shape + batch_shapeindicating whether each event in value satisfies this constraint.
torch.distributions.constraints.cat-
alias of
_Cat
torch.distributions.constraints.dependent_property-
alias of
_DependentProperty
torch.distributions.constraints.greater_than-
alias of
_GreaterThan
torch.distributions.constraints.greater_than_eq-
alias of
_GreaterThanEq
torch.distributions.constraints.independent-
alias of
_IndependentConstraint
torch.distributions.constraints.integer_interval-
alias of
_IntegerInterval
torch.distributions.constraints.interval-
alias of
_Interval
torch.distributions.constraints.half_open_interval-
alias of
_HalfOpenInterval
torch.distributions.constraints.less_than-
alias of
_LessThan
torch.distributions.constraints.multinomial-
alias of
_Multinomial
torch.distributions.constraints.stack-
alias of
_Stack
Constraint Registry
PyTorch provides two global ConstraintRegistry objects that link Constraint objects to Transform objects. These objects both input constraints and return transforms, but they have different guarantees on bijectivity.
biject_to(constraint)looks up a bijectiveTransformfromconstraints.realto the givenconstraint. The returned transform is guaranteed to have.bijective = Trueand should implement.log_abs_det_jacobian().transform_to(constraint)looks up a not-necessarily bijectiveTransformfromconstraints.realto the givenconstraint. The returned transform is not guaranteed to implement.log_abs_det_jacobian().
The transform_to() registry is useful for performing unconstrained optimization on constrained parameters of probability distributions, which are indicated by each distribution’s .arg_constraints dict. These transforms often overparameterize a space in order to avoid rotation; they are thus more suitable for coordinate-wise optimization algorithms like Adam:
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
The biject_to() registry is useful for Hamiltonian Monte Carlo, where samples from a probability distribution with constrained .support are propagated in an unconstrained space, and algorithms are typically rotation invariant.:
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
Note
An example where transform_to and biject_to differ is constraints.simplex: transform_to(constraints.simplex) returns a SoftmaxTransform that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, biject_to(constraints.simplex) returns a StickBreakingTransform that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC.
The biject_to and transform_to objects can be extended by user-defined constraints and transforms using their .register() method either as a function on singleton constraints:
transform_to.register(my_constraint, my_transform)
or as a decorator on parameterized constraints:
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
You can create your own registry by creating a new ConstraintRegistry object.
class torch.distributions.constraint_registry.ConstraintRegistry[source]-
Registry to link constraints to transforms.
register(constraint, factory=None)[source]-
Registers a
Constraintsubclass in this registry. Usage:@my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints)- Parameters
-
- constraint (subclass of
Constraint) – A subclass ofConstraint, or a singleton object of the desired class. - factory (Callable) – A callable that inputs a constraint object and returns a
Transformobject.
- constraint (subclass of
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/distributions.html