Skip to content

Commit f59819e

Browse files
committed
Add pivot_longer and pivot_wider
1 parent 8a94412 commit f59819e

5 files changed

Lines changed: 376 additions & 0 deletions

File tree

src/QueryOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include("enumerable/enumerable_count.jl")
2323
include("enumerable/enumerable_take.jl")
2424
include("enumerable/enumerable_drop.jl")
2525
include("enumerable/enumerable_unique.jl")
26+
include("enumerable/enumerable_pivot.jl")
2627
include("enumerable/show.jl")
2728

2829
include("source_iterable.jl")

src/enumerable/enumerable_pivot.jl

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# ===== pivot_longer =====
2+
# Transforms wide data to long format by pivoting named columns into key/value rows.
3+
4+
struct EnumerablePivotLonger{T, S, COLS, ID_COLS} <: Enumerable
5+
source::S
6+
end
7+
8+
Base.IteratorSize(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = haslength(S)
9+
10+
Base.eltype(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = T
11+
12+
Base.length(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS} =
13+
length(iter.source) * length(COLS)
14+
15+
function pivot_longer(source::Enumerable, cols::NTuple{N,Symbol};
16+
names_to::Symbol=:variable, values_to::Symbol=:value) where N
17+
N == 0 && error("pivot_longer requires at least one column to pivot")
18+
TS = eltype(source)
19+
all_fields = fieldnames(TS)
20+
id_cols = tuple((f for f in all_fields if f cols)...)
21+
22+
value_type = reduce(promote_type, (fieldtype(TS, c) for c in cols))
23+
24+
out_names = (id_cols..., names_to, values_to)
25+
out_types = Tuple{(fieldtype(TS, f) for f in id_cols)..., Symbol, value_type}
26+
T = NamedTuple{out_names, out_types}
27+
28+
return EnumerablePivotLonger{T, typeof(source), cols, id_cols}(source)
29+
end
30+
31+
# Type-stable row construction: generates a static if/elseif chain over COLS at compile time.
32+
@generated function _pivot_longer_row(row::NamedTuple, ::Val{ID_COLS}, col_idx::Int,
33+
::Val{COLS}, ::Type{T}) where {ID_COLS, COLS, T}
34+
id_exprs = [:(getfield(row, $(QuoteNode(f)))) for f in ID_COLS]
35+
value_type = fieldtype(T, length(ID_COLS) + 2) # fields: id..., names_to, values_to
36+
37+
branches = Expr[]
38+
for (i, col) in enumerate(COLS)
39+
push!(branches, quote
40+
if col_idx == $i
41+
return T(($(id_exprs...), $(QuoteNode(col)),
42+
convert($value_type, getfield(row, $(QuoteNode(col))))))
43+
end
44+
end)
45+
end
46+
47+
return quote
48+
$(branches...)
49+
error("pivot_longer: col_idx out of range")
50+
end
51+
end
52+
53+
function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS}
54+
source_ret = iterate(iter.source)
55+
source_ret === nothing && return nothing
56+
row, source_state = source_ret
57+
out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T)
58+
return out, (row=row, source_state=source_state, col_idx=1)
59+
end
60+
61+
function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}, state) where {T,S,COLS,ID_COLS}
62+
next_idx = state.col_idx + 1
63+
if next_idx <= length(COLS)
64+
out = _pivot_longer_row(state.row, Val(ID_COLS), next_idx, Val(COLS), T)
65+
return out, (row=state.row, source_state=state.source_state, col_idx=next_idx)
66+
else
67+
source_ret = iterate(iter.source, state.source_state)
68+
source_ret === nothing && return nothing
69+
row, source_state = source_ret
70+
out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T)
71+
return out, (row=row, source_state=source_state, col_idx=1)
72+
end
73+
end
74+
75+
# ===== pivot_wider =====
76+
# Transforms long data to wide format by spreading a key column into multiple value columns.
77+
78+
struct EnumerablePivotWider{T} <: Enumerable
79+
results::Vector{T}
80+
end
81+
82+
Base.IteratorSize(::Type{EnumerablePivotWider{T}}) where T = Base.HasLength()
83+
84+
Base.eltype(::Type{EnumerablePivotWider{T}}) where T = T
85+
86+
Base.length(iter::EnumerablePivotWider{T}) where T = length(iter.results)
87+
88+
function pivot_wider(source::Enumerable, names_from::Symbol, values_from::Symbol;
89+
id_cols=nothing)
90+
TS = eltype(source)
91+
all_fields = fieldnames(TS)
92+
93+
id_col_names = if id_cols === nothing
94+
tuple((f for f in all_fields if f != names_from && f != values_from)...)
95+
else
96+
tuple(id_cols...)
97+
end
98+
99+
val_type = fieldtype(TS, values_from)
100+
out_val_type = DataValues.DataValue{val_type}
101+
102+
all_rows = collect(source)
103+
104+
# Collect unique name values in order of first appearance
105+
seen_names = OrderedDict{Symbol, Nothing}()
106+
for row in all_rows
107+
seen_names[Symbol(getfield(row, names_from))] = nothing
108+
end
109+
new_col_names = tuple(keys(seen_names)...)
110+
111+
out_names = (id_col_names..., new_col_names...)
112+
out_types = Tuple{(fieldtype(TS, f) for f in id_col_names)...,
113+
(out_val_type for _ in new_col_names)...}
114+
T = NamedTuple{out_names, out_types}
115+
116+
# Group rows by their id-column values
117+
id_to_values = OrderedDict{Any, Dict{Symbol, val_type}}()
118+
for row in all_rows
119+
id_key = ntuple(i -> getfield(row, id_col_names[i]), length(id_col_names))
120+
name_sym = Symbol(getfield(row, names_from))
121+
value = getfield(row, values_from)
122+
if !haskey(id_to_values, id_key)
123+
id_to_values[id_key] = Dict{Symbol, val_type}()
124+
end
125+
id_to_values[id_key][name_sym] = value
126+
end
127+
128+
na = out_val_type()
129+
results = Vector{T}(undef, length(id_to_values))
130+
for (i, (id_key, vals_dict)) in enumerate(id_to_values)
131+
new_vals = ntuple(
132+
j -> haskey(vals_dict, new_col_names[j]) ?
133+
out_val_type(vals_dict[new_col_names[j]]) : na,
134+
length(new_col_names))
135+
results[i] = T((id_key..., new_vals...))
136+
end
137+
138+
return EnumerablePivotWider{T}(results)
139+
end
140+
141+
# ===== Column selector resolution =====
142+
# Resolves a tuple of column names from a NamedTuple type according to a list of
143+
# selector instructions (encoded in a Val type parameter for compile-time evaluation).
144+
#
145+
# Each instruction is a 2-tuple (op, arg):
146+
# (:include_name, sym) — include field by name
147+
# (:exclude_name, sym) — exclude field by name
148+
# (:include_position, idx) — include field at 1-based position
149+
# (:exclude_position, idx) — exclude field at 1-based position
150+
# (:include_startswith, prefix_sym) — include fields whose name starts with prefix
151+
# (:exclude_startswith, prefix_sym) — exclude fields whose name starts with prefix
152+
# (:include_endswith, suffix_sym) — include fields whose name ends with suffix
153+
# (:exclude_endswith, suffix_sym) — exclude fields whose name ends with suffix
154+
# (:include_occursin, sub_sym) — include fields whose name contains sub
155+
# (:exclude_occursin, sub_sym) — exclude fields whose name contains sub
156+
# (:include_range, (from, to)) — include field names from :from to :to (inclusive)
157+
# (:include_range_idx, (a, b)) — include fields at 1-based positions a through b
158+
# (:include_all, :_) — include all remaining fields
159+
#
160+
# If all instructions are "exclude" ops, the starting set is ALL field names; otherwise
161+
# the starting set is empty and "include" instructions accumulate into it.
162+
@generated function _resolve_pivot_cols(::Type{NT}, ::Val{instructions}) where {NT<:NamedTuple, instructions}
163+
all_names = collect(fieldnames(NT))
164+
165+
include_ops = (:include_name, :include_position, :include_startswith,
166+
:include_endswith, :include_occursin, :include_all,
167+
:include_range, :include_range_idx)
168+
has_positive = any(inst[1] include_ops for inst in instructions)
169+
170+
result = has_positive ? Symbol[] : copy(all_names)
171+
172+
for inst in instructions
173+
op = inst[1]
174+
arg = inst[2]
175+
176+
if op == :include_all
177+
for n in all_names
178+
n result && push!(result, n)
179+
end
180+
elseif op == :include_name
181+
arg result && push!(result, arg)
182+
elseif op == :exclude_name
183+
filter!(!=( arg), result)
184+
elseif op == :include_position
185+
n = all_names[arg]
186+
n result && push!(result, n)
187+
elseif op == :exclude_position
188+
n = all_names[arg]
189+
filter!(!=(n), result)
190+
elseif op == :include_startswith
191+
for n in all_names
192+
if Base.startswith(String(n), String(arg)) && n result
193+
push!(result, n)
194+
end
195+
end
196+
elseif op == :exclude_startswith
197+
filter!(n -> !Base.startswith(String(n), String(arg)), result)
198+
elseif op == :include_endswith
199+
for n in all_names
200+
if Base.endswith(String(n), String(arg)) && n result
201+
push!(result, n)
202+
end
203+
end
204+
elseif op == :exclude_endswith
205+
filter!(n -> !Base.endswith(String(n), String(arg)), result)
206+
elseif op == :include_occursin
207+
for n in all_names
208+
if Base.occursin(String(arg), String(n)) && n result
209+
push!(result, n)
210+
end
211+
end
212+
elseif op == :exclude_occursin
213+
filter!(n -> !Base.occursin(String(arg), String(n)), result)
214+
elseif op == :include_range
215+
from_sym, to_sym = arg
216+
in_range = false
217+
for n in all_names
218+
n == from_sym && (in_range = true)
219+
in_range && n result && push!(result, n)
220+
n == to_sym && (in_range = false; break)
221+
end
222+
elseif op == :include_range_idx
223+
from_idx, to_idx = arg
224+
for i in from_idx:to_idx
225+
n = all_names[i]
226+
n result && push!(result, n)
227+
end
228+
end
229+
end
230+
231+
names = tuple(result...)
232+
return :($names)
233+
end
234+
235+
function Base.iterate(iter::EnumerablePivotWider{T}) where T
236+
isempty(iter.results) && return nothing
237+
return iter.results[1], 2
238+
end
239+
240+
function Base.iterate(iter::EnumerablePivotWider{T}, state) where T
241+
state > length(iter.results) && return nothing
242+
return iter.results[state], state + 1
243+
end

