@@ -222,6 +222,7 @@ def __init__(self, domain: str = "tunnel.terrateam.dev"):
222222 self .subdomain_to_endpoint : Dict [str , str ] = {}
223223 self .pending_requests : Dict [str , asyncio .Future ] = {}
224224 self .connection_start_times : Dict [str , float ] = {}
225+ self .validated_endpoints : Dict [str , bool ] = {} # Track validation state per subdomain
225226 self .domain = domain
226227 self .lock = threading .Lock ()
227228
@@ -238,6 +239,9 @@ def disconnect(self, subdomain: str):
238239 if subdomain in self .connection_start_times :
239240 del self .connection_start_times [subdomain ]
240241
242+ if subdomain in self .validated_endpoints :
243+ del self .validated_endpoints [subdomain ]
244+
241245 if subdomain in self .subdomain_to_endpoint :
242246 del self .subdomain_to_endpoint [subdomain ]
243247
@@ -356,38 +360,8 @@ async def websocket_endpoint(websocket: WebSocket):
356360 "subdomain" : subdomain
357361 }))
358362
359- # Perform endpoint validation AFTER tunnel is established
360- validation_failed = False
361- if domain_validation_mode and domain_validator :
362- # Parse the full endpoint URL to extract host and port
363- try :
364- from urllib .parse import urlparse
365- parsed_endpoint = urlparse (local_endpoint )
366- validation_host = parsed_endpoint .hostname or 'localhost'
367- validation_port = parsed_endpoint .port or (443 if parsed_endpoint .scheme == 'https' else 80 )
368- validation_endpoint = f"{ validation_host } :{ validation_port } "
369-
370- if not await domain_validator .validate_domain (validation_endpoint ):
371- logger .warning (f" Endpoint validation failed for: { validation_endpoint } " )
372- validation_failed = True
373- except Exception as e :
374- logger .error (f" Failed to parse endpoint URL { local_endpoint } : { e } " )
375- # Continue anyway - don't block on parsing errors
376- logger .info (f" Continuing without validation due to parsing error" )
377-
378- # If validation failed, clean up and close connection
379- if validation_failed :
380- with manager .lock :
381- if subdomain in manager .active_connections :
382- del manager .active_connections [subdomain ]
383- if hostname in manager .hostname_to_subdomain :
384- del manager .hostname_to_subdomain [hostname ]
385- if subdomain in manager .subdomain_to_endpoint :
386- del manager .subdomain_to_endpoint [subdomain ]
387- if subdomain in manager .connection_start_times :
388- del manager .connection_start_times [subdomain ]
389- await websocket .close (code = 1008 , reason = "Endpoint validation failed" )
390- return
363+ # NOTE: Endpoint validation removed here - will be done on first request instead
364+ # This allows the local endpoint to start up after the tunnel is established
391365
392366 while True :
393367 try :
@@ -765,6 +739,65 @@ async def proxy_request(request: Request, path: str):
765739 except Exception as e :
766740 logger .warning (f" Failed to parse endpoint URL for revalidation { local_endpoint } : { e } " )
767741
742+ # Check if this subdomain has been validated (first request validation)
743+ if domain_validation_mode and domain_validator :
744+ with manager .lock :
745+ validated = manager .validated_endpoints .get (subdomain , False )
746+
747+ if not validated :
748+ local_endpoint = manager .get_endpoint_from_subdomain (subdomain )
749+ if local_endpoint :
750+ # Parse endpoint to get validation format
751+ try :
752+ from urllib .parse import urlparse
753+ parsed_endpoint = urlparse (local_endpoint )
754+ validation_host = parsed_endpoint .hostname or 'localhost'
755+ validation_port = parsed_endpoint .port or (443 if parsed_endpoint .scheme == 'https' else 80 )
756+ validation_endpoint = f"{ validation_host } :{ validation_port } "
757+
758+ # Skip validation for local endpoints
759+ host_part = validation_endpoint .split (':' )[0 ] if ':' in validation_endpoint else validation_endpoint
760+ is_local = (
761+ host_part in ['localhost' , '127.0.0.1' , '::1' ] or
762+ host_part .startswith ('192.168.' ) or
763+ host_part .startswith ('10.' ) or
764+ host_part .startswith ('172.' ) or
765+ host_part .startswith ('169.254.' ) or
766+ host_part .startswith ('fc00:' ) or
767+ host_part .startswith ('fe80:' )
768+ )
769+
770+ if not is_local :
771+ logger .info (f" Performing first-request validation for: { validation_endpoint } " )
772+
773+ # Validation with retry logic
774+ max_retries = 3
775+ retry_delay = 1.0
776+ validation_success = False
777+
778+ for attempt in range (max_retries + 1 ):
779+ if await domain_validator .validate_domain (validation_endpoint ):
780+ validation_success = True
781+ logger .info (f" First-request validation successful: { validation_endpoint } " )
782+ break
783+ elif attempt < max_retries :
784+ logger .warning (f" Validation attempt { attempt + 1 } /{ max_retries + 1 } failed for { validation_endpoint } " )
785+ logger .info (f" Retrying validation in { retry_delay } seconds..." )
786+ await asyncio .sleep (retry_delay )
787+ retry_delay *= 2 # Exponential backoff
788+
789+ if not validation_success :
790+ logger .error (f" First-request validation failed after { max_retries + 1 } attempts: { validation_endpoint } " )
791+ raise HTTPException (status_code = 403 , detail = "Endpoint validation failed" )
792+
793+ # Mark as validated
794+ with manager .lock :
795+ manager .validated_endpoints [subdomain ] = True
796+
797+ except Exception as e :
798+ logger .warning (f" Failed to parse endpoint URL for validation { local_endpoint } : { e } " )
799+ # Continue anyway - don't block on parsing errors
800+
768801 body = await request .body ()
769802
770803 request_data = {
0 commit comments