Skip to content

Commit 79002ca

Browse files
committed
convert to onnx
Signed-off-by: Pranav Doma <pranavreddy2327@gmail.com>
1 parent 2b95829 commit 79002ca

1 file changed

Lines changed: 45 additions & 16 deletions

File tree

Models/exports/convert_pytorch_to_onnx.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import onnx
66
from argparse import ArgumentParser
77
import sys
8-
sys.path.append('..')
8+
from pathlib import Path
9+
_REPO_ROOT = Path(__file__).resolve().parents[2]
10+
sys.path.insert(0, str(_REPO_ROOT))
911
from Models.model_components.scene_seg_network import SceneSegNetwork
1012
from Models.model_components.scene_3d_network import Scene3DNetwork
1113
from Models.model_components.domain_seg_network import DomainSegNetwork
1214
from Models.model_components.auto_speed.auto_speed_network import AutoSpeedNetwork
1315
from Models.model_components.ego_lanes_network import EgoLanesNetwork
1416
from Models.model_components.auto_steer.auto_steer_network import AutoSteerNetwork
17+
from Models.model_components.autodrive.autodrive_network import AutoDrive
1518
def main():
1619

1720
# Argument parser for data root path and save path
@@ -59,6 +62,9 @@ def main():
5962
elif (model_name == 'AutoSteer'):
6063
print('Processing AutoSteer Network')
6164
model = AutoSteerNetwork().build_model(version='n')
65+
elif (model_name == 'AutoDrive'):
66+
print('Processing AutoDrive Network')
67+
model = AutoDrive()
6268
else:
6369
raise Exception("Model name not specified correctly, please check")
6470

@@ -80,33 +86,56 @@ def main():
8086
model = model.to(device)
8187
model = model.eval()
8288

83-
# Fake input data (AutoSpeed uses 1024x512)
89+
# Fake input data
8490
if model_name == 'AutoSpeed':
8591
input_shape=(1, 3, 512, 1024)
8692
elif model_name == 'AutoSteer':
8793
input_shape=(1, 3, 512, 1024)
94+
elif model_name == 'AutoDrive':
95+
input_shape=(1, 3, 512, 1024)
8896
else:
8997
input_shape=(1, 3, 320, 640)
90-
input_data = torch.randn(input_shape)
91-
input_data = input_data.to(device)
98+
input_data = torch.randn(input_shape).to(device)
99+
input_data_prev = torch.randn(input_shape).to(device)
92100

93101
# Test inference
94102
print('Testing inference')
95-
_ = model(input_data)
103+
if model_name == 'AutoDrive':
104+
_ = model(input_data_prev, input_data)
105+
else:
106+
_ = model(input_data)
96107

97108
# Export FP32 model to onnx
98109
print('Converting model to ONNX at FP32 and exporting')
99-
torch.onnx.export(model, # model
100-
input_data, # model input
101-
onnx_model_path, # path
102-
export_params=True, # store the trained parameter weights inside the model file
103-
opset_version=18, # the ONNX version to export the model to
104-
do_constant_folding=True, # constant folding for optimization
105-
input_names = ['input'], # input names
106-
output_names = ['output'], # output names
107-
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
108-
'output' : {0 : 'batch_size'}},
109-
external_data=False)
110+
if model_name == 'AutoDrive':
111+
torch.onnx.export(model, # model
112+
(input_data_prev, input_data), # model input tuple
113+
onnx_model_path, # path
114+
export_params=True, # store the trained parameter weights inside the model file
115+
opset_version=18, # the ONNX version to export the model to
116+
do_constant_folding=True, # constant folding for optimization
117+
input_names=['image_prev', 'image_curr'], # input names
118+
output_names=['distance', 'curvature', 'flag_logit'], # output names
119+
dynamic_axes={
120+
'image_prev': {0: 'batch_size'},
121+
'image_curr': {0: 'batch_size'},
122+
'distance': {0: 'batch_size'},
123+
'curvature': {0: 'batch_size'},
124+
'flag_logit': {0: 'batch_size'},
125+
},
126+
external_data=False)
127+
else:
128+
torch.onnx.export(model, # model
129+
input_data, # model input
130+
onnx_model_path, # path
131+
export_params=True, # store the trained parameter weights inside the model file
132+
opset_version=18, # the ONNX version to export the model to
133+
do_constant_folding=True, # constant folding for optimization
134+
input_names = ['input'], # input names
135+
output_names = ['output'], # output names
136+
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
137+
'output' : {0 : 'batch_size'}},
138+
external_data=False)
110139

111140
# Run checks on exported FP32 ONNX network
112141
ONNX_network = onnx.load(onnx_model_path)

0 commit comments

Comments
 (0)