src/operators.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,7 @@ macro unique(source, f)
123123
q = Expr(:quote, f)
124124
:(unique($(esc(source)), $(esc(f)), $(esc(q))))
125125
end
126+
127+
function pivot_longer end
128+
129+
function pivot_wider end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,6 @@ include("test_enumerable_unique.jl")
156156

157157
include("test_namedtupleutilities.jl")
158158

159+
include("test_pivot.jl")
160+
159161
end

test/test_pivot.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
@testset "pivot_longer" begin
2+
3+
# Basic: pivot all columns
4+
data = QueryOperators.query([(US=1, EU=2, CN=3), (US=4, EU=5, CN=6)])
5+
result = QueryOperators.pivot_longer(data, (:US, :EU, :CN)) |> collect
6+
7+
@test length(result) == 6
8+
@test eltype(result) == NamedTuple{(:variable, :value), Tuple{Symbol, Int64}}
9+
@test result[1] == (variable=:US, value=1)
10+
@test result[2] == (variable=:EU, value=2)
11+
@test result[3] == (variable=:CN, value=3)
12+
@test result[4] == (variable=:US, value=4)
13+
@test result[5] == (variable=:EU, value=5)
14+
@test result[6] == (variable=:CN, value=6)
15+
16+
# With id columns retained
17+
data2 = QueryOperators.query([(year=2017, US=1, EU=2), (year=2018, US=3, EU=4)])
18+
result2 = QueryOperators.pivot_longer(data2, (:US, :EU)) |> collect
19+
20+
@test length(result2) == 4
21+
@test eltype(result2) == NamedTuple{(:year, :variable, :value), Tuple{Int64, Symbol, Int64}}
22+
@test result2[1] == (year=2017, variable=:US, value=1)
23+
@test result2[2] == (year=2017, variable=:EU, value=2)
24+
@test result2[3] == (year=2018, variable=:US, value=3)
25+
@test result2[4] == (year=2018, variable=:EU, value=4)
26+
27+
# Custom names_to and values_to
28+
result3 = QueryOperators.pivot_longer(data2, (:US, :EU); names_to=:country, values_to=:sales) |> collect
29+
30+
@test eltype(result3) == NamedTuple{(:year, :country, :sales), Tuple{Int64, Symbol, Int64}}
31+
@test result3[1] == (year=2017, country=:US, sales=1)
32+
33+
# Type promotion: mixing Int and Float
34+
data3 = QueryOperators.query([(id=1, a=1, b=2.0)])
35+
result4 = QueryOperators.pivot_longer(data3, (:a, :b)) |> collect
36+
37+
@test eltype(result4) == NamedTuple{(:id, :variable, :value), Tuple{Int64, Symbol, Float64}}
38+
@test result4[1] == (id=1, variable=:a, value=1.0)
39+
@test result4[2] == (id=1, variable=:b, value=2.0)
40+
41+
# Type stability
42+
@test Base.return_types(iterate, (QueryOperators.EnumerablePivotLonger,)) |> only <:
43+
Union{Nothing, Tuple}
44+
45+
end
46+
47+
@testset "_resolve_pivot_cols" begin
48+
NT = NamedTuple{(:year, :wk1, :wk2, :total), Tuple{Int,Int,Int,Int}}
49+
50+
# Include by name
51+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_name, :wk1), (:include_name, :wk2)))) == (:wk1, :wk2)
52+
53+
# Include by startswith
54+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_startswith, :wk),))) == (:wk1, :wk2)
55+
56+
# Include by endswith
57+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_endswith, Symbol("1")),))) == (:wk1,)
58+
59+
# Include by occursin
60+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_occursin, :wk),))) == (:wk1, :wk2)
61+
62+
# Exclude by name from all (only-negative → starts from all)
63+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:exclude_name, :year), (:exclude_name, :total)))) == (:wk1, :wk2)
64+
65+
# Exclude by startswith from all
66+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:exclude_startswith, :wk),))) == (:year, :total)
67+
68+
# Mix: include startswith then exclude one
69+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_startswith, :wk), (:exclude_name, :wk2)))) == (:wk1,)
70+
71+
# Include by position
72+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_position, 2), (:include_position, 3)))) == (:wk1, :wk2)
73+
74+
# Include range by index
75+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_range_idx, (2, 3)),))) == (:wk1, :wk2)
76+
77+
# Include range by name
78+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_range, (:wk1, :wk2)),))) == (:wk1, :wk2)
79+
80+
# include_all adds everything not yet in set
81+
@test QueryOperators._resolve_pivot_cols(NT, Val(((:include_name, :year), (:include_all, :_), (:exclude_name, :total)))) == (:year, :wk1, :wk2)
82+
end
83+
84+
@testset "pivot_wider" begin
85+
86+
# Basic: long to wide
87+
data = QueryOperators.query([
88+
(year=2017, country=:US, value=1),
89+
(year=2017, country=:EU, value=2),
90+
(year=2018, country=:US, value=3),
91+
(year=2018, country=:EU, value=4),
92+
])
93+
result = QueryOperators.pivot_wider(data, :country, :value) |> collect
94+
95+
@test length(result) == 2
96+
T = eltype(result)
97+
@test fieldnames(T) == (:year, :US, :EU)
98+
@test fieldtype(T, :US) == DataValues.DataValue{Int64}
99+
@test result[1].year == 2017
100+
@test result[1].US == DataValues.DataValue(1)
101+
@test result[1].EU == DataValues.DataValue(2)
102+
@test result[2].year == 2018
103+
@test result[2].US == DataValues.DataValue(3)
104+
@test result[2].EU == DataValues.DataValue(4)
105+
106+
# Absent combinations become NA DataValues
107+
data2 = QueryOperators.query([
108+
(year=2017, country=:US, value=1),
109+
(year=2017, country=:EU, value=2),
110+
(year=2018, country=:US, value=3),
111+
# year=2018, country=:EU is absent
112+
])
113+
result2 = QueryOperators.pivot_wider(data2, :country, :value) |> collect
114+
115+
@test length(result2) == 2
116+
@test result2[1].year == 2017
117+
@test result2[1].US == DataValues.DataValue(1)
118+
@test result2[1].EU == DataValues.DataValue(2)
119+
@test result2[2].US == DataValues.DataValue(3)
120+
@test DataValues.isna(result2[2].EU)
121+
122+
# Length is known
123+
wide = QueryOperators.pivot_wider(data, :country, :value)
124+
@test length(wide) == 2
125+
126+
end

0 commit comments

Comments
 (0)