/pyspark-helpers

Useful helper functions for PySpark dataframe operations

Primary LanguageJupyter Notebook

Pyspark Helper Functions

[For less verbose and foolproof operations]

try:
    from pyspark import SparkConf
except ImportError:
    ! pip install pyspark==3.2.1

from pyspark import SparkConf
from pyspark.sql import SparkSession, types as st
from IPython.display import HTML

import spark.helpers as sh
# Setup Spark

conf = SparkConf().setMaster("local[1]").setAppName("examples")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark.sparkContext.setLogLevel('ERROR')
# Load example datasets

dataframe_1 = spark.read.options(header=True).csv("./data/dataset_1.csv")
dataframe_2 = spark.read.options(header=True).csv("./data/dataset_2.csv")
html = (
    "<div style='float:left'><h4>Dataset 1:</h3>" +
    dataframe_1.toPandas().to_html() + 
    "</div><div style='float:left; margin-left:50px;'><h4>Dataset 2:</h3>" +
    dataframe_2.toPandas().to_html() +
    "</div>"
)
HTML(html)

Dataset 1:

x1 x2 x3 x4 x5
0 A J 734 499 595.0
1 B J 357 202 525.0
2 C H 864 568 433.5
3 D J 530 703 112.3
4 E H 61 521 906.0
5 F H 482 496 13.0
6 G A 350 279 941.0
7 H C 171 267 423.0
8 I C 755 133 600.0
9 J A 228 765 7.0

Dataset 2:

x1 x3 x4 x6 x7
0 W K 391 140 872.0
1 X G 88 483 707.1
2 Y M 144 476 714.3
3 Z J 896 68 902.0
4 A O 946 187 431.0
5 B P 692 523 503.5
6 C Q 550 988 181.05
7 D R 50 419 42.0
8 E S 824 805 558.2
9 F T 69 722 721.0

1. Pandas-like group by

for group, data in sh.group_iterator(dataframe_1, "x2"):
    print(group, " => ", data.toPandas().shape[0])
A  =>  2
C  =>  2
H  =>  3
J  =>  3

[Multiple columns group by]

for group, data in sh.group_iterator(dataframe_1, ["x1", "x2"]):
    print(group, " => ", data.toPandas().shape[0])
('A', 'J')  =>  1
('B', 'J')  =>  1
('C', 'H')  =>  1
('D', 'J')  =>  1
('E', 'H')  =>  1
('F', 'H')  =>  1
('G', 'A')  =>  1
('H', 'C')  =>  1
('I', 'C')  =>  1
('J', 'A')  =>  1

2. Bulk-change schema

before = [(x["name"], x["type"]) for x in dataframe_1.schema.jsonValue()["fields"]]

schema = {
    "x2": st.IntegerType(),
    "x5": st.FloatType(),
}
new_dataframe = sh.change_schema(dataframe_1, schema)

after = [(x["name"], x["type"]) for x in new_dataframe.schema.jsonValue()["fields"]]
check = [
    ('x1', 'string'),
    ('x2', 'integer'),
    ('x3', 'string'),
    ('x4', 'string'),
    ('x5', 'float')
]

assert before != after
assert after == check

3. Improved joins

joined = sh.join(dataframe_1.select("x2", "x5"), dataframe_2, sh.JoinStatement("x2", "x1"))
joined.toPandas()
x1 x2 x3 x4 x5 x6 x7
0 A A O 946 7.0 187 431.0
1 A A O 946 941.0 187 431.0
2 C C Q 550 600.0 988 181.05
3 C C Q 550 423.0 988 181.05

[When there are overlapping columns]

try:
    joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"))
except ValueError as error:
    print(f"Error raised as expected: {error}")
    joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="left")
joined.toPandas()
Error raised as expected: 

Overlapping columns found in the dataframes: ['x1', 'x3', 'x4']
Please provide the `overwrite_strategy` argument therefore, to select a selection strategy:
	* "left": Use all the intersecting columns from the left dataframe
	* "right": Use all the intersecting columns from the right dataframe
	* [["x_in_left", "y_in_left"], ["z_in_right"]]: Provide column names for both
x1 x2 x3 x4 x5 x6 x7
0 A J 734 499 595.0 187 431.0
1 B J 357 202 525.0 523 503.5
2 C H 864 568 433.5 988 181.05
3 D J 530 703 112.3 419 42.0
4 E H 61 521 906.0 805 558.2
5 F H 482 496 13.0 722 721.0

[Keeping the duplicate columns from the right dataframe]

joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="right")
joined.toPandas()
x1 x2 x3 x4 x5 x6 x7
0 A J O 946 595.0 187 431.0
1 B J P 692 525.0 523 503.5
2 C H Q 550 433.5 988 181.05
3 D J R 50 112.3 419 42.0
4 E H S 824 906.0 805 558.2
5 F H T 69 13.0 722 721.0

[Keeping the duplicate columns from both]

joined = sh.join(
    dataframe_1, dataframe_2, sh.JoinStatement("x1"), 
    overwrite_strategy=[["x1", "x3"], ["x4"]]
)
joined.toPandas()
x1 x2 x3 x4 x5 x6 x7
0 A J 734 946 595.0 187 431.0
1 B J 357 692 525.0 523 503.5
2 C H 864 550 433.5 988 181.05
3 D J 530 50 112.3 419 42.0
4 E H 61 824 906.0 805 558.2
5 F H 482 69 13.0 722 721.0

[Complex join]

x1_x1 = sh.JoinStatement("x1")
x1_x3 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x3, "or")
joined = sh.join(dataframe_1, dataframe_2, statement, overwrite_strategy="left")
joined.toPandas()
x1 x2 x3 x4 x5 x6 x7
0 A J 734 499 595.0 187 431.0
1 B J 357 202 525.0 523 503.5
2 C H 864 568 433.5 988 181.05
3 D J 530 703 112.3 419 42.0
4 E H 61 521 906.0 805 558.2
5 F H 482 496 13.0 722 721.0
6 G A 350 279 941.0 483 707.1
7 J A 228 765 7.0 68 902.0

[Further nested joins are not supported]

(Perform sequential joins instead)

x1_x1 = sh.JoinStatement("x1")
x1_x2 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x2, "or")
statement_complex = sh.JoinStatement(statement, statement, "and")
try:
    joined = sh.join(dataframe_1, dataframe_2, statement_complex, overwrite_strategy="left")
except NotImplementedError as error:
    print(f"Error raised as expected: [{error}]")
Error raised as expected: [Recursive JoinStatement not implemented]