Skip to content

Commit 00dd1b7

Browse files
[WASM] implement mul_u8x16 and mul_i8x16 (#12)
mul_u8x16 and mul_i8x16 are implemented as truncating multiplication, matching neon.
1 parent b07fbff commit 00dd1b7

File tree

3 files changed

+56
-27
lines changed

3 files changed

+56
-27
lines changed

fearless_simd/src/generated/wasm.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ impl Simd for WasmSimd128 {
174174
}
175175
#[inline(always)]
176176
fn mul_i8x16(self, a: i8x16<Self>, b: i8x16<Self>) -> i8x16<Self> {
177-
todo!()
177+
let low = i16x8_extmul_low_i8x16(a.into(), b.into());
178+
let high = i16x8_extmul_high_i8x16(a.into(), b.into());
179+
u8x16_shuffle::<0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30>(low, high)
180+
.simd_into(self)
178181
}
179182
#[inline(always)]
180183
fn and_i8x16(self, a: i8x16<Self>, b: i8x16<Self>) -> i8x16<Self> {
@@ -265,7 +268,10 @@ impl Simd for WasmSimd128 {
265268
}
266269
#[inline(always)]
267270
fn mul_u8x16(self, a: u8x16<Self>, b: u8x16<Self>) -> u8x16<Self> {
268-
todo!()
271+
let low = u16x8_extmul_low_u8x16(a.into(), b.into());
272+
let high = u16x8_extmul_high_u8x16(a.into(), b.into());
273+
u8x16_shuffle::<0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30>(low, high)
274+
.simd_into(self)
269275
}
270276
#[inline(always)]
271277
fn and_u8x16(self, a: u8x16<Self>, b: u8x16<Self>) -> u8x16<Self> {

fearless_simd_gen/src/mk_wasm.rs

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,33 +107,36 @@ fn mk_simd_impl(level: Level) -> TokenStream {
107107
}
108108
OpSig::Binary => {
109109
let args = [quote! { a.into() }, quote! { b.into() }];
110-
if method == "mul"
111-
&& (vec_ty
112-
== (&VecType {
113-
scalar: ScalarType::Unsigned,
114-
scalar_bits: 8,
115-
len: 16,
116-
})
117-
|| vec_ty
118-
== (&VecType {
119-
scalar: ScalarType::Int,
120-
scalar_bits: 8,
121-
len: 16,
122-
}))
123-
{
124-
quote! {
125-
#[inline(always)]
126-
fn #method_ident(self, a: #ty<Self>, b: #ty<Self>) -> #ret_ty {
127-
// TODO: WASM doesn't have `i8x16_mul` or `u8x16_mul`.
128-
todo!()
110+
match method {
111+
"mul" if vec_ty.scalar_bits == 8 && vec_ty.len == 16 => {
112+
let (extmul_low, extmul_high) = match vec_ty.scalar {
113+
ScalarType::Unsigned => (
114+
quote! { u16x8_extmul_low_u8x16 },
115+
quote! { u16x8_extmul_high_u8x16 },
116+
),
117+
ScalarType::Int => (
118+
quote! { i16x8_extmul_low_i8x16 },
119+
quote! { i16x8_extmul_high_i8x16 },
120+
),
121+
_ => unreachable!(),
122+
};
123+
124+
quote! {
125+
#[inline(always)]
126+
fn #method_ident(self, a: #ty<Self>, b: #ty<Self>) -> #ret_ty {
127+
let low = #extmul_low(a.into(), b.into());
128+
let high = #extmul_high(a.into(), b.into());
129+
u8x16_shuffle::<0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30>(low, high).simd_into(self)
130+
}
129131
}
130132
}
131-
} else {
132-
let expr = Wasm.expr(method, vec_ty, &args);
133-
quote! {
134-
#[inline(always)]
135-
fn #method_ident(self, a: #ty<Self>, b: #ty<Self>) -> #ret_ty {
136-
#expr.simd_into(self)
133+
_ => {
134+
let expr = Wasm.expr(method, vec_ty, &args);
135+
quote! {
136+
#[inline(always)]
137+
fn #method_ident(self, a: #ty<Self>, b: #ty<Self>) -> #ret_ty {
138+
#expr.simd_into(self)
139+
}
137140
}
138141
}
139142
}

fearless_simd_tests/tests/wasm.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ test_wasm_simd_parity! {
6060
}
6161
}
6262

63+
test_wasm_simd_parity! {
64+
fn mul_u8x16() {
65+
|s| -> [u8; 16] {
66+
let a = u8x16::from_slice(s, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
67+
let b = u8x16::from_slice(s, &[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]);
68+
(a * b).into()
69+
}
70+
}
71+
}
72+
73+
test_wasm_simd_parity! {
74+
fn mul_i8x16() {
75+
|s| -> [i8; 16] {
76+
let a = i8x16::from_slice(s, &[0, -0, 3, -3, 0, -0, 3, -3, 0, -0, 3, -3, 0, -0, 3, -3]);
77+
let b = i8x16::from_slice(s, &[0, 0, 0, 0, -0, -0, -0, -0, 3, 3, 3, 3, -3, -3, -3, -3]);
78+
(a * b).into()
79+
}
80+
}
81+
}
82+
6383
test_wasm_simd_parity! {
6484
fn splat_f32x4() {
6585
|s| -> [f32; 4] {

0 commit comments

Comments
 (0)