1515import itertools
1616import json
1717import logging
18+ import random
1819import sys
1920from dataclasses import asdict
2021from datetime import datetime
@@ -247,17 +248,26 @@ def optimize_strategy(
247248 stock_data : pd .DataFrame ,
248249 benchmark : pd .DataFrame ,
249250 n_folds : int = 5 ,
251+ max_samples : int = 100 ,
250252) -> List [Dict [str , Any ]]:
251- """Run walk-forward grid search for a strategy."""
253+ """Run walk-forward random search for a strategy."""
252254 all_dates = sorted (stock_data ["timestamp" ].unique ())
253255 date_index = pd .DatetimeIndex (all_dates )
254256 splits = make_walk_forward_splits (date_index , n_folds )
255257
256- combos = expand_grid (grid )
257- logger .info (
258- "%s: %d parameter combos × %d folds = %d trials" ,
259- name , len (combos ), len (splits ), len (combos ) * len (splits ),
260- )
258+ all_combos = expand_grid (grid )
259+ if len (all_combos ) > max_samples :
260+ combos = random .sample (all_combos , max_samples )
261+ logger .info (
262+ "%s: %d/%d random samples × %d folds = %d trials" ,
263+ name , max_samples , len (all_combos ), len (splits ), max_samples * len (splits ),
264+ )
265+ else :
266+ combos = all_combos
267+ logger .info (
268+ "%s: %d parameter combos × %d folds = %d trials" ,
269+ name , len (combos ), len (splits ), len (combos ) * len (splits ),
270+ )
261271
262272 # Aggregate OOS results across folds
263273 results_by_combo = {i : [] for i in range (len (combos ))}
@@ -339,11 +349,12 @@ def main():
339349 parser = argparse .ArgumentParser (description = "APEX Strategy Parameter Optimizer" )
340350 parser .add_argument ("--output" , default = "optimization_results.json" )
341351 parser .add_argument ("--folds" , type = int , default = 5 )
352+ parser .add_argument ("--max-samples" , type = int , default = 100 , help = "Max random samples per strategy" )
342353 args = parser .parse_args ()
343354
344355 print (SEPARATOR )
345356 print (" APEX STRATEGY PARAMETER OPTIMIZER" )
346- print (f" Walk-Forward Grid Search | { args .folds } folds" )
357+ print (f" Walk-Forward Random Search | { args .folds } folds | { args . max_samples } samples/strategy " )
347358 print (f" { datetime .now ().strftime ('%Y-%m-%d %H:%M' )} " )
348359 print (SEPARATOR )
349360
@@ -369,6 +380,7 @@ def main():
369380 stocks_no_spy ,
370381 benchmark ,
371382 n_folds = args .folds ,
383+ max_samples = args .max_samples ,
372384 )
373385 print_top_results (momentum_results )
374386 all_results ["cross_sectional_momentum" ] = momentum_results [:10 ]
@@ -385,6 +397,7 @@ def main():
385397 stocks_no_spy ,
386398 benchmark ,
387399 n_folds = args .folds ,
400+ max_samples = args .max_samples ,
388401 )
389402 print_top_results (statarb_results )
390403 all_results ["pairs_trading" ] = statarb_results [:10 ]
0 commit comments