Skip to content

Commit 3820aa3

Browse files
author
Suzannah Cooper
committed
Separate helper classes out into separate file
Move the classes used in the migrate_with_timeouts command out into a separate class, with the intention of enabling this to be used in a more general way.
1 parent f948dfd commit 3820aa3

File tree

4 files changed

+274
-259
lines changed

4 files changed

+274
-259
lines changed
Lines changed: 8 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
import dataclasses
2-
import datetime
31
import hashlib
4-
import importlib
52
import io
63
import time
7-
from typing import Any, Protocol, cast
4+
from typing import Any
85

96
from django.core.management import base
107
from django.core.management.commands.migrate import Command as DjangoMigrationMC
118
from django.db import connections
12-
from typing_extensions import Self
139

14-
from django_pg_migration_tools import timeouts
10+
from django_pg_migration_tools import timeout_retries, timeouts
1511

1612

1713
class MaximumRetriesReached(base.CommandError):
@@ -100,9 +96,13 @@ def add_arguments(self, parser: base.CommandParser) -> None:
10096

10197
@base.no_translations
10298
def handle(self, *args: Any, **options: Any) -> None:
103-
timeout_options = MigrationTimeoutOptions.from_dictionary(options)
99+
timeout_options = timeout_retries.MigrationTimeoutOptions.from_dictionary(
100+
options
101+
)
104102
timeout_options.validate()
105-
retry_strategy = MigrateRetryStrategy(timeout_options=timeout_options)
103+
retry_strategy = timeout_retries.MigrateRetryStrategy(
104+
timeout_options=timeout_options
105+
)
106106

