1111from ldap .controls .ppolicy import PasswordPolicyControl
1212from ldap .controls .ppolicy import PasswordPolicyError
1313from ldap .controls .readentry import PostReadControl
14+ from ldappool import ConnectionManager
15+ from ldappool import StateConnector
1416
1517from canaille .app import models
1618from canaille .app .configuration import CheckResult
2729from .utils import python_attrs_to_ldap
2830
2931
32+ def _make_connector_cls (network_timeout ):
33+ """Create a StateConnector subclass that sets OPT_NETWORK_TIMEOUT."""
34+
35+ class _Connector (StateConnector ):
36+ def __init__ (self , * args , ** kwargs ):
37+ super ().__init__ (* args , ** kwargs )
38+ self .set_option (ldap .OPT_NETWORK_TIMEOUT , network_timeout )
39+
40+ return _Connector
41+
42+
3043@contextmanager
3144def ldap_connection (config ):
3245 conn = ldap .initialize (config ["CANAILLE_LDAP" ]["URI" ])
@@ -89,10 +102,22 @@ def default(self, obj):
89102
90103class LDAPBackend (Backend ):
91104 json_encoder = LDAPModelEncoder
105+ pool = None
92106
93107 def __init__ (self , config ):
94108 super ().__init__ (config )
95- self ._connection = None
109+ ldap_config = self .config ["CANAILLE_LDAP" ]
110+ LDAPBackend .pool = ConnectionManager (
111+ uri = ldap_config ["URI" ],
112+ bind = ldap_config ["BIND_DN" ],
113+ passwd = ldap_config ["BIND_PW" ],
114+ size = ldap_config ["POOL_SIZE" ],
115+ max_lifetime = ldap_config ["POOL_MAX_LIFETIME" ],
116+ retry_max = ldap_config ["POOL_RETRY_MAX" ],
117+ retry_delay = ldap_config ["POOL_RETRY_DELAY" ],
118+ timeout = ldap_config ["TIMEOUT" ],
119+ connector_cls = _make_connector_cls (ldap_config ["TIMEOUT" ]),
120+ )
96121
97122 def init_app (self , app , init_backend = None ):
98123 super ().init_app (app , init_backend )
@@ -119,43 +144,25 @@ def setup_schemas(cls, config):
119144 os .path .dirname (__file__ ) + "/schemas/oauth2-openldap.ldif" ,
120145 )
121146
122- @property
147+ @contextmanager
123148 def connection (self ):
124- if self ._connection :
125- return self ._connection
126-
149+ """Get a connection from the pool."""
127150 try :
128- self ._connection = ldap .initialize (self .config ["CANAILLE_LDAP" ]["URI" ])
129- self ._connection .set_option (
130- ldap .OPT_NETWORK_TIMEOUT ,
131- self .config ["CANAILLE_LDAP" ]["TIMEOUT" ],
132- )
133- self ._connection .simple_bind_s (
134- self .config ["CANAILLE_LDAP" ]["BIND_DN" ],
135- self .config ["CANAILLE_LDAP" ]["BIND_PW" ],
136- )
137-
151+ with self .pool .connection () as conn :
152+ yield conn
138153 except ldap .SERVER_DOWN as exc :
139154 message = _ ("Could not connect to the LDAP server '{uri}'" ).format (
140155 uri = self .config ["CANAILLE_LDAP" ]["URI" ]
141156 )
142157 logging .error (message )
143158 raise ConfigurationException (message ) from exc
144-
145159 except ldap .INVALID_CREDENTIALS as exc :
146160 message = _ ("LDAP authentication failed with user '{user}'" ).format (
147161 user = self .config ["CANAILLE_LDAP" ]["BIND_DN" ]
148162 )
149163 logging .error (message )
150164 raise ConfigurationException (message ) from exc
151165
152- return self ._connection
153-
154- def teardown (self ) -> None :
155- if self ._connection : # pragma: no branch
156- self ._connection .unbind_s ()
157- self ._connection = None
158-
159166 @classmethod
160167 def check_network_config (cls , config ):
161168 from canaille .app import models
@@ -274,12 +281,12 @@ def gettext(x):
274281 return result , message
275282
276283 def set_user_password (self , user , password ) -> None :
277- conn = self .connection
278- conn .passwd_s (
279- user .dn ,
280- None ,
281- password .encode ("utf-8" ),
282- )
284+ with self .connection () as conn :
285+ conn .passwd_s (
286+ user .dn ,
287+ None ,
288+ password .encode ("utf-8" ),
289+ )
283290
284291 def do_query (self , model , dn = None , filter = None , * args , ** kwargs ):
285292 from .ldapobjectquery import LDAPObjectQuery
@@ -325,9 +332,10 @@ def do_query(self, model, dn=None, filter=None, *args, **kwargs):
325332 ldapfilter = f"(&{ class_filter } { arg_filter } { filter } )"
326333 base = base or f"{ model .base } ,{ model .root_dn } "
327334 try :
328- result = self .connection .search_s (
329- base , ldap .SCOPE_SUBTREE , ldapfilter or None , ["+" , "*" ]
330- )
335+ with self .connection () as conn :
336+ result = conn .search_s (
337+ base , ldap .SCOPE_SUBTREE , ldapfilter or None , ["+" , "*" ]
338+ )
331339 except ldap .NO_SUCH_OBJECT :
332340 result = []
333341 return LDAPObjectQuery (model , result )
@@ -475,9 +483,10 @@ def do_save(self, instance) -> None:
475483 (ldap .MOD_REPLACE , name , values )
476484 for name , values in formatted_changes .items ()
477485 ]
478- _ , _ , _ , [result ] = self .connection .modify_ext_s (
479- instance .dn , modlist , serverctrls = [read_post_control ]
480- )
486+ with self .connection () as conn :
487+ _ , _ , _ , [result ] = conn .modify_ext_s (
488+ instance .dn , modlist , serverctrls = [read_post_control ]
489+ )
481490
482491 # Object does not exist yet in the LDAP database
483492 else :
@@ -488,24 +497,25 @@ def do_save(self, instance) -> None:
488497 }
489498 formatted_changes = python_attrs_to_ldap (changes , null_allowed = False )
490499 modlist = [(name , values ) for name , values in formatted_changes .items ()]
491- _ , _ , _ , [result ] = self .connection .add_ext_s (
492- instance .dn , modlist , serverctrls = [read_post_control ]
493- )
500+ with self .connection () as conn :
501+ _ , _ , _ , [result ] = conn .add_ext_s (
502+ instance .dn , modlist , serverctrls = [read_post_control ]
503+ )
494504
495505 instance .exists = True
496506 instance .state = {** result .entry , ** instance .changes }
497507 instance .changes = {}
498508
499509 def do_delete (self , instance ) -> None :
500510 try :
501- self .connection .delete_s (instance .dn )
511+ with self .connection () as conn :
512+ conn .delete_s (instance .dn )
502513 except ldap .NO_SUCH_OBJECT :
503514 pass
504515
505516 def do_reload (self , instance ) -> None :
506- result = self .connection .search_s (
507- instance .dn , ldap .SCOPE_SUBTREE , None , ["+" , "*" ]
508- )
517+ with self .connection () as conn :
518+ result = conn .search_s (instance .dn , ldap .SCOPE_SUBTREE , None , ["+" , "*" ])
509519 instance .changes = {}
510520 instance .state = result [0 ][1 ]
511521
0 commit comments