On this page
torch.library
Python operator registration API provides capabilities for extending PyTorch’s core library of operators with user defined operators. Currently, this can be done in two ways:
Creating new libraries
Lets you to register new operators and kernels for various backends and functionalities by specifying the appropriate dispatch keys. For example,
- Consider registering a new operator
add
in your newly created namespacefoo
. You can access this operator using thetorch.ops
API and calling into by callingtorch.ops.foo.add
. You can also access specific registered overloads by callingtorch.ops.foo.add.{overload_name}
. - If you registered a new kernel for the
CUDA
dispatch key for this operator, then your custom defined function will be called for CUDA tensor inputs.
- Consider registering a new operator
- This can be done by creating Library class objects of
"DEF"
kind.
Extending existing C++ libraries (e.g., aten)
- Lets you register kernels for existing operators corresponding to various backends and functionalities by specifying the appropriate dispatch keys.
This may come in handy to fill up spotty operator support for a feature implemented through a dispatch key. For example.,
- You can add operator support for Meta Tensors (by registering function to the
Meta
dispatch key).
- You can add operator support for Meta Tensors (by registering function to the
- This can be done by creating Library class objects of
"IMPL"
kind.
A tutorial that walks you through some examples on how to use this API is available on Google Colab.
Warning
Dispatcher is a complicated PyTorch concept and having a sound understanding of Dispatcher is crucial to be able to do anything advanced with this API. This blog post is a good starting point to learn about Dispatcher.
class torch.library.Library(ns, kind, dispatch_key='')
[source]-
A class to create libraries that can be used to register new operators or override operators in existing libraries from Python. A user can optionally pass in a dispatch keyname if they only want to register kernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”. To create a new library (with name ns) to register new operators, set the kind to “DEF”. To create a fragment of a possibly existing library to register operators (and bypass the limitation that there is only one library for a given namespace), set the kind to “FRAGMENT”.
- Parameters
-
- ns – library name
- kind – “DEF”, “IMPL” (default: “IMPL”), “FRAGMENT”
- dispatch_key – PyTorch dispatch key (default: “”)
define(schema, alias_analysis='')
[source]-
Defines a new operator and its semantics in the ns namespace.
- Parameters
-
- schema – function schema to define a new operator.
- alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can be inferred from the schema (default behavior) or not (“CONSERVATIVE”).
- Returns
-
name of the operator as inferred from the schema.
- Example::
-
>>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
impl(op_name, fn, dispatch_key='')
[source]-
Registers the function implementation for an operator defined in the library.
- Parameters
-
- op_name – operator name (along with the overload) or OpOverload object.
- fn – function that’s the operator implementation for the input dispatch key or
fallthrough_kernel()
to register a fallthrough. - dispatch_key – dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with.
- Example::
-
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()
[source]-
A dummy function to pass to
Library.impl
in order to register a fallthrough.
We have also added some function decorators to make it convenient to register functions for operators:
torch.library.impl()
torch.library.define()
© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/library.html