Skip to content

Commit

Permalink
Merge pull request #11 from mrufsvold/refactor/compose-iters
Browse files Browse the repository at this point in the history
Refactor/compose-iters
  • Loading branch information
mrufsvold committed Mar 20, 2023
2 parents bcb4c64 + edadb7c commit cd25f4e
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 103 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
develop
*/Manifest.toml
runner.jl
archive
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExpandNestedData"
uuid = "8a7d223a-a7dc-4abf-8bc1-b0ce2ace9adc"
authors = ["Micah Rufsvold <mjrufsvold@protonmail.com>"]
version = "0.1.0"
version = "0.1.2"

[deps]
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Expand Down
19 changes: 12 additions & 7 deletions src/ConfiguredProcessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ for each column.
function expand(data, column_defs::ColumnDefs; lazy_columns::Bool = false, column_style::ColumnStyle=flat_columns)
# TODO we should parse the user's column definitions into a graph before processing
columns = process_node(data, column_defs)
return ExpandedTable(columns, column_defs, lazy_columns, column_style)
return ExpandedTable(columns, column_defs; lazy_columns = lazy_columns, column_style = column_style)
end


Expand All @@ -30,9 +30,10 @@ function process_node(::D, data, col_defs::ColumnDefs) where D <: NameValueConta
(names, names_with_children) = analyze_column_defs(col_defs)
columns = ColumnSet()
data_names = get_names(data)
multiplier = 1
multiplier_container = Ref{Int}(1)
for name in names
# This creates a view of configured columns to pass down
multiplier = multiplier_container[]
# This creates a copy of configured columns to pass down
child_col_defs = make_column_def_child_copies(col_defs, name)

# Get child columns in 1 of 3 cases:
Expand All @@ -46,16 +47,20 @@ function process_node(::D, data, col_defs::ColumnDefs) where D <: NameValueConta
col_def = first(child_col_defs)
new_column = NestedIterator(child_data;
flatten_arrays = flatten_arrays(col_def), default_value=default_value(col_def))
Dict([] => new_column)
columnset(new_column)
end
prepend_name!(child_columns, name)
child_columns
else
make_missing_column_set(child_col_defs, path_index(first(col_defs)))
end
repeat_each!.(values(child_columns), multiplier)
multiplier *= column_length(child_columns)
merge!(columns, child_columns)

match_len_child_cols = Dict(
key => repeat_each(col, multiplier)
for (key, col) in child_columns
)
multiplier_container[] = multiplier * column_length(match_len_child_cols)
merge!(columns, match_len_child_cols)
end
# catch up short columns with the total length for this group
cycle_columns_to_length!(columns)
Expand Down
145 changes: 87 additions & 58 deletions src/ExpandTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ function has_namevaluecontainer_element(itr)
return itr |> eltype |> get_member_types .|> is_NameValueContainer |> any
end
end
get_member_types(T) = T isa Union ? Base.uniontypes(T) : [T]
get_member_types(::Type{T}) where T = T isa Union ? Base.uniontypes(T) : [T]

"""Define a pairs iterator for all DataType structs"""
get_pairs(x::T) where T = get_pairs(StructTypes.StructType(T), x)
get_pairs(::StructTypes.DataType, x) = ((p, getproperty(x, p)) for p in fieldnames(typeof(x)))
get_pairs(::StructTypes.DataType, x::T) where T = ((p, getproperty(x, p)) for p in fieldnames(T))
get_pairs(::StructTypes.DictType, x) = pairs(x)

"""Get the keys/names of any NameValueContainer"""
get_names(x::T) where T = get_names(StructTypes.StructType(T), x)
get_names(::StructTypes.DataType, x) = (n for n in fieldnames(typeof(x)))
get_names(::StructTypes.DataType, x::T) where T = (n for n in fieldnames(T))
get_names(::StructTypes.DictType, x) = keys(x)

get_value(x::T, name) where T = get_value(StructTypes.StructType(T), x, name)
Expand All @@ -37,60 +37,82 @@ get_value(::StructTypes.DictType, x, name) = x[name]
##########################

