Skip to content

Commit 15bc06d

Browse files
Improve ALP parameter selection cost model and sampler to match C++
- Switch findBestFloatParams/findBestDoubleParams from minimizing exception count to minimizing estimated compressed size (length * bitWidth + exceptions * (typeSize + 2 bytes)), matching the C++ ALP cost model. This closes the ~4-5% compression gap vs C++. - Rewrite sampler to collect evenly-spaced sample vectors and run findBestParams on each, then rank by win count. Matches C++ AlpSampler behavior more closely than the previous HashMap-based approach. - Minor fixes: IOExceptionUtils null check, MemoryManager volatile scale, Files utility cleanup, parquet-cli dependency update.
1 parent 56ce400 commit 15bc06d

File tree

9 files changed

+212
-67
lines changed

9 files changed

+212
-67
lines changed

parquet-cli/pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@
8787
</dependency>
8888

8989
<!-- Protobuf dependencies for CLI Tests -->
90+
<dependency>
91+
<groupId>org.mockito</groupId>
92+
<artifactId>mockito-core</artifactId>
93+
<version>4.11.0</version>
94+
<scope>test</scope>
95+
</dependency>
96+
<dependency>
97+
<groupId>org.mockito</groupId>
98+
<artifactId>mockito-inline</artifactId>
99+
<version>4.11.0</version>
100+
<scope>test</scope>
101+
</dependency>
90102
<dependency>
91103
<groupId>org.apache.parquet</groupId>
92104
<artifactId>parquet-protobuf</artifactId>

parquet-column/src/main/java/org/apache/parquet/column/values/alp/AlpConstants.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ private AlpConstants() {
5353
static final int FLOAT_MAX_EXPONENT = 10;
5454
static final int DOUBLE_MAX_EXPONENT = 18;
5555

56-
// Preset caching: full search for the first N vectors, then lock in the top combos
57-
static final int SAMPLER_SAMPLE_VECTORS = 8;
56+
// Sampler constants matching C++ AlpConstants.
57+
// Sample SAMPLER_SAMPLE_VECTORS_PER_ROWGROUP vectors evenly distributed across a rowgroup
58+
// of SAMPLER_ROWGROUP_SIZE values, then lock in top MAX_PRESET_COMBINATIONS combos.
59+
static final int SAMPLER_ROWGROUP_SIZE = 122_880;
60+
static final int SAMPLER_SAMPLE_VECTORS_PER_ROWGROUP = 8;
5861
static final int MAX_PRESET_COMBINATIONS = 5;
5962

6063
// Magic numbers for the fast-rounding trick (see ALP paper, Section 3.2)

parquet-column/src/main/java/org/apache/parquet/column/values/alp/AlpEncoderDecoder.java

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,26 +170,51 @@ public static class EncodingParams {
170170
}
171171
}
172172

