MrPowers/mack

Feature Request: Flatten Nested Schema

Closed this issue · 1 comments

Having to flatten nested schemas can be cumbersome. A function to flatten the structure with various options would be handy.

Some configurations to include could be:

  • Include Parent Struct Names as prefix in column name
  • Stop after N amount of levels
  • Break each level into individual dataframes and create primary and foreign key relationships between the dataframes

I've included below a working example of an existing flatten function I've used previously.

import pyspark.sql.functions as F 

def flatten_dataframe(df):
    
    has_nested_data = True
    
    while has_nested_data:
        
        array_cols = [c[0] for c in df.dtypes if c[1][:5]=="array"]
        
        for c in array_cols:
            df = df.withColumn(c, F.explode_outer(c))
        
        df = child_struct(df)
        
        array_cols = [c[0] for c in df.dtypes if c[1][:5]=="array"]
        
        if len(array_cols) == 0:
            has_nested_data = False
        
    return df


def child_struct(nested_df):
    # Creating python list to store dataframe metadata
    list_schema = [((), nested_df)]
    # Creating empty python list for final flattern columns
    flat_columns = []

    while len(list_schema) > 0:
        # Removing latest or recently added item (dataframe schema) and returning into df variable
        parents, df = list_schema.pop()
        flat_cols = [
            F.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        struct_cols = [c[0] for c in df.dtypes if c[1][:6] == "struct"]

        flat_columns.extend(flat_cols)
        # Reading  nested columns and appending into stack list
        for i in struct_cols:
            projected_df = df.select(i + ".*")
            list_schema.append((parents + (i,), projected_df))
    return nested_df.select(flat_columns)


sample = """{
  "name":"MSFT","location":"Redmond", "satellites": ["Bay Area", "Shanghai"],
  "goods": {
    "trade":true, "customers":["government", "distributer", "retail"],
    "orders":[
        {"orderId":1,"orderTotal":123.34,"shipped":{"orderItems":[{"itemName":"Laptop","itemQty":20},{"itemName":"Charger","itemQty":2}]}},
        {"orderId":2,"orderTotal":323.34,"shipped":{"orderItems":[{"itemName":"Mice","itemQty":2},{"itemName":"Keyboard","itemQty":1}]}}
    ]}}
{"name":"Company1","location":"Seattle", "satellites": ["New York"],
  "goods":{"trade":false, "customers":["store1", "store2"],
  "orders":[
      {"orderId":4,"orderTotal":123.34,"shipped":{"orderItems":[{"itemName":"Laptop","itemQty":20},{"itemName":"Charger","itemQty":3}]}},
      {"orderId":5,"orderTotal":343.24,"shipped":{"orderItems":[{"itemName":"Chair","itemQty":4},{"itemName":"Lamp","itemQty":2}]}}
    ]}}
{"name": "Company2", "location": "Bellevue",
  "goods": {"trade": true, "customers":["Bank"], "orders": [{"orderId": 4, "orderTotal": 123.34}]}}
{"name": "Company3", "location": "Kirkland"}"""


df = spark.read.json(sc.parallelize([sample]))
flatten_df = flatten_dataframe(df)

I didn't see there was a Quinn library that would be more appropriate.