@@ -38,6 +38,8 @@ def _validate_model(model):
3838
3939
4040class JokerPrior :
41+ _sb2 = False
42+
4143 def __init__ (self , pars = None , poly_trend = 1 , v0_offsets = None , model = None ):
4244 """
4345 This class controls the prior probability distributions for the
@@ -121,7 +123,9 @@ def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
121123 # are only used to validate that the units for each parameter are
122124 # equivalent to these
123125 self ._nonlinear_equiv_units = get_nonlinear_equiv_units ()
124- self ._linear_equiv_units = get_linear_equiv_units (self .poly_trend )
126+ self ._linear_equiv_units = get_linear_equiv_units (
127+ self .poly_trend , sb2 = self ._sb2
128+ )
125129 self ._v0_offsets_equiv_units = get_v0_offsets_equiv_units (self .n_offsets )
126130 self ._all_par_unit_equiv = {
127131 ** self ._nonlinear_equiv_units ,
@@ -291,10 +295,7 @@ def __repr__(self):
291295 def __str__ (self ):
292296 return ", " .join (self .par_names )
293297
294- @deprecated_renamed_argument (
295- "random_state" , "rng" , since = "v1.3" , warning_type = DeprecationWarning
296- )
297- def sample (
298+ def _get_raw_samples (
298299 self ,
299300 size = 1 ,
300301 generate_linear = False ,
@@ -303,29 +304,6 @@ def sample(
303304 dtype = None ,
304305 ** kwargs ,
305306 ):
306- """
307- Generate random samples from the prior.
308-
309- Parameters
310- ----------
311- size : int (optional)
312- The number of samples to generate.
313- generate_linear : bool (optional)
314- Also generate samples in the linear parameters.
315- return_logprobs : bool (optional)
316- Generate the log-prior probability at the position of each sample.
317- **kwargs
318- Additional keyword arguments are passed to the
319- `~thejoker.JokerSamples` initializer.
320-
321- Returns
322- -------
323- samples : `thejoker.Jokersamples`
324- The random samples.
325-
326- """
327- from .samples import JokerSamples
328-
329307 if dtype is None :
330308 dtype = np .float64
331309
@@ -339,11 +317,6 @@ def sample(
339317 )
340318 }
341319
342- if generate_linear :
343- par_names = self .par_names
344- else :
345- par_names = list (self ._nonlinear_equiv_units .keys ())
346-
347320 # MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
348321 # init_shapes = {}
349322 # for name, par in sub_pars.items():
@@ -374,12 +347,68 @@ def sample(
374347
375348 logp .append (_logp )
376349 log_prior = np .sum (logp , axis = 0 )
350+ else :
351+ log_prior = None
377352
378353 # CONTINUED MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
379354 # for name, par in sub_pars.items():
380355 # if hasattr(par, "distribution"):
381356 # par.distribution.shape = init_shapes[name]
382357
358+ return raw_samples , sub_pars , log_prior
359+
360+ @deprecated_renamed_argument (
361+ "random_state" , "rng" , since = "v1.3" , warning_type = DeprecationWarning
362+ )
363+ def sample (
364+ self ,
365+ size = 1 ,
366+ generate_linear = False ,
367+ return_logprobs = False ,
368+ rng = None ,
369+ dtype = None ,
370+ ** kwargs ,
371+ ):
372+ """
373+ Generate random samples from the prior.
374+
375+ .. note::
376+
377+ Right now, generating samples with the prior values is slow (i.e.
378+ with ``return_logprobs=True``) because of pymc3 issues (see
379+ discussion here:
380+ https://discourse.pymc.io/t/draw-values-speed-scaling-with-transformed-variables/4076).
381+ This will hopefully be resolved in the future...
382+
383+ Parameters
384+ ----------
385+ size : int (optional)
386+ The number of samples to generate.
387+ generate_linear : bool (optional)
388+ Also generate samples in the linear parameters.
389+ return_logprobs : bool (optional)
390+ Generate the log-prior probability at the position of each sample.
391+ **kwargs
392+ Additional keyword arguments are passed to the
393+ `~thejoker.JokerSamples` initializer.
394+
395+ Returns
396+ -------
397+ samples : `thejoker.Jokersamples`
398+ The random samples.
399+
400+ """
401+ from thejoker .samples import JokerSamples
402+
403+ raw_samples , sub_pars , log_prior = self ._get_raw_samples (
404+ size , generate_linear , return_logprobs , rng , dtype , ** kwargs
405+ )
406+
407+ if generate_linear :
408+ par_names = self .par_names
409+ else :
410+ par_names = list (self ._nonlinear_equiv_units .keys ())
411+
383412 # Apply units if they are specified:
384413 prior_samples = JokerSamples (
385414 poly_trend = self .poly_trend , n_offsets = self .n_offsets , ** kwargs
@@ -448,9 +477,8 @@ def default_nonlinear_prior(P_min=None, P_max=None, s=None, model=None, pars=Non
448477
449478 if isinstance (s , pt .TensorVariable ):
450479 pars ["s" ] = pars .get ("s" , s )
451- else :
452- if not hasattr (s , "unit" ) or not s .unit .is_equivalent (u .km / u .s ):
453- raise u .UnitsError ("Invalid unit for s: must be equivalent to km/s" )
480+ elif not hasattr (s , "unit" ) or not s .unit .is_equivalent (u .km / u .s ):
481+ raise u .UnitsError ("Invalid unit for s: must be equivalent to km/s" )
454482
455483 # dictionary of parameters to return
456484 out_pars = {}
0 commit comments