Skip to content

Commit c4d9573

Browse files
authored
feat: warn when simp theorem LHS has variable or unrecognized head symbol (#13325)
This PR adds warnings when registering `@[simp]` theorems whose left-hand side has a problematic head symbol in the discrimination tree: - **Variable head** (`.star` key): The theorem will be tried on every `simp` step, which can be expensive. The warning notes this may be acceptable for `local` or `scoped` simp lemmas. Controlled by `warning.simp.varHead` (default: `true`). - **Unrecognized head** (`.other` key, e.g. a lambda expression): The theorem is unlikely to ever be applied by `simp`. Controlled by `warning.simp.otherHead` (default: `true`).
1 parent 9db52c7 commit c4d9573

File tree

3 files changed

+136
-9
lines changed

3 files changed

+136
-9
lines changed

src/Lean/Meta/Tactic/Simp/SimpTheorems.lean

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ register_builtin_option debug.tactic.simp.checkDefEqAttr : Bool := {
5050
of the `defeq` attribute, and warn if it was. Note that this is a costly check."
5151
}
5252

53+
register_builtin_option warning.simp.varHead : Bool := {
54+
defValue := false
55+
descr := "If true, warns when the head symbol of the left-hand side of a `@[simp]` theorem \
56+
is a variable. Such lemmas are tried on every simp step, which can be slow."
57+
}
58+
59+
register_builtin_option warning.simp.otherHead : Bool := {
60+
defValue := true
61+
descr := "If true, warns when the left-hand side of a `@[simp]` theorem is headed by a \
62+
`.other` key in the discrimination tree (e.g. a lambda expression). Such lemmas can cause slowdowns."
63+
}
64+
5365
/--
5466
An `Origin` is an identifier for simp theorems which indicates roughly
5567
what action the user took which lead to this theorem existing in the simp set.
@@ -359,12 +371,27 @@ private def checkTypeIsProp (type : Expr) : MetaM Unit :=
359371
unless (← isProp type) do
360372
throwError "Invalid simp theorem: Expected a proposition, but found{indentExpr type}"
361373

362-
private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) : MetaM (Array SimpTheoremKey × Bool) := do
374+
private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) (checkLhs : Bool := false) : MetaM (Array SimpTheoremKey × Bool) := do
363375
withNewMCtxDepth do
364376
let (_, _, type) ← forallMetaTelescopeReducing type
365377
let type ← whnfR type
366378
match type.eq? with
367-
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs)
379+
| some (_, lhs, rhs) =>
380+
let keys ← DiscrTree.mkPath lhs noIndexAtArgs
381+
if checkLhs then
382+
if warning.simp.varHead.get (← getOptions) && keys[0]? == some .star then
383+
logWarning m!"Left-hand side of simp theorem has a variable as head symbol. \
384+
This means the theorem will be tried on every simp step, which can be expensive. \
385+
This may be acceptable for `local` or `scoped` simp lemmas.\n\
386+
Use `set_option warning.simp.varHead false` to disable this warning."
387+
if warning.simp.otherHead.get (← getOptions) && keys[0]? == some .other then
388+
logWarning m!"Left-hand side of simp theorem is headed by a `.other` key in the \
389+
discrimination tree (e.g. because it is a lambda expression). \
390+
This theorem will be tried against all expressions that also have a `.other` key as head, \
391+
which can cause slowdowns. \
392+
This may be acceptable for `local` or `scoped` simp lemmas.\n\
393+
Use `set_option warning.simp.otherHead false` to disable this warning."
394+
pure (keys, ← isPerm lhs rhs)
368395
| none => throwError "Unexpected kind of simp theorem{indentExpr type}"
369396

