Skip to content

Commit b074be3

Browse files
axumweyaneclaude
andcommitted
Add random sampling to strategy optimizer (100 samples/strategy default)
Replaces exhaustive grid search (13k+ trials) with random sampling, reducing runtime from 30+ minutes to ~4 minutes while maintaining walk-forward cross-validation quality. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d5499fc commit b074be3

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

TFT-main/optimize_strategies.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import itertools
1616
import json
1717
import logging
18+
import random
1819
import sys
1920
from dataclasses import asdict
2021
from datetime import datetime
@@ -247,17 +248,26 @@ def optimize_strategy(
247248
stock_data: pd.DataFrame,
248249
benchmark: pd.DataFrame,
249250
n_folds: int = 5,
251+
max_samples: int = 100,
250252
) -> List[Dict[str, Any]]:
251-
"""Run walk-forward grid search for a strategy."""
253+
"""Run walk-forward random search for a strategy."""
252254
all_dates = sorted(stock_data["timestamp"].unique())
253255
date_index = pd.DatetimeIndex(all_dates)
254256
splits = make_walk_forward_splits(date_index, n_folds)
255257

256-
combos = expand_grid(grid)
257-
logger.info(
258-
"%s: %d parameter combos × %d folds = %d trials",
259-
name, len(combos), len(splits), len(combos) * len(splits),
260-
)
258+
all_combos = expand_grid(grid)
259+
if len(all_combos) > max_samples:
260+
combos = random.sample(all_combos, max_samples)
261+
logger.info(
262+
"%s: %d/%d random samples × %d folds = %d trials",
263+
name, max_samples, len(all_combos), len(splits), max_samples * len(splits),
264+
)
265+
else:
266+
combos = all_combos
267+
logger.info(
268+
"%s: %d parameter combos × %d folds = %d trials",
269+
name, len(combos), len(splits), len(combos) * len(splits),
270+
)
261271

262272
# Aggregate OOS results across folds
263273
results_by_combo = {i: [] for i in range(len(combos))}
@@ -339,11 +349,12 @@ def main():
339349
parser = argparse.ArgumentParser(description="APEX Strategy Parameter Optimizer")
340350
parser.add_argument("--output", default="optimization_results.json")
341351
parser.add_argument("--folds", type=int, default=5)
352+
parser.add_argument("--max-samples", type=int, default=100, help="Max random samples per strategy")
342353
args = parser.parse_args()
343354

344355
print(SEPARATOR)
345356
print(" APEX STRATEGY PARAMETER OPTIMIZER")
346-
print(f" Walk-Forward Grid Search | {args.folds} folds")
357+
print(f" Walk-Forward Random Search | {args.folds} folds | {args.max_samples} samples/strategy")
347358
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M')}")
348359
print(SEPARATOR)
349360

@@ -369,6 +380,7 @@ def main():
369380
stocks_no_spy,
370381
benchmark,
371382
n_folds=args.folds,
383+
max_samples=args.max_samples,
372384
)
373385
print_top_results(momentum_results)
374386
all_results["cross_sectional_momentum"] = momentum_results[:10]
@@ -385,6 +397,7 @@ def main():
385397
stocks_no_spy,
386398
benchmark,
387399
n_folds=args.folds,
400+
max_samples=args.max_samples,
388401
)
389402
print_top_results(statarb_results)
390403
all_results["pairs_trading"] = statarb_results[:10]

0 commit comments

Comments
 (0)