Skip to content

Commit 93ad1be

Browse files
committed
Add zero-kernel-dimension guard and widen ih_num/iw_num casts to i64 in ConvTranspose
1 parent d908f2f commit 93ad1be

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

rust/jstprove_circuits/src/circuit_functions/layers/conv_transpose.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for ConvTransposeLayer
216216
for ow in 0..out_w {
217217
for ci in 0..c_in {
218218
for ki_h in 0..kh {
219-
let ih_num = oh as i32 + pad_h_begin as i32 - (ki_h * dil_h) as i32;
219+
let ih_num = oh as i64 + pad_h_begin as i64 - (ki_h * dil_h) as i64;
220220
if ih_num < 0 || (ih_num as usize) % stride_h != 0 {
221221
continue;
222222
}
@@ -226,7 +226,7 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for ConvTransposeLayer
226226
}
227227
for ki_w in 0..kw {
228228
let iw_num =
229-
ow as i32 + pad_w_begin as i32 - (ki_w * dil_w) as i32;
229+
ow as i64 + pad_w_begin as i64 - (ki_w * dil_w) as i64;
230230
if iw_num < 0 || (iw_num as usize) % stride_w != 0 {
231231
continue;
232232
}
@@ -316,6 +316,17 @@ impl<C: Config, Builder: RootAPI<C>> LayerOp<C, Builder> for ConvTransposeLayer
316316
.into());
317317
}
318318

319+
if kernel_shape.contains(&0) {
320+
return Err(LayerError::UnsupportedConfig {
321+
layer: LayerKind::ConvTranspose,
322+
msg: format!(
323+
"layer '{}': kernel_shape contains a zero dimension: {:?}",
324+
layer.name, kernel_shape
325+
),
326+
}
327+
.into());
328+
}
329+
319330
let default_zeros: Vec<u32> = vec![0; 2 * spatial_rank];
320331
let default_op: Vec<u32> = vec![0; spatial_rank];
321332
let default_ones: Vec<u32> = vec![1; spatial_rank];

0 commit comments

Comments
 (0)