diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index dc18c4bf2..6be9f383c 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -136,6 +136,12 @@ float dot256(VectorFloat v1, int offset1, VectorFloat v2, int offset2) { return a.mul(b).reduceLanes(VectorOperators.ADD); } + float dot512(VectorFloat v1, int offset1, VectorFloat v2, int offset2) { + var a = fromVectorFloat(FloatVector.SPECIES_512, v1, offset1); + var b = fromVectorFloat(FloatVector.SPECIES_512, v2, offset2); + return a.mul(b).reduceLanes(VectorOperators.ADD); + } + float dotPreferred(VectorFloat v1, int offset1, VectorFloat v2, int offset2) { var a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, v1, offset1); var b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, v2, offset2); @@ -158,9 +164,10 @@ public float dotProduct(VectorFloat v1, int v1offset, VectorFloat v2, int return dotProduct64(v1, v1offset, v2, v2offset, length); else if (length < FloatVector.SPECIES_256.length()) return dotProduct128(v1, v1offset, v2, v2offset, length); - else + else if (length < FloatVector.SPECIES_512.length()) return dotProduct256(v1, v1offset, v2, v2offset, length); - + else + return dotProduct512(v1, v1offset, v2, v2offset, length); } float dotProduct64(VectorFloat v1, int v1offset, VectorFloat v2, int v2offset, int length) { @@ -238,6 +245,31 @@ float dotProduct256(VectorFloat v1, int v1offset, VectorFloat v2, int v2of return res; } + float dotProduct512(VectorFloat v1, int v1offset, VectorFloat v2, int v2offset, int length) { + + if (length == FloatVector.SPECIES_512.length()) + return dot512(v1, v1offset, v2, v2offset); + + final int vectorizedLength = FloatVector.SPECIES_512.loopBound(length); + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); + + int i = 0; + // Process the vectorized part + for (; i < vectorizedLength; i += FloatVector.SPECIES_512.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_512, v1, v1offset + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_512, v2, v2offset + i); + sum = a.fma(b, sum); + } + + float res = sum.reduceLanes(VectorOperators.ADD); + + // Process the tail + for (; i < length; ++i) + res += v1.get(v1offset + i) * v2.get(v2offset + i); + + return res; + } + float dotProductPreferred(VectorFloat va, int vaoffset, VectorFloat vb, int vboffset, int length) { if (length == FloatVector.SPECIES_PREFERRED.length()) return dotPreferred(va, vaoffset, vb, vboffset);