Skip to content

Commit 43c998d

Browse files
committed
merge resolutions
1 parent f3f5f15 commit 43c998d

File tree

11 files changed

+872
-59
lines changed

11 files changed

+872
-59
lines changed

thejoker/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
"phase_coverage_per_period",
3232
]
3333

34+
# SB2:
35+
from .thejoker_sb2 import *
36+
from .prior_sb2 import JokerSB2Prior
37+
3438

3539
__bibtex__ = __citation__ = """@ARTICLE{thejoker,
3640
author = {{Price-Whelan}, Adrian M. and {Hogg}, David W. and
@@ -55,3 +59,12 @@
5559
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
5660
}
5761
"""
62+
63+
__all__ = [
64+
'TheJoker',
65+
'RVData',
66+
'JokerSamples',
67+
'JokerPrior',
68+
'plot_rv_curves',
69+
'TheJokerSB2'
70+
]

thejoker/data.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ class RVData:
3333
(days). Set to ``False`` to disable subtracting the reference time.
3434
clean : bool (optional)
3535
Filter out any NaN or Inf data points.
36+
sort : bool (optional)
37+
Whether or not to sort on time.
3638
3739
"""
3840

3941
@u.quantity_input(rv=u.km / u.s, rv_err=[u.km / u.s, (u.km / u.s) ** 2])
40-
def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
42+
def __init__(self, t, rv, rv_err, t_ref=None, clean=True, sort=True):
4143
# For speed, time is saved internally as BMJD:
4244
if isinstance(t, Time):
4345
_t_bmjd = t.tcb.mjd
@@ -94,15 +96,16 @@ def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
9496
else:
9597
self.rv_err = self.rv_err[idx]
9698

97-
# sort on times
98-
idx = self._t_bmjd.argsort()
99-
self._t_bmjd = self._t_bmjd[idx]
100-
self.rv = self.rv[idx]
101-
if self._has_cov:
102-
self.rv_err = self.rv_err[idx]
103-
self.rv_err = self.rv_err[:, idx]
104-
else:
105-
self.rv_err = self.rv_err[idx]
99+
if sort:
100+
# sort on times
101+
idx = self._t_bmjd.argsort()
102+
self._t_bmjd = self._t_bmjd[idx]
103+
self.rv = self.rv[idx]
104+
if self._has_cov:
105+
self.rv_err = self.rv_err[idx]
106+
self.rv_err = self.rv_err[:, idx]
107+
else:
108+
self.rv_err = self.rv_err[idx]
106109

107110
if t_ref is False:
108111
self.t_ref = None

thejoker/likelihood_helpers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,13 @@ def marginal_ln_likelihood_inmem(joker_helper, prior_samples_batch):
6666
return np.array(ll)
6767

6868

69-
def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_samples=1):
70-
from .samples import JokerSamples
69+
def make_full_samples_inmem(
70+
joker_helper, prior_samples_batch, rng, n_linear_samples=1, SamplesCls=None
71+
):
72+
if SamplesCls is None:
73+
from .samples import JokerSamples
74+
75+
SamplesCls = JokerSamples
7176

7277
if prior_samples_batch.dtype != np.float64:
7378
prior_samples_batch = prior_samples_batch.astype(np.float64)
@@ -77,7 +82,7 @@ def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_sam
7782
)
7883

7984
# unpack the raw samples
80-
samples = JokerSamples.unpack(
85+
samples = SamplesCls.unpack(
8186
raw_samples,
8287
joker_helper.internal_units,
8388
t_ref=joker_helper.data.t_ref,
@@ -96,6 +101,7 @@ def rejection_sample_inmem(
96101
max_posterior_samples=None,
97102
n_linear_samples=1,
98103
return_all_logprobs=False,
104+
SamplesCls=None,
99105
):
100106
if max_posterior_samples is None:
101107
max_posterior_samples = len(prior_samples_batch)
@@ -114,6 +120,7 @@ def rejection_sample_inmem(
114120
prior_samples_batch[good_samples_idx],
115121
rng,
116122
n_linear_samples=n_linear_samples,
123+
SamplesCls=SamplesCls,
117124
)
118125

119126
if ln_prior is not None and ln_prior is not False:
@@ -136,6 +143,7 @@ def iterative_rejection_inmem(
136143
init_batch_size=None,
137144
growth_factor=128,
138145
n_linear_samples=1,
146+
SamplesCls=None,
139147
):
140148
n_total_samples = len(prior_samples_batch)
141149

@@ -219,6 +227,7 @@ def iterative_rejection_inmem(
219227
prior_samples_batch[full_samples_idx],
220228
rng,
221229
n_linear_samples=n_linear_samples,
230+
SamplesCls=SamplesCls,
222231
)
223232

224233
# FIXME: copy-pasted from function above

thejoker/multiproc_helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def make_full_samples(
155155
samples_idx,
156156
n_linear_samples=1,
157157
n_batches=None,
158+
SamplesCls=JokerSamples,
158159
):
159160
task_args = (prior_samples_file, joker_helper, n_linear_samples)
160161
results = run_worker(
@@ -164,14 +165,14 @@ def make_full_samples(
164165
task_args=task_args,
165166
n_batches=n_batches,
166167
samples_idx=samples_idx,
167-
rng=rng,
168+
random_state=rng,
168169
)
169170

170171
# Concatenate all of the raw samples arrays
171172
raw_samples = np.concatenate(results)
172173

173174
# unpack the raw samples
174-
samples = JokerSamples.unpack(
175+
samples = SamplesCls.unpack(
175176
raw_samples,
176177
joker_helper.internal_units,
177178
t_ref=joker_helper.data.t_ref,
@@ -195,6 +196,7 @@ def rejection_sample_helper(
195196
n_batches=None,
196197
randomize_prior_order=False,
197198
return_all_logprobs=False,
199+
SamplesCls=None,
198200
):
199201
# Total number of samples in the cache:
200202
with tb.open_file(prior_samples_file, mode="r") as f:
@@ -271,6 +273,7 @@ def rejection_sample_helper(
271273
full_samples_idx,
272274
n_linear_samples=n_linear_samples,
273275
n_batches=n_batches,
276+
SamplesCls=SamplesCls,
274277
)
275278

276279
if return_logprobs:
@@ -300,6 +303,7 @@ def iterative_rejection_helper(
300303
return_logprobs=False,
301304
n_batches=None,
302305
randomize_prior_order=False,
306+
SamplesCls=None,
303307
):
304308
# Total number of samples in the cache:
305309
with tb.open_file(prior_samples_file, mode="r") as f:
@@ -412,6 +416,7 @@ def iterative_rejection_helper(
412416
full_samples_idx,
413417
n_linear_samples=n_linear_samples,
414418
n_batches=n_batches,
419+
SamplesCls=SamplesCls,
415420
)
416421

417422
# FIXME: copy-pasted from function above

thejoker/prior.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def _validate_model(model):
3838

3939

4040
class JokerPrior:
41+
_sb2 = False
42+
4143
def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
4244
"""
4345
This class controls the prior probability distributions for the
@@ -121,7 +123,9 @@ def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
121123
# are only used to validate that the units for each parameter are
122124
# equivalent to these
123125
self._nonlinear_equiv_units = get_nonlinear_equiv_units()
124-
self._linear_equiv_units = get_linear_equiv_units(self.poly_trend)
126+
self._linear_equiv_units = get_linear_equiv_units(
127+
self.poly_trend, sb2=self._sb2
128+
)
125129
self._v0_offsets_equiv_units = get_v0_offsets_equiv_units(self.n_offsets)
126130
self._all_par_unit_equiv = {
127131
**self._nonlinear_equiv_units,
@@ -291,10 +295,7 @@ def __repr__(self):
291295
def __str__(self):
292296
return ", ".join(self.par_names)
293297

294-
@deprecated_renamed_argument(
295-
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
296-
)
297-
def sample(
298+
def _get_raw_samples(
298299
self,
299300
size=1,
300301
generate_linear=False,
@@ -303,29 +304,6 @@ def sample(
303304
dtype=None,
304305
**kwargs,
305306
):
306-
"""
307-
Generate random samples from the prior.
308-
309-
Parameters
310-
----------
311-
size : int (optional)
312-
The number of samples to generate.
313-
generate_linear : bool (optional)
314-
Also generate samples in the linear parameters.
315-
return_logprobs : bool (optional)
316-
Generate the log-prior probability at the position of each sample.
317-
**kwargs
318-
Additional keyword arguments are passed to the
319-
`~thejoker.JokerSamples` initializer.
320-
321-
Returns
322-
-------
323-
samples : `thejoker.Jokersamples`
324-
The random samples.
325-
326-
"""
327-
from .samples import JokerSamples
328-
329307
if dtype is None:
330308
dtype = np.float64
331309

@@ -339,11 +317,6 @@ def sample(
339317
)
340318
}
341319

342-
if generate_linear:
343-
par_names = self.par_names
344-
else:
345-
par_names = list(self._nonlinear_equiv_units.keys())
346-
347320
# MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
348321
# init_shapes = {}
349322
# for name, par in sub_pars.items():
@@ -374,12 +347,68 @@ def sample(
374347

375348
logp.append(_logp)
376349
log_prior = np.sum(logp, axis=0)
350+
else:
351+
log_prior = None
377352

378353
# CONTINUED MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
379354
# for name, par in sub_pars.items():
380355
# if hasattr(par, "distribution"):
381356
# par.distribution.shape = init_shapes[name]
382357

358+
return raw_samples, sub_pars, log_prior
359+
360+
@deprecated_renamed_argument(
361+
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
362+
)
363+
def sample(
364+
self,
365+
size=1,
366+
generate_linear=False,
367+
return_logprobs=False,
368+
rng=None,
369+
dtype=None,
370+
**kwargs,
371+
):
372+
"""
373+
Generate random samples from the prior.
374+
375+
.. note::
376+
377+
Right now, generating samples with the prior values is slow (i.e.
378+
with ``return_logprobs=True``) because of pymc3 issues (see
379+
discussion here:
380+
https://discourse.pymc.io/t/draw-values-speed-scaling-with-transformed-variables/4076).
381+
This will hopefully be resolved in the future...
382+
383+
Parameters
384+
----------
385+
size : int (optional)
386+
The number of samples to generate.
387+
generate_linear : bool (optional)
388+
Also generate samples in the linear parameters.
389+
return_logprobs : bool (optional)
390+
Generate the log-prior probability at the position of each sample.
391+
**kwargs
392+
Additional keyword arguments are passed to the
393+
`~thejoker.JokerSamples` initializer.
394+
395+
Returns
396+
-------
397+
samples : `thejoker.Jokersamples`
398+
The random samples.
399+
400+
"""
401+
from thejoker.samples import JokerSamples
402+
403+
raw_samples, sub_pars, log_prior = self._get_raw_samples(
404+
size, generate_linear, return_logprobs, rng, dtype, **kwargs
405+
)
406+
407+
if generate_linear:
408+
par_names = self.par_names
409+
else:
410+
par_names = list(self._nonlinear_equiv_units.keys())
411+
383412
# Apply units if they are specified:
384413
prior_samples = JokerSamples(
385414
poly_trend=self.poly_trend, n_offsets=self.n_offsets, **kwargs
@@ -448,9 +477,8 @@ def default_nonlinear_prior(P_min=None, P_max=None, s=None, model=None, pars=Non
448477

449478
if isinstance(s, pt.TensorVariable):
450479
pars["s"] = pars.get("s", s)
451-
else:
452-
if not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
453-
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")
480+
elif not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
481+
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")
454482

455483
# dictionary of parameters to return
456484
out_pars = {}

thejoker/prior_helpers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,19 @@ def validate_poly_trend(poly_trend):
3131
return poly_trend, vtrend_names
3232

3333

34-
def get_linear_equiv_units(poly_trend):
34+
def get_linear_equiv_units(poly_trend, sb2=False):
3535
poly_trend, v_names = validate_poly_trend(poly_trend)
36-
return {
37-
'K': u.m/u.s,
38-
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
39-
}
36+
if sb2:
37+
return {
38+
'K1': u.m/u.s,
39+
'K2': u.m/u.s,
40+
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
41+
}
42+
else:
43+
return {
44+
'K': u.m/u.s,
45+
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
46+
}
4047

4148

4249
def validate_sigma_v(sigma_v, poly_trend, v_names):

0 commit comments

Comments
 (0)