55import onnx
66from argparse import ArgumentParser
77import sys
8- sys .path .append ('..' )
8+ from pathlib import Path
9+ _REPO_ROOT = Path (__file__ ).resolve ().parents [2 ]
10+ sys .path .insert (0 , str (_REPO_ROOT ))
911from Models .model_components .scene_seg_network import SceneSegNetwork
1012from Models .model_components .scene_3d_network import Scene3DNetwork
1113from Models .model_components .domain_seg_network import DomainSegNetwork
1214from Models .model_components .auto_speed .auto_speed_network import AutoSpeedNetwork
1315from Models .model_components .ego_lanes_network import EgoLanesNetwork
1416from Models .model_components .auto_steer .auto_steer_network import AutoSteerNetwork
17+ from Models .model_components .autodrive .autodrive_network import AutoDrive
1518def main ():
1619
1720 # Argument parser for data root path and save path
@@ -59,6 +62,9 @@ def main():
5962 elif (model_name == 'AutoSteer' ):
6063 print ('Processing AutoSteer Network' )
6164 model = AutoSteerNetwork ().build_model (version = 'n' )
65+ elif (model_name == 'AutoDrive' ):
66+ print ('Processing AutoDrive Network' )
67+ model = AutoDrive ()
6268 else :
6369 raise Exception ("Model name not specified correctly, please check" )
6470
@@ -80,33 +86,56 @@ def main():
8086 model = model .to (device )
8187 model = model .eval ()
8288
83- # Fake input data (AutoSpeed uses 1024x512)
89+ # Fake input data
8490 if model_name == 'AutoSpeed' :
8591 input_shape = (1 , 3 , 512 , 1024 )
8692 elif model_name == 'AutoSteer' :
8793 input_shape = (1 , 3 , 512 , 1024 )
94+ elif model_name == 'AutoDrive' :
95+ input_shape = (1 , 3 , 512 , 1024 )
8896 else :
8997 input_shape = (1 , 3 , 320 , 640 )
90- input_data = torch .randn (input_shape )
91- input_data = input_data .to (device )
98+ input_data = torch .randn (input_shape ). to ( device )
99+ input_data_prev = torch . randn ( input_shape ) .to (device )
92100
93101 # Test inference
94102 print ('Testing inference' )
95- _ = model (input_data )
103+ if model_name == 'AutoDrive' :
104+ _ = model (input_data_prev , input_data )
105+ else :
106+ _ = model (input_data )
96107
97108 # Export FP32 model to onnx
98109 print ('Converting model to ONNX at FP32 and exporting' )
99- torch .onnx .export (model , # model
100- input_data , # model input
101- onnx_model_path , # path
102- export_params = True , # store the trained parameter weights inside the model file
103- opset_version = 18 , # the ONNX version to export the model to
104- do_constant_folding = True , # constant folding for optimization
105- input_names = ['input' ], # input names
106- output_names = ['output' ], # output names
107- dynamic_axes = {'input' : {0 : 'batch_size' }, # variable length axes
108- 'output' : {0 : 'batch_size' }},
109- external_data = False )
110+ if model_name == 'AutoDrive' :
111+ torch .onnx .export (model , # model
112+ (input_data_prev , input_data ), # model input tuple
113+ onnx_model_path , # path
114+ export_params = True , # store the trained parameter weights inside the model file
115+ opset_version = 18 , # the ONNX version to export the model to
116+ do_constant_folding = True , # constant folding for optimization
117+ input_names = ['image_prev' , 'image_curr' ], # input names
118+ output_names = ['distance' , 'curvature' , 'flag_logit' ], # output names
119+ dynamic_axes = {
120+ 'image_prev' : {0 : 'batch_size' },
121+ 'image_curr' : {0 : 'batch_size' },
122+ 'distance' : {0 : 'batch_size' },
123+ 'curvature' : {0 : 'batch_size' },
124+ 'flag_logit' : {0 : 'batch_size' },
125+ },
126+ external_data = False )
127+ else :
128+ torch .onnx .export (model , # model
129+ input_data , # model input
130+ onnx_model_path , # path
131+ export_params = True , # store the trained parameter weights inside the model file
132+ opset_version = 18 , # the ONNX version to export the model to
133+ do_constant_folding = True , # constant folding for optimization
134+ input_names = ['input' ], # input names
135+ output_names = ['output' ], # output names
136+ dynamic_axes = {'input' : {0 : 'batch_size' }, # variable length axes
137+ 'output' : {0 : 'batch_size' }},
138+ external_data = False )
110139
111140 # Run checks on exported FP32 ONNX network
112141 ONNX_network = onnx .load (onnx_model_path )
0 commit comments