pytorch / 2 / generated / torch.onnx.verification.graphinfo.html

GraphInfo

class torch.onnx.verification.GraphInfo(graph, input_args, params_dict, export_options=<factory>, id='', _EXCLUDED_NODE_KINDS=frozenset({'aten::ScalarImplicit', 'prim::Constant', 'prim::ListConstruct'})) [source]

GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.

all_mismatch_leaf_graph_info() [source]

Return a list of all leaf GraphInfo objects that have mismatch.

Return type

List[GraphInfo]

clear() [source]

Clear states and results of previous verification.

essential_node_count() [source]

Return the number of nodes in the subgraph excluding those in _EXCLUDED_NODE_KINDS.

Return type

int

essential_node_kinds() [source]

Return the set of node kinds in the subgraph excluding those in _EXCLUDED_NODE_KINDS.

Return type

Set[str]

export_repro(repro_dir=None, name=None) [source]

Export the subgraph to ONNX along with the input/output data for repro.

The repro directory will contain the following files:

dir
├── test_<name>
│   ├── model.onnx
│   └── test_data_set_0
│       ├── input_0.pb
│       ├── input_1.pb
│       ├── output_0.pb
│       └── output_1.pb
Parameters
  • repro_dir (Optional[str]) – The directory to export the repro files to. Defaults to current working directory if None.
  • name (Optional[str]) – An optional name for the test case folder: “test_{name}”.
Returns

The path to the exported repro directory.

Return type

str

find_mismatch(options=None) [source]

Find all mismatches between the TorchScript IR graph and the exported onnx model.

Binary searches the model graph to find the minimal subgraph that exhibits the mismatch. A GraphInfo object is created for each subgraph, recording the test inputs and export options, as well as the validation results.

Parameters

options (Optional[VerificationOptions]) – The verification options.

find_partition(id) [source]

Find the GraphInfo object with the given id.

Return type

Optional[GraphInfo]

has_mismatch() [source]

Return True if the subgraph has output mismatch between torch and ONNX.

Return type

bool

pretty_print_mismatch(graph=False) [source]

Pretty print details of the mismatch between torch and ONNX.

Parameters

graph (bool) – If True, print the ATen JIT graph and ONNX graph.

pretty_print_tree() [source]

Pretty print GraphInfo tree.

Each node represents a subgraph, showing the number of nodes in the subgraph and a check mark if the subgraph has output mismatch between torch and ONNX.

The id of the subgraph is shown under the node. The GraphInfo object for any subgraph can be retrieved by calling graph_info.find_partition(id).

Example:

==================================== Tree: =====================================
5 X   __2 X    __1 ✓
id:  |  id: 0 |  id: 00
     |        |
     |        |__1 X (aten::relu)
     |           id: 01
     |
     |__3 X    __1 ✓
        id: 1 |  id: 10
              |
              |__2 X     __1 X (aten::relu)
                 id: 11 |  id: 110
                        |
                        |__1 ✓
                           id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
verify_export(options) [source]

Verify the export from TorchScript IR graph to ONNX.

Export the TorchScript IR graph to ONNX, with the inputs, parameters and export options recorded in this object. Then verify the exported ONNX graph against the original TorchScript IR graph under the provided verification options.

Parameters

options (VerificationOptions) – The verification options.

Returns

The AssertionError raised during the verification. Returns None if no error is raised. onnx_graph: The exported ONNX graph in TorchScript IR format. onnx_outs: The outputs from running exported ONNX model under the onnx backend in options. pt_outs: The outputs from running the TorchScript IR graph.

Return type

error

© 2024, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://pytorch.org/docs/2.1/generated/torch.onnx.verification.GraphInfo.html