SPFlow/SPFlow

Incorrect value check during if condition

GiuliaGualtieri opened this issue · 1 comments

What did you expect to happen?

Check the value of variable 'pre_proc', instead of variable 'ds_context'.

What actually happened?

in file src/spn/algorithms/splitting/Base.py :

def preproc(data, ds_context, pre_proc, ohe):
if pre_proc:
f = None
if pre_proc == "tf-idf":
f = lambda data: TfidfTransformer().fit_transform(data)
elif ds_context == "log+1":
f = lambda data: np.log(data + 1)
elif ds_context == "sqrt":
f = lambda data: np.sqrt(data)
if f is not None:
data = np.copy(data)
data[:, ds_context.distribution_family == "poisson"] = f(
data[:, ds_context.distribution_family == "poisson"]
)
if ohe:
data = getOHE(data, ds_context)
return data

there is an incorrect check value condition in first e second 'elif.'

Describe your attempts to resolve the issue

def preproc(data, ds_context, pre_proc, ohe):
    if pre_proc:
        f = None
        if pre_proc == "tf-idf":
            f = lambda data: TfidfTransformer().fit_transform(data)
        elif pre_proc == "log+1":
            f = lambda data: np.log(data + 1)
        elif pre_proc == "sqrt":
            f = lambda data: np.sqrt(data)

        if f is not None:
            data = np.copy(data)
            data[:, ds_context.distribution_family == "poisson"] = f(
                data[:, ds_context.distribution_family == "poisson"]
            )

    if ohe:
        data = getOHE(data, ds_context)

    return 

Steps to reproduce

def preproc(data, ds_context, pre_proc, ohe):
    if pre_proc:
        f = None
        if pre_proc == "tf-idf":
            f = lambda data: TfidfTransformer().fit_transform(data)
        elif pre_proc == "log+1":
            f = lambda data: np.log(data + 1)
        elif pre_proc == "sqrt":
            f = lambda data: np.sqrt(data)

        if f is not None:
            data = np.copy(data)
            data[:, ds_context.distribution_family == "poisson"] = f(
                data[:, ds_context.distribution_family == "poisson"]
            )

    if ohe:
        data = getOHE(data, ds_context)

    return 

System Information

Python:3.9.6

Installed Python Packages

#Add here requirements numpy scipy sklearn statsmodels networkx joblib matplotlib pydot lark-parser tqdm ete3 geomstats sympy PyQt5 arff pytest dataclasses geomstats

Thanks for your contribution! Fixed in d69babd.