TFT可以使用XAI中的SHAP方法么?
IgnoranceSmile opened this issue · 2 comments
项目中的TFT是否只能支持区间预测没法改成点预测,在使用TFT模型进行SHAP分析时:
se = ShapExplainer(model, train_data)
会显示报错:
[2024-03-01 01:26:49,864] [paddlets] [ERROR] ValueError: Only support point prediction but not probability prediction!
ValueError Traceback (most recent call last)
/tmp/ipykernel_53563/2232659875.py in
1 # se = ShapExplainer(model, train_data, background_sample_number=100, keep_index=True, use_paddleloader=False)
----> 2 se = ShapExplainer(model, train_data)
/usr/local/lib/python3.7/dist-packages/paddlets/xai/post_hoc/shap_explainer.py in init(self, model, background_data, background_sample_number, shap_method, task_type, seed, use_paddleloader, **kwargs)
77 # Judge whether it is probability prediction
78 if hasattr(_model_obj, "_output_mode"):
---> 79 raise_if(_model_obj._output_mode == 'quantiles', 'Only support point prediction but not probability prediction!')
80
81 # Base parameter
/usr/local/lib/python3.7/dist-packages/paddlets/logger/logger.py in raise_if(condition, message, logger)
154
155 """
--> 156 raise_if_not(not condition, message, logger)
/usr/local/lib/python3.7/dist-packages/paddlets/logger/logger.py in raise_if_not(condition, message, logger)
133 if not condition:
134 logger.error("ValueError: " + message)
--> 135 raise ValueError(message)
136
137
ValueError: Only support point prediction but not probability prediction!
但是没有找到将TFT改成点预测的方法
您好 可以给出具体执行命令吗?便于我们复现和解决您的问题。
代码如下:
import pandas as pd
import numpy as np
from paddlets import TSDataset
x = np.linspace(-np.pi, np.pi, 200)
sinx = np.sin(x) * 4 + np.random.randn(200)
df = pd.DataFrame(
{
'time_col': pd.date_range('2022-01-01', periods=200, freq='1h'),
'value': sinx,
'known_cov_1': sinx + 4,
'known_cov_2': sinx + 5,
'observed_cov': sinx + 8,
'static_cov': [1. for i in range(200)],
}
)
target_cov_dataset = TSDataset.load_from_dataframe(
df,
time_col='time_col',
target_cols='value',
known_cov_cols=['known_cov_1', 'known_cov_2'],
observed_cov_cols='observed_cov',
static_cov_cols='static_cov',
freq='1h'
)
target_cov_dataset.plot(['value', 'known_cov_1', 'known_cov_2', 'observed_cov'])
from paddlets.xai.post_hoc.shap_explainer import ShapExplainer
from paddlets.models.forecasting import TFTModel, RNNBlockRegressor, DeepARModel
import paddle.nn.functional as F
in_chunk_len = 24
out_chunk_len = 24
skip_chunk_len = 0
sampling_stride = 24
max_epochs = 10
patience = 5
model = TFTModel(
in_chunk_len = in_chunk_len,
out_chunk_len = out_chunk_len,
max_epochs=max_epochs,
patience=patience,
# loss_fn=F.mse_loss
)
model.fit(train_data, val_data)
se = ShapExplainer(model, train_data)
报错信息:
[2024-03-04 08:37:07,564] [paddlets] [ERROR] ValueError: Only support point prediction but not probability prediction!
ValueError Traceback (most recent call last)
/tmp/ipykernel_94700/2232659875.py in
1 # se = ShapExplainer(model, train_data, background_sample_number=100, keep_index=True, use_paddleloader=False)
----> 2 se = ShapExplainer(model, train_data)
/usr/local/lib/python3.7/dist-packages/paddlets/xai/post_hoc/shap_explainer.py in init(self, model, background_data, background_sample_number, shap_method, task_type, seed, use_paddleloader, **kwargs)
77 # Judge whether it is probability prediction
78 if hasattr(_model_obj, "_output_mode"):
---> 79 raise_if(_model_obj._output_mode == 'quantiles', 'Only support point prediction but not probability prediction!')
80
81 # Base parameter
/usr/local/lib/python3.7/dist-packages/paddlets/logger/logger.py in raise_if(condition, message, logger)
154
155 """
--> 156 raise_if_not(not condition, message, logger)
/usr/local/lib/python3.7/dist-packages/paddlets/logger/logger.py in raise_if_not(condition, message, logger)
133 if not condition:
134 logger.error("ValueError: " + message)
--> 135 raise ValueError(message)
136
137
ValueError: Only support point prediction but not probability prediction!