Skip to content

Commit bce5683

Browse files
committed
[Refactor] Support concurrent inference accorss tasks.
1 parent ae78149 commit bce5683

File tree

13 files changed

+1081
-84
lines changed

13 files changed

+1081
-84
lines changed

opencompass/cli/main.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,36 @@
55
import getpass
66
import os
77
import os.path as osp
8+
import threading
89
from datetime import datetime
910

1011
from mmengine.config import Config, DictAction
1112

1213
from opencompass.registry import PARTITIONERS, RUNNERS, build_from_cfg
1314
from opencompass.runners import SlurmRunner
1415
from opencompass.summarizers import DefaultSummarizer
15-
from opencompass.utils import (LarkReporter, get_logger, pretty_print_config,
16-
read_from_station, save_to_station)
16+
from opencompass.utils import (HeartBeatManager, LarkReporter, get_logger,
17+
pretty_print_config, read_from_station,
18+
save_to_station)
1719
from opencompass.utils.run import (fill_eval_cfg, fill_infer_cfg,
1820
get_config_from_arg)
1921

2022

23+
def _run_eval_tasks(runner, tasks):
24+
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(tasks[0],
25+
list):
26+
for task_part in tasks:
27+
runner(task_part)
28+
else:
29+
runner(tasks)
30+
31+
32+
def _is_eval_daemon(task_type) -> bool:
33+
if isinstance(task_type, str):
34+
return task_type.endswith('OpenICLEvalWatchTask')
35+
return getattr(task_type, '__name__', '') == 'OpenICLEvalWatchTask'
36+
37+
2138
def parse_args():
2239
parser = argparse.ArgumentParser(description='Run an evaluation task')
2340
parser.add_argument('config', nargs='?', help='Train config file path')
@@ -318,7 +335,15 @@ def main():
318335
if args.config_verbose:
319336
pretty_print_config(cfg)
320337

321-
# infer
338+
infer_tasks = None
339+
infer_runner = None
340+
eval_tasks = None
341+
eval_runner = None
342+
eval_daemon = False
343+
344+
# ========================
345+
# Setup Configuration
346+
# ========================
322347
if args.mode in ['all', 'infer']:
323348
# When user have specified --slurm or --dlc, or have not set
324349
# "infer" in config, we will provide a default configuration
@@ -358,7 +383,8 @@ def main():
358383
if args.dump_res_length:
359384
for task in tasks:
360385
task.dump_res_length = True
361-
runner(tasks)
386+
infer_tasks = tasks
387+
infer_runner = runner
362388

363389
# evaluate
364390
if args.mode in ['all', 'eval']:
@@ -397,14 +423,35 @@ def main():
397423
if args.dry_run:
398424
return
399425
runner = RUNNERS.build(cfg.eval.runner)
400-
401-
# For meta-review-judge in subjective evaluation
402-
if isinstance(tasks, list) and len(tasks) != 0 and isinstance(
403-
tasks[0], list):
404-
for task_part in tasks:
405-
runner(task_part)
406-
else:
407-
runner(tasks)
426+
task_type = getattr(cfg.eval.runner, 'task', {}).get('type', '')
427+
eval_daemon = _is_eval_daemon(task_type)
428+
429+
eval_tasks = tasks
430+
eval_runner = runner
431+
432+
# =================
433+
# Startup Runner
434+
# =================
435+
if infer_runner and eval_runner and eval_daemon:
436+
heartbeat = HeartBeatManager(cfg['work_dir'])
437+
stop_event, hb_thread = heartbeat.start_heartbeat()
438+
439+
eval_thread = threading.Thread(target=_run_eval_tasks,
440+
args=(eval_runner, eval_tasks),
441+
daemon=True)
442+
eval_thread.start()
443+
444+
infer_runner(infer_tasks)
445+
446+
stop_event.set()
447+
hb_thread.join()
448+
logger.info('All infer tasks finished, stop heartbeat.')
449+
eval_thread.join()
450+
else:
451+
if infer_runner is not None:
452+
infer_runner(infer_tasks)
453+
if eval_runner is not None:
454+
_run_eval_tasks(eval_runner, eval_tasks)
408455

409456
# save to station
410457
if args.station_path is not None or cfg.get('station_path') is not None:

opencompass/models/openai_api.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
self.keys = [key]
139139
else:
140140
self.keys = key
141+
self._key_lock = Lock()
141142

