modules/py_ml: Optimize threshold using Helium.

The majority of the time spent processing the model is during
the max, threshold, and nonzero operation. Combing these togheter
in C using Helium provides a massive speedup for larger arrays.
This commit is contained in:
Kwabena W. Agyeman 2025-07-25 19:12:51 -07:00
parent a40976e1b3
commit 77950f03e1
3 changed files with 309 additions and 19 deletions

View File

@ -867,6 +867,26 @@ static inline v128_t vadd_s32(v128_t v0, v128_t v1) {
#endif
}
static inline v128_t vadd_n_u32(v128_t v0, uint32_t x) {
#if (__ARM_ARCH >= 8)
return (v128_t) vaddq_n_u32(v0.u32, x);
#else
return (v128_t) {
.u32 = v0.u32 + x
};
#endif
}
static inline v128_t vadd_n_s32(v128_t v0, int32_t x) {
#if (__ARM_ARCH >= 8)
return (v128_t) vaddq_n_s32(v0.s32, x);
#else
return (v128_t) {
.s32 = v0.s32 + x
};
#endif
}
static inline v128_t vsub_u8(v128_t v0, v128_t v1) {
#if (__ARM_ARCH >= 8)
return (v128_t) vsubq_u8(v0.u8, v1.u8);
@ -923,6 +943,26 @@ static inline v128_t vsub_s16(v128_t v0, v128_t v1) {
#endif
}
static inline v128_t vsub_n_u32(v128_t v0, uint32_t x) {
#if (__ARM_ARCH >= 8)
return (v128_t) vsubq_n_u32(v0.u32, x);
#else
return (v128_t) {
.u32 = v0.u32 - x
};
#endif
}
static inline v128_t vsub_n_s32(v128_t v0, int32_t x) {
#if (__ARM_ARCH >= 8)
return (v128_t) vsubq_n_s32(v0.s32, x);
#else
return (v128_t) {
.s32 = v0.s32 - x
};
#endif
}
#if (__ARM_ARCH >= 8)
#define vsli_u8(v0, v1, n) ((v128_t) vsliq_n_u8(v0.u8, v1.u8, n))
#else
@ -1229,6 +1269,16 @@ static inline v128_t vmul_n_s32(v128_t v0, int32_t x) {
#endif
}
static inline v128_t vmul_n_f32(v128_t v0, float32_t x) {
#if (__ARM_ARCH >= 8)
return (v128_t) vmulq_n_f32(v0.f32, x);
#else
return (v128_t) {
.f32 = v0.f32 * x
};
#endif
}
static inline v128_t vmla_n_u16(v128_t v0, uint16_t x, v128_t v2) {
#if (__ARM_ARCH >= 8)
return (v128_t) vmlaq_n_u16(v2.u16, v0.u16, x);
@ -1373,6 +1423,50 @@ static inline int32_t vmladava_s16(v128_t v0, v128_t v1, int32_t acc) {
#endif
}
static inline v128_t vcvt_f32_u32(v128_t v0) {
#if (__ARM_ARCH >= 8)
return (v128_t) vcvtq(v0.u32);
#else
return (v128_t) {
.f32 = { (float32_t) v0.u32[0] }
};
#endif
}
static inline v128_t vcvt_f32_s32(v128_t v0) {
#if (__ARM_ARCH >= 8)
return (v128_t) vcvtq(v0.s32);
#else
return (v128_t) {
.f32 = { (float32_t) v0.s32[0] }
};
#endif
}
static inline float vminv_f32_pred(v128_t v, float min, v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return vminnmvq_p_f32(min, v.f32, pred);
#else
if (pred > 0) {
min = (v.f32[0] < min) ? v.f32[0] : min;
}
return min;
#endif
}
static inline float vmaxv_f32_pred(v128_t v, float max, v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return vmaxnmvq_p_f32(max, v.f32, pred);
#else
if (pred > 0) {
max = (v.f32[0] > max) ? v.f32[0] : max;
}
return max;
#endif
}
static inline v128_t vldr_u8(const uint8_t *p) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrbq_u8(p);
@ -1404,6 +1498,30 @@ static inline v128_t vldr_u8_pred(const uint8_t *p, v128_predicate_t pred) {
#endif
}
static inline v128_t vldr_u8_widen_u32_gather_pred(const uint8_t *p,
v128_t offsets,
v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrbq_gather_offset_z_u32(p, offsets.u32, pred);
#else
return (v128_t) {
.u32 = { *(p + offsets.u32[0]) }
};
#endif
}
static inline v128_t vldr_s8_widen_s32_gather_pred(const int8_t *p,
v128_t offsets,
v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrbq_gather_offset_z_s32(p, offsets.u32, pred);
#else
return (v128_t) {
.s32 = { *(p + offsets.u32[0]) }
};
#endif
}
static inline void vstr_u8(uint8_t *p, v128_t v0) {
#if (__ARM_ARCH >= 8)
vstrbq(p, v0.u8);
@ -1455,6 +1573,30 @@ static inline v128_t vldr_u16_pred(const uint16_t *p, v128_predicate_t pred) {
#endif
}
static inline v128_t vldr_u16_widen_u32_gather_pred(const uint16_t *p,
v128_t offsets,
v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrhq_gather_shifted_offset_z_u32(p, offsets.u32, pred);
#else
return (v128_t) {
.u32 = { *(p + offsets.u32[0]) }
};
#endif
}
static inline v128_t vldr_s16_widen_s32_gather_pred(const int16_t *p,
v128_t offsets,
v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrhq_gather_shifted_offset_z_s32(p, offsets.u32, pred);
#else
return (v128_t) {
.s32 = { *(p + offsets.u32[0]) }
};
#endif
}
static inline void vstr_u16(uint16_t *p, v128_t v0) {
#if (__ARM_ARCH >= 8)
vstrhq(p, v0.u16);
@ -1570,6 +1712,18 @@ static inline v4x_rows_t vldr_u32_gather_pred_x4_unaligned(v4x_row_ptrs_t rowptr
return rows;
}
static inline v128_t vldr_f32_gather_pred(const float32_t *p,
v128_t offsets,
v128_predicate_t pred) {
#if (__ARM_ARCH >= 8)
return (v128_t) vldrwq_gather_shifted_offset_z_f32(p, offsets.u32, pred);
#else
return (v128_t) {
.f32 = { *(p + offsets.u32[0]) }
};
#endif
}
static inline void vstr_f32_scatter(float32_t *p, v128_t offsets, v128_t v0) {
#if (__ARM_ARCH >= 8)
vstrwq_scatter_shifted_offset(p, offsets.u32, v0.f32);

View File

@ -50,6 +50,7 @@
#include "file_utils.h"
#include "py_ml.h"
#include "ulab/code/ndarray.h"
#include "simd.h"
#define IMLIB_ML_MODEL_ALIGN (OMV_CACHE_LINE_SIZE)
@ -426,11 +427,159 @@ static MP_DEFINE_CONST_OBJ_TYPE(
locals_dict, &py_ml_model_locals_dict
);
extern const mp_obj_type_t py_ml_nms_type;
// The function finds the maximum value in each row of a 2D ndarray and returns the indices
// of the rows where the maximum exceeds a threshold which is the most CPU intensive post-processing
// step. This function can handle regular and transposed ndarrays where the strides result in
// non-contiguous memory access patterns. TODO: Futher performance optimizations are possible for
// contiguous int8/uint8/int16/uint16 ndarrays by processing contiguous rows using contiguous
// memory operations and integer SIMD operations.
static mp_obj_t py_ml_threshold(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
OMV_PROFILE_START();
enum { ARG_output_scale, ARG_output_zero_point, ARG_threshold };
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_output_scale, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
{ MP_QSTR_output_zero_point, MP_ARG_INT, {.u_int = 0 } },
{ MP_QSTR_threshold, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }
};
if (!MP_OBJ_IS_TYPE(pos_args[0], &ulab_ndarray_type)) {
mp_raise_msg(&mp_type_ValueError, MP_ERROR_TEXT("Expected an ndarray."));
}
ndarray_obj_t *input = MP_OBJ_TO_PTR(pos_args[0]);
size_t height, width, row_stride, value_stride;
if (input->ndim == 1) {
height = input->shape[ULAB_MAX_DIMS - 1];
width = 1;
row_stride = input->strides[ULAB_MAX_DIMS - 1] / input->itemsize;
value_stride = 1;
} else if (input->ndim == 2) {
height = input->shape[ULAB_MAX_DIMS - 2];
width = input->shape[ULAB_MAX_DIMS - 1];
row_stride = input->strides[ULAB_MAX_DIMS - 2] / input->itemsize;
value_stride = input->strides[ULAB_MAX_DIMS - 1] / input->itemsize;
} else {
mp_raise_msg(&mp_type_ValueError, MP_ERROR_TEXT("Expected a 1D or 2D ndarray."));
}
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
float output_scale = py_helper_arg_to_float(args[ARG_output_scale].u_obj, 1.0f);
int output_zero_point = args[ARG_output_zero_point].u_int;
float threshold = py_helper_arg_to_float(args[ARG_threshold].u_obj, 0.1f);
fb_alloc_mark();
uint16_t *output_array = (uint16_t *) fb_alloc(height * sizeof(uint16_t), FB_ALLOC_PREFER_SPEED);
v128_t offsets_start = vmul_n_u32(vidup_u32(0, 1), value_stride);
size_t offsets_inc = FLOAT32_VECTOR_SIZE * value_stride;
size_t count = 0;
if (input->dtype == 'f') {
for (size_t y = 0; y < height; y++) {
v128_t offsets = vadd_n_u32(offsets_start, y * row_stride);
float row_max = -FLT_MAX;
for (size_t m = 0; m < width; m += FLOAT32_VECTOR_SIZE) {
v128_predicate_t pred = vpredicate_32(width - m);
v128_t v = vldr_f32_gather_pred((float32_t *) input->array, offsets, pred);
row_max = vmaxv_f32_pred(v, row_max, pred);
offsets = vadd_n_u32(offsets, offsets_inc);
}
if (row_max > threshold) {
output_array[count++] = y;
}
}
} else if (input->dtype == 'b') {
for (size_t y = 0; y < height; y++) {
v128_t offsets = vadd_n_u32(offsets_start, y * row_stride);
float row_max = -FLT_MAX;
for (size_t x = 0; x < width; x += FLOAT32_VECTOR_SIZE) {
v128_predicate_t pred = vpredicate_32(width - x);
v128_t v = vldr_s8_widen_s32_gather_pred((int8_t *) input->array, offsets, pred);
v = vmul_n_f32(vcvt_f32_s32(vsub_n_s32(v, output_zero_point)), output_scale);
row_max = vmaxv_f32_pred(v, row_max, pred);
offsets = vadd_n_u32(offsets, offsets_inc);
}
if (row_max > threshold) {
output_array[count++] = y;
}
}
} else if (input->dtype == 'B') {
for (size_t y = 0; y < height; y++) {
v128_t offsets = vadd_n_u32(offsets_start, y * row_stride);
float row_max = -FLT_MAX;
for (size_t x = 0; x < width; x += FLOAT32_VECTOR_SIZE) {
v128_predicate_t pred = vpredicate_32(width - x);
v128_t v = vldr_u8_widen_u32_gather_pred((uint8_t *) input->array, offsets, pred);
v = vmul_n_f32(vcvt_f32_s32(vsub_n_s32(v, output_zero_point)), output_scale);
row_max = vmaxv_f32_pred(v, row_max, pred);
offsets = vadd_n_u32(offsets, offsets_inc);
}
if (row_max > threshold) {
output_array[count++] = y;
}
}
} else if (input->dtype == 'h') {
for (size_t y = 0; y < height; y++) {
v128_t offsets = vadd_n_u32(offsets_start, y * row_stride);
float row_max = -FLT_MAX;
for (size_t x = 0; x < width; x += FLOAT32_VECTOR_SIZE) {
v128_predicate_t pred = vpredicate_32(width - x);
v128_t v = vldr_s16_widen_s32_gather_pred((int16_t *) input->array, offsets, pred);
v = vmul_n_f32(vcvt_f32_s32(vsub_n_s32(v, output_zero_point)), output_scale);
row_max = vmaxv_f32_pred(v, row_max, pred);
offsets = vadd_n_u32(offsets, offsets_inc);
}
if (row_max > threshold) {
output_array[count++] = y;
}
}
} else if (input->dtype == 'H') {
for (size_t y = 0; y < height; y++) {
v128_t offsets = vadd_n_u32(offsets_start, y * row_stride);
float row_max = -FLT_MAX;
for (size_t x = 0; x < width; x += FLOAT32_VECTOR_SIZE) {
v128_predicate_t pred = vpredicate_32(width - x);
v128_t v = vldr_u16_widen_u32_gather_pred((uint16_t *) input->array, offsets, pred);
v = vmul_n_f32(vcvt_f32_s32(vsub_n_s32(v, output_zero_point)), output_scale);
row_max = vmaxv_f32_pred(v, row_max, pred);
offsets = vadd_n_u32(offsets, offsets_inc);
}
if (row_max > threshold) {
output_array[count++] = y;
}
}
} else {
mp_raise_ValueError(MP_ERROR_TEXT("Unsupported dtype"));
}
// Copy the output array to a new ndarray.
size_t output_shape[ULAB_MAX_DIMS] = {};
output_shape[ULAB_MAX_DIMS - 1] = count;
ndarray_obj_t *output = ndarray_new_dense_ndarray(1, output_shape, NDARRAY_UINT16);
memcpy(output->array, output_array, count * sizeof(uint16_t));
fb_alloc_free_till_mark();
OMV_PROFILE_PRINT();
return MP_OBJ_FROM_PTR(output);
}
static MP_DEFINE_CONST_FUN_OBJ_KW(py_ml_threshold_obj, 1, py_ml_threshold);
static const mp_rom_map_elem_t py_ml_globals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_OBJ_NEW_QSTR(MP_QSTR_ml) },
{ MP_ROM_QSTR(MP_QSTR_Model), MP_ROM_PTR(&py_ml_model_type) },
{ MP_ROM_QSTR(MP_QSTR_threshold), MP_ROM_PTR(&py_ml_threshold_obj) },
};
static MP_DEFINE_CONST_DICT(py_ml_globals_dict, py_ml_globals_dict_table);

View File

@ -26,6 +26,7 @@
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import uml
import ml.utils
from micropython import const
from ulab import numpy as np
@ -75,17 +76,6 @@ def softmax(x):
return e_x / np.sum(e_x, axis=1, keepdims=True)
def threshold(scores, threshold, scale, find_max=False, find_max_axis=1):
if scale > 0:
if find_max:
scores = np.max(scores, axis=find_max_axis)
return np.nonzero(scores > threshold)[0]
else:
if find_max:
scores = np.min(scores, axis=find_max_axis)
return np.nonzero(scores < threshold)[0]
class fomo_postprocess:
def __init__(self, threshold=0.4, w_scale=1.414214, h_scale=1.414214,
nms_threshold=0.1, nms_sigma=0.001,
@ -102,14 +92,13 @@ class fomo_postprocess:
s = model.output_scale[0]
zp = model.output_zero_point[0]
dt = model.output_dtype[0]
t = (self.threshold / s) + zp
# Reshape the output to a 2D array
row_outputs = outputs[0].reshape((oh * ow, oc))
# Threshold all the scores
score_indices = row_outputs[:, _FOMO_CLASSES:]
score_indices = threshold(score_indices, t, s, find_max=True, find_max_axis=1)
score_indices = uml.threshold(score_indices, s, zp, self.threshold)
if not len(score_indices):
return _NO_DETECTION
@ -169,7 +158,7 @@ class yolo_v2_postprocess:
# Threshold all the scores
score_indices = row_outputs[:, _YOLO_V2_SCORE]
score_indices = threshold(score_indices, t, s)
score_indices = uml.threshold(score_indices, s, zp, logit(self.threshold))
if not len(score_indices):
return _NO_DETECTION
@ -228,7 +217,6 @@ class yolo_v5_postprocess:
s = model.output_scale[0]
zp = model.output_zero_point[0]
dt = model.output_dtype[0]
t = (self.threshold / s) + zp
class_count = oc - _YOLO_V5_CLASSES
# Reshape the output to a 2D array
@ -236,7 +224,7 @@ class yolo_v5_postprocess:
# Threshold all the scores
score_indices = row_outputs[:, _YOLO_V5_SCORE]
score_indices = threshold(score_indices, t, s)
score_indices = uml.threshold(score_indices, s, zp, self.threshold)
if not len(score_indices):
return _NO_DETECTION
@ -273,7 +261,6 @@ class yolo_v8_postprocess:
s = model.output_scale[0]
zp = model.output_zero_point[0]
dt = model.output_dtype[0]
t = (self.threshold / s) + zp
class_count = ow - _YOLO_V8_CLASSES
# Reshape the output to a 2D array
@ -281,7 +268,7 @@ class yolo_v8_postprocess:
# Threshold all the scores
score_indices = row_outputs[:, _YOLO_V8_CLASSES:]
score_indices = threshold(score_indices, t, s, find_max=True, find_max_axis=1)
score_indices = uml.threshold(score_indices, s, zp, self.threshold)
if not len(score_indices):
return _NO_DETECTION