-
Notifications
You must be signed in to change notification settings - Fork 405
user provided bound for torchtrt compile when export dimension is unb… #4213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d1b9ebf
4cb00ce
487e868
457213a
75d9826
37726a6
fd52275
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,11 +5,13 @@ | |
| import os | ||
| import platform | ||
| import warnings | ||
| from typing import Any, Collection, List, Optional, Sequence, Union | ||
| from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union | ||
|
|
||
| import sympy | ||
| import torch | ||
| from torch.export import ExportedProgram | ||
| from torch.fx.node import Target | ||
| from torch.utils._sympy.numbers import int_oo | ||
| from torch_tensorrt._Device import Device | ||
| from torch_tensorrt._enums import EngineCapability, dtype | ||
| from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile | ||
|
|
@@ -874,6 +876,139 @@ def _insert_complex_io_adapters( | |
| partitioned_module.recompile() | ||
|
|
||
|
|
||
| def _build_user_symbol_bounds( | ||
| gm: torch.fx.GraphModule, | ||
| sample_arg_inputs: Sequence[Input], | ||
| sample_kwarg_inputs: dict[Any, Any], | ||
| ) -> Dict[sympy.Symbol, Tuple[int, int]]: | ||
| """Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to | ||
| fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``. | ||
| Validates against finite exporter bounds: ``user_max > exp_max`` and | ||
| ``user_min < exp_min`` raise (TRT would reject those shapes at runtime); | ||
| a strict subset narrows the engine profile to the user's bounds (info | ||
| log only); the ``user_min=1, exp_min=2`` case warns -- it's PyTorch's | ||
| 0/1 specialization artifact, not a user error. | ||
| """ | ||
| placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] | ||
|
|
||
| sample_by_name: dict[str, Input] = {} | ||
| for i, node in enumerate(placeholders): | ||
| if i < len(sample_arg_inputs): | ||
| inp = sample_arg_inputs[i] | ||
| if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: | ||
| sample_by_name[node.target] = inp | ||
| for name, inp in sample_kwarg_inputs.items(): | ||
| if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC: | ||
| sample_by_name[name] = inp | ||
|
|
||
| user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {} | ||
| if not sample_by_name: | ||
| return user_symbol_bounds | ||
|
|
||
| for node in placeholders: | ||
| if node.target not in sample_by_name: | ||
| continue | ||
| sample_input = sample_by_name[node.target] | ||
| fake_val = node.meta.get("val") | ||
| if not isinstance(fake_val, torch.Tensor): | ||
| continue | ||
|
|
||
| min_shape = sample_input.shape["min_shape"] | ||
| max_shape = sample_input.shape["max_shape"] | ||
|
|
||
| for d, dim in enumerate(fake_val.size()): | ||
| if not isinstance(dim, torch.SymInt) or d >= len(min_shape): | ||
| continue | ||
| expr = dim.node.expr | ||
| # Composite exprs (e.g. ``2*s0``) are recomputed by | ||
| # ``ShapeEnv.bound_sympy``; overriding them directly would lie. | ||
| if not isinstance(expr, sympy.Symbol): | ||
| continue | ||
| if expr in user_symbol_bounds: | ||
| continue | ||
| user_min = int(min_shape[d]) | ||
| user_max = int(max_shape[d]) | ||
| user_symbol_bounds[expr] = (user_min, user_max) | ||
| logger.debug( | ||
| "Recorded user-supplied bounds for %s: [%d, %d]", | ||
| expr, | ||
| user_min, | ||
| user_max, | ||
| ) | ||
|
|
||
| # If exporter bounds are finite, ``extract_var_range_info`` keeps | ||
| # them (override is gated on ``max_val is None``). Catch the | ||
| # mismatch here so the user doesn't hit a runtime "shape outside | ||
| # profile" error on shapes they explicitly declared. | ||
| shape_env = getattr(dim.node, "shape_env", None) | ||
| if shape_env is None: | ||
| continue | ||
| exp_range = shape_env.var_to_range.get(expr) | ||
| if exp_range is None: | ||
| continue | ||
| exp_lower = exp_range.lower | ||
| exp_upper = exp_range.upper | ||
| exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo | ||
| if exp_max_unbounded: | ||
| # Dim.DYNAMIC: user fills the gap (intended use). | ||
| continue | ||
| try: | ||
| exp_min = int(exp_lower) | ||
| exp_max = int(exp_upper) | ||
| except (TypeError, ValueError): | ||
| continue | ||
| if user_min == exp_min and user_max == exp_max: | ||
| continue | ||
|
|
||
| mismatch = ( | ||
| f"symbol {expr}: Input({user_min}, {user_max}) vs " | ||
| f"exporter({exp_min}, {exp_max})." | ||
| ) | ||
| hint = ( | ||
| f" Re-export with Dim('{expr}', min={user_min}, " | ||
| f"max={user_max}) or adjust Input to match." | ||
| ) | ||
|
|
||
| if user_max > exp_max: | ||
| raise ValueError( | ||
| f"{mismatch} Input.max_shape exceeds the exporter's max " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the exported program right? |
||
| f"({user_max} > {exp_max}); TRT will reject shapes above " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TRT will reject or pytorch will? |
||
| f"{exp_max} at runtime.{hint}" | ||
| ) | ||
|
|
||
| if user_min < exp_min: | ||
| # 1->2 is the 0/1 specialization artifact, not a user error. | ||
| if user_min == 1 and exp_min == 2: | ||
| logger.warning( | ||
| "%s Input.min_shape=1 vs exporter min=2 is the " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to make sure these error messages are clear to follow from a user perspective |
||
| "PyTorch 0/1 specialization artifact; TRT engine " | ||
| "min will be 2.", | ||
| mismatch, | ||
| ) | ||
| continue | ||
| raise ValueError( | ||
| f"{mismatch} Input.min_shape is below the exporter's min " | ||
| f"({user_min} < {exp_min}); TRT will reject shapes " | ||
| f"below {exp_min} at runtime.{hint}" | ||
| ) | ||
|
|
||
| # Strict subset: engine profile narrows to the user's bounds | ||
| # (applied in ``extract_var_range_info``). Not a warning -- the | ||
| # user got exactly what they asked for. | ||
| logger.info( | ||
| "%s Narrowing engine profile to user bounds [%d, %d] " | ||
| "(exporter range was [%d, %d]).", | ||
| mismatch, | ||
| user_min, | ||
| user_max, | ||
| exp_min, | ||
| exp_max, | ||
| ) | ||
|
|
||
| return user_symbol_bounds | ||
|
|
||
|
|
||
| @fn_supports_debugger # type: ignore[misc] | ||
| def compile_module( | ||
| gm: torch.fx.GraphModule, | ||
|
|
@@ -905,6 +1040,12 @@ def compile_module( | |
| if sample_kwarg_inputs is None: | ||
| sample_kwarg_inputs = {} | ||
|
|
||
| # Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds. | ||
| # Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export. | ||
| user_symbol_bounds = _build_user_symbol_bounds( | ||
| gm, sample_arg_inputs, sample_kwarg_inputs | ||
| ) | ||
|
|
||
| # Configure user compilation settings to converters. | ||
| CONVERTERS.set_compilation_settings(settings) | ||
|
|
||
|
|
@@ -1086,7 +1227,9 @@ def preserve_module_specs( | |
| ) | ||
|
|
||
| # Get the submodule inputs for min, opt, max shapes of the graph inputs | ||
| submodule_inputs = partitioning.construct_submodule_inputs(submodule) | ||
| submodule_inputs = partitioning.construct_submodule_inputs( | ||
| submodule, user_symbol_bounds=user_symbol_bounds | ||
| ) | ||
|
|
||
| assert submodule_inputs is not None | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what the logic here is?