|
| 1 | +# ===== pivot_longer ===== |
| 2 | +# Transforms wide data to long format by pivoting named columns into key/value rows. |
| 3 | + |
| 4 | +struct EnumerablePivotLonger{T, S, COLS, ID_COLS} <: Enumerable |
| 5 | + source::S |
| 6 | +end |
| 7 | + |
| 8 | +Base.IteratorSize(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = haslength(S) |
| 9 | + |
| 10 | +Base.eltype(::Type{EnumerablePivotLonger{T,S,COLS,ID_COLS}}) where {T,S,COLS,ID_COLS} = T |
| 11 | + |
| 12 | +Base.length(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS} = |
| 13 | + length(iter.source) * length(COLS) |
| 14 | + |
| 15 | +function pivot_longer(source::Enumerable, cols::NTuple{N,Symbol}; |
| 16 | + names_to::Symbol=:variable, values_to::Symbol=:value) where N |
| 17 | + N == 0 && error("pivot_longer requires at least one column to pivot") |
| 18 | + TS = eltype(source) |
| 19 | + all_fields = fieldnames(TS) |
| 20 | + id_cols = tuple((f for f in all_fields if f ∉ cols)...) |
| 21 | + |
| 22 | + value_type = reduce(promote_type, (fieldtype(TS, c) for c in cols)) |
| 23 | + |
| 24 | + out_names = (id_cols..., names_to, values_to) |
| 25 | + out_types = Tuple{(fieldtype(TS, f) for f in id_cols)..., Symbol, value_type} |
| 26 | + T = NamedTuple{out_names, out_types} |
| 27 | + |
| 28 | + return EnumerablePivotLonger{T, typeof(source), cols, id_cols}(source) |
| 29 | +end |
| 30 | + |
| 31 | +# Type-stable row construction: generates a static if/elseif chain over COLS at compile time. |
| 32 | +@generated function _pivot_longer_row(row::NamedTuple, ::Val{ID_COLS}, col_idx::Int, |
| 33 | + ::Val{COLS}, ::Type{T}) where {ID_COLS, COLS, T} |
| 34 | + id_exprs = [:(getfield(row, $(QuoteNode(f)))) for f in ID_COLS] |
| 35 | + value_type = fieldtype(T, length(ID_COLS) + 2) # fields: id..., names_to, values_to |
| 36 | + |
| 37 | + branches = Expr[] |
| 38 | + for (i, col) in enumerate(COLS) |
| 39 | + push!(branches, quote |
| 40 | + if col_idx == $i |
| 41 | + return T(($(id_exprs...), $(QuoteNode(col)), |
| 42 | + convert($value_type, getfield(row, $(QuoteNode(col)))))) |
| 43 | + end |
| 44 | + end) |
| 45 | + end |
| 46 | + |
| 47 | + return quote |
| 48 | + $(branches...) |
| 49 | + error("pivot_longer: col_idx out of range") |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}) where {T,S,COLS,ID_COLS} |
| 54 | + source_ret = iterate(iter.source) |
| 55 | + source_ret === nothing && return nothing |
| 56 | + row, source_state = source_ret |
| 57 | + out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T) |
| 58 | + return out, (row=row, source_state=source_state, col_idx=1) |
| 59 | +end |
| 60 | + |
| 61 | +function Base.iterate(iter::EnumerablePivotLonger{T,S,COLS,ID_COLS}, state) where {T,S,COLS,ID_COLS} |
| 62 | + next_idx = state.col_idx + 1 |
| 63 | + if next_idx <= length(COLS) |
| 64 | + out = _pivot_longer_row(state.row, Val(ID_COLS), next_idx, Val(COLS), T) |
| 65 | + return out, (row=state.row, source_state=state.source_state, col_idx=next_idx) |
| 66 | + else |
| 67 | + source_ret = iterate(iter.source, state.source_state) |
| 68 | + source_ret === nothing && return nothing |
| 69 | + row, source_state = source_ret |
| 70 | + out = _pivot_longer_row(row, Val(ID_COLS), 1, Val(COLS), T) |
| 71 | + return out, (row=row, source_state=source_state, col_idx=1) |
| 72 | + end |
| 73 | +end |
| 74 | + |
| 75 | +# ===== pivot_wider ===== |
| 76 | +# Transforms long data to wide format by spreading a key column into multiple value columns. |
| 77 | + |
| 78 | +struct EnumerablePivotWider{T} <: Enumerable |
| 79 | + results::Vector{T} |
| 80 | +end |
| 81 | + |
| 82 | +Base.IteratorSize(::Type{EnumerablePivotWider{T}}) where T = Base.HasLength() |
| 83 | + |
| 84 | +Base.eltype(::Type{EnumerablePivotWider{T}}) where T = T |
| 85 | + |
| 86 | +Base.length(iter::EnumerablePivotWider{T}) where T = length(iter.results) |
| 87 | + |
| 88 | +function pivot_wider(source::Enumerable, names_from::Symbol, values_from::Symbol; |
| 89 | + id_cols=nothing) |
| 90 | + TS = eltype(source) |
| 91 | + all_fields = fieldnames(TS) |
| 92 | + |
| 93 | + id_col_names = if id_cols === nothing |
| 94 | + tuple((f for f in all_fields if f != names_from && f != values_from)...) |
| 95 | + else |
| 96 | + tuple(id_cols...) |
| 97 | + end |
| 98 | + |
| 99 | + val_type = fieldtype(TS, values_from) |
| 100 | + out_val_type = DataValues.DataValue{val_type} |
| 101 | + |
| 102 | + all_rows = collect(source) |
| 103 | + |
| 104 | + # Collect unique name values in order of first appearance |
| 105 | + seen_names = OrderedDict{Symbol, Nothing}() |
| 106 | + for row in all_rows |
| 107 | + seen_names[Symbol(getfield(row, names_from))] = nothing |
| 108 | + end |
| 109 | + new_col_names = tuple(keys(seen_names)...) |
| 110 | + |
| 111 | + out_names = (id_col_names..., new_col_names...) |
| 112 | + out_types = Tuple{(fieldtype(TS, f) for f in id_col_names)..., |
| 113 | + (out_val_type for _ in new_col_names)...} |
| 114 | + T = NamedTuple{out_names, out_types} |
| 115 | + |
| 116 | + # Group rows by their id-column values |
| 117 | + id_to_values = OrderedDict{Any, Dict{Symbol, val_type}}() |
| 118 | + for row in all_rows |
| 119 | + id_key = ntuple(i -> getfield(row, id_col_names[i]), length(id_col_names)) |
| 120 | + name_sym = Symbol(getfield(row, names_from)) |
| 121 | + value = getfield(row, values_from) |
| 122 | + if !haskey(id_to_values, id_key) |
| 123 | + id_to_values[id_key] = Dict{Symbol, val_type}() |
| 124 | + end |
| 125 | + id_to_values[id_key][name_sym] = value |
| 126 | + end |
| 127 | + |
| 128 | + na = out_val_type() |
| 129 | + results = Vector{T}(undef, length(id_to_values)) |
| 130 | + for (i, (id_key, vals_dict)) in enumerate(id_to_values) |
| 131 | + new_vals = ntuple( |
| 132 | + j -> haskey(vals_dict, new_col_names[j]) ? |
| 133 | + out_val_type(vals_dict[new_col_names[j]]) : na, |
| 134 | + length(new_col_names)) |
| 135 | + results[i] = T((id_key..., new_vals...)) |
| 136 | + end |
| 137 | + |
| 138 | + return EnumerablePivotWider{T}(results) |
| 139 | +end |
| 140 | + |
| 141 | +# ===== Column selector resolution ===== |
| 142 | +# Resolves a tuple of column names from a NamedTuple type according to a list of |
| 143 | +# selector instructions (encoded in a Val type parameter for compile-time evaluation). |
| 144 | +# |
| 145 | +# Each instruction is a 2-tuple (op, arg): |
| 146 | +# (:include_name, sym) — include field by name |
| 147 | +# (:exclude_name, sym) — exclude field by name |
| 148 | +# (:include_position, idx) — include field at 1-based position |
| 149 | +# (:exclude_position, idx) — exclude field at 1-based position |
| 150 | +# (:include_startswith, prefix_sym) — include fields whose name starts with prefix |
| 151 | +# (:exclude_startswith, prefix_sym) — exclude fields whose name starts with prefix |
| 152 | +# (:include_endswith, suffix_sym) — include fields whose name ends with suffix |
| 153 | +# (:exclude_endswith, suffix_sym) — exclude fields whose name ends with suffix |
| 154 | +# (:include_occursin, sub_sym) — include fields whose name contains sub |
| 155 | +# (:exclude_occursin, sub_sym) — exclude fields whose name contains sub |
| 156 | +# (:include_range, (from, to)) — include field names from :from to :to (inclusive) |
| 157 | +# (:include_range_idx, (a, b)) — include fields at 1-based positions a through b |
| 158 | +# (:include_all, :_) — include all remaining fields |
| 159 | +# |
| 160 | +# If all instructions are "exclude" ops, the starting set is ALL field names; otherwise |
| 161 | +# the starting set is empty and "include" instructions accumulate into it. |
| 162 | +@generated function _resolve_pivot_cols(::Type{NT}, ::Val{instructions}) where {NT<:NamedTuple, instructions} |
| 163 | + all_names = collect(fieldnames(NT)) |
| 164 | + |
| 165 | + include_ops = (:include_name, :include_position, :include_startswith, |
| 166 | + :include_endswith, :include_occursin, :include_all, |
| 167 | + :include_range, :include_range_idx) |
| 168 | + has_positive = any(inst[1] ∈ include_ops for inst in instructions) |
| 169 | + |
| 170 | + result = has_positive ? Symbol[] : copy(all_names) |
| 171 | + |
| 172 | + for inst in instructions |
| 173 | + op = inst[1] |
| 174 | + arg = inst[2] |
| 175 | + |
| 176 | + if op == :include_all |
| 177 | + for n in all_names |
| 178 | + n ∉ result && push!(result, n) |
| 179 | + end |
| 180 | + elseif op == :include_name |
| 181 | + arg ∉ result && push!(result, arg) |
| 182 | + elseif op == :exclude_name |
| 183 | + filter!(!=( arg), result) |
| 184 | + elseif op == :include_position |
| 185 | + n = all_names[arg] |
| 186 | + n ∉ result && push!(result, n) |
| 187 | + elseif op == :exclude_position |
| 188 | + n = all_names[arg] |
| 189 | + filter!(!=(n), result) |
| 190 | + elseif op == :include_startswith |
| 191 | + for n in all_names |
| 192 | + if Base.startswith(String(n), String(arg)) && n ∉ result |
| 193 | + push!(result, n) |
| 194 | + end |
| 195 | + end |
| 196 | + elseif op == :exclude_startswith |
| 197 | + filter!(n -> !Base.startswith(String(n), String(arg)), result) |
| 198 | + elseif op == :include_endswith |
| 199 | + for n in all_names |
| 200 | + if Base.endswith(String(n), String(arg)) && n ∉ result |
| 201 | + push!(result, n) |
| 202 | + end |
| 203 | + end |
| 204 | + elseif op == :exclude_endswith |
| 205 | + filter!(n -> !Base.endswith(String(n), String(arg)), result) |
| 206 | + elseif op == :include_occursin |
| 207 | + for n in all_names |
| 208 | + if Base.occursin(String(arg), String(n)) && n ∉ result |
| 209 | + push!(result, n) |
| 210 | + end |
| 211 | + end |
| 212 | + elseif op == :exclude_occursin |
| 213 | + filter!(n -> !Base.occursin(String(arg), String(n)), result) |
| 214 | + elseif op == :include_range |
| 215 | + from_sym, to_sym = arg |
| 216 | + in_range = false |
| 217 | + for n in all_names |
| 218 | + n == from_sym && (in_range = true) |
| 219 | + in_range && n ∉ result && push!(result, n) |
| 220 | + n == to_sym && (in_range = false; break) |
| 221 | + end |
| 222 | + elseif op == :include_range_idx |
| 223 | + from_idx, to_idx = arg |
| 224 | + for i in from_idx:to_idx |
| 225 | + n = all_names[i] |
| 226 | + n ∉ result && push!(result, n) |
| 227 | + end |
| 228 | + end |
| 229 | + end |
| 230 | + |
| 231 | + names = tuple(result...) |
| 232 | + return :($names) |
| 233 | +end |
| 234 | + |
| 235 | +function Base.iterate(iter::EnumerablePivotWider{T}) where T |
| 236 | + isempty(iter.results) && return nothing |
| 237 | + return iter.results[1], 2 |
| 238 | +end |
| 239 | + |
| 240 | +function Base.iterate(iter::EnumerablePivotWider{T}, state) where T |
| 241 | + state > length(iter.results) && return nothing |
| 242 | + return iter.results[state], state + 1 |
| 243 | +end |
0 commit comments