173-
/** Try all (exponent, factor) combos and pick the one with fewest exceptions. */
173+
/**
174+
* Try all (exponent, factor) combos and pick the one with the smallest estimated compressed size.
175+
*
176+
* <p>Estimated size (in bits) = {@code length * bitWidth + exceptions * (Float.SIZE + Short.SIZE)},
177+
* where bitWidth is the number of bits needed to represent the unsigned range of non-exception
178+
* encoded values after frame-of-reference subtraction. This matches the C++ ALP cost model and
179+
* produces better compression ratios than minimizing exception count alone.
180+
*/
174181
static EncodingParams findBestFloatParams(float[] values, int offset, int length) {
175182
int bestExponent = 0;
176183
int bestFactor = 0;
177184
int bestExceptions = length;
185+
long bestEstimatedSize = Long.MAX_VALUE;
178186

179187
for (int e = 0; e <= FLOAT_MAX_EXPONENT; e++) {
180188
for (int f = 0; f <= e; f++) {
181189
int exceptions = 0;
190+
int minEncoded = Integer.MAX_VALUE;
191+
int maxEncoded = Integer.MIN_VALUE;
182192
for (int i = 0; i < length; i++) {
183-
if (isFloatException(values[offset + i], e, f)) {
193+
float value = values[offset + i];
194+
if (isFloatException(value, e, f)) {
184195
exceptions++;
196+
} else {
197+
int encoded = encodeFloat(value, e, f);
198+
if (encoded < minEncoded) minEncoded = encoded;
199+
if (encoded > maxEncoded) maxEncoded = encoded;
185200
}
186201
}
187-
if (exceptions < bestExceptions) {
202+
int nonExceptions = length - exceptions;
203+
if (nonExceptions == 0) continue;
204+
long delta = (nonExceptions < 2) ? 0 :
205+
Integer.toUnsignedLong(maxEncoded) - Integer.toUnsignedLong(minEncoded);
206+
int bitsPerValue = (delta == 0) ? 0 : (64 - Long.numberOfLeadingZeros(delta));
207+
long estimatedSize = (long) length * bitsPerValue
208+
+ (long) exceptions * (Float.SIZE + Short.SIZE);
209+
if (estimatedSize < bestEstimatedSize
210+
|| (estimatedSize == bestEstimatedSize
211+
&& (e > bestExponent || (e == bestExponent && f > bestFactor)))) {
212+
bestEstimatedSize = estimatedSize;
188213
bestExponent = e;
189214
bestFactor = f;
190215
bestExceptions = exceptions;
191-
if (bestExceptions == 0) {
192-
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
216+
if (bestExceptions == 0 && bitsPerValue == 0) {
217+
return new EncodingParams(bestExponent, bestFactor, 0);
193218
}
194219
}
195220
}
@@ -202,74 +227,130 @@ static EncodingParams findBestFloatParamsWithPresets(float[] values, int offset,
202227
int bestExponent = presets[0][0];
203228
int bestFactor = presets[0][1];
204229
int bestExceptions = length;
230+
long bestEstimatedSize = Long.MAX_VALUE;
205231

206232
for (int[] preset : presets) {
207233
int e = preset[0];
208234
int f = preset[1];
209235
int exceptions = 0;
236+
int minEncoded = Integer.MAX_VALUE;
237+
int maxEncoded = Integer.MIN_VALUE;
210238
for (int i = 0; i < length; i++) {
211-
if (isFloatException(values[offset + i], e, f)) {
239+
float value = values[offset + i];
240+
if (isFloatException(value, e, f)) {
212241
exceptions++;
242+
} else {
243+
int encoded = encodeFloat(value, e, f);
244+
if (encoded < minEncoded) minEncoded = encoded;
245+
if (encoded > maxEncoded) maxEncoded = encoded;
213246
}
214247
}
215-
if (exceptions < bestExceptions) {
248+
int nonExceptions = length - exceptions;
249+
if (nonExceptions == 0) continue;
250+
long delta = (nonExceptions < 2) ? 0 :
251+
Integer.toUnsignedLong(maxEncoded) - Integer.toUnsignedLong(minEncoded);
252+
int bitsPerValue = (delta == 0) ? 0 : (64 - Long.numberOfLeadingZeros(delta));
253+
long estimatedSize = (long) length * bitsPerValue
254+
+ (long) exceptions * (Float.SIZE + Short.SIZE);
255+
if (estimatedSize < bestEstimatedSize
256+
|| (estimatedSize == bestEstimatedSize
257+
&& (e > bestExponent || (e == bestExponent && f > bestFactor)))) {
258+
bestEstimatedSize = estimatedSize;
216259
bestExponent = e;
217260
bestFactor = f;
218261
bestExceptions = exceptions;
219-
if (bestExceptions == 0) {
220-
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
262+
if (bestExceptions == 0 && bitsPerValue == 0) {
263+
return new EncodingParams(bestExponent, bestFactor, 0);
221264
}
222265
}
223266
}
224267
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
225268
}
226269

270+
/** Try all (exponent, factor) combos and pick the one with the smallest estimated compressed size. */
227271
static EncodingParams findBestDoubleParams(double[] values, int offset, int length) {
228272
int bestExponent = 0;
229273
int bestFactor = 0;
230274
int bestExceptions = length;
275+
long bestEstimatedSize = Long.MAX_VALUE;
231276

232277
for (int e = 0; e <= DOUBLE_MAX_EXPONENT; e++) {
233278
for (int f = 0; f <= e; f++) {
234279
int exceptions = 0;
280+
long minEncoded = Long.MAX_VALUE;
281+
long maxEncoded = Long.MIN_VALUE;
235282
for (int i = 0; i < length; i++) {
236-
if (isDoubleException(values[offset + i], e, f)) {
283+
double value = values[offset + i];
284+
if (isDoubleException(value, e, f)) {
237285
exceptions++;
286+
} else {
287+
long encoded = encodeDouble(value, e, f);
288+
if (encoded < minEncoded) minEncoded = encoded;
289+
if (encoded > maxEncoded) maxEncoded = encoded;
238290
}
239291
}
240-
if (exceptions < bestExceptions) {
292+
int nonExceptions = length - exceptions;
293+
if (nonExceptions == 0) continue;
294+
// delta as signed subtraction; Long.numberOfLeadingZeros handles the unsigned bit width
295+
// correctly even when the subtraction overflows (large range → penalized with 64 bits).
296+
long delta = (nonExceptions < 2) ? 0 : (maxEncoded - minEncoded);
297+
int bitsPerValue = (delta == 0) ? 0 : (64 - Long.numberOfLeadingZeros(delta));
298+
long estimatedSize = (long) length * bitsPerValue
299+
+ (long) exceptions * (Double.SIZE + Short.SIZE);
300+
if (estimatedSize < bestEstimatedSize
301+
|| (estimatedSize == bestEstimatedSize
302+
&& (e > bestExponent || (e == bestExponent && f > bestFactor)))) {
303+
bestEstimatedSize = estimatedSize;
241304
bestExponent = e;
242305
bestFactor = f;
243306
bestExceptions = exceptions;
244-
if (bestExceptions == 0) {
245-
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
307+
if (bestExceptions == 0 && bitsPerValue == 0) {
308+
return new EncodingParams(bestExponent, bestFactor, 0);
246309
}
247310
}
248311
}
249312
}
250313
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
251314
}
252315

316+
/** Same as findBestDoubleParams but only tries the cached preset combos. */
253317
static EncodingParams findBestDoubleParamsWithPresets(double[] values, int offset, int length, int[][] presets) {
254318
int bestExponent = presets[0][0];
255319
int bestFactor = presets[0][1];
256320
int bestExceptions = length;
321+
long bestEstimatedSize = Long.MAX_VALUE;
257322

258323
for (int[] preset : presets) {
259324
int e = preset[0];
260325
int f = preset[1];
261326
int exceptions = 0;
327+
long minEncoded = Long.MAX_VALUE;
328+
long maxEncoded = Long.MIN_VALUE;
262329
for (int i = 0; i < length; i++) {
263-
if (isDoubleException(values[offset + i], e, f)) {
330+
double value = values[offset + i];
331+
if (isDoubleException(value, e, f)) {
264332
exceptions++;
333+
} else {
334+
long encoded = encodeDouble(value, e, f);
335+
if (encoded < minEncoded) minEncoded = encoded;
336+
if (encoded > maxEncoded) maxEncoded = encoded;
265337
}
266338
}
267-
if (exceptions < bestExceptions) {
339+
int nonExceptions = length - exceptions;
340+
if (nonExceptions == 0) continue;
341+
long delta = (nonExceptions < 2) ? 0 : (maxEncoded - minEncoded);
342+
int bitsPerValue = (delta == 0) ? 0 : (64 - Long.numberOfLeadingZeros(delta));
343+
long estimatedSize = (long) length * bitsPerValue
344+
+ (long) exceptions * (Double.SIZE + Short.SIZE);
345+
if (estimatedSize < bestEstimatedSize
346+
|| (estimatedSize == bestEstimatedSize
347+
&& (e > bestExponent || (e == bestExponent && f > bestFactor)))) {
348+
bestEstimatedSize = estimatedSize;
268349
bestExponent = e;
269350
bestFactor = f;
270351
bestExceptions = exceptions;
271-
if (bestExceptions == 0) {
272-
return new EncodingParams(bestExponent, bestFactor, bestExceptions);
352+
if (bestExceptions == 0 && bitsPerValue == 0) {
353+
return new EncodingParams(bestExponent, bestFactor, 0);
273354
}
274355
}
275356
}

parquet-column/src/main/java/org/apache/parquet/column/values/alp/AlpValuesReaderForDouble.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
public class AlpValuesReaderForDouble extends AlpValuesReader {
3535

3636
private double[] decodedValues;
37+
private long[] deltasBuffer;
38+
private int[] excPositionsBuffer;
39+
private final long[] unpackPadBuf = new long[8];
40+
private byte[] unpackByteBuf;
3741

3842
public AlpValuesReaderForDouble() {
3943
super();
@@ -42,6 +46,9 @@ public AlpValuesReaderForDouble() {
4246
@Override
4347
protected void allocateDecodedBuffer(int capacity) {
4448
this.decodedValues = new double[capacity];
49+
this.deltasBuffer = new long[capacity];
50+
this.excPositionsBuffer = new int[capacity];
51+
this.unpackByteBuf = new byte[Long.SIZE]; // max bit width for long = 64 bytes
4552
}
4653

4754
@Override
@@ -69,24 +76,24 @@ protected void decodeVector(int vectorIdx) {
6976
int bitWidth = vectorsData.get(pos + 8) & 0xFF;
7077
pos += DOUBLE_FOR_INFO_SIZE;
7178

72-
long[] deltas = new long[vectorLen];
7379
if (bitWidth > 0) {
74-
pos = unpackLongsWithBytePacker(vectorsData, pos, deltas, vectorLen, bitWidth);
80+
pos = unpackLongsWithBytePacker(vectorsData, pos, deltasBuffer, vectorLen, bitWidth);
81+
} else {
82+
java.util.Arrays.fill(deltasBuffer, 0, vectorLen, 0L);
7583
}
7684

7785
for (int i = 0; i < vectorLen; i++) {
78-
long encoded = deltas[i] + frameOfReference;
86+
long encoded = deltasBuffer[i] + frameOfReference;
7987
decodedValues[i] = AlpEncoderDecoder.decodeDouble(encoded, exponent, factor);
8088
}
8189

8290
if (numExceptions > 0) {
83-
int[] excPositions = new int[numExceptions];
8491
for (int e = 0; e < numExceptions; e++) {
85-
excPositions[e] = getShortLE(vectorsData, pos) & 0xFFFF;
92+
excPositionsBuffer[e] = getShortLE(vectorsData, pos) & 0xFFFF;
8693
pos += Short.BYTES;
8794
}
8895
for (int e = 0; e < numExceptions; e++) {
89-
decodedValues[excPositions[e]] = getDoubleLE(vectorsData, pos);
96+
decodedValues[excPositionsBuffer[e]] = getDoubleLE(vectorsData, pos);
9097
pos += Double.BYTES;
9198
}
9299
}
@@ -109,14 +116,15 @@ private int unpackLongsWithBytePacker(ByteBuffer buf, int pos, long[] output, in
109116
int alreadyRead = numFullGroups * bitWidth;
110117
int partialBytes = totalPackedBytes - alreadyRead;
111118

112-
byte[] padded = new byte[bitWidth];
113119
for (int i = 0; i < partialBytes; i++) {
114-
padded[i] = buf.get(pos + i);
120+
unpackByteBuf[i] = buf.get(pos + i);
121+
}
122+
for (int i = partialBytes; i < bitWidth; i++) {
123+
unpackByteBuf[i] = 0;
115124
}
116125

117-
long[] temp = new long[8];
118-
packer.unpack8Values(padded, 0, temp, 0);
119-
System.arraycopy(temp, 0, output, numFullGroups * 8, remaining);
126+
packer.unpack8Values(unpackByteBuf, 0, unpackPadBuf, 0);
127+
System.arraycopy(unpackPadBuf, 0, output, numFullGroups * 8, remaining);
120128
pos += partialBytes;
121129
}
122130

parquet-column/src/main/java/org/apache/parquet/column/values/alp/AlpValuesReaderForFloat.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
public class AlpValuesReaderForFloat extends AlpValuesReader {
3535

3636
private float[] decodedValues;
37+
private int[] deltasBuffer;
38+
private int[] excPositionsBuffer;
39+
private final int[] unpackPadBuf = new int[8];
40+
private byte[] unpackByteBuf;
3741

3842
public AlpValuesReaderForFloat() {
3943
super();
@@ -42,6 +46,9 @@ public AlpValuesReaderForFloat() {
4246
@Override
4347
protected void allocateDecodedBuffer(int capacity) {
4448
this.decodedValues = new float[capacity];
49+
this.deltasBuffer = new int[capacity];
50+
this.excPositionsBuffer = new int[capacity];
51+
this.unpackByteBuf = new byte[Integer.SIZE]; // max bit width for int = 32 bytes
4552
}
4653

4754
@Override
@@ -69,26 +76,25 @@ protected void decodeVector(int vectorIdx) {
6976
int bitWidth = vectorsData.get(pos + 4) & 0xFF;
7077
pos += FLOAT_FOR_INFO_SIZE;
7178

72-
int[] deltas = new int[vectorLen];
7379
if (bitWidth > 0) {
74-
pos = unpackIntsWithBytePacker(vectorsData, pos, deltas, vectorLen, bitWidth);
80+
pos = unpackIntsWithBytePacker(vectorsData, pos, deltasBuffer, vectorLen, bitWidth);
81+
} else {
82+
java.util.Arrays.fill(deltasBuffer, 0, vectorLen, 0);
7583
}
7684

77-
// Reverse the frame-of-reference subtraction, then decimal-decode
7885
for (int i = 0; i < vectorLen; i++) {
79-
int encoded = deltas[i] + frameOfReference;
86+
int encoded = deltasBuffer[i] + frameOfReference;
8087
decodedValues[i] = AlpEncoderDecoder.decodeFloat(encoded, exponent, factor);
8188
}
8289

8390
// Overwrite exception slots with their original float values
8491
if (numExceptions > 0) {
85-
int[] excPositions = new int[numExceptions];
8692
for (int e = 0; e < numExceptions; e++) {
87-
excPositions[e] = getShortLE(vectorsData, pos) & 0xFFFF;
93+
excPositionsBuffer[e] = getShortLE(vectorsData, pos) & 0xFFFF;
8894
pos += Short.BYTES;
8995
}
9096
for (int e = 0; e < numExceptions; e++) {
91-
decodedValues[excPositions[e]] = getFloatLE(vectorsData, pos);
97+
decodedValues[excPositionsBuffer[e]] = getFloatLE(vectorsData, pos);
9298
pos += Float.BYTES;
9399
}
94100
}
@@ -110,14 +116,15 @@ private int unpackIntsWithBytePacker(ByteBuffer buf, int pos, int[] output, int
110116
int alreadyRead = numFullGroups * bitWidth;
111117
int partialBytes = totalPackedBytes - alreadyRead;
112118

113-
byte[] padded = new byte[bitWidth];
114119
for (int i = 0; i < partialBytes; i++) {
115-
padded[i] = buf.get(pos + i);
120+
unpackByteBuf[i] = buf.get(pos + i);
121+
}
122+
for (int i = partialBytes; i < bitWidth; i++) {
123+
unpackByteBuf[i] = 0;
116124
}
117125

118-
int[] temp = new int[8];
119-
packer.unpack8Values(padded, 0, temp, 0);
120-
System.arraycopy(temp, 0, output, numFullGroups * 8, remaining);
126+
packer.unpack8Values(unpackByteBuf, 0, unpackPadBuf, 0);
127+
System.arraycopy(unpackPadBuf, 0, output, numFullGroups * 8, remaining);
121128
pos += partialBytes;
122129
}
123130

0 commit comments

Comments
 (0)