pytorch / 1.8.0 / generated / torch.triangular_solve.html /

torch.triangular_solve

torch.triangular_solve(input, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)

Solves a system of equations with a triangular coefficient matrix A A and multiple right-hand sides b b .

In particular, solves A X = b AX = b and assumes A A is upper-triangular with the default keyword arguments.

torch.triangular_solve(b, A) can take in 2D inputs b, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs X

Supports real-valued and complex-valued inputs.

Parameters
  • input (Tensor) – multiple right-hand sides of size ( , m , k ) (*, m, k) where * is zero of more batch dimensions ( b b )
  • A (Tensor) – the input triangular coefficient matrix of size ( , m , m ) (*, m, m) where * is zero or more batch dimensions
  • upper (bool, optional) – whether to solve the upper-triangular system of equations (default) or the lower-triangular system of equations. Default: True.
  • transpose (bool, optional) – whether A A should be transposed before being sent into the solver. Default: False.
  • unitriangular (bool, optional) – whether A A is unit triangular. If True, the diagonal elements of A A are assumed to be 1 and not referenced from A A . Default: False.
Returns

A namedtuple (solution, cloned_coefficient) where cloned_coefficient is a clone of A A and solution is the solution X X to A X = b AX = b (or whatever variant of the system of equations, depending on the keyword arguments.)

Examples:

>>> A = torch.randn(2, 2).triu()
>>> A
tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]])
>>> b = torch.randn(2, 3)
>>> b
tensor([[-0.0210,  2.3513, -1.5492],
        [ 1.5429,  0.7403, -1.0243]])
>>> torch.triangular_solve(b, A)
torch.return_types.triangular_solve(
solution=tensor([[ 1.7841,  2.9046, -2.5405],
        [ 1.9320,  0.9270, -1.2826]]),
cloned_coefficient=tensor([[ 1.1527, -1.0753],
        [ 0.0000,  0.7986]]))

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.triangular_solve.html