"""NestedIterator is a container for instructions that build columns"""
mutable struct NestedIterator{T} <: AbstractArray{T, 1}
struct NestedIterator{T} <: AbstractArray{T, 1}
get_index::Function
column_length::Int64
unique_values::Set{T}
el_type::Type{T}
one_value::Bool
unique_val::Ref{T}
end
Base.length(ni::NestedIterator) = ni.column_length
Base.size(ni::NestedIterator) = (ni.column_length,)
Base.getindex(ni::NestedIterator, i) = ni.get_index(i)
Base.eachindex(ni::NestedIterator) = 1:length(ni)


Base.collect(x::NestedIterator, pool_arrays) = pool_arrays && !(x.unique_values isa Nothing) ? PooledArray(x) : Vector(x)
Base.collect(x::NestedIterator, pool_arrays) = pool_arrays ? PooledArray(x) : Vector(x)

abstract type InstructionCapture <: Function end

"""repeat_each!(c, N) will return an array where each source element appears N times in a row"""
function repeat_each!(c::NestedIterator, n)
# when there is only one unique value, we can skip composing the unrepeat_each step
if length(c.unique_values) != 1
c.get_index = c.get_index ((i) -> unrepeat_each(i, n))
end
c.column_length *= n
struct Seed{T} <: InstructionCapture
data::T
end
unrepeat_each(i, n) = ceil(Int64, i/n)
(s::Seed)(i) = s.data[i]

struct UnrepeatEach <: InstructionCapture
n::Int64
end
(u::UnrepeatEach)(i) = ceil(Int64, i/u.n)

"""cycle!(c, n) cycles through an array N times"""
function cycle!(c::NestedIterator, n)
"""repeat_each(c, N) will return an array where each source element appears N times in a row"""
function repeat_each(c::NestedIterator{T}, n) where T
# when there is only one unique value, we can skip composing the repeat_each step
return if c.one_value
NestedIterator(c.get_index, c.column_length * n, T, true, c.unique_val)
else
NestedIterator(c.get_index UnrepeatEach(n), c.column_length * n, T, false, c.unique_val)
end
end

struct Uncycle <: InstructionCapture
n::Int64
end
(u::Uncycle)(i) = mod((i-1),u.n) + 1
"""cycle(c, n) cycles through an array N times"""
function cycle(c::NestedIterator{T}, n) where T
# when there is only one unique value, we can skip composing the uncycle step
if length(c.unique_values) != 1
return if c.one_value && !(typeof(c.get_index) <: Seed)
NestedIterator(c.get_index, c.column_length * n, T, true, c.unique_val)
else
l = length(c)
c.get_index = c.get_index ((i::Int64) -> uncycle(i, l))
NestedIterator(c.get_index Uncycle(l), c.column_length * n, T, false, c.unique_val)
end
c.column_length *= n
end
uncycle(i,n) = mod((i-1),n) + 1


struct Unstack{F, G} <: InstructionCapture
f_len::Int64
f::F
g::G
end
(u::Unstack)(i) = i > u.f_len ? u.g(i-u.f_len) : u.f(i)

"""stack(c1::NestedIterator, c2::NestedIterator)
Return a single NestedIterator which is the result of vcat(c1,c2)
"""
function stack(c1::NestedIterator, c2::NestedIterator)
type = Union{eltype(c1), eltype(c2)}
function stack(c1::NestedIterator{T}, c2::NestedIterator{U}) where {T, U}
type = Union{T, U}
len = (c1,c2) .|> length |> sum

continue_tracking_uniques = 0 < length(c1.unique_values) < 100 &&
0 < length(c2.unique_values) < 100
values = continue_tracking_uniques ? union(c1.unique_values, c2.unique_values) : Set{type}([])

f = length(values) == 1 ?
c1.get_index :
((i::Int64) -> unstack(i, length(c1), c1.get_index, c2.get_index))

return NestedIterator{type}(f, len, values)
if T <: U
only_one_value = c1.one_value && c2.one_value && isequal(c1.unique_val[], c2.unique_val[])
if only_one_value
return NestedIterator(c1.get_index, len, type, true, c1.unique_val)
end
end
NestedIterator(Unstack(length(c1), c1.get_index, c2.get_index), len, type, false, Ref{type}())
end
unstack(i::Int64, c1_len::Int64, f1::Function, f2::Function) = i > c1_len ? f2(i-c1_len) : f1(i)



