On this page
torch.jit.fork
torch.jit.fork(func, *args, **kwargs)[source]-
Creates an asynchronous task executing
funcand a reference to the value of the result of this execution.forkwill return immediately, so the return value offuncmay not have been computed yet. To force completion of the task and access the return value invoketorch.jit.waiton the Future.forkinvoked with afuncwhich returnsTis typed astorch.jit.Future[T].forkcalls can be arbitrarily nested, and may be invoked with positional and keyword arguments. Asynchronous execution will only occur when run in TorchScript. If run in pure python,forkwill not execute in parallel.forkwill also not execute in parallel when invoked while tracing, however theforkandwaitcalls will be captured in the exported IR Graph.Warning
forktasks will execute non-deterministically. We recommend only spawning parallel fork tasks for pure functions that do not modify their inputs, module attributes, or global state.- Parameters
-
- func (callable or torch.nn.Module) – A Python function or
torch.nn.Modulethat will be invoked. If executed in TorchScript, it will execute asynchronously, otherwise it will not. Traced invocations of fork will be captured in the IR. - *args – arguments to invoke
funcwith. - **kwargs – arguments to invoke
funcwith.
- func (callable or torch.nn.Module) – A Python function or
- Returns
-
a reference to the execution of
func. The valueTcan only be accessed by forcing completion offuncthroughtorch.jit.wait. - Return type
-
torch.jit.Future[T]
Example (fork a free function):
import torch from torch import Tensor def foo(a : Tensor, b : int) -> Tensor: return a + b def bar(a): fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) return torch.jit.wait(fut) script_bar = torch.jit.script(bar) input = torch.tensor(2) # only the scripted version executes asynchronously assert script_bar(input) == bar(input) # trace is not run asynchronously, but fork is captured in IR graph = torch.jit.trace(bar, (input,)).graph assert "fork" in str(graph)Example (fork a module method):
import torch from torch import Tensor class AddMod(torch.nn.Module): def forward(self, a: Tensor, b : int): return a + b class Mod(torch.nn.Module): def __init__(self): super(self).__init__() self.mod = AddMod() def forward(self, input): fut = torch.jit.fork(self.mod, a, b=2) return torch.jit.wait(fut) input = torch.tensor(2) mod = Mod() assert mod(input) == torch.jit.script(mod).forward(input)
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.jit.fork.html