try:
from pyspark import SparkConf
except ImportError:
! pip install pyspark==3.2.1
from pyspark import SparkConf
from pyspark.sql import SparkSession, types as st
from IPython.display import HTML
import spark.helpers as sh
# Setup Spark
conf = SparkConf().setMaster("local[1]").setAppName("examples")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark.sparkContext.setLogLevel('ERROR')
# Load example datasets
dataframe_1 = spark.read.options(header=True).csv("./data/dataset_1.csv")
dataframe_2 = spark.read.options(header=True).csv("./data/dataset_2.csv")
html = (
"<div style='float:left'><h4>Dataset 1:</h3>" +
dataframe_1.toPandas().to_html() +
"</div><div style='float:left; margin-left:50px;'><h4>Dataset 2:</h3>" +
dataframe_2.toPandas().to_html() +
"</div>"
)
HTML(html)
|
x1 |
x2 |
x3 |
x4 |
x5 |
0 |
A |
J |
734 |
499 |
595.0 |
1 |
B |
J |
357 |
202 |
525.0 |
2 |
C |
H |
864 |
568 |
433.5 |
3 |
D |
J |
530 |
703 |
112.3 |
4 |
E |
H |
61 |
521 |
906.0 |
5 |
F |
H |
482 |
496 |
13.0 |
6 |
G |
A |
350 |
279 |
941.0 |
7 |
H |
C |
171 |
267 |
423.0 |
8 |
I |
C |
755 |
133 |
600.0 |
9 |
J |
A |
228 |
765 |
7.0 |
|
x1 |
x3 |
x4 |
x6 |
x7 |
0 |
W |
K |
391 |
140 |
872.0 |
1 |
X |
G |
88 |
483 |
707.1 |
2 |
Y |
M |
144 |
476 |
714.3 |
3 |
Z |
J |
896 |
68 |
902.0 |
4 |
A |
O |
946 |
187 |
431.0 |
5 |
B |
P |
692 |
523 |
503.5 |
6 |
C |
Q |
550 |
988 |
181.05 |
7 |
D |
R |
50 |
419 |
42.0 |
8 |
E |
S |
824 |
805 |
558.2 |
9 |
F |
T |
69 |
722 |
721.0 |
for group, data in sh.group_iterator(dataframe_1, "x2"):
print(group, " => ", data.toPandas().shape[0])
A => 2
C => 2
H => 3
J => 3
for group, data in sh.group_iterator(dataframe_1, ["x1", "x2"]):
print(group, " => ", data.toPandas().shape[0])
('A', 'J') => 1
('B', 'J') => 1
('C', 'H') => 1
('D', 'J') => 1
('E', 'H') => 1
('F', 'H') => 1
('G', 'A') => 1
('H', 'C') => 1
('I', 'C') => 1
('J', 'A') => 1
before = [(x["name"], x["type"]) for x in dataframe_1.schema.jsonValue()["fields"]]
schema = {
"x2": st.IntegerType(),
"x5": st.FloatType(),
}
new_dataframe = sh.change_schema(dataframe_1, schema)
after = [(x["name"], x["type"]) for x in new_dataframe.schema.jsonValue()["fields"]]
check = [
('x1', 'string'),
('x2', 'integer'),
('x3', 'string'),
('x4', 'string'),
('x5', 'float')
]
assert before != after
assert after == check
joined = sh.join(dataframe_1.select("x2", "x5"), dataframe_2, sh.JoinStatement("x2", "x1"))
joined.toPandas()
|
x1 |
x2 |
x3 |
x4 |
x5 |
x6 |
x7 |
0 |
A |
A |
O |
946 |
7.0 |
187 |
431.0 |
1 |
A |
A |
O |
946 |
941.0 |
187 |
431.0 |
2 |
C |
C |
Q |
550 |
600.0 |
988 |
181.05 |
3 |
C |
C |
Q |
550 |
423.0 |
988 |
181.05 |
try:
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"))
except ValueError as error:
print(f"Error raised as expected: {error}")
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="left")
joined.toPandas()
Error raised as expected:
Overlapping columns found in the dataframes: ['x1', 'x3', 'x4']
Please provide the `overwrite_strategy` argument therefore, to select a selection strategy:
* "left": Use all the intersecting columns from the left dataframe
* "right": Use all the intersecting columns from the right dataframe
* [["x_in_left", "y_in_left"], ["z_in_right"]]: Provide column names for both
|
x1 |
x2 |
x3 |
x4 |
x5 |
x6 |
x7 |
0 |
A |
J |
734 |
499 |
595.0 |
187 |
431.0 |
1 |
B |
J |
357 |
202 |
525.0 |
523 |
503.5 |
2 |
C |
H |
864 |
568 |
433.5 |
988 |
181.05 |
3 |
D |
J |
530 |
703 |
112.3 |
419 |
42.0 |
4 |
E |
H |
61 |
521 |
906.0 |
805 |
558.2 |
5 |
F |
H |
482 |
496 |
13.0 |
722 |
721.0 |
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="right")
joined.toPandas()
|
x1 |
x2 |
x3 |
x4 |
x5 |
x6 |
x7 |
0 |
A |
J |
O |
946 |
595.0 |
187 |
431.0 |
1 |
B |
J |
P |
692 |
525.0 |
523 |
503.5 |
2 |
C |
H |
Q |
550 |
433.5 |
988 |
181.05 |
3 |
D |
J |
R |
50 |
112.3 |
419 |
42.0 |
4 |
E |
H |
S |
824 |
906.0 |
805 |
558.2 |
5 |
F |
H |
T |
69 |
13.0 |
722 |
721.0 |
joined = sh.join(
dataframe_1, dataframe_2, sh.JoinStatement("x1"),
overwrite_strategy=[["x1", "x3"], ["x4"]]
)
joined.toPandas()
|
x1 |
x2 |
x3 |
x4 |
x5 |
x6 |
x7 |
0 |
A |
J |
734 |
946 |
595.0 |
187 |
431.0 |
1 |
B |
J |
357 |
692 |
525.0 |
523 |
503.5 |
2 |
C |
H |
864 |
550 |
433.5 |
988 |
181.05 |
3 |
D |
J |
530 |
50 |
112.3 |
419 |
42.0 |
4 |
E |
H |
61 |
824 |
906.0 |
805 |
558.2 |
5 |
F |
H |
482 |
69 |
13.0 |
722 |
721.0 |
x1_x1 = sh.JoinStatement("x1")
x1_x3 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x3, "or")
joined = sh.join(dataframe_1, dataframe_2, statement, overwrite_strategy="left")
joined.toPandas()
|
x1 |
x2 |
x3 |
x4 |
x5 |
x6 |
x7 |
0 |
A |
J |
734 |
499 |
595.0 |
187 |
431.0 |
1 |
B |
J |
357 |
202 |
525.0 |
523 |
503.5 |
2 |
C |
H |
864 |
568 |
433.5 |
988 |
181.05 |
3 |
D |
J |
530 |
703 |
112.3 |
419 |
42.0 |
4 |
E |
H |
61 |
521 |
906.0 |
805 |
558.2 |
5 |
F |
H |
482 |
496 |
13.0 |
722 |
721.0 |
6 |
G |
A |
350 |
279 |
941.0 |
483 |
707.1 |
7 |
J |
A |
228 |
765 |
7.0 |
68 |
902.0 |
(Perform sequential joins instead)
x1_x1 = sh.JoinStatement("x1")
x1_x2 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x2, "or")
statement_complex = sh.JoinStatement(statement, statement, "and")
try:
joined = sh.join(dataframe_1, dataframe_2, statement_complex, overwrite_strategy="left")
except NotImplementedError as error:
print(f"Error raised as expected: [{error}]")
Error raised as expected: [Recursive JoinStatement not implemented]