142143
# record invalid keys and skip them when requesting API
143144
# - keys have insufficient_quota
@@ -160,6 +161,23 @@ def __init__(
160161

161162
self.path = path
162163

164+
def _next_valid_key(self):
165+
with self._key_lock:
166+
if len(self.invalid_keys) == len(self.keys):
167+
raise RuntimeError('All keys have insufficient quota.')
168+
169+
# find the next valid key
170+
while True:
171+
self.key_ctr += 1
172+
if self.key_ctr == len(self.keys):
173+
self.key_ctr = 0
174+
175+
if self.keys[self.key_ctr] not in self.invalid_keys:
176+
break
177+
178+
key = self.keys[self.key_ctr]
179+
return key
180+
163181
def generate(
164182
self,
165183
inputs: List[PromptType],
@@ -185,6 +203,10 @@ def generate(
185203
if self.temperature is not None:
186204
temperature = self.temperature
187205

206+
if len(inputs) == 1:
207+
# Forget multi-thread for single inference.
208+
return [self._generate(inputs[0], max_out_len, temperature)]
209+
188210
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
189211
results = list(
190212
tqdm(
@@ -224,22 +246,7 @@ def _generate(self, input: PromptType, max_out_len: int,
224246

225247
max_num_retries = 0
226248
while max_num_retries < self.retry:
227-
self.wait()
228-
229-
with Lock():
230-
if len(self.invalid_keys) == len(self.keys):
231-
raise RuntimeError('All keys have insufficient quota.')
232-
233-
# find the next valid key
234-
while True:
235-
self.key_ctr += 1
236-
if self.key_ctr == len(self.keys):
237-
self.key_ctr = 0
238-
239-
if self.keys[self.key_ctr] not in self.invalid_keys:
240-
break
241-
242-
key = self.keys[self.key_ctr]
249+
key = self._next_valid_key()
243250

244251
header = {
245252
'Authorization': f'Bearer {key}',
@@ -254,6 +261,7 @@ def _generate(self, input: PromptType, max_out_len: int,
254261
self.org_ctr = 0
255262
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
256263

264+
self.acquire()
257265
try:
258266
if any(model in self.path
259267
for model in OAI_REASONING_MODEL_LIST):
@@ -314,23 +322,13 @@ def _generate(self, input: PromptType, max_out_len: int,
314322
self.logger.debug(
315323
f'Get response from {self.proxy_url}')
316324

317-
except requests.ConnectionError:
318-
self.logger.error('Got connection error, retrying...')
319-
continue
320-
try:
321325
if raw_response.status_code != 200:
322326
self.logger.error(f'Request failed with status code '
323327
f'{raw_response.status_code}, response: '
324328
f'{raw_response.content.decode()}')
325329
continue
326330
response = raw_response.json()
327-
except requests.JSONDecodeError:
328-
self.logger.error(f'JsonDecode error, got status code '
329-
f'{raw_response.status_code}, response: '
330-
f'{raw_response.content.decode()}')
331-
continue
332-
self.logger.debug(str(response))
333-
try:
331+
self.logger.debug(str(response))
334332
if self.logprobs:
335333
return response['choices']
336334
else:
@@ -356,6 +354,12 @@ def _generate(self, input: PromptType, max_out_len: int,
356354
return reasoning_content
357355
else:
358356
return content.strip()
357+
except requests.ConnectionError:
358+
self.logger.error('Got connection error, retrying...')
359+
except requests.JSONDecodeError:
360+
self.logger.error(f'JsonDecode error, got status code '
361+
f'{raw_response.status_code}, response: '
362+
f'{raw_response.content.decode()}')
359363
except KeyError:
360364
if 'error' in response:
361365
if response['error']['code'] == 'rate_limit_exceeded':
@@ -377,6 +381,8 @@ def _generate(self, input: PromptType, max_out_len: int,
377381
'Find error message in response: ',
378382
str(response['error']),
379383
)
384+
finally:
385+
self.release()
380386
max_num_retries += 1
381387

382388
raise RuntimeError('Calling OpenAI failed after retrying for '
@@ -575,7 +581,7 @@ def __init__(
575581
query_per_second: int = 1,
576582
rpm_verbose: bool = False,
577583
retry: int = 2,
578-
key: str | List[str] = 'ENV',
584+
key: str = 'ENV',
579585
org: str | List[str] | None = None,
580586
meta_template: Dict | None = None,
581587
openai_api_base: str | List[str] = OPENAISDK_API_BASE,
@@ -671,7 +677,6 @@ def _generate(
671677

672678
num_retries = 0
673679
while num_retries < self.retry:
674-
self.wait()
675680
if any(model in self.path for model in OAI_REASONING_MODEL_LIST):
676681
self.logger.warning(
677682
f"'max_token' is unsupported for model {self.path}")
@@ -697,6 +702,7 @@ def _generate(
697702
if self.openai_extra_kwargs:
698703
query_data.update(self.openai_extra_kwargs)
699704

705+
self.acquire()
700706
try:
701707
if self.verbose:
702708
self.logger.info('Start calling OpenAI API')
@@ -789,6 +795,8 @@ def _generate(
789795
except Exception as e:
790796
self.logger.error(f'error occurs at {self.openai_api_base}')
791797
self.logger.error(e)
798+
finally:
799+
self.release()
792800
num_retries += 1
793801
raise RuntimeError('Calling OpenAI API failed after retrying for '
794802
f'{self.retry} times. Check the logs for details.')
@@ -925,6 +933,7 @@ def _generate(
925933
if self.openai_extra_kwargs:
926934
query_data.update(self.openai_extra_kwargs)
927935

936+
self.acquire()
928937
try:
929938
if self.verbose:
930939
self.logger.info('Start calling OpenAI API')
@@ -1052,6 +1061,8 @@ def _generate(
10521061
except Exception as e:
10531062
self.logger.error(f'error occurs at {self.openai_api_base}')
10541063
self.logger.error(e)
1064+
finally:
1065+
self.release()
10551066
num_retries += 1
10561067
raise RuntimeError('Calling OpenAI API failed after retrying for '
10571068
f'{self.retry} times. Check the logs for details.')

0 commit comments

Comments
 (0)