worldbank/REaLTabFormer

Bug in process_datetime_data() when using pandas 2.2.1

efstathios-chatzikyriakidis opened this issue · 1 comments

Hi @avsolatorio,

I hope you are well.

Unfortunately, I am blocked in a situation where I need to upgrade pandas to latest (2.2.1) but I can't because the library REalTabFormer can't work with it. It seems that the function process_datetime_data() fails. Pandas has deprecated the Series.view() function, and we get both a warning and an error from that line that uses it:

series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(

The good thing is that it is the only place in the code that we use Series.view() so it might be easy to fix.

Can you help me on this? I will need a new PyPI version also (1.0.6). Thanks anyway.

WARNING:

C:\Users\me\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:259: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.
   series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(int)

ERROR:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[61], line 1
----> 1 trainer = parent_model.fit(df=training_source_data_df.reset_index(drop = True),
      2                            n_critic=0,
      3                            device=_get_device())

File ~\.conda\envs\test\lib\site-packages\realtabformer\realtabformer.py:455, in REaLTabFormer.fit(self, df, in_df, join_on, resume_from_checkpoint, device, num_bootstrap, frac, frac_max_data, qt_max, qt_max_default, qt_interval, qt_interval_unique, distance, quantile, n_critic, n_critic_stop, gen_rounds, sensitivity_max_col_nums, use_ks, full_sensitivity, sensitivity_orig_frac_multiple, orig_samples_rounds, load_from_best_mean_sensitivity, target_col)
    453 if self.model_type == ModelType.tabular:
    454     if n_critic <= 0:
***********--> 455         trainer = self._fit_tabular(df, device=device)***********
    456         trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    457     else:

File ~\.conda\envs\test\lib\site-packages\realtabformer\realtabformer.py:1046, in REaLTabFormer._fit_tabular(self, df, device, num_train_epochs, target_epochs)
   1038 def _fit_tabular(
   1039     self,
   1040     df: pd.DataFrame,
   (...)
   1043     target_epochs: int = None,
   1044 ) -> Trainer:
   1045     self._extract_column_info(df)
***********-> 1046     df, self.col_transform_data = process_data(***********
   1047         df,
   1048         numeric_max_len=self.numeric_max_len,
   1049         numeric_precision=self.numeric_precision,
   1050         numeric_nparts=self.numeric_nparts,
   1051         target_col=self.target_col,
   1052     )
   1053     self.processed_columns = df.columns.to_list()
   1054     self.vocab = self._generate_vocab(df)

File ~\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:486, in process_data(df, numeric_max_len, numeric_precision, numeric_nparts, first_col_type, col_transform_data, target_col)
    483 col_name = encode_processed_column(col_idx[c], ColDataType.DATETIME, c)
    485 _col_transform_data = col_transform_data.get(c)
***********--> 486 series, transform_data = process_datetime_data(***********
    487     df[c],
    488     transform_data=_col_transform_data,
    489 )
    490 if _col_transform_data is None:
    491     # This means that no transform data is available
    492     # before the processing.
    493     col_transform_data[c] = transform_data

File ~\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:259, in process_datetime_data(series, transform_data)
    253 # Convert the datetimes to
    254 # their equivalent timestamp values.
    255 
    256 # Make sure that we don't convert the NaT
    257 # to some integer.
    258 series = series.copy()
***********--> 259 series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(***********
    260     int
    261 )
    262 series = series.fillna(pd.NA)
    264 # Take the mean value to re-align the data.
    265 # This will help reduce the scale of the numeric
    266 # data that will need to be generated. Let's just
    267 # add this offset back later before casting.

File ~\.conda\envs\test\lib\site-packages\pandas\core\series.py:965, in Series.view(self, dtype)
    962 # self.array instead of self._values so we piggyback on NumpyExtensionArray
    963 #  implementation
    964 res_values = self.array.view(dtype)
--> 965 res_ser = self._constructor(res_values, index=self.index, copy=False)
    966 if isinstance(res_ser._mgr, SingleBlockManager):
    967     blk = res_ser._mgr._block

File ~\.conda\envs\test\lib\site-packages\pandas\core\series.py:575, in Series.__init__(self, data, index, dtype, name, copy, fastpath)
    573     index = default_index(len(data))
    574 elif is_list_like(data):
--> 575     com.require_length_match(data, index)
    577 # create/copy the manager
    578 if isinstance(data, (SingleBlockManager, SingleArrayManager)):

File ~\.conda\envs\test\lib\site-packages\pandas\core\common.py:573, in require_length_match(data, index)
    569 """
    570 Check the length of data matches the length of the index.
    571 """
    572 if len(data) != len(index):
--> 573     raise ValueError(
    574         "Length of values "
    575         f"({len(data)}) "
    576         "does not match length of index "
    577         f"({len(index)})"
    578     )

ValueError: Length of values (22318) does not match length of index (11159)

Here is the list of python packages I am using:

numpy==1.22.4
pandas==2.2.1
multiprocess==0.70.14
dill==0.3.6
transformers==4.27.4
REaLTabFormer==0.1.5
psycopg2==2.9.6
SQLAlchemy==2.0.12
pydantic==1.10.7
jsonschema==4.17.3

Hi @avsolatorio!

Unfortunately, I still have a problem with the fix. I think that we need to do .astype('int64') instead of .astype(int). It is safer to convert datetime64[ns] to int64 because in some systems bare int could translate to int32.

Line that needs to be changed, from:

series = (series.astype(int) / 1e9)

To:

series = (series.astype('int64') / 1e9)

The fix will allow to run it everywhere with latest pandas.