Skip to content

Commit 87eb325

Browse files
authored
Merge pull request #226 from MilesCranmer/mlj-integration
MLJ Integration
2 parents 6e08d42 + be60f13 commit 87eb325

11 files changed

Lines changed: 841 additions & 100 deletions

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicRegression"
22
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.19.1"
4+
version = "0.20.0"
55

66
[deps]
77
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -10,6 +10,8 @@ DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
1010
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1111
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1212
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
13+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
14+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1315
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1416
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1517
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -33,6 +35,8 @@ DynamicExpressions = "0.9"
3335
JSON3 = "1"
3436
LineSearches = "7"
3537
LossFunctions = "0.6, 0.7, 0.8, 0.10"
38+
MLJModelInterface = "1.5, 1.6, 1.7, 1.8"
39+
MacroTools = "0.4, 0.5"
3640
Optim = "0.19, 1.1"
3741
Pkg = "1"
3842
PrecompileTools = "1"
@@ -47,10 +51,12 @@ julia = "1.6"
4751
[extras]
4852
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4953
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
54+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
55+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
5056
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
5157
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
5258
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5359
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5460

5561
[targets]
56-
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "SymbolicUtils", "Zygote"]
62+
test = ["Test", "SafeTestsets", "ForwardDiff", "LinearAlgebra", "MLJBase", "MLJTestInterface", "SymbolicUtils", "Zygote"]

src/HallOfFame.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,44 @@ function string_dominating_pareto_curve(
161161
return output
162162
end
163163

164+
function format_hall_of_fame(
165+
hof::HallOfFame{T,L}, options, baseline_loss::L
166+
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
167+
ZERO_POINT = L(1e-10)
168+
169+
dominating = calculate_pareto_frontier(hof)
170+
cur_loss = baseline_loss
171+
last_loss = cur_loss
172+
last_complexity = 0
173+
174+
trees = [member.tree for member in dominating]
175+
losses = [member.loss for member in dominating]
176+
complexities = [compute_complexity(member, options) for member in dominating]
177+
scores = Array{L}(undef, length(dominating))
178+
179+
for i in 1:length(dominating)
180+
complexity = complexities[i]
181+
cur_loss = losses[i]
182+
delta_c = complexity - last_complexity
183+
delta_l_mse = log(abs(cur_loss / last_loss) + ZERO_POINT)
184+
185+
scores[i] = -delta_l_mse / delta_c
186+
last_loss = cur_loss
187+
last_complexity = complexity
188+
end
189+
return (; trees=trees, scores=scores, losses=losses, complexities=complexities)
190+
end
191+
function format_hall_of_fame(
192+
hof::AH, options, baseline_loss
193+
) where {T,L,H<:HallOfFame{T,L},AH<:AbstractVector{H}}
194+
outs = [format_hall_of_fame(h, options, baseline_loss) for h in hof]
195+
return (;
196+
trees=[out.trees for out in outs],
197+
scores=[out.scores for out in outs],
198+
losses=[out.losses for out in outs],
199+
complexities=[out.complexities for out in outs],
200+
)
201+
end
202+
# TODO: Re-use this in `string_dominating_pareto_curve`
203+
164204
end

0 commit comments

Comments
 (0)