107107
stdout: io.StringIO = options.pop("stdout", io.StringIO())
108108
start_time: float = time.time()
@@ -158,172 +158,3 @@ def _cast_lock_value_to_int(self, value: str) -> int:
158158
return int.from_bytes(
159159
hashlib.sha256(value.encode("utf-8")).digest()[:8], "little", signed=True
160160
)
161-
162-
163-
@dataclasses.dataclass
164-
class RetryState:
165-
current_exception: timeouts.DBTimeoutError
166-
lock_timeouts_count: int
167-
stdout: io.StringIO
168-
time_since_start: datetime.timedelta
169-
database: str
170-
171-
172-
class RetryCallback(Protocol):
173-
def __call__(self, retry_state: RetryState, /) -> None: ... # pragma: no cover
174-
175-
176-
@dataclasses.dataclass(kw_only=True)
177-
class TimeoutRetryOptions:
178-
max_retries: int
179-
exp: int
180-
max_wait: datetime.timedelta
181-
min_wait: datetime.timedelta
182-
183-
def validate(self) -> None:
184-
if (self.min_wait is not None and self.max_wait is not None) and (
185-
self.min_wait > self.max_wait
186-
):
187-
raise ValueError(
188-
"The minimum wait cannot be greater than the maximum wait for retries."
189-
)
190-
191-
192-
@dataclasses.dataclass(frozen=True, kw_only=True)
193-
class MigrationTimeoutOptions:
194-
lock_timeout: datetime.timedelta | None
195-
statement_timeout: datetime.timedelta | None
196-
lock_retry_options: TimeoutRetryOptions
197-
retry_callback: RetryCallback | None
198-
199-
@classmethod
200-
def from_dictionary(cls, options: dict[str, Any]) -> Self:
201-
return cls(
202-
lock_timeout=_Parser.optional_positive_ms_to_timedelta(
203-
options.pop("lock_timeout_in_ms", None)
204-
),
205-
statement_timeout=_Parser.optional_positive_ms_to_timedelta(
206-
options.pop("statement_timeout_in_ms", None),
207-
),
208-
lock_retry_options=TimeoutRetryOptions(
209-
max_retries=_Parser.required_positive_int(
210-
options.pop("lock_timeout_max_retries")
211-
),
212-
exp=_Parser.required_positive_int(
213-
options.pop("lock_timeout_retry_exp")
214-
),
215-
max_wait=_Parser.required_positive_ms_to_timedelta(
216-
options.pop("lock_timeout_retry_max_wait_in_ms")
217-
),
218-
min_wait=_Parser.required_positive_ms_to_timedelta(
219-
options.pop("lock_timeout_retry_min_wait_in_ms")
220-
),
221-
),
222-
retry_callback=_Parser.optional_retry_callback(
223-
options.pop("retry_callback_path", None)
224-
),
225-
)
226-
227-
def validate(self) -> None:
228-
if self.statement_timeout is None and self.lock_timeout is None:
229-
raise ValueError(
230-
"At least one of --lock-timeout-in-ms or --statement-timeout-in-ms "
231-
"must be specified."
232-
)
233-
self.lock_retry_options.validate()
234-
235-
236-
class MigrateRetryStrategy:
237-
timeout_options: MigrationTimeoutOptions
238-
retries: int
239-
240-
def __init__(self, timeout_options: MigrationTimeoutOptions):
241-
self.timeout_options = timeout_options
242-
self.retries = 0
243-
244-
def wait(self) -> None:
245-
exp = self.timeout_options.lock_retry_options.exp
246-
min_wait = self.timeout_options.lock_retry_options.min_wait
247-
max_wait = self.timeout_options.lock_retry_options.max_wait
248-
249-
if not self.can_migrate():
250-
# No point waiting if we can't migrate.
251-
return
252-
try:
253-
# self.retries is an integer, but it is turned into a float here
254-
# because a huge exponentiation in Python between integers
255-
# **never** overflows. Instead, the CPU is left trying to calculate
256-
# the result forever and it will eventually return a memory error
257-
# instead. Which we absolutely do not want. Please see:
258-
# https://docs.python.org/3.12/library/exceptions.html#OverflowError
259-
result = exp ** (float(self.retries))
260-
except OverflowError:
261-
result = max_wait.total_seconds()
262-
wait = max(min_wait.total_seconds(), min(result, max_wait.total_seconds()))
263-
time.sleep(wait)
264-
265-
def attempt_callback(
266-
self,
267-
current_exception: timeouts.DBTimeoutError,
268-
stdout: io.StringIO,
269-
start_time: float,
270-
database: str,
271-
) -> None:
272-
if self.timeout_options.retry_callback:
273-
self.timeout_options.retry_callback(
274-
RetryState(
275-
current_exception=current_exception,
276-
lock_timeouts_count=self.retries,
277-
stdout=stdout,
278-
time_since_start=datetime.timedelta(
279-
seconds=time.time() - start_time
280-
),
281-
database=database,
282-
)
283-
)
284-
285-
def can_migrate(self) -> bool:
286-
if self.retries == 0:
287-
# This is the first time migration will run.
288-
return True
289-
return bool(self.retries <= self.timeout_options.lock_retry_options.max_retries)
290-
291-
def increment_retry_count(self) -> None:
292-
self.retries += 1
293-
294-
295-
class _Parser:
296-
@classmethod
297-
def optional_positive_ms_to_timedelta(
298-
cls, value: int | None
299-
) -> datetime.timedelta | None:
300-
if value is None:
301-
return None
302-
return cls.required_positive_ms_to_timedelta(value)
303-
304-
@classmethod
305-
def required_positive_ms_to_timedelta(cls, value: int) -> datetime.timedelta:
306-
value = cls.required_positive_int(value)
307-
return datetime.timedelta(milliseconds=value)
308-
309-
@classmethod
310-
def required_positive_int(cls, value: Any) -> int:
311-
if (not isinstance(value, int)) or (value < 0):
312-
raise ValueError(f"{value} is not a positive integer.")
313-
return value
314-
315-
@classmethod
316-
def optional_retry_callback(cls, value: str | None) -> RetryCallback | None:
317-
if not value:
318-
return None
319-
320-
assert "." in value
321-
module, attr_name = value.rsplit(".", 1)
322-
323-
# This raises ModuleNotFoundError, which gives a good explanation
324-
# of the error already (see tests). We don't have to wrap this into
325-
# our own exception.
326-
callback_module = importlib.import_module(module)
327-
callback = getattr(callback_module, attr_name)
328-
assert callable(callback)
329-
return cast(RetryCallback, callback)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import dataclasses
2+
import datetime
3+
import importlib
4+
import io
5+
import time
6+
from typing import Any, Protocol, cast
7+
8+
from typing_extensions import Self
9+
10+
from django_pg_migration_tools import timeouts
11+
12+
13+
@dataclasses.dataclass
14+
class RetryState:
15+
current_exception: timeouts.DBTimeoutError
16+
lock_timeouts_count: int
17+
stdout: io.StringIO
18+
time_since_start: datetime.timedelta
19+
database: str
20+
21+
22+
class RetryCallback(Protocol):
23+
def __call__(self, retry_state: RetryState, /) -> None: ... # pragma: no cover
24+
25+
26+
@dataclasses.dataclass(kw_only=True)
27+
class TimeoutRetryOptions:
28+
max_retries: int
29+
exp: int
30+
max_wait: datetime.timedelta
31+
min_wait: datetime.timedelta
32+
33+
def validate(self) -> None:
34+
if (self.min_wait is not None and self.max_wait is not None) and (
35+
self.min_wait > self.max_wait
36+
):
37+
raise ValueError(
38+
"The minimum wait cannot be greater than the maximum wait for retries."
39+
)
40+
41+
42+
@dataclasses.dataclass(frozen=True, kw_only=True)
43+
class MigrationTimeoutOptions:
44+
lock_timeout: datetime.timedelta | None
45+
statement_timeout: datetime.timedelta | None
46+
lock_retry_options: TimeoutRetryOptions
47+
retry_callback: RetryCallback | None
48+
49+
@classmethod
50+
def from_dictionary(cls, options: dict[str, Any]) -> Self:
51+
return cls(
52+
lock_timeout=_Parser.optional_positive_ms_to_timedelta(
53+
options.pop("lock_timeout_in_ms", None)
54+
),
55+
statement_timeout=_Parser.optional_positive_ms_to_timedelta(
56+
options.pop("statement_timeout_in_ms", None),
57+
),
58+
lock_retry_options=TimeoutRetryOptions(
59+
max_retries=_Parser.required_positive_int(
60+
options.pop("lock_timeout_max_retries")
61+
),
62+
exp=_Parser.required_positive_int(
63+
options.pop("lock_timeout_retry_exp")
64+
),
65+
max_wait=_Parser.required_positive_ms_to_timedelta(
66+
options.pop("lock_timeout_retry_max_wait_in_ms")
67+
),
68+
min_wait=_Parser.required_positive_ms_to_timedelta(
69+
options.pop("lock_timeout_retry_min_wait_in_ms")
70+
),
71+
),
72+
retry_callback=_Parser.optional_retry_callback(
73+
options.pop("retry_callback_path", None)
74+
),
75+
)
76+
77+
def validate(self) -> None:
78+
if self.statement_timeout is None and self.lock_timeout is None:
79+
raise ValueError(
80+
"At least one of --lock-timeout-in-ms or --statement-timeout-in-ms "
81+
"must be specified."
82+
)
83+
self.lock_retry_options.validate()
84+
85+
86+
class MigrateRetryStrategy:
87+
timeout_options: MigrationTimeoutOptions
88+
retries: int
89+
90+
def __init__(self, timeout_options: MigrationTimeoutOptions):
91+
self.timeout_options = timeout_options
92+
self.retries = 0
93+
94+
def wait(self) -> None:
95+
exp = self.timeout_options.lock_retry_options.exp
96+
min_wait = self.timeout_options.lock_retry_options.min_wait
97+
max_wait = self.timeout_options.lock_retry_options.max_wait
98+
99+
if not self.can_migrate():
100+
# No point waiting if we can't migrate.
101+
return
102+
try:
103+
# self.retries is an integer, but it is turned into a float here
104+
# because a huge exponentiation in Python between integers
105+
# **never** overflows. Instead, the CPU is left trying to calculate
106+
# the result forever and it will eventually return a memory error
107+
# instead. Which we absolutely do not want. Please see:
108+
# https://docs.python.org/3.12/library/exceptions.html#OverflowError
109+
result = exp ** (float(self.retries))
110+
except OverflowError:
111+
result = max_wait.total_seconds()
112+
wait = max(min_wait.total_seconds(), min(result, max_wait.total_seconds()))
113+
time.sleep(wait)
114+
115+
def attempt_callback(
116+
self,
117+
current_exception: timeouts.DBTimeoutError,
118+
stdout: io.StringIO,
119+
start_time: float,
120+
database: str,
121+
) -> None:
122+
if self.timeout_options.retry_callback:
123+
self.timeout_options.retry_callback(
124+
RetryState(
125+
current_exception=current_exception,
126+
lock_timeouts_count=self.retries,
127+
stdout=stdout,
128+
time_since_start=datetime.timedelta(
129+
seconds=time.time() - start_time
130+
),
131+
database=database,
132+
)
133+
)
134+
135+
def can_migrate(self) -> bool:
136+
if self.retries == 0:
137+
# This is the first time migration will run.
138+
return True
139+
return bool(self.retries <= self.timeout_options.lock_retry_options.max_retries)
140+
141+
def increment_retry_count(self) -> None:
142+
self.retries += 1
143+
144+
145+
class _Parser:
146+
@classmethod
147+
def optional_positive_ms_to_timedelta(
148+
cls, value: int | None
149+
) -> datetime.timedelta | None:
150+
if value is None:
151+
return None
152+
return cls.required_positive_ms_to_timedelta(value)
153+
154+
@classmethod
155+
def required_positive_ms_to_timedelta(cls, value: int) -> datetime.timedelta:
156+
value = cls.required_positive_int(value)
157+
return datetime.timedelta(milliseconds=value)
158+
159+
@classmethod
160+
def required_positive_int(cls, value: Any) -> int:
161+
if (not isinstance(value, int)) or (value < 0):
162+
raise ValueError(f"{value} is not a positive integer.")
163+
return value
164+
165+
@classmethod
166+
def optional_retry_callback(cls, value: str | None) -> RetryCallback | None:
167+
if not value:
168+
return None
169+
170+
assert "." in value
171+
module, attr_name = value.rsplit(".", 1)
172+
173+
# This raises ModuleNotFoundError, which gives a good explanation
174+
# of the error already (see tests). We don't have to wrap this into
175+
# our own exception.
176+
callback_module = importlib.import_module(module)
177+
callback = getattr(callback_module, attr_name)
178+
assert callable(callback)
179+
return cast(RetryCallback, callback)

0 commit comments

Comments
 (0)