Skip to content

Commit 6002d07

Browse files
authored
Improve typing (#112)
* Improve typing * explicit params to fsm_transition_view
1 parent 3220dd7 commit 6002d07

4 files changed

Lines changed: 44 additions & 62 deletions

File tree

django_fsm/__init__.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@
2121
from django_fsm.signals import pre_transition
2222

2323
if typing.TYPE_CHECKING: # pragma: no cover
24-
from collections.abc import Callable
25-
from collections.abc import Collection
26-
from collections.abc import Generator
27-
from collections.abc import Iterable
28-
from collections.abc import Sequence
2924
from typing import Self
3025

3126
from django.contrib.auth.models import AbstractUser
@@ -39,10 +34,11 @@
3934
IntegerField: typing.TypeAlias = models.IntegerField[int, int]
4035
ForeignKey: typing.TypeAlias = models.ForeignKey[typing.Any, typing.Any]
4136

42-
_FSMModel = models.Model
37+
_FSMModel: typing.TypeAlias = models.Model
4338
_StateValue: typing.TypeAlias = str | int | models.Choices
44-
_Permission: typing.TypeAlias = str | Callable[[_FSMModel, typing.Any], bool]
45-
_Condition: typing.TypeAlias = Callable[[models.Model], bool]
39+
_Permission: typing.TypeAlias = str | typing.Callable[[_FSMModel, UserWithPermissions], bool]
40+
_Condition: typing.TypeAlias = typing.Callable[[_FSMModel], bool]
41+
_TransitionFunc: typing.TypeAlias = typing.Callable[..., _StateValue | typing.Any | None]
4642

4743
else:
4844
_FSMModel = object
@@ -99,7 +95,7 @@ class ConcurrentTransition(Exception): # noqa: N818
9995
class Transition:
10096
def __init__(
10197
self,
102-
method: Callable[..., _StateValue | None],
98+
method: _TransitionFunc,
10399
source: _StateValue,
104100
target: _StateValue,
105101
on_error: _StateValue | None,
@@ -147,7 +143,7 @@ def __eq__(self, other: object) -> bool:
147143

148144
def get_available_FIELD_transitions( # noqa: N802
149145
instance: _FSMModel, field: FSMFieldMixin
150-
) -> Generator[Transition]:
146+
) -> typing.Generator[Transition]:
151147
"""
152148
List of transitions available in current model state
153149
with all conditions met
@@ -161,7 +157,9 @@ def get_available_FIELD_transitions( # noqa: N802
161157
yield meta.get_transition(curr_state)
162158

163159

164-
def get_all_FIELD_transitions(instance: _FSMModel, field: FSMFieldMixin) -> Generator[Transition]: # noqa: N802
160+
def get_all_FIELD_transitions( # noqa: N802
161+
instance: _FSMModel, field: FSMFieldMixin
162+
) -> typing.Generator[Transition]:
165163
"""
166164
List of all transitions available in current model state
167165
"""
@@ -170,7 +168,7 @@ def get_all_FIELD_transitions(instance: _FSMModel, field: FSMFieldMixin) -> Gene
170168

171169
def get_available_user_FIELD_transitions( # noqa: N802
172170
instance: _FSMModel, user: UserWithPermissions, field: FSMFieldMixin
173-
) -> Generator[Transition]:
171+
) -> typing.Generator[Transition]:
174172
"""
175173
List of transitions available in current model state
176174
with all conditions met and user have rights on it
@@ -199,12 +197,12 @@ def get_transition(self, source: _StateValue) -> Transition | None:
199197

200198
def add_transition(
201199
self,
202-
method: Callable[..., _StateValue],
200+
method: _TransitionFunc,
203201
source: _StateValue,
204202
target: _StateValue,
205203
on_error: _StateValue | None = None,
206204
conditions: list[_Condition] | None = None,
207-
permission: str | Callable[[_FSMModel, UserWithPermissions], bool] | None = None,
205+
permission: _Permission | None = None,
208206
custom: dict[str, typing.Any] | None = None,
209207
) -> None:
210208
if source in self.transitions:
@@ -320,7 +318,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
320318
super().__init__(*args, **kwargs)
321319

322320
@override
323-
def deconstruct(self) -> tuple[str, str, Sequence[typing.Any], dict[str, typing.Any]]:
321+
def deconstruct(self) -> tuple[str, str, typing.Sequence[typing.Any], dict[str, typing.Any]]:
324322
name, path, args, kwargs = super().deconstruct()
325323
if self.protected:
326324
kwargs["protected"] = self.protected
@@ -414,7 +412,7 @@ def change_state(
414412

415413
return result
416414

417-
def get_all_transitions(self, instance_cls: type[_FSMModel]) -> Generator[Transition]:
415+
def get_all_transitions(self, instance_cls: type[_FSMModel]) -> typing.Generator[Transition]:
418416
"""
419417
Returns [(source, target, name, method)] for all field transitions
420418
"""
@@ -458,7 +456,7 @@ def _collect_transitions(self, *args: typing.Any, **kwargs: typing.Any) -> None:
458456
if not issubclass(sender, self.base_cls):
459457
return
460458

461-
def is_field_transition_method(attr: Callable[[typing.Any], typing.Any]) -> bool:
459+
def is_field_transition_method(attr: _TransitionFunc) -> bool:
462460
return (
463461
(inspect.ismethod(attr) or inspect.isfunction(attr))
464462
and hasattr(attr, "_django_fsm")
@@ -567,7 +565,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
567565
self._update_initial_state()
568566

569567
@property
570-
def state_fields(self) -> Iterable[FSMFieldMixin]:
568+
def state_fields(self) -> typing.Iterable[FSMFieldMixin]:
571569
return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields) # type: ignore[arg-type]
572570

573571
@override
@@ -576,8 +574,8 @@ def _do_update(
576574
base_qs: QuerySet[Self],
577575
using: str | None,
578576
pk_val: typing.Any,
579-
values: Collection[tuple[_Field, type[models.Model] | None, typing.Any]],
580-
update_fields: Iterable[str] | None,
577+
values: typing.Collection[tuple[_Field, type[models.Model] | None, typing.Any]],
578+
update_fields: typing.Iterable[str] | None,
581579
forced_update: bool,
582580
returning_fields: bool | None = None,
583581
) -> bool:
@@ -644,13 +642,13 @@ def save(self, *args: typing.Any, **kwargs: typing.Any) -> None:
644642

645643
def transition(
646644
field: FSMFieldMixin | str,
647-
source: _StateValue | Sequence[_StateValue] = "*",
645+
source: _StateValue | typing.Sequence[_StateValue] = "*",
648646
target: _StateValue | State | None = None,
649647
on_error: _StateValue | None = None,
650648
conditions: list[_Condition] | None = None,
651649
permission: _Permission | None = None,
652650
custom: dict[str, typing.Any] | None = None,
653-
) -> Callable[[typing.Any], typing.Any]:
651+
) -> typing.Callable[[typing.Any], typing.Any]:
654652
"""
655653
Method decorator to mark allowed transitions.
656654
@@ -728,14 +726,14 @@ def has_transition_perm(bound_method: typing.Any, user: UserWithPermissions) ->
728726

729727

730728
class State:
731-
allowed_states: Sequence[_StateValue]
729+
allowed_states: typing.Sequence[_StateValue]
732730

733731
def get_state(
734732
self,
735733
model: _FSMModel,
736734
transition: Transition,
737735
result: typing.Any,
738-
args: Sequence[typing.Any] | None = None,
736+
args: typing.Sequence[typing.Any] | None = None,
739737
kwargs: dict[str, typing.Any] | None = None,
740738
) -> typing.Any:
741739
raise NotImplementedError
@@ -750,7 +748,7 @@ def get_state(
750748
model: _FSMModel,
751749
transition: Transition,
752750
result: typing.Any,
753-
args: Sequence[typing.Any] | None = None,
751+
args: typing.Sequence[typing.Any] | None = None,
754752
kwargs: dict[str, typing.Any] | None = None,
755753
) -> typing.Any:
756754
if self.allowed_states is not None and result not in self.allowed_states:
@@ -763,8 +761,8 @@ def get_state(
763761
class GET_STATE(State): # noqa: N801
764762
def __init__(
765763
self,
766-
func: Callable[..., _StateValue],
767-
states: Sequence[_StateValue] | None = None,
764+
func: typing.Callable[..., _StateValue],
765+
states: typing.Sequence[_StateValue] | None = None,
768766
) -> None:
769767
self.func = func
770768
self.allowed_states = states or []
@@ -774,7 +772,7 @@ def get_state(
774772
model: _FSMModel,
775773
transition: Transition,
776774
result: _StateValue,
777-
args: Sequence[typing.Any] | None = None,
775+
args: typing.Sequence[typing.Any] | None = None,
778776
kwargs: dict[str, typing.Any] | None = None,
779777
) -> typing.Any:
780778
if args is None:

django_fsm/admin.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
else:
3838
_ModelAdmin = admin.ModelAdmin
3939

40-
_FormType = type[Form | ModelForm[typing.Any]]
40+
_FormType: typing.TypeAlias = type[Form | ModelForm[fsm._FSMModel]]
4141

4242

4343
@dataclass
@@ -85,7 +85,7 @@ def __init__(self, model: type[fsm._FSMModel], admin_site: admin.AdminSite) -> N
8585

8686
@override
8787
def get_readonly_fields(
88-
self, request: http.HttpRequest, obj: typing.Any = None
88+
self, request: http.HttpRequest, obj: fsm._FSMModel | None = None
8989
) -> tuple[str, ...]:
9090
"""Ensures 'protected' fields are 'readonly'"""
9191

@@ -141,7 +141,7 @@ def change_view(
141141
)
142142

143143
@override
144-
def response_change(self, request: http.HttpRequest, obj: typing.Any) -> http.HttpResponse:
144+
def response_change(self, request: http.HttpRequest, obj: fsm._FSMModel) -> http.HttpResponse:
145145
transition_name = request.POST.get(self.fsm_post_param)
146146
if not transition_name:
147147
return super().response_change(request=request, obj=obj)
@@ -172,7 +172,7 @@ def response_change(self, request: http.HttpRequest, obj: typing.Any) -> http.Ht
172172
"preserved_filters": self.get_preserved_filters(request),
173173
"opts": self.model._meta,
174174
},
175-
url=self.get_fsm_redirect_url(request=request, obj=obj),
175+
url=request.path,
176176
)
177177
)
178178

@@ -191,9 +191,6 @@ def get_fsm_label(self, transition: fsm.Transition) -> str:
191191
def get_help_text(self, transition: fsm.Transition) -> str | None:
192192
return transition.custom.get("help_text")
193193

194-
def get_fsm_redirect_url(self, request: http.HttpRequest, obj: typing.Any) -> str:
195-
return request.path
196-
197194
def get_fsm_transition_form(self, transition: fsm.Transition) -> _FormType | None:
198195
"""Get transition form class with error handling."""
199196
form = self.fsm_forms.get(transition.name, transition.custom.get("form"))
@@ -210,7 +207,7 @@ def get_fsm_transition_form(self, transition: fsm.Transition) -> _FormType | Non
210207
# Transition helpers
211208

212209
def _get_fsm_extra_context(
213-
self, *, request: http.HttpRequest, obj: typing.Any
210+
self, *, request: http.HttpRequest, obj: fsm._FSMModel | None
214211
) -> typing.Generator[FSMObjectTransition]:
215212
for field_name in sorted(self.fsm_fields):
216213
transitions_func = getattr(obj, f"get_available_user_{field_name}_transitions", None)
@@ -232,10 +229,10 @@ def _get_fsm_extra_context(
232229
)
233230

234231
def _get_fsm_transition_func(
235-
self, *, obj: typing.Any, transition_name: str
236-
) -> typing.Callable[..., typing.Any]:
232+
self, *, obj: fsm._FSMModel, transition_name: str
233+
) -> fsm._TransitionFunc:
237234
try:
238-
transition_func: typing.Callable[..., typing.Any] = getattr(obj, transition_name)
235+
transition_func: fsm._TransitionFunc = getattr(obj, transition_name)
239236
except AttributeError:
240237
raise AttributeError(
241238
f"{obj.__class__.__name__} has no transition method '{transition_name}'."
@@ -251,7 +248,7 @@ def _get_fsm_transition_func(
251248
return transition_func
252249

253250
def _get_fsm_transition_by_name(
254-
self, *, obj: typing.Any, transition_name: str
251+
self, *, obj: fsm._FSMModel, transition_name: str
255252
) -> fsm.Transition:
256253
transition_func = self._get_fsm_transition_func(obj=obj, transition_name=transition_name)
257254
transitions = transition_func._django_fsm.transitions # type: ignore[attr-defined]
@@ -271,7 +268,7 @@ def _is_fsm_log_enabled() -> bool:
271268
def _execute_fsm_transition(
272269
self,
273270
*,
274-
transition_func: typing.Callable[..., typing.Any],
271+
transition_func: fsm._TransitionFunc,
275272
request: http.HttpRequest,
276273
kwargs: typing.Mapping[str, typing.Any] | None = None,
277274
) -> None:
@@ -287,7 +284,7 @@ def _execute_fsm_transition(
287284
def _apply_fsm_transition(
288285
self,
289286
*,
290-
obj: typing.Any,
287+
obj: fsm._FSMModel,
291288
transition_name: str,
292289
request: http.HttpRequest,
293290
kwargs: typing.Mapping[str, typing.Any] | None = None,
@@ -344,16 +341,14 @@ def _apply_fsm_transition(
344341
# Form handling
345342

346343
def fsm_transition_view(
347-
self, request: http.HttpRequest, *args: typing.Any, **kwargs: typing.Any
344+
self, request: http.HttpRequest, object_id: str, transition_name: str, **kwargs: typing.Any
348345
) -> http.HttpResponse:
349346
"""Handle FSM transition form view with enhanced validation."""
350-
object_id = kwargs["object_id"]
347+
351348
obj = self.get_object(request, object_id)
352349
if obj is None:
353350
return self._get_obj_does_not_exist_redirect(request, self.opts, object_id) # type: ignore[no-any-return, attr-defined]
354351

355-
transition_name = kwargs["transition_name"]
356-
357352
transition = self._get_fsm_transition_by_name(obj=obj, transition_name=transition_name)
358353

359354
if not transition.has_perm(obj, user=request.user):

tests/testapp/models.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
import typing
4-
3+
from django.contrib.auth.models import AbstractUser
54
from django.db import models
65
from django_fsm_log.decorators import fsm_log_by
76
from django_fsm_log.decorators import fsm_log_description
@@ -13,9 +12,6 @@
1312
from django_fsm import FSMKeyField
1413
from django_fsm import transition
1514

16-
if typing.TYPE_CHECKING:
17-
from django.contrib.auth.models import AbstractUser
18-
1915

2016
class Application(models.Model):
2117
"""
@@ -210,8 +206,10 @@ class Meta:
210206
("can_remove_post", "Can remove post"),
211207
]
212208

213-
def can_restore(self: models.Model, user: AbstractUser) -> bool:
214-
return bool(user.is_superuser or user.is_staff)
209+
def can_restore(self: models.Model, user: fsm.UserWithPermissions) -> bool:
210+
if isinstance(user, AbstractUser):
211+
return bool(user.is_superuser or user.is_staff)
212+
return False
215213

216214
@transition(
217215
field=state,

tests/testapp/tests/test_admin.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,6 @@ def setUp(self):
194194
def test_protected_fields_are_readonly(self):
195195
assert self.model_admin.get_readonly_fields(request=self.request) == ("state",)
196196

197-
def test_get_fsm_redirect_url(self):
198-
assert (
199-
self.model_admin.get_fsm_redirect_url(
200-
request=RequestFactory().get(path="/path"),
201-
obj=None,
202-
)
203-
== "/path"
204-
)
205-
206197
# Execution
207198
def test_execute_fsm_transition_falls_back_to_plain_call(self) -> None:
208199
called: dict[str, str] = {}

0 commit comments

Comments
 (0)