Feature Request: Flatten Nested Schema
Closed this issue · 1 comments
gardnmi commented
Having to flatten nested schemas can be cumbersome. A function to flatten the structure with various options would be handy.
Some configurations to include could be:
- Include Parent Struct Names as prefix in column name
- Stop after N amount of levels
- Break each level into individual dataframes and create primary and foreign key relationships between the dataframes
I've included below a working example of an existing flatten function I've used previously.
import pyspark.sql.functions as F
def flatten_dataframe(df):
has_nested_data = True
while has_nested_data:
array_cols = [c[0] for c in df.dtypes if c[1][:5]=="array"]
for c in array_cols:
df = df.withColumn(c, F.explode_outer(c))
df = child_struct(df)
array_cols = [c[0] for c in df.dtypes if c[1][:5]=="array"]
if len(array_cols) == 0:
has_nested_data = False
return df
def child_struct(nested_df):
# Creating python list to store dataframe metadata
list_schema = [((), nested_df)]
# Creating empty python list for final flattern columns
flat_columns = []
while len(list_schema) > 0:
# Removing latest or recently added item (dataframe schema) and returning into df variable
parents, df = list_schema.pop()
flat_cols = [
F.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
for c in df.dtypes
if c[1][:6] != "struct"
]
struct_cols = [c[0] for c in df.dtypes if c[1][:6] == "struct"]
flat_columns.extend(flat_cols)
# Reading nested columns and appending into stack list
for i in struct_cols:
projected_df = df.select(i + ".*")
list_schema.append((parents + (i,), projected_df))
return nested_df.select(flat_columns)
sample = """{
"name":"MSFT","location":"Redmond", "satellites": ["Bay Area", "Shanghai"],
"goods": {
"trade":true, "customers":["government", "distributer", "retail"],
"orders":[
{"orderId":1,"orderTotal":123.34,"shipped":{"orderItems":[{"itemName":"Laptop","itemQty":20},{"itemName":"Charger","itemQty":2}]}},
{"orderId":2,"orderTotal":323.34,"shipped":{"orderItems":[{"itemName":"Mice","itemQty":2},{"itemName":"Keyboard","itemQty":1}]}}
]}}
{"name":"Company1","location":"Seattle", "satellites": ["New York"],
"goods":{"trade":false, "customers":["store1", "store2"],
"orders":[
{"orderId":4,"orderTotal":123.34,"shipped":{"orderItems":[{"itemName":"Laptop","itemQty":20},{"itemName":"Charger","itemQty":3}]}},
{"orderId":5,"orderTotal":343.24,"shipped":{"orderItems":[{"itemName":"Chair","itemQty":4},{"itemName":"Lamp","itemQty":2}]}}
]}}
{"name": "Company2", "location": "Bellevue",
"goods": {"trade": true, "customers":["Bank"], "orders": [{"orderId": 4, "orderTotal": 123.34}]}}
{"name": "Company3", "location": "Kirkland"}"""
df = spark.read.json(sc.parallelize([sample]))
flatten_df = flatten_dataframe(df)