Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions src/Lean/Meta/Tactic/Simp/SimpTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ register_builtin_option debug.tactic.simp.checkDefEqAttr : Bool := {
of the `defeq` attribute, and warn if it was. Note that this is a costly check."
}

register_builtin_option warning.simp.varHead : Bool := {
defValue := true
descr := "If true, warns when the head symbol of the left-hand side of a `@[simp]` theorem \
is a variable. Such lemmas are tried on every simp step, which can be slow."
}

register_builtin_option warning.simp.otherHead : Bool := {
defValue := true
descr := "If true, warns when the left-hand side of a `@[simp]` theorem is headed by a \
`.other` key in the discrimination tree (e.g. a lambda expression). Such lemmas can cause slowdowns."
}

/--
An `Origin` is an identifier for simp theorems which indicates roughly
what action the user took which lead to this theorem existing in the simp set.
Expand Down Expand Up @@ -359,12 +371,27 @@ private def checkTypeIsProp (type : Expr) : MetaM Unit :=
unless (← isProp type) do
throwError "Invalid simp theorem: Expected a proposition, but found{indentExpr type}"

private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) : MetaM (Array SimpTheoremKey × Bool) := do
private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) (checkLhs : Bool := false) : MetaM (Array SimpTheoremKey × Bool) := do
withNewMCtxDepth do
let (_, _, type) ← forallMetaTelescopeReducing type
let type ← whnfR type
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs)
| some (_, lhs, rhs) =>
let keys ← DiscrTree.mkPath lhs noIndexAtArgs
if checkLhs then
if warning.simp.varHead.get (← getOptions) && keys[0]? == some .star then
logWarning m!"Left-hand side of simp theorem has a variable as head symbol. \
This means the theorem will be tried on every simp step, which can be expensive. \
This may be acceptable for `local` or `scoped` simp lemmas.\n\
Use `set_option warning.simp.varHead false` to disable this warning."
if warning.simp.otherHead.get (← getOptions) && keys[0]? == some .other then
logWarning m!"Left-hand side of simp theorem is headed by a `.other` key in the \
discrimination tree (e.g. because it is a lambda expression). \
This theorem will be tried against all expressions that also have a `.other` key as head, \
which can cause slowdowns. \
This may be acceptable for `local` or `scoped` simp lemmas.\n\
Use `set_option warning.simp.otherHead false` to disable this warning."
pure (keys, ← isPerm lhs rhs)
| none => throwError "Unexpected kind of simp theorem{indentExpr type}"

register_builtin_option simp.rfl.checkTransparency: Bool := {
Expand All @@ -373,10 +400,10 @@ register_builtin_option simp.rfl.checkTransparency: Bool := {
}

private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr)
(post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
(post : Bool) (prio : Nat) (noIndexAtArgs : Bool) (checkLhs : Bool := false) : MetaM SimpTheorem := do
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs
let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs checkLhs
let rfl ← isRflProof proof
if rfl && simp.rfl.checkTransparency.get (← getOptions) then
forallTelescopeReducing type fun _ type => do
Expand All @@ -399,7 +426,7 @@ Creates a `SimpTheorem` from a global theorem.
Because some theorems lead to multiple `SimpTheorems` (in particular conjunctions), returns an array.
-/
def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false)
(prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) := do
(prio : Nat := eval_prio default) (checkLhs : Bool := false) : MetaM (Array SimpTheorem) := do
let cinfo ← getConstVal declName
let us := cinfo.levelParams.map mkLevelParam
let origin := .decl declName post inv
Expand All @@ -413,10 +440,10 @@ def mkSimpTheoremFromConst (declName : Name) (post := true) (inv := false)
let auxName ← mkAuxLemma (kind? := `_simp) cinfo.levelParams type val (inferRfl := true)
(forceExpose := true) -- These kinds of theorems are small and `to_additive` may need to
-- unfold them.
r := r.push <| (← do mkSimpTheoremCore origin (mkConst auxName us) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
r := r.push <| (← do mkSimpTheoremCore origin (mkConst auxName us) #[] (mkConst auxName) post prio (noIndexAtArgs := false) (checkLhs := checkLhs))
return r
else
return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
return #[← withoutExporting do mkSimpTheoremCore origin (mkConst declName us) #[] (mkConst declName) post prio (noIndexAtArgs := false) (checkLhs := checkLhs)]

def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do
if simpThm.proof.isConst && simpThm.levelParams.isEmpty then
Expand Down Expand Up @@ -670,7 +697,7 @@ def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems :=
Adds a simp theorem to a simp extension
-/
def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
let simpThms ← withExporting (isExporting := attrKind != .local && !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio
let simpThms ← withExporting (isExporting := attrKind != .local && !isPrivateName declName) do mkSimpTheoremFromConst declName post inv prio (checkLhs := true)
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind

Expand Down
98 changes: 98 additions & 0 deletions tests/elab/simpVarHead.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
section
theorem broken1 (x : Nat) : x = x + 0 := by simp

/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
attribute [local simp] broken1
end

section
theorem broken2 (x : Nat) : x + 0 = x := by simp

-- Works in the usual direction
attribute [local simp] broken2

-- Breaks in the other direction
/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
attribute [local simp ←] broken2
end

theorem broken3 (x : Nat → Nat) : x 0 = x 0 + 0 := by simp

/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
attribute [simp] broken3

theorem broken4 (x : Nat → Nat) : x 0 + 0 = x 0 := by rfl

/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
attribute [simp ←] broken4

section
/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
@[local simp] theorem broken5 (x : Prop) : x ↔ x ∧ True := by simp
end

theorem broken6 (x : Prop → Prop) : x False ∧ True ↔ x False := by simp

/--
warning: Left-hand side of simp theorem has a variable as head symbol. This means the theorem will be tried on every simp step, which can be expensive. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.varHead false` to disable this warning.
-/
#guard_msgs in
attribute [simp ←] broken6

-- Abbrev as head symbol should not trigger the warning (mathlib false positive regression test)
structure Foo where
val : Nat

abbrev Foo.get (f : Foo) : Nat := f.val

theorem Foo.get_mk (n : Nat) : (Foo.mk n).get = n := rfl

#guard_msgs in
attribute [simp] Foo.get_mk

-- `.other` head key: lambda as LHS head
theorem broken8 : (fun x : Nat => x + 0) = (fun x => x) := by ext; omega

/--
warning: Left-hand side of simp theorem is headed by a `.other` key in the discrimination tree (e.g. because it is a lambda expression). This theorem will be tried against all expressions that also have a `.other` key as head, which can cause slowdowns. This may be acceptable for `local` or `scoped` simp lemmas.
Use `set_option warning.simp.otherHead false` to disable this warning.
-/
#guard_msgs in
attribute [simp] broken8

-- Option to disable the `.other` head warning
section
#guard_msgs in
set_option warning.simp.otherHead false in
attribute [local simp] broken8
end

-- Option to disable the warning
section
theorem broken7 (x : Nat) : x = x + 0 := by omega

#guard_msgs in
set_option warning.simp.varHead false in
attribute [local simp] broken7
end
Loading