|
1 | | -import dataclasses |
2 | | -import datetime |
3 | 1 | import hashlib |
4 | | -import importlib |
5 | 2 | import io |
6 | 3 | import time |
7 | | -from typing import Any, Protocol, cast |
| 4 | +from typing import Any |
8 | 5 |
|
9 | 6 | from django.core.management import base |
10 | 7 | from django.core.management.commands.migrate import Command as DjangoMigrationMC |
11 | 8 | from django.db import connections |
12 | | -from typing_extensions import Self |
13 | 9 |
|
14 | | -from django_pg_migration_tools import timeouts |
| 10 | +from django_pg_migration_tools import timeout_retries, timeouts |
15 | 11 |
|
16 | 12 |
|
17 | 13 | class MaximumRetriesReached(base.CommandError): |
@@ -100,9 +96,13 @@ def add_arguments(self, parser: base.CommandParser) -> None: |
100 | 96 |
|
101 | 97 | @base.no_translations |
102 | 98 | def handle(self, *args: Any, **options: Any) -> None: |
103 | | - timeout_options = MigrationTimeoutOptions.from_dictionary(options) |
| 99 | + timeout_options = timeout_retries.MigrationTimeoutOptions.from_dictionary( |
| 100 | + options |
| 101 | + ) |
104 | 102 | timeout_options.validate() |
105 | | - retry_strategy = MigrateRetryStrategy(timeout_options=timeout_options) |
| 103 | + retry_strategy = timeout_retries.MigrateRetryStrategy( |
| 104 | + timeout_options=timeout_options |
| 105 | + ) |
106 | 106 |
|
107 | 107 | stdout: io.StringIO = options.pop("stdout", io.StringIO()) |
108 | 108 | start_time: float = time.time() |
@@ -158,172 +158,3 @@ def _cast_lock_value_to_int(self, value: str) -> int: |
158 | 158 | return int.from_bytes( |
159 | 159 | hashlib.sha256(value.encode("utf-8")).digest()[:8], "little", signed=True |
160 | 160 | ) |
161 | | - |
162 | | - |
163 | | -@dataclasses.dataclass |
164 | | -class RetryState: |
165 | | - current_exception: timeouts.DBTimeoutError |
166 | | - lock_timeouts_count: int |
167 | | - stdout: io.StringIO |
168 | | - time_since_start: datetime.timedelta |
169 | | - database: str |
170 | | - |
171 | | - |
172 | | -class RetryCallback(Protocol): |
173 | | - def __call__(self, retry_state: RetryState, /) -> None: ... # pragma: no cover |
174 | | - |
175 | | - |
176 | | -@dataclasses.dataclass(kw_only=True) |
177 | | -class TimeoutRetryOptions: |
178 | | - max_retries: int |
179 | | - exp: int |
180 | | - max_wait: datetime.timedelta |
181 | | - min_wait: datetime.timedelta |
182 | | - |
183 | | - def validate(self) -> None: |
184 | | - if (self.min_wait is not None and self.max_wait is not None) and ( |
185 | | - self.min_wait > self.max_wait |
186 | | - ): |
187 | | - raise ValueError( |
188 | | - "The minimum wait cannot be greater than the maximum wait for retries." |
189 | | - ) |
190 | | - |
191 | | - |
192 | | -@dataclasses.dataclass(frozen=True, kw_only=True) |
193 | | -class MigrationTimeoutOptions: |
194 | | - lock_timeout: datetime.timedelta | None |
195 | | - statement_timeout: datetime.timedelta | None |
196 | | - lock_retry_options: TimeoutRetryOptions |
197 | | - retry_callback: RetryCallback | None |
198 | | - |
199 | | - @classmethod |
200 | | - def from_dictionary(cls, options: dict[str, Any]) -> Self: |
201 | | - return cls( |
202 | | - lock_timeout=_Parser.optional_positive_ms_to_timedelta( |
203 | | - options.pop("lock_timeout_in_ms", None) |
204 | | - ), |
205 | | - statement_timeout=_Parser.optional_positive_ms_to_timedelta( |
206 | | - options.pop("statement_timeout_in_ms", None), |
207 | | - ), |
208 | | - lock_retry_options=TimeoutRetryOptions( |
209 | | - max_retries=_Parser.required_positive_int( |
210 | | - options.pop("lock_timeout_max_retries") |
211 | | - ), |
212 | | - exp=_Parser.required_positive_int( |
213 | | - options.pop("lock_timeout_retry_exp") |
214 | | - ), |
215 | | - max_wait=_Parser.required_positive_ms_to_timedelta( |
216 | | - options.pop("lock_timeout_retry_max_wait_in_ms") |
217 | | - ), |
218 | | - min_wait=_Parser.required_positive_ms_to_timedelta( |
219 | | - options.pop("lock_timeout_retry_min_wait_in_ms") |
220 | | - ), |
221 | | - ), |
222 | | - retry_callback=_Parser.optional_retry_callback( |
223 | | - options.pop("retry_callback_path", None) |
224 | | - ), |
225 | | - ) |
226 | | - |
227 | | - def validate(self) -> None: |
228 | | - if self.statement_timeout is None and self.lock_timeout is None: |
229 | | - raise ValueError( |
230 | | - "At least one of --lock-timeout-in-ms or --statement-timeout-in-ms " |
231 | | - "must be specified." |
232 | | - ) |
233 | | - self.lock_retry_options.validate() |
234 | | - |
235 | | - |
236 | | -class MigrateRetryStrategy: |
237 | | - timeout_options: MigrationTimeoutOptions |
238 | | - retries: int |
239 | | - |
240 | | - def __init__(self, timeout_options: MigrationTimeoutOptions): |
241 | | - self.timeout_options = timeout_options |
242 | | - self.retries = 0 |
243 | | - |
244 | | - def wait(self) -> None: |
245 | | - exp = self.timeout_options.lock_retry_options.exp |
246 | | - min_wait = self.timeout_options.lock_retry_options.min_wait |
247 | | - max_wait = self.timeout_options.lock_retry_options.max_wait |
248 | | - |
249 | | - if not self.can_migrate(): |
250 | | - # No point waiting if we can't migrate. |
251 | | - return |
252 | | - try: |
253 | | - # self.retries is an integer, but it is turned into a float here |
254 | | - # because a huge exponentiation in Python between integers |
255 | | - # **never** overflows. Instead, the CPU is left trying to calculate |
256 | | - # the result forever and it will eventually return a memory error |
257 | | - # instead. Which we absolutely do not want. Please see: |
258 | | - # https://docs.python.org/3.12/library/exceptions.html#OverflowError |
259 | | - result = exp ** (float(self.retries)) |
260 | | - except OverflowError: |
261 | | - result = max_wait.total_seconds() |
262 | | - wait = max(min_wait.total_seconds(), min(result, max_wait.total_seconds())) |
263 | | - time.sleep(wait) |
264 | | - |
265 | | - def attempt_callback( |
266 | | - self, |
267 | | - current_exception: timeouts.DBTimeoutError, |
268 | | - stdout: io.StringIO, |
269 | | - start_time: float, |
270 | | - database: str, |
271 | | - ) -> None: |
272 | | - if self.timeout_options.retry_callback: |
273 | | - self.timeout_options.retry_callback( |
274 | | - RetryState( |
275 | | - current_exception=current_exception, |
276 | | - lock_timeouts_count=self.retries, |
277 | | - stdout=stdout, |
278 | | - time_since_start=datetime.timedelta( |
279 | | - seconds=time.time() - start_time |
280 | | - ), |
281 | | - database=database, |
282 | | - ) |
283 | | - ) |
284 | | - |
285 | | - def can_migrate(self) -> bool: |
286 | | - if self.retries == 0: |
287 | | - # This is the first time migration will run. |
288 | | - return True |
289 | | - return bool(self.retries <= self.timeout_options.lock_retry_options.max_retries) |
290 | | - |
291 | | - def increment_retry_count(self) -> None: |
292 | | - self.retries += 1 |
293 | | - |
294 | | - |
295 | | -class _Parser: |
296 | | - @classmethod |
297 | | - def optional_positive_ms_to_timedelta( |
298 | | - cls, value: int | None |
299 | | - ) -> datetime.timedelta | None: |
300 | | - if value is None: |
301 | | - return None |
302 | | - return cls.required_positive_ms_to_timedelta(value) |
303 | | - |
304 | | - @classmethod |
305 | | - def required_positive_ms_to_timedelta(cls, value: int) -> datetime.timedelta: |
306 | | - value = cls.required_positive_int(value) |
307 | | - return datetime.timedelta(milliseconds=value) |
308 | | - |
309 | | - @classmethod |
310 | | - def required_positive_int(cls, value: Any) -> int: |
311 | | - if (not isinstance(value, int)) or (value < 0): |
312 | | - raise ValueError(f"{value} is not a positive integer.") |
313 | | - return value |
314 | | - |
315 | | - @classmethod |
316 | | - def optional_retry_callback(cls, value: str | None) -> RetryCallback | None: |
317 | | - if not value: |
318 | | - return None |
319 | | - |
320 | | - assert "." in value |
321 | | - module, attr_name = value.rsplit(".", 1) |
322 | | - |
323 | | - # This raises ModuleNotFoundError, which gives a good explanation |
324 | | - # of the error already (see tests). We don't have to wrap this into |
325 | | - # our own exception. |
326 | | - callback_module = importlib.import_module(module) |
327 | | - callback = getattr(callback_module, attr_name) |
328 | | - assert callable(callback) |
329 | | - return cast(RetryCallback, callback) |
0 commit comments