muellerzr/fastinference

NameError: name 'cat_names' is not defined

hududed opened this issue · 2 comments

Hello

I run this code:

from fastai.tabular.all import *
from fastinference.inference.export import to_fastinference

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')

splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'

dls = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, splits=splits).dataloaders()
learn = tabular_learner(dls, layers=[200,100])

learn.to_fastinference()

and got the error:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-11-55f457cbac14> in <module>
----> 1 learn.to_fastinference()

~/.virtualenvs/temp/lib/python3.6/site-packages/fastinference/inference/export.py in to_fastinference(x, data_fname, model_fname, path)
     89     "Export data for `fastinference_onnx` or `_pytorch` to use"
     90     if not isinstance(path,Path): path = Path(path)
---> 91     dicts = get_information(x.dls)
     92     with open(path/f'{data_fname}.pkl', 'wb') as handle:
     93         pickle.dump(dicts, handle, protocol=pickle.HIGHEST_PROTOCOL)

~/.virtualenvs/temp/lib/python3.6/site-packages/fastinference/inference/export.py in get_information(dls)
     45 
     46 # Cell
---> 47 def get_information(dls): return _extract_tfm_dicts(dls[0])
     48 
     49 # Cell

~/.virtualenvs/temp/lib/python3.6/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
    127         elif self.inst is not None: f = MethodType(f, self.inst)
    128         elif self.owner is not None: f = MethodType(f, self.owner)
--> 129         return f(*args, **kwargs)
    130 
    131     def __get__(self, inst, owner):

~/.virtualenvs/temp/lib/python3.6/site-packages/fastinference/inference/export.py in _extract_tfm_dicts(dl)
     60     name2idx = {name:n for n,name in enumerate(dl.dataset) if name in dl.cat_names or name in dl.cont_names}
     61     idx2name = {v:k for k,v in name2idx.items()}
---> 62     cat_idxs = {name2idx[name]:name for name in cat_names}
     63     cont_idxs = {name2idx[name]:name for name in cont_names}
     64     names = {'cats':cat_idxs, 'conts':cont_idxs}

NameError: name 'cat_names' is not defined

Please advise

Thanks! Can you try installing fastinference like so and try again?

pip install git+https://github.com/muellerzr/fastinference

Fixed in latest release