Skip to content

Commit bc7c02b

Browse files
authored
Merge pull request #310 from MilesCranmer/dimensionless-constants
Add option to force dimensionless constants
2 parents 1203a81 + 571f85b commit bc7c02b

7 files changed

Lines changed: 95 additions & 12 deletions

File tree

docs/src/examples.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ which can cancel out other units in the expression.) For example,
225225
would indicate that the expression is dimensionally consistent, with
226226
a constant `"2.6353e-22[m s⁻²]"`.
227227

228+
Note that you can also search for dimensionless units by settings
229+
`dimensionless_constants_only` to `true`.
230+
228231
## 7. Additional features
229232

230233
For the many other features available in SymbolicRegression.jl,

src/DimensionalAnalysis.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,18 @@ end
116116

117117
# Define dimensionally-aware evaluation routine:
118118
@inline function deg0_eval(
119-
x::AbstractVector{T}, x_units::Vector{Q}, t::AbstractExpressionNode{T}
119+
x::AbstractVector{T},
120+
x_units::Vector{Q},
121+
t::AbstractExpressionNode{T},
122+
allow_wildcards::Bool,
120123
) where {T,R,Q<:AbstractQuantity{T,R}}
121-
t.constant && return WildcardQuantity{Q}(Quantity(t.val, R), true, false)
122-
return WildcardQuantity{Q}(
123-
(@inbounds x[t.feature]) * (@inbounds x_units[t.feature]), false, false
124-
)
124+
if t.constant
125+
return WildcardQuantity{Q}(Quantity(t.val, R), allow_wildcards, false)
126+
else
127+
return WildcardQuantity{Q}(
128+
(@inbounds x[t.feature]) * (@inbounds x_units[t.feature]), false, false
129+
)
130+
end
125131
end
126132
@inline function deg1_eval(
127133
op::F, l::W
@@ -149,16 +155,26 @@ end
149155
end
150156

151157
function violates_dimensional_constraints_dispatch(
152-
tree::AbstractExpressionNode{T}, x_units::Vector{Q}, x::AbstractVector{T}, operators
158+
tree::AbstractExpressionNode{T},
159+
x_units::Vector{Q},
160+
x::AbstractVector{T},
161+
operators,
162+
allow_wildcards,
153163
) where {T,Q<:AbstractQuantity{T}}
154164
if tree.degree == 0
155-
return deg0_eval(x, x_units, tree)::WildcardQuantity{Q}
165+
return deg0_eval(x, x_units, tree, allow_wildcards)::WildcardQuantity{Q}
156166
elseif tree.degree == 1
157-
l = violates_dimensional_constraints_dispatch(tree.l, x_units, x, operators)
167+
l = violates_dimensional_constraints_dispatch(
168+
tree.l, x_units, x, operators, allow_wildcards
169+
)
158170
return deg1_eval((@inbounds operators.unaops[tree.op]), l)::WildcardQuantity{Q}
159171
else
160-
l = violates_dimensional_constraints_dispatch(tree.l, x_units, x, operators)
161-
r = violates_dimensional_constraints_dispatch(tree.r, x_units, x, operators)
172+
l = violates_dimensional_constraints_dispatch(
173+
tree.l, x_units, x, operators, allow_wildcards
174+
)
175+
r = violates_dimensional_constraints_dispatch(
176+
tree.r, x_units, x, operators, allow_wildcards
177+
)
162178
return deg2_eval((@inbounds operators.binops[tree.op]), l, r)::WildcardQuantity{Q}
163179
end
164180
end
@@ -186,8 +202,9 @@ function violates_dimensional_constraints(
186202
if X_units === nothing && y_units === nothing
187203
return false
188204
end
205+
allow_wildcards = !(options.dimensionless_constants_only)
189206
dimensional_output = violates_dimensional_constraints_dispatch(
190-
tree, X_units, x, options.operators
207+
tree, X_units, x, options.operators, allow_wildcards
191208
)
192209
# ^ Eventually do this with map_treereduce. However, right now it seems
193210
# like we are passing around too many arguments, which slows things down.

src/InterfaceDynamicExpressions.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ Convert an equation to a string.
170170
tree,
171171
options.operators;
172172
f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
173-
f_constant=(val,) -> string_constant(val, vprecision, WILDCARD_UNIT_STRING),
173+
f_constant=let
174+
unit_placeholder =
175+
options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING
176+
(val,) -> string_constant(val, vprecision, unit_placeholder)
177+
end,
174178
variable_names=display_variable_names,
175179
kws...,
176180
)

src/Options.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
277277
punished.
278278
- `dimensional_constraint_penalty`: An additive factor if the dimensional
279279
constraint is violated.
280+
- `dimensionless_constants_only`: Whether to only allow dimensionless
281+
constants.
280282
- `use_frequency`: Whether to use a parsimony that adapts to the
281283
relative proportion of equations at each complexity; this will
282284
ensure that there are a balanced number of equations considered
@@ -387,6 +389,7 @@ function Options end
387389
complexity_of_variables::Union{Nothing,Real}=nothing,
388390
parsimony::Real=0.0032,
389391
dimensional_constraint_penalty::Union{Nothing,Real}=nothing,
392+
dimensionless_constants_only::Bool=false,
390393
alpha::Real=0.100000,
391394
maxsize::Integer=20,
392395
maxdepth::Union{Nothing,Integer}=nothing,
@@ -780,6 +783,7 @@ function Options end
780783
tournament_selection_weights,
781784
parsimony,
782785
dimensional_constraint_penalty,
786+
dimensionless_constants_only,
783787
alpha,
784788
maxsize,
785789
maxdepth,

src/OptionsStruct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct Options{
5959
tournament_selection_weights::W
6060
parsimony::Float32
6161
dimensional_constraint_penalty::Union{Float32,Nothing}
62+
dimensionless_constants_only::Bool
6263
alpha::Float32
6364
maxsize::Int
6465
maxdepth::Int

test/test_units.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using DynamicQuantities:
3333
using Test
3434
using MLJBase: MLJBase as MLJ
3535
using MLJModelInterface: MLJModelInterface as MMI
36+
include("utils.jl")
3637

3738
custom_op(x, y) = x + y
3839

@@ -369,6 +370,57 @@ end
369370
X_sym_units=dataset2.X_sym_units,
370371
y_sym_units=dataset2.y_sym_units,
371372
) == "x₅[5.0 m] * 3.2[?]"
373+
374+
# With dimensionless_constants_only, it will not print the [?]:
375+
options = Options(;
376+
binary_operators=[+, -, *, /],
377+
unary_operators=[cos, sin],
378+
dimensionless_constants_only=true,
379+
)
380+
@test string_tree(
381+
x5 * 3.2,
382+
options;
383+
raw=false,
384+
display_variable_names=dataset2.display_variable_names,
385+
X_sym_units=dataset2.X_sym_units,
386+
y_sym_units=dataset2.y_sym_units,
387+
) == "x₅[5.0 m] * 3.2"
388+
end
389+
390+
@testset "Dimensionless constants" begin
391+
options = Options(;
392+
binary_operators=[+, -, *, /, square, cube],
393+
unary_operators=[cos, sin],
394+
dimensionless_constants_only=true,
395+
)
396+
X = randn(5, 64)
397+
y = randn(64)
398+
dataset = Dataset(X, y; X_units=["m^3", "km/s", "kg", "hr", "1"], y_units="kg")
399+
x1, x2, x3, x4, x5 = [Node(Float64; feature=i) for i in 1:5]
400+
401+
dimensionally_valid_equations = [
402+
1.5 * x1 / (cube(x2) * cube(x4)) * x3, x3, (square(x3) / x3) + x3
403+
]
404+
for tree in dimensionally_valid_equations
405+
onfail(@test !violates_dimensional_constraints(tree, dataset, options)) do
406+
@warn "Failed on" tree
407+
end
408+
end
409+
dimensionally_invalid_equations = [Node(Float64; val=1.5), 1.5 * x1, x3 - 1.0 * x1]
410+
for tree in dimensionally_invalid_equations
411+
onfail(@test violates_dimensional_constraints(tree, dataset, options)) do
412+
@warn "Failed on" tree
413+
end
414+
end
415+
# But, all of these would be fine if we allow dimensionless constants:
416+
let
417+
options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
418+
for tree in dimensionally_invalid_equations
419+
onfail(@test !violates_dimensional_constraints(tree, dataset, options)) do
420+
@warn "Failed on" tree
421+
end
422+
end
423+
end
372424
end
373425

374426
@testset "Miscellaneous" begin

test/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
onfail(f, ::Test.Fail) = f()
2+
onfail(_, ::Test.Pass) = nothing

0 commit comments

Comments
 (0)