diff --git a/tests/flex_moe_micro_bench.py b/tests/flex_moe_micro_bench.py index 9523c6707..d2bf5a67c 100644 --- a/tests/flex_moe_micro_bench.py +++ b/tests/flex_moe_micro_bench.py @@ -43,7 +43,7 @@ def main(): p.add_argument("--compile_mode", choices=["off", "walker", "walker_fullgraph"], default="off", help="wrap call_moe_model_with_flex_kwargs in torch.compile") - p.add_argument("--compile_opts", choices=["stock", "unsloth_O3", "inference_freeze"], + p.add_argument("--compile_opts", choices=["stock", "unsloth_O3", "inference_freeze", "coord_descent"], default="stock", help="which inductor / dynamo options profile to apply before compile") p.add_argument("--explain", action="store_true", @@ -76,6 +76,12 @@ def main(): _dc.capture_scalar_outputs = True _dc.capture_dynamic_output_shape_ops = True print("[micro] inductor/dynamo options: unsloth_O3") + elif args.compile_opts == "coord_descent": + # Just ``coordinate_descent_tuning = True`` — fast compile, small + # fusion upside. + import torch._inductor.config as _ic + _ic.coordinate_descent_tuning = True + print("[micro] inductor options: coord_descent only") elif args.compile_opts == "inference_freeze": # Inference-friendly: constant-fold weights via freezing=True. # Only safe when the model weights won't be updated after compile