dehoyosb/temporal_fusion_transformer_pytorch

TFTDataset

yunxileo opened this issue · 0 comments

In the
class TFTDataset(Dataset, ElectricityFormatter):
"""Dataset Basic Structure for Temporal Fusion Transformer"""

def __init__(self, 
             data_df):
    super(ElectricityFormatter, self).__init__()
    """
    Args:
        csv_file (string): Path to the csv file with annotations.
    """
    # Attribute loading the data
    self.data = data_df.reset_index(drop=True)
    
    self.id_col = get_single_col_by_input_type(InputTypes.ID, self._column_definition) #id
    self.time_col = get_single_col_by_input_type(InputTypes.TIME, self._column_definition)
    self.target_col = get_single_col_by_input_type(InputTypes.TARGET, self._column_definition) #power_usage
    self.input_cols = [
                        tup[0]
                        for tup in self._column_definition
                        if tup[2] not in {InputTypes.ID, InputTypes.TIME}
                      ]
    self.col_mappings = {
                          'identifier': [self.id_col],
                          'time': [self.time_col],
                          'outputs': [self.target_col],
                          'inputs': self.input_cols
                      }
    self.lookback = self.get_time_steps()
    self.num_encoder_steps = self.get_num_encoder_steps()
    
    self.data_index = self.get_index_filtering()
    self.group_size = self.data.groupby([self.id_col]).apply(lambda x: x.shape[0]).mean()
    **self.data_index = self.data_index[self.data_index.end_rel < self.group_size].reset_index()**

why do you have this restricted condition :' self.data_index = self.data_index[self.data_index.end_rel < self.group_size].reset_index()'