@@ -138,6 +138,7 @@ def __init__(
138138 self .keys = [key ]
139139 else :
140140 self .keys = key
141+ self ._key_lock = Lock ()
141142
142143 # record invalid keys and skip them when requesting API
143144 # - keys have insufficient_quota
@@ -160,6 +161,23 @@ def __init__(
160161
161162 self .path = path
162163
164+ def _next_valid_key (self ):
165+ with self ._key_lock :
166+ if len (self .invalid_keys ) == len (self .keys ):
167+ raise RuntimeError ('All keys have insufficient quota.' )
168+
169+ # find the next valid key
170+ while True :
171+ self .key_ctr += 1
172+ if self .key_ctr == len (self .keys ):
173+ self .key_ctr = 0
174+
175+ if self .keys [self .key_ctr ] not in self .invalid_keys :
176+ break
177+
178+ key = self .keys [self .key_ctr ]
179+ return key
180+
163181 def generate (
164182 self ,
165183 inputs : List [PromptType ],
@@ -185,6 +203,10 @@ def generate(
185203 if self .temperature is not None :
186204 temperature = self .temperature
187205
206+ if len (inputs ) == 1 :
207+ # Forget multi-thread for single inference.
208+ return [self ._generate (inputs [0 ], max_out_len , temperature )]
209+
188210 with ThreadPoolExecutor (max_workers = self .max_workers ) as executor :
189211 results = list (
190212 tqdm (
@@ -224,22 +246,7 @@ def _generate(self, input: PromptType, max_out_len: int,
224246
225247 max_num_retries = 0
226248 while max_num_retries < self .retry :
227- self .wait ()
228-
229- with Lock ():
230- if len (self .invalid_keys ) == len (self .keys ):
231- raise RuntimeError ('All keys have insufficient quota.' )
232-
233- # find the next valid key
234- while True :
235- self .key_ctr += 1
236- if self .key_ctr == len (self .keys ):
237- self .key_ctr = 0
238-
239- if self .keys [self .key_ctr ] not in self .invalid_keys :
240- break
241-
242- key = self .keys [self .key_ctr ]
249+ key = self ._next_valid_key ()
243250
244251 header = {
245252 'Authorization' : f'Bearer { key } ' ,
@@ -254,6 +261,7 @@ def _generate(self, input: PromptType, max_out_len: int,
254261 self .org_ctr = 0
255262 header ['OpenAI-Organization' ] = self .orgs [self .org_ctr ]
256263
264+ self .acquire ()
257265 try :
258266 if any (model in self .path
259267 for model in OAI_REASONING_MODEL_LIST ):
@@ -314,23 +322,13 @@ def _generate(self, input: PromptType, max_out_len: int,
314322 self .logger .debug (
315323 f'Get response from { self .proxy_url } ' )
316324
317- except requests .ConnectionError :
318- self .logger .error ('Got connection error, retrying...' )
319- continue
320- try :
321325 if raw_response .status_code != 200 :
322326 self .logger .error (f'Request failed with status code '
323327 f'{ raw_response .status_code } , response: '
324328 f'{ raw_response .content .decode ()} ' )
325329 continue
326330 response = raw_response .json ()
327- except requests .JSONDecodeError :
328- self .logger .error (f'JsonDecode error, got status code '
329- f'{ raw_response .status_code } , response: '
330- f'{ raw_response .content .decode ()} ' )
331- continue
332- self .logger .debug (str (response ))
333- try :
331+ self .logger .debug (str (response ))
334332 if self .logprobs :
335333 return response ['choices' ]
336334 else :
@@ -356,6 +354,12 @@ def _generate(self, input: PromptType, max_out_len: int,
356354 return reasoning_content
357355 else :
358356 return content .strip ()
357+ except requests .ConnectionError :
358+ self .logger .error ('Got connection error, retrying...' )
359+ except requests .JSONDecodeError :
360+ self .logger .error (f'JsonDecode error, got status code '
361+ f'{ raw_response .status_code } , response: '
362+ f'{ raw_response .content .decode ()} ' )
359363 except KeyError :
360364 if 'error' in response :
361365 if response ['error' ]['code' ] == 'rate_limit_exceeded' :
@@ -377,6 +381,8 @@ def _generate(self, input: PromptType, max_out_len: int,
377381 'Find error message in response: ' ,
378382 str (response ['error' ]),
379383 )
384+ finally :
385+ self .release ()
380386 max_num_retries += 1
381387
382388 raise RuntimeError ('Calling OpenAI failed after retrying for '
@@ -575,7 +581,7 @@ def __init__(
575581 query_per_second : int = 1 ,
576582 rpm_verbose : bool = False ,
577583 retry : int = 2 ,
578- key : str | List [ str ] = 'ENV' ,
584+ key : str = 'ENV' ,
579585 org : str | List [str ] | None = None ,
580586 meta_template : Dict | None = None ,
581587 openai_api_base : str | List [str ] = OPENAISDK_API_BASE ,
@@ -671,7 +677,6 @@ def _generate(
671677
672678 num_retries = 0
673679 while num_retries < self .retry :
674- self .wait ()
675680 if any (model in self .path for model in OAI_REASONING_MODEL_LIST ):
676681 self .logger .warning (
677682 f"'max_token' is unsupported for model { self .path } " )
@@ -697,6 +702,7 @@ def _generate(
697702 if self .openai_extra_kwargs :
698703 query_data .update (self .openai_extra_kwargs )
699704
705+ self .acquire ()
700706 try :
701707 if self .verbose :
702708 self .logger .info ('Start calling OpenAI API' )
@@ -789,6 +795,8 @@ def _generate(
789795 except Exception as e :
790796 self .logger .error (f'error occurs at { self .openai_api_base } ' )
791797 self .logger .error (e )
798+ finally :
799+ self .release ()
792800 num_retries += 1
793801 raise RuntimeError ('Calling OpenAI API failed after retrying for '
794802 f'{ self .retry } times. Check the logs for details.' )
@@ -925,6 +933,7 @@ def _generate(
925933 if self .openai_extra_kwargs :
926934 query_data .update (self .openai_extra_kwargs )
927935
936+ self .acquire ()
928937 try :
929938 if self .verbose :
930939 self .logger .info ('Start calling OpenAI API' )
@@ -1052,6 +1061,8 @@ def _generate(
10521061 except Exception as e :
10531062 self .logger .error (f'error occurs at { self .openai_api_base } ' )
10541063 self .logger .error (e )
1064+ finally :
1065+ self .release ()
10551066 num_retries += 1
10561067 raise RuntimeError ('Calling OpenAI API failed after retrying for '
10571068 f'{ self .retry } times. Check the logs for details.' )
0 commit comments