TFTDataset
yunxileo opened this issue · 0 comments
yunxileo commented
In the
class TFTDataset(Dataset, ElectricityFormatter):
"""Dataset Basic Structure for Temporal Fusion Transformer"""
def __init__(self,
data_df):
super(ElectricityFormatter, self).__init__()
"""
Args:
csv_file (string): Path to the csv file with annotations.
"""
# Attribute loading the data
self.data = data_df.reset_index(drop=True)
self.id_col = get_single_col_by_input_type(InputTypes.ID, self._column_definition) #id
self.time_col = get_single_col_by_input_type(InputTypes.TIME, self._column_definition)
self.target_col = get_single_col_by_input_type(InputTypes.TARGET, self._column_definition) #power_usage
self.input_cols = [
tup[0]
for tup in self._column_definition
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
self.col_mappings = {
'identifier': [self.id_col],
'time': [self.time_col],
'outputs': [self.target_col],
'inputs': self.input_cols
}
self.lookback = self.get_time_steps()
self.num_encoder_steps = self.get_num_encoder_steps()
self.data_index = self.get_index_filtering()
self.group_size = self.data.groupby([self.id_col]).apply(lambda x: x.shape[0]).mean()
**self.data_index = self.data_index[self.data_index.end_rel < self.group_size].reset_index()**
why do you have this restricted condition :' self.data_index = self.data_index[self.data_index.end_rel < self.group_size].reset_index()'