"""
Expand All @@ -101,40 +123,42 @@ data::Any: seed value
flatten_arrays::Bool: if data is an array, flatten_arrays==false will treat the array as a single value when
cycling the columns values
"""
function NestedIterator(data; flatten_arrays=false, total_length=nothing, default_value=missing)
@debug "creating new NestedIterator with dtype: $(typeof(data))"
value = if flatten_arrays && typeof(data) <: AbstractArray
function NestedIterator(data::T; flatten_arrays=false, total_length=nothing, default_value=missing) where T
value = if flatten_arrays && T <: AbstractArray
length(data) >= 1 ? data : [default_value]
else
[data]
end
len = length(value)
@debug "after ensuring value is wrapped in vector, we have the following number of elements $len"
len == 0 && @debug "data was $data"
type = eltype(value)
f = len == 1 ? ((::Int64) -> value[1]) : ((i::Int64) -> value[i])
ni = NestedIterator{type}(f, len, Set(value))
if !(total_length isa Nothing)
cycle!(ni, total_length)
ncycle = total_length isa Nothing ? 1 : total_length ÷ len
return _NestedIterator(value, len, ncycle)
end

function _NestedIterator(value::AbstractArray{T}, len::Int64, ncycle::Int64) where T
f = Seed(value)
is_one = len == 1
unique_val = Ref{T}()
if is_one
unique_val[] = first(value)::T
end
return ni
ni = NestedIterator{T}(f, len, T, is_one, unique_val)
return cycle(ni, ncycle)
end


function missing_column(default, len=1)
col = NestedIterator(default)
cycle!(col, len)
return col
end
missing_column(default, len=1) = return NestedIterator(default; total_length=len)


##### ColumnDefinition #####
############################

"""ColumnDefinition provides a mechanism for specifying details for extracting data from a nested data source"""
struct ColumnDefinition
# Path to values
field_path
# Index of current level TODO: should be removed and stored externally
path_index::Int64
# name of this column in the table once expanded
column_name::Symbol
flatten_arrays::Bool
default_value
Expand Down Expand Up @@ -196,6 +220,8 @@ function analyze_column_defs(col_defs::ColumnDefs)
return (unique_names, names_with_children)
end

