⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ float dot256(VectorFloat<?> v1, int offset1, VectorFloat<?> v2, int offset2) {
return a.mul(b).reduceLanes(VectorOperators.ADD);
}

float dot512(VectorFloat<?> v1, int offset1, VectorFloat<?> v2, int offset2) {
var a = fromVectorFloat(FloatVector.SPECIES_512, v1, offset1);
var b = fromVectorFloat(FloatVector.SPECIES_512, v2, offset2);
return a.mul(b).reduceLanes(VectorOperators.ADD);
}

float dotPreferred(VectorFloat<?> v1, int offset1, VectorFloat<?> v2, int offset2) {
var a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, v1, offset1);
var b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, v2, offset2);
Expand All @@ -158,9 +164,10 @@ public float dotProduct(VectorFloat<?> v1, int v1offset, VectorFloat<?> v2, int
return dotProduct64(v1, v1offset, v2, v2offset, length);
else if (length < FloatVector.SPECIES_256.length())
return dotProduct128(v1, v1offset, v2, v2offset, length);
else
else if (length < FloatVector.SPECIES_512.length())
return dotProduct256(v1, v1offset, v2, v2offset, length);

else
return dotProduct512(v1, v1offset, v2, v2offset, length);
}

float dotProduct64(VectorFloat<?> v1, int v1offset, VectorFloat<?> v2, int v2offset, int length) {
Expand Down Expand Up @@ -238,6 +245,31 @@ float dotProduct256(VectorFloat<?> v1, int v1offset, VectorFloat<?> v2, int v2of
return res;
}

float dotProduct512(VectorFloat<?> v1, int v1offset, VectorFloat<?> v2, int v2offset, int length) {

if (length == FloatVector.SPECIES_512.length())
return dot512(v1, v1offset, v2, v2offset);

final int vectorizedLength = FloatVector.SPECIES_512.loopBound(length);
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512);

int i = 0;
// Process the vectorized part
for (; i < vectorizedLength; i += FloatVector.SPECIES_512.length()) {
FloatVector a = fromVectorFloat(FloatVector.SPECIES_512, v1, v1offset + i);
FloatVector b = fromVectorFloat(FloatVector.SPECIES_512, v2, v2offset + i);
sum = a.fma(b, sum);
}

float res = sum.reduceLanes(VectorOperators.ADD);

// Process the tail
for (; i < length; ++i)
res += v1.get(v1offset + i) * v2.get(v2offset + i);

return res;
}

float dotProductPreferred(VectorFloat<?> va, int vaoffset, VectorFloat<?> vb, int vboffset, int length) {
if (length == FloatVector.SPECIES_PREFERRED.length())
return dotPreferred(va, vaoffset, vb, vboffset);
Expand Down