Skip to content

Commit de34b9f

Browse files
authored
Use PSRAM for internal buffers if available (#1)
* add support for external psram * don't cast malloc pointer * fall back to internal memory * bump version
1 parent 7dfb2ae commit de34b9f

5 files changed

Lines changed: 225 additions & 70 deletions

File tree

library.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "ESPMicroSpeechFeatures",
3-
"version": "1.0.0",
3+
"version": "1.1.0",
44
"description": "Generates TensorFlow micro spectrogram features from audio samples",
55
"keywords": "tensorflow, spectrogram, audio",
66
"repository":

src/fft_util.c

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,51 +17,84 @@ limitations under the License.
1717

1818
#include <stdio.h>
1919

20+
#ifdef USE_ESP32
21+
#include <esp_heap_caps.h>
22+
#endif
23+
2024
#include "kiss_fftr.h"
2125

22-
int FftPopulateState(struct FftState *state, size_t input_size) {
26+
int FftPopulateState(struct FftState *state, size_t input_size)
27+
{
2328
state->input_size = input_size;
2429
state->fft_size = 1;
25-
while (state->fft_size < state->input_size) {
30+
while (state->fft_size < state->input_size)
31+
{
2632
state->fft_size <<= 1;
2733
}
2834

29-
state->input = (int16_t *) (malloc(state->fft_size * sizeof(*state->input)));
30-
if (state->input == NULL) {
35+
#ifdef USE_ESP32
36+
state->input =
37+
heap_caps_malloc(state->fft_size * sizeof(*state->input), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
38+
if (state->input == NULL)
39+
#endif
40+
{
41+
state->input = malloc(state->fft_size * sizeof(*state->input));
42+
}
43+
if (state->input == NULL)
44+
{
3145
fprintf(stderr, "Failed to alloc fft input buffer\n");
3246
return 0;
3347
}
3448

35-
// Cast to int32_t instead of the original complex_int16_t, easy way to fix issues converting the original file into C (may not be the best/proper way!)
36-
state->output = (int32_t *) (malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2));
37-
if (state->output == NULL) {
49+
#ifdef USE_ESP32
50+
state->output = heap_caps_malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2,
51+
MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
52+
if (state->output == NULL)
53+
#endif
54+
{
55+
state->output = malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2);
56+
}
57+
if (state->output == NULL)
58+
{
3859
fprintf(stderr, "Failed to alloc fft output buffer\n");
3960
return 0;
4061
}
4162

4263
// Ask kissfft how much memory it wants.
4364
size_t scratch_size = 0;
4465
kiss_fftr_cfg kfft_cfg = kiss_fftr_alloc(state->fft_size, 0, NULL, &scratch_size);
45-
if (kfft_cfg != NULL) {
66+
if (kfft_cfg != NULL)
67+
{
4668
fprintf(stderr, "Kiss memory sizing failed.\n");
4769
return 0;
4870
}
49-
state->scratch = malloc(scratch_size);
50-
if (state->scratch == NULL) {
71+
72+
#ifdef USE_ESP32
73+
state->scratch = heap_caps_malloc(scratch_size, MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
74+
if (state->scratch == NULL)
75+
#endif
76+
{
77+
state->scratch = malloc(scratch_size);
78+
}
79+
80+
if (state->scratch == NULL)
81+
{
5182
fprintf(stderr, "Failed to alloc fft scratch buffer\n");
5283
return 0;
5384
}
5485
state->scratch_size = scratch_size;
5586
// Let kissfft configure the scratch space we just allocated
5687
kfft_cfg = kiss_fftr_alloc(state->fft_size, 0, state->scratch, &scratch_size);
57-
if (kfft_cfg != state->scratch) {
88+
if (kfft_cfg != state->scratch)
89+
{
5890
fprintf(stderr, "Kiss memory preallocation strategy failed.\n");
5991
return 0;
6092
}
6193
return 1;
6294
}
6395

64-
void FftFreeStateContents(struct FftState *state) {
96+
void FftFreeStateContents(struct FftState *state)
97+
{
6598
free(state->input);
6699
free(state->output);
67100
free(state->scratch);

src/filterbank_util.c

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
Modifications copyright 2024 Kevin Ahrendt.
23
34
Licensed under the Apache License, Version 2.0 (the "License");
45
you may not use this file except in compliance with the License.
@@ -18,10 +19,15 @@ limitations under the License.
1819
#include <math.h>
1920
#include <stdio.h>
2021

22+
#ifdef USE_ESP32
23+
#include <esp_heap_caps.h>
24+
#endif
25+
2126
#define kFilterbankIndexAlignment 4
2227
#define kFilterbankChannelBlockSize 4
2328

24-
void FilterbankFillConfigWithDefaults(struct FilterbankConfig *config) {
29+
void FilterbankFillConfigWithDefaults(struct FilterbankConfig *config)
30+
{
2531
config->num_channels = 32;
2632
config->lower_band_limit = 125.0f;
2733
config->upper_band_limit = 7500.0f;
@@ -31,46 +37,101 @@ void FilterbankFillConfigWithDefaults(struct FilterbankConfig *config) {
3137
static float FreqToMel(float freq) { return 1127.0f * log1pf(freq / 700.0f); }
3238

3339
static void CalculateCenterFrequencies(const int num_channels, const float lower_frequency_limit,
34-
const float upper_frequency_limit, float *center_frequencies) {
40+
const float upper_frequency_limit, float *center_frequencies)
41+
{
3542
assert(lower_frequency_limit >= 0.0f);
3643
assert(upper_frequency_limit > lower_frequency_limit);
3744

3845
const float mel_low = FreqToMel(lower_frequency_limit);
3946
const float mel_hi = FreqToMel(upper_frequency_limit);
4047
const float mel_span = mel_hi - mel_low;
41-
const float mel_spacing = mel_span / ((float) num_channels);
48+
const float mel_spacing = mel_span / ((float)num_channels);
4249
int i;
43-
for (i = 0; i < num_channels; ++i) {
50+
for (i = 0; i < num_channels; ++i)
51+
{
4452
center_frequencies[i] = mel_low + (mel_spacing * (i + 1));
4553
}
4654
}
4755

48-
static void QuantizeFilterbankWeights(const float float_weight, int16_t *weight, int16_t *unweight) {
56+
static void QuantizeFilterbankWeights(const float float_weight, int16_t *weight, int16_t *unweight)
57+
{
4958
*weight = floorf(float_weight * (1 << kFilterbankBits) + 0.5f);
5059
*unweight = floorf((1.0f - float_weight) * (1 << kFilterbankBits) + 0.5f);
5160
}
5261

5362
int FilterbankPopulateState(const struct FilterbankConfig *config, struct FilterbankState *state, int sample_rate,
54-
int spectrum_size) {
63+
int spectrum_size)
64+
{
5565
state->num_channels = config->num_channels;
5666
const int num_channels_plus_1 = config->num_channels + 1;
5767

5868
// How should we align things to index counts given the byte alignment?
5969
const int index_alignment =
6070
(kFilterbankIndexAlignment < sizeof(int16_t) ? 1 : kFilterbankIndexAlignment / sizeof(int16_t));
6171

62-
state->channel_frequency_starts = (int16_t *) malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts));
63-
state->channel_weight_starts = (int16_t *) malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts));
64-
state->channel_widths = (int16_t *) malloc(num_channels_plus_1 * sizeof(*state->channel_widths));
65-
state->work = (uint64_t *) malloc(num_channels_plus_1 * sizeof(*state->work));
72+
#ifdef USE_ESP32
73+
state->channel_frequency_starts = heap_caps_malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
74+
if (state->channel_frequency_starts == NULL)
75+
#endif
76+
{
77+
state->channel_frequency_starts = malloc(num_channels_plus_1 * sizeof(*state->channel_frequency_starts));
78+
}
79+
80+
#ifdef USE_ESP32
81+
state->channel_weight_starts = heap_caps_malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
82+
if (state->channel_weight_starts == NULL)
83+
#endif
84+
{
85+
state->channel_weight_starts = malloc(num_channels_plus_1 * sizeof(*state->channel_weight_starts));
86+
}
87+
88+
#ifdef USE_ESP32
89+
state->channel_widths = heap_caps_malloc(num_channels_plus_1 * sizeof(*state->channel_widths), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
90+
if (state->channel_widths == NULL)
91+
#endif
92+
{
93+
state->channel_widths = malloc(num_channels_plus_1 * sizeof(*state->channel_widths));
94+
}
95+
96+
#ifdef USE_ESP32
97+
state->work = heap_caps_malloc(num_channels_plus_1 * sizeof(*state->work), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
98+
if (state->work == NULL)
99+
#endif
100+
{
101+
state->work = malloc(num_channels_plus_1 * sizeof(*state->work));
102+
}
103+
104+
float *center_mel_freqs = NULL;
105+
#ifdef USE_ESP32
106+
center_mel_freqs = heap_caps_malloc(num_channels_plus_1 * sizeof(*center_mel_freqs), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
107+
if (center_mel_freqs == NULL)
108+
#endif
109+
{
110+
center_mel_freqs = malloc(num_channels_plus_1 * sizeof(*center_mel_freqs));
111+
}
66112

67-
float *center_mel_freqs = (float *) malloc(num_channels_plus_1 * sizeof(*center_mel_freqs));
68-
int16_t *actual_channel_starts = (int16_t *) malloc(num_channels_plus_1 * sizeof(*actual_channel_starts));
69-
int16_t *actual_channel_widths = (int16_t *) malloc(num_channels_plus_1 * sizeof(*actual_channel_widths));
113+
int16_t *actual_channel_starts = NULL;
114+
#ifdef USE_ESP32
115+
actual_channel_starts = heap_caps_malloc(num_channels_plus_1 * sizeof(*actual_channel_starts), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
116+
if (actual_channel_starts == NULL)
117+
#endif
118+
{
119+
actual_channel_starts = malloc(num_channels_plus_1 * sizeof(*actual_channel_starts));
120+
}
121+
122+
int16_t *actual_channel_widths = NULL;
123+
#ifdef USE_ESP32
124+
actual_channel_widths = heap_caps_malloc(num_channels_plus_1 * sizeof(*actual_channel_widths), MALLOC_CAP_SPIRAM | MALLOC_CAP_8BIT);
125+
if (actual_channel_widths == NULL)
126+
#endif
127+
{
128+
actual_channel_widths = malloc(num_channels_plus_1 * sizeof(*actual_channel_widths));
129+
}
70130

71131
if (state->channel_frequency_starts == NULL || state->channel_weight_starts == NULL ||
72132
state->channel_widths == NULL || center_mel_freqs == NULL || actual_channel_starts == NULL ||
73-
actual_channel_widths == NULL) {
133+
actual_channel_widths == NULL)
134+
{
74135
free(center_mel_freqs);
75136
free(actual_channel_starts);
76137
free(actual_channel_widths);
@@ -81,9 +142,9 @@ int FilterbankPopulateState(const struct FilterbankConfig *config, struct Filter
81142
CalculateCenterFrequencies(num_channels_plus_1, config->lower_band_limit, config->upper_band_limit, center_mel_freqs);
82143

83144
// Always exclude DC.
84-
const float hz_per_sbin = 0.5f * sample_rate / ((float) spectrum_size - 1);
145+
const float hz_per_sbin = 0.5f * sample_rate / ((float)spectrum_size - 1);
85146
state->start_index = 1.5f + config->lower_band_limit / hz_per_sbin;
86-
state->end_index = 0; // Initialized to zero here, but actually set below.
147+
state->end_index = 0; // Initialized to zero here, but actually set below.
87148

88149
// For each channel, we need to figure out what frequencies belong to it, and
89150
// how much padding we need to add so that we can efficiently multiply the
@@ -96,18 +157,21 @@ int FilterbankPopulateState(const struct FilterbankConfig *config, struct Filter
96157
int needs_zeros = 0;
97158

98159
int chan;
99-
for (chan = 0; chan < num_channels_plus_1; ++chan) {
160+
for (chan = 0; chan < num_channels_plus_1; ++chan)
161+
{
100162
// Keep jumping frequencies until we overshoot the bound on this channel.
101163
int freq_index = chan_freq_index_start;
102-
while (FreqToMel((freq_index) *hz_per_sbin) <= center_mel_freqs[chan]) {
164+
while (FreqToMel((freq_index)*hz_per_sbin) <= center_mel_freqs[chan])
165+
{
103166
++freq_index;
104167
}
105168

106169
const int width = freq_index - chan_freq_index_start;
107170
actual_channel_starts[chan] = chan_freq_index_start;
108171
actual_channel_widths[chan] = width;
109172

110-
if (width == 0) {
173+
if (width == 0)
174+
{
111175
// This channel doesn't actually get anything from the frequencies, it's
112176
// always zero. We need then to insert some 'zero' weights into the
113177
// output, and just redirect this channel to do a single multiplication at
@@ -117,15 +181,19 @@ int FilterbankPopulateState(const struct FilterbankConfig *config, struct Filter
117181
state->channel_frequency_starts[chan] = 0;
118182
state->channel_weight_starts[chan] = 0;
119183
state->channel_widths[chan] = kFilterbankChannelBlockSize;
120-
if (!needs_zeros) {
184+
if (!needs_zeros)
185+
{
121186
needs_zeros = 1;
122187
int j;
123-
for (j = 0; j < chan; ++j) {
188+
for (j = 0; j < chan; ++j)
189+
{
124190
state->channel_weight_starts[j] += kFilterbankChannelBlockSize;
125191
}
126192
weight_index_start += kFilterbankChannelBlockSize;
127193
}
128-
} else {
194+
}
195+
else
196+
{
129197
// How far back do we need to go to ensure that we have the proper
130198
// alignment?
131199
const int aligned_start = (chan_freq_index_start / index_alignment) * index_alignment;
@@ -143,11 +211,12 @@ int FilterbankPopulateState(const struct FilterbankConfig *config, struct Filter
143211
// Allocate the two arrays to store the weights - weight_index_start contains
144212
// the index of what would be the next set of weights that we would need to
145213
// add, so that's how many weights we need to allocate.
146-
state->weights = (int16_t *) calloc(weight_index_start, sizeof(*state->weights));
147-
state->unweights = (int16_t *) calloc(weight_index_start, sizeof(*state->unweights));
214+
state->weights = (int16_t *)calloc(weight_index_start, sizeof(*state->weights));
215+
state->unweights = (int16_t *)calloc(weight_index_start, sizeof(*state->unweights));
148216

149217
// If the alloc failed, we also need to nuke the arrays.
150-
if (state->weights == NULL || state->unweights == NULL) {
218+
if (state->weights == NULL || state->unweights == NULL)
219+
{
151220
free(center_mel_freqs);
152221
free(actual_channel_starts);
153222
free(actual_channel_widths);
@@ -159,38 +228,43 @@ int FilterbankPopulateState(const struct FilterbankConfig *config, struct Filter
159228
// zero, we only need to fill in the weights that correspond to some frequency
160229
// for a channel.
161230
const float mel_low = FreqToMel(config->lower_band_limit);
162-
for (chan = 0; chan < num_channels_plus_1; ++chan) {
231+
for (chan = 0; chan < num_channels_plus_1; ++chan)
232+
{
163233
int frequency = actual_channel_starts[chan];
164234
const int num_frequencies = actual_channel_widths[chan];
165235
const int frequency_offset = frequency - state->channel_frequency_starts[chan];
166236
const int weight_start = state->channel_weight_starts[chan];
167237
const float denom_val = (chan == 0) ? mel_low : center_mel_freqs[chan - 1];
168238

169239
int j;
170-
for (j = 0; j < num_frequencies; ++j, ++frequency) {
240+
for (j = 0; j < num_frequencies; ++j, ++frequency)
241+
{
171242
const float weight =
172243
(center_mel_freqs[chan] - FreqToMel(frequency * hz_per_sbin)) / (center_mel_freqs[chan] - denom_val);
173244

174245
// Make the float into an integer for the weights (and unweights).
175246
const int weight_index = weight_start + frequency_offset + j;
176247
QuantizeFilterbankWeights(weight, state->weights + weight_index, state->unweights + weight_index);
177248
}
178-
if (frequency > state->end_index) {
249+
if (frequency > state->end_index)
250+
{
179251
state->end_index = frequency;
180252
}
181253
}
182254

183255
free(center_mel_freqs);
184256
free(actual_channel_starts);
185257
free(actual_channel_widths);
186-
if (state->end_index >= spectrum_size) {
258+
if (state->end_index >= spectrum_size)
259+
{
187260
fprintf(stderr, "Filterbank end_index is above spectrum size.\n");
188261
return 0;
189262
}
190263
return 1;
191264
}
192265

193-
void FilterbankFreeStateContents(struct FilterbankState *state) {
266+
void FilterbankFreeStateContents(struct FilterbankState *state)
267+
{
194268
free(state->channel_frequency_starts);
195269
free(state->channel_weight_starts);
196270
free(state->channel_widths);

0 commit comments

Comments
 (0)