# TODO: This is a huge source of unnecessary allocations. We should be storing level outside this struct
# and passing along the same defs without copying
function make_column_def_child_copies(column_defs::ColumnDefs, name)
return filter((def -> is_current_name(def, name)), column_defs) .|>
(def -> ColumnDefinition(
Expand Down Expand Up @@ -238,22 +264,23 @@ the input columns. i.e.
column_set_product!(
Dict(
[:a] => [1,2],
[:b] =? [3,4,5]
[:b] => [3,4,5]
)
)
returns
Dict(
[:a] => [1,1,1,2,2,2],
[:b] =? [3,4,5,3,4,5]
[:b] => [3,4,5,3,4,5]
)
"""
function column_set_product!(cols::ColumnSet)
multiplier = 1
for child_column in values(cols)
repeat_each!(child_column, multiplier)
for (key, child_column) in pairs(cols)
cols[key] = repeat_each(child_column, multiplier)
multiplier *= length(child_column)
end
cycle_columns_to_length!(cols)
cols = cycle_columns_to_length!(cols)
return cols
end


Expand All @@ -264,11 +291,13 @@ Given a column set where the length of all columns is some factor of the length
column, cycle all the short columns to match the length of the longest
"""
function cycle_columns_to_length!(cols::ColumnSet)
longest = cols |> values .|> length |> maximum
for child_column in values(cols)
catchup_mult = Int(longest / length(child_column))
cycle!(child_column, catchup_mult)
col_lengths = cols |> values .|> length
longest = col_lengths |> maximum
for (key, child_column) in pairs(cols)
catchup_mult = longest ÷ length(child_column)
cols[key] = cycle(child_column, catchup_mult)
end
return cols
end

"""Return a missing column for each member of a ColumnDefs"""
Expand Down Expand Up @@ -323,7 +352,7 @@ function make_path_nodes(column_defs)

children_col_defs = make_column_def_child_copies(matching_defs, unique_name)
if any(are_value_nodes)
throw(ArgumentError("The path name $unique_name refers a value in one branch and to nested child(ren): $(field_path.(children_names))"))
throw(ArgumentError("The path name $unique_name refers a value field in one branch and to nested child(ren) fields in another: $(field_path.(children_col_defs))"))
end
nodes[i] = PathNode(unique_name, make_path_nodes(children_col_defs))
end
Expand Down
31 changes: 13 additions & 18 deletions src/ExpandedTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,9 @@ using TypedTables

@enum ColumnStyle flat_columns nested_columns

"""
The functionality we want here is:

t = ExpandedTable(columns, column_names)
t.a_c == [2, missing, 1, missing]
t.a.c == [2, missing, 1, missing]
eachrow(t, :flatten) |> first == (a_b = 1, a_c = 2, d= 4)
eachrow(t, :nested) |> first == (a = (b = 1, c = 2), d = 4)
names(t) ==(a_b, a_c, d)
"""
struct ExpandedTable
col_lookup # Dict( column_name => path )
col_lookup::Dict{Symbol, Vector} # Name of column => path into nested data
columns # TypedTable, nested in the same pattern as src_data
end

Expand All @@ -39,23 +27,28 @@ end


"""Construct an ExpandedTable from the results of `expand`"""
function ExpandedTable(columns::Dict{Vector{<:Any}, T} , column_names::Dict, lazy_columns, pool_arrays, column_style) where {T<: NestedIterator{<:Any}}
function ExpandedTable(columns::Dict{Vector, T}, col_defs; lazy_columns=false, pool_arrays=false, column_style=flat_columns) where {T<: NestedIterator{<:Any}}
sym_key_columns = Dict(
Symbol.(k) => v
for (k, v) in pairs(columns)
)
paths = keys(sym_key_columns)
return ExpandedTable(sym_key_columns, col_defs; lazy_columns =lazy_columns, pool_arrays=pool_arrays, column_style = column_style)
end
function ExpandedTable(columns::Dict{Vector{Symbol}, T}, column_names::Dict; lazy_columns=false, pool_arrays=false, column_style=flat_columns) where {T<: NestedIterator{<:Any}}
paths = keys(columns)
col_defs = ColumnDefinition.(paths, Ref(column_names); pool_arrays=pool_arrays)
return ExpandedTable(sym_key_columns, col_defs, lazy_columns, column_style)
return ExpandedTable(columns, col_defs; lazy_columns =lazy_columns, column_style = column_style)
end
function ExpandedTable(columns::Dict{Vector{Symbol}, T} , column_defs::ColumnDefs, lazy_columns, column_style) where {T<: NestedIterator{<:Any}}
function ExpandedTable(columns::Dict{Vector{Symbol}, T}, column_defs::ColumnDefs; kwargs...) where {T<: NestedIterator{<:Any}}
path_graph = make_path_graph(column_defs)
column_tuple = make_column_tuple(columns, path_graph, lazy_columns)
column_tuple = make_column_tuple(columns, path_graph, kwargs[:lazy_columns])
col_lookup = Dict(
column_name(def) => field_path(def)
for def in column_defs
)
expanded_table = ExpandedTable(col_lookup, column_tuple)

column_style = kwargs[:column_style]
if column_style == flat_columns
return as_flat_table(expanded_table)
elseif column_style == nested_columns
Expand All @@ -68,6 +61,8 @@ end
as_nested_table(t::ExpandedTable) = t.columns
function as_flat_table(t::ExpandedTable)
return NamedTuple(
# foldl here is apply get property to t.columns (a nested Typed Table) and then traversing down
# the path provided in column look up to find the column that matches the name
name => foldl(getproperty, path, init=t.columns)
for (name, path) in pairs(t.col_lookup)
)
Expand Down
Loading

0 comments on commit cd25f4e

Please sign in to comment.