370397
register_builtin_option simp.rfl.checkTransparency: Bool := {
@@ -373,10 +400,10 @@ register_builtin_option simp.rfl.checkTransparency: Bool := {
373400
}
374401

375402
private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr)
376-
(post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
403+
(post : Bool) (prio : Nat) (noIndexAtArgs : Bool) (checkLhs : Bool := false): MetaM SimpTheorem := do
377404
assert! origin != .fvar ⟨.anonymous⟩
378405
let type ← instantiateMVars (← inferType e)
379-
let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs
406+
let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs checkLhs
380407
let rfl ← isRflProof proof
381408
if rfl && simp.rfl.checkTransparency.get (← getOptions) then
382409
forallTelescopeReducing type fun _ type => do
@@ -399,7 +426,7 @@ Creates a `SimpTheorem` from a global theorem.
399426
Because some theorems lead to multiple `SimpTheorems` (in particular conjunctions), returns an array.
400427
-/
401428
def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false)
402-
(prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) := do
429+
(prio : Nat := eval_prio default) (checkLhs : Bool := false) : MetaM (Array SimpTheorem) := do
403430
let cinfo ← getConstVal declName
404431
let us := cinfo.levelParams.map mkLevelParam
405432
let origin := .decl declName post inv
@@ -413,10 +440,10 @@ def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false)
413440
let auxName ← mkAuxLemma (kind? := `_simp) cinfo.levelParams type val (inferRfl := true)
414441
(forceExpose := true) -- These kinds of theorems are small and `to_additive` may need to
415442
-- unfold them.
416-
r := r.push <| (← do mkSimpTheoremCore origin (mkConst auxName us) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
443+
r := r.push <| (← do mkSimpTheoremCore origin (mkConst auxName us) #[] (mkConst auxName) post prio (noIndexAtArgs := false) (checkLhs := checkLhs))
417444
return r
418445
else
419-
return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
446+
return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false) (checkLhs := checkLhs)]
420447

421448
def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do
422449
if simpThm.proof.isConst && simpThm.levelParams.isEmpty then
@@ -670,7 +697,7 @@ def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems :=
670697
Adds a simp theorem to a simp extension
671698
-/
672699
def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
673-
let simpThms ← withExporting (isExporting := attrKind != .local && !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio
700+
let simpThms ← withExporting (isExporting := attrKind != .local && !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio (checkLhs := true)
674701
for simpThm in simpThms do
675702
ext.add (SimpEntry.thm simpThm) attrKind
676703

stage0/src/stdlib_flags.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "util/options.h"
2-
2+
// please update stage0
33
namespace lean {
44
options get_default_options() {
55
options opts;

tests/elab/simpVarHead.lean

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
set_option warning.simp.varHead true
2+
3+
section
4+
theorem broken1 (x : Nat) : x = x + 0 := by simp
5+
6+
/--
7+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
8+
Use `set_option warning.simp.varHead false` to disable this warning.
9+
-/
10+
#guard_msgs in
11+
attribute [local simp] broken1
12+
end
13+
14+
section
15+
theorem broken2 (x : Nat) : x + 0 = x := by simp
16+
17+
-- Works in the usual direction
18+
attribute [local simp] broken2
19+
20+
-- Breaks in the other direction
21+
/--
22+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
23+
Use `set_option warning.simp.varHead false` to disable this warning.
24+
-/
25+
#guard_msgs in
26+
attribute [local simp ←] broken2
27+
end
28+
29+
theorem broken3 (x : Nat → Nat) : x 0 = x 0 + 0 := by simp
30+
31+
/--
32+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
33+
Use `set_option warning.simp.varHead false` to disable this warning.
34+
-/
35+
#guard_msgs in
36+
attribute [simp] broken3
37+
38+
theorem broken4 (x : Nat → Nat) : x 0 + 0 = x 0 := by rfl
39+
40+
/--
41+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
42+
Use `set_option warning.simp.varHead false` to disable this warning.
43+
-/
44+
#guard_msgs in
45+
attribute [simp ←] broken4
46+
47+
section
48+
/--
49+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
50+
Use `set_option warning.simp.varHead false` to disable this warning.
51+
-/
52+
#guard_msgs in
53+
@[local simp] theorem broken5 (x : Prop) : x ↔ x ∧ True := by simp
54+
end
55+
56+
theorem broken6 (x : PropProp) : x False ∧ True ↔ x False := by simp
57+
58+
/--
59+
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
60+
Use `set_option warning.simp.varHead false` to disable this warning.
61+
-/
62+
#guard_msgs in
63+
attribute [simp ←] broken6
64+
65+
-- Abbrev as head symbol should not trigger the warning (mathlib false positive regression test)
66+
structure Foo where
67+
val : Nat
68+
69+
abbrev Foo.get (f : Foo) : Nat := f.val
70+
71+
theorem Foo.get_mk (n : Nat) : (Foo.mk n).get = n := rfl
72+
73+
#guard_msgs in
74+
attribute [simp] Foo.get_mk
75+
76+
-- `.other` head key: lambda as LHS head
77+
theorem broken8 : (fun x : Nat => x + 0) = (fun x => x) := by ext; omega
78+
79+
/--
80+
warning: Left-hand side of simp theorem is headed by a `.other` key in the discrimination tree (e.g. because it is a lambda expression). This theorem will be tried against all expressions that also have a `.other` key as head, which can cause slowdowns. This may be acceptable for `local` or `scoped` simp lemmas.
81+
Use `set_option warning.simp.otherHead false` to disable this warning.
82+
-/
83+
#guard_msgs in
84+
attribute [simp] broken8
85+
86+
-- Option to disable the `.other` head warning
87+
section
88+
#guard_msgs in
89+
set_option warning.simp.otherHead false in
90+
attribute [local simp] broken8
91+
end
92+
93+
-- Option to disable the warning
94+
section
95+
theorem broken7 (x : Nat) : x = x + 0 := by omega
96+
97+
#guard_msgs in
98+
set_option warning.simp.varHead false in
99+
attribute [local simp] broken7
100+
end

0 commit comments

Comments
 (0)