Skip to content

Commit 014dc33

Browse files
authored
Merge pull request #48 from queryverse/pivot
Add pivot longer and wider
2 parents accbb28 + f3e4436 commit 014dc33

9 files changed

Lines changed: 387 additions & 99 deletions

File tree

docs/make.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using Documenter, QueryOperators
22

3+
# Configure DocMeta to automatically import QueryOperators for all doctests
4+
DocMeta.setdocmeta!(QueryOperators, :DocTestSetup, :(using QueryOperators); recursive=true)
5+
36
makedocs(
47
modules = [QueryOperators],
58
sitename = "QueryOperators.jl",
@@ -8,7 +11,8 @@ makedocs(
811
),
912
pages = [
1013
"Introduction" => "index.md"
11-
]
14+
],
15+
warnonly = [:missing_docs]
1216
)
1317

1418
deploydocs(

src/NamedTupleUtilities.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ end
185185
oftype(a::NamedTuple, b::DataType)
186186
Returns a NamedTuple which retains the fields whose elements have type `b`.
187187
```jldoctest
188-
julia> QueryOperators.NamedTupleUtilities.oftype((a = [4,5,6], b = [3.,2.,1.], c = ["He","llo","World!"]), Val(Int64))
189-
(a = [4, 5, 6],)
190-
julia> QueryOperators.NamedTupleUtilities.oftype((a = [4,5,6], b = [3.,2.,1.], c = ["He","llo","World!"]), Val(Number))
191-
(a = [4, 5, 6], b = [3., 2., 1.])
192-
julia> QueryOperators.NamedTupleUtilities.oftype((a = [4,5,6], b = [3.,2.,1.], c = ["He","llo","World!"]), Val(Float32))
188+
julia> QueryOperators.NamedTupleUtilities.oftype((a = 4, b = 3., c = "He"), Val(Int64))
189+
(a = 4,)
190+
julia> QueryOperators.NamedTupleUtilities.oftype((a = 4, b = 3., c = "He"), Val(Number))
191+
(a = 4, b = 3.0)
192+
julia> QueryOperators.NamedTupleUtilities.oftype((a = 4, b = 3., c = "He"), Val(Float32))
193193
NamedTuple()
194194
```
195195
"""

src/QueryOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ include("enumerable/enumerable_join.jl")
1616
include("enumerable/enumerable_groupjoin.jl")
1717
include("enumerable/enumerable_orderby.jl")
1818
include("enumerable/enumerable_map.jl")
19-
include("enumerable/enumerable_gather.jl")
2019
include("enumerable/enumerable_filter.jl")
2120
include("enumerable/enumerable_mapmany.jl")
2221
include("enumerable/enumerable_defaultifempty.jl")
2322
include("enumerable/enumerable_count.jl")
2423
include("enumerable/enumerable_take.jl")
2524
include("enumerable/enumerable_drop.jl")
2625
include("enumerable/enumerable_unique.jl")
26+
include("enumerable/enumerable_pivot.jl")
2727
include("enumerable/show.jl")
2828

2929
include("source_iterable.jl")

src/enumerable/enumerable_gather.jl

Lines changed: 0 additions & 81 deletions
This file was deleted.

src/enumerable/enumerable_pivot.jl

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# ===== pivot_longer =====
2+
# Transforms wide data to long format by pivoting named columns into key/value rows.
3+
4+
struct EnumerablePivotLonger{T, S, COLS, ID_COLS} <: Enumerable
5+
source::S
6+
end
7+
8+
Base.IteratorSize(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = haslength(S)
9+
10+
Base.eltype(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = T
11+
12+
Base.length(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS} =
13+
length(iter.source) * length(COLS)
14+
15+
function pivot_longer(source::Enumerable, cols::NTuple{N,Symbol};
16+
names_to::Symbol=:variable, values_to::Symbol=:value) where N
17+
N == 0 && error("pivot_longer requires at least one column to pivot")
18+
TS = eltype(source)
19+
all_fields = fieldnames(TS)
20+
id_cols = tuple((f for f in all_fields if f cols)...)
21+
22+
value_type = reduce(promote_type, (fieldtype(TS, c) for c in cols))
23+
24+
out_names = (id_cols..., names_to, values_to)
25+
out_types = Tuple{(fieldtype(TS, f) for f in id_cols)..., Symbol, value_type}
26+
T = NamedTuple{out_names, out_types}
27+
28+
return EnumerablePivotLonger{T, typeof(source), cols, id_cols}(source)
29+
end
30+
31+
# Type-stable row construction: generates a static if/elseif chain over COLS at compile time.
32+
@generated function _pivot_longer_row(row::NamedTuple, ::Val{ID_COLS}, col_idx::Int,
33+
::Val{COLS}, ::Type{T}) where {ID_COLS, COLS, T}
34+
id_exprs = [:(getfield(row, $(QuoteNode(f)))) for f in ID_COLS]
35+
value_type = fieldtype(T, length(ID_COLS) + 2) # fields: id..., names_to, values_to
36+
37+
branches = Expr[]
38+
for (i, col) in enumerate(COLS)
39+
push!(branches, quote
40+
if col_idx == $i
41+
return T(($(id_exprs...), $(QuoteNode(col)),
42+
convert($value_type, getfield(row, $(QuoteNode(col))))))
43+
end
44+
end)
45+
end
46+
47+
return quote
48+
$(branches...)
49+
error("pivot_longer: col_idx out of range")
50+
end
51+
end
52+
53+
function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS}
54+
source_ret = iterate(iter.source)
55+
source_ret === nothing && return nothing
56+
row, source_state = source_ret
57+
out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T)
58+
return out, (row=row, source_state=source_state, col_idx=1)
59+
end
60+
61+
function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}, state) where {T,S,COLS,ID_COLS}
62+
next_idx = state.col_idx + 1
63+
if next_idx <= length(COLS)
64+
out = _pivot_longer_row(state.row, Val(ID_COLS), next_idx, Val(COLS), T)
65+
return out, (row=state.row, source_state=state.source_state, col_idx=next_idx)
66+
else
67+
source_ret = iterate(iter.source, state.source_state)
68+
source_ret === nothing && return nothing
69+
row, source_state = source_ret
70+
out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T)
71+
return out, (row=row, source_state=source_state, col_idx=1)
72+
end
73+
end
74+
75+
# ===== pivot_wider =====
76+
# Transforms long data to wide format by spreading a key column into multiple value columns.
77+
78+
struct EnumerablePivotWider{T} <: Enumerable
79+
results::Vector{T}
80+
end
81+
82+
Base.IteratorSize(::Type{EnumerablePivotWider{T}}) where T = Base.HasLength()
83+
84+
Base.eltype(::Type{EnumerablePivotWider{T}}) where T = T
85+
86+
Base.length(iter::EnumerablePivotWider{T}) where T = length(iter.results)
87+
88+
function pivot_wider(source::Enumerable, names_from::Symbol, values_from::Symbol;
89+
id_cols=nothing)
90+
TS = eltype(source)
91+
all_fields = fieldnames(TS)
92+
93+
id_col_names = if id_cols === nothing
94+
tuple((f for f in all_fields if f != names_from && f != values_from)...)
95+
else
96+
tuple(id_cols...)
97+
end
98+
99+
val_type = fieldtype(TS, values_from)
100+
out_val_type = DataValues.DataValue{val_type}
101+
102+
all_rows = collect(source)
103+
104+
# Collect unique name values in order of first appearance
105+
seen_names = OrderedDict{Symbol, Nothing}()
106+
for row in all_rows
107+
seen_names[Symbol(getfield(row, names_from))] = nothing
108+
end
109+
new_col_names = tuple(keys(seen_names)...)
110+
111+
out_names = (id_col_names..., new_col_names...)
112+
out_types = Tuple{(fieldtype(TS, f) for f in id_col_names)...,
113+
(out_val_type for _ in new_col_names)...}
114+
T = NamedTuple{out_names, out_types}
115+
116+
# Group rows by their id-column values
117+
id_to_values = OrderedDict{Any, Dict{Symbol, val_type}}()
118+
for row in all_rows
119+
id_key = ntuple(i -> getfield(row, id_col_names[i]), length(id_col_names))
120+
name_sym = Symbol(getfield(row, names_from))
121+
value = getfield(row, values_from)
122+
if !haskey(id_to_values, id_key)
123+
id_to_values[id_key] = Dict{Symbol, val_type}()
124+
end
125+
id_to_values[id_key][name_sym] = value
126+
end
127+
128+
na = out_val_type()
129+
results = Vector{T}(undef, length(id_to_values))
130+
for (i, (id_key, vals_dict)) in enumerate(id_to_values)
131+
new_vals = ntuple(
132+
j -> haskey(vals_dict, new_col_names[j]) ?
133+
out_val_type(vals_dict[new_col_names[j]]) : na,
134+
length(new_col_names))
135+
results[i] = T((id_key..., new_vals...))
136+
end
137+
138+
return EnumerablePivotWider{T}(results)
139+
end
140+
141+
# ===== Column selector resolution =====
142+
# Resolves a tuple of column names from a NamedTuple type according to a list of
143+
# selector instructions (encoded in a Val type parameter for compile-time evaluation).
144+
#
145+
# Each instruction is a 2-tuple (op, arg):
146+
# (:include_name, sym) — include field by name
147+
# (:exclude_name, sym) — exclude field by name
148+
# (:include_position, idx) — include field at 1-based position
149+
# (:exclude_position, idx) — exclude field at 1-based position
150+
# (:include_startswith, prefix_sym) — include fields whose name starts with prefix
151+
# (:exclude_startswith, prefix_sym) — exclude fields whose name starts with prefix
152+
# (:include_endswith, suffix_sym) — include fields whose name ends with suffix
153+
# (:exclude_endswith, suffix_sym) — exclude fields whose name ends with suffix
154+
# (:include_occursin, sub_sym) — include fields whose name contains sub
155+
# (:exclude_occursin, sub_sym) — exclude fields whose name contains sub
156+
# (:include_range, (from, to)) — include field names from :from to :to (inclusive)
157+
# (:include_range_idx, (a, b)) — include fields at 1-based positions a through b
158+
# (:include_all, :_) — include all remaining fields
159+
#
160+
# If all instructions are "exclude" ops, the starting set is ALL field names; otherwise
161+
# the starting set is empty and "include" instructions accumulate into it.
162+
@generated function _resolve_pivot_cols(::Type{NT}, ::Val{instructions}) where {NT<:NamedTuple, instructions}
163+
all_names = collect(fieldnames(NT))
164+
165+
include_ops = (:include_name, :include_position, :include_startswith,
166+
:include_endswith, :include_occursin, :include_all,
167+
:include_range, :include_range_idx)
168+
has_positive = any(inst[1] include_ops for inst in instructions)
169+
170+
result = has_positive ? Symbol[] : copy(all_names)
171+
172+
for inst in instructions
173+
op = inst[1]
174+
arg = inst[2]
175+
176+
if op == :include_all
177+
for n in all_names
178+
n result && push!(result, n)
179+
end
180+
elseif op == :include_name
181+
arg result && push!(result, arg)
182+
elseif op == :exclude_name
183+
filter!(!=( arg), result)
184+
elseif op == :include_position
185+
n = all_names[arg]
186+
n result && push!(result, n)
187+
elseif op == :exclude_position
188+
n = all_names[arg]
189+
filter!(!=(n), result)
190+
elseif op == :include_startswith
191+
for n in all_names
192+
if Base.startswith(String(n), String(arg)) && n result
193+
push!(result, n)
194+
end
195+
end
196+
elseif op == :exclude_startswith
197+
filter!(n -> !Base.startswith(String(n), String(arg)), result)
198+
elseif op == :include_endswith
199+
for n in all_names
200+
if Base.endswith(String(n), String(arg)) && n result
201+
push!(result, n)
202+
end
203+
end
204+
elseif op == :exclude_endswith
205+
filter!(n -> !Base.endswith(String(n), String(arg)), result)
206+
elseif op == :include_occursin
207+
for n in all_names
208+
if Base.occursin(String(arg), String(n)) && n result
209+
push!(result, n)
210+
end
211+
end
212+
elseif op == :exclude_occursin
213+
filter!(n -> !Base.occursin(String(arg), String(n)), result)
214+
elseif op == :include_range
215+
from_sym, to_sym = arg
216+
in_range = false
217+
for n in all_names
218+
n == from_sym && (in_range = true)
219+
in_range && n result && push!(result, n)
220+
n == to_sym && (in_range = false; break)
221+
end
222+
elseif op == :include_range_idx
223+
from_idx, to_idx = arg
224+
for i in from_idx:to_idx
225+
n = all_names[i]
226+
n result && push!(result, n)
227+
end
228+
end
229+
end
230+
231+
names = tuple(result...)
232+
return :($names)
233+
end
234+
235+
function Base.iterate(iter::EnumerablePivotWider{T}) where T
236+
isempty(iter.results) && return nothing
237+
return iter.results[1], 2
238+
end
239+
240+
function Base.iterate(iter::EnumerablePivotWider{T}, state) where T
241+
state > length(iter.results) && return nothing
242+
return iter.results[state], state + 1
243+
end

src/operators.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ macro filter(source, f)
2626
:(QueryOperators.filter($(esc(source)), $(esc(f)), $(esc(q))))
2727
end
2828

29-
function gather end
30-
31-
macro gather(source, withIndex = False)
32-
:(groupby($(esc(source)), $(esc(withIndex))))
33-
end
34-
3529
function groupby end
3630

3731
macro groupby(source,elementSelector,resultSelector)
@@ -129,3 +123,7 @@ macro unique(source, f)
129123
q = Expr(:quote, f)
130124
:(unique($(esc(source)), $(esc(f)), $(esc(q))))
131125
end
126+
127+
function pivot_longer end
128+
129+
function pivot_wider end

0 commit comments

Comments
 (0)