AICPS/roadscene2vec

about new pretrained models

xinxinlv opened this issue · 3 comments

Hi, I have downloaded your newly updated folder 'use_case_data', but I didn't find the 'sequence_classification_example_model.pt'.
I would like to know whether the folder 'use_case_data' of the previous version can still be used.

Hi, you can use the 5_fold_271_carla_sequence_classification_example_model.pt in the folder. It is the same model just named differently. The old version should still work as well, but the new one has updated model weights that may perform better.

hi, thank you for your reply.
I modified the “use_case_2.py” and directly loaded the new pre-trained model “5_fold_271_carla_sequence_classification_example_model.pt” provided by you to predict “use_case_data/lanechange” (without training).
`def risk_assess():
scenegraph_extraction_config = configuration(r"use_case_2_scenegraph_extraction_config.yaml",from_function = True) #create scenegraph extraction config object
extracted_scenegraphs = extract_seq(scenegraph_extraction_config) #extracted scenegraphs for each frame for the given sequence into a ScenegraphDataset
training_config = configuration(r"use_case_2_learning_config.yaml",from_function = True) #create training config object
trainer = Scenegraph_Trainer(training_config) #create trainer object using config
#trainer.split_dataset() #split ScenegraphDataset specified in learning config into training, testing data
# trainer.build_model() #build model specified in learning config
# trainer.learn()

trainer.load_model() 


model_input = format_use_case_model_input(extracted_scenegraphs, trainer) #turn extracted original sequence's extracted ScenegraphDataset into model input
output, _ = trainer.model.forward(*model_input) #output risk assessment for the original sequence 
return output   `

The correct label is risky for “use_case_data/lanechange/22_lanchange”.
But, the direct prediction result is “safe”. This is incorrect. I don't know what went wrong

Model loaded from file. /home/liuxx/anaconda3/envs/av/lib/python3.9/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead warnings.warn(out) tensor([-0.3435, -1.2354], device='cuda:0', grad_fn=<LogSoftmaxBackward0>)

It is a machine learning model so it will not always predict the correct result :). Our experiments show it predicts the correct result about 90% of the time.