NaN may occur in the CEC2022 SO benchmark
Closed this issue · 2 comments
Desciption
In the CEC2022 single objective benchmark, F10 ~ F12 all use the compose_operat
function in OperatFunc
class, where the codes below may lead to NaN
items in term1
, resulting in final NaN
in the evaluated fitness.
diff_square_sum = jnp.sum((x - Os_mat) ** 2, axis=1)
term1 = 1 / jnp.sqrt(diff_square_sum)
term2 = jnp.exp(-0.5 * diff_square_sum / (sigma**2 * D))
W = term1 * term2
W_norm = W / jnp.sum(W)
Possible Solution
diff_square_sum = jnp.sum((x - Os_mat) ** 2, axis=1)
term1 = 1 / jnp.sqrt(diff_square_sum)
term2 = jnp.exp(-0.5 * diff_square_sum / (sigma**2 * D))
W = term1 * term2
term1_nan = jnp.isnan(term1)
W_norm = jnp.select(jnp.any(term1_nan), term1_nan / jnp.count_nonzero(term1_nan), W / jnp.sum(W))
Location
- problems
- numerical
- cec2022_so.py
- line 121
Is this the intended behavior of CEC2022 single objective benchmark or a bug?
Note
The overall code of CEC2022 benchmark is hard to read due to the lack of documentations and improper abbreviations, please fix these problems if possible.
Yes, just as you mentioned, NaN values can indeed be produced unexpectedly, which is something we do not wish to see. Your Possible Solution addresses this issue very well.
CEC2022 does have some areas where readability is poor. In order to keep the numerical results produced by this code consistent with those of the original code, I had to port over some complicated parts. I apologize for that.
Please take a look at #131. If no there is objection, I will merge the PR and close this issue.