Runtime error in prompter.py
imJunaidAfzal opened this issue · 3 comments
imJunaidAfzal commented
I'm trying to use it, but get the error
File "python3.9/site-packages/mario_gpt/prompter.py", line 113, in output_hidden
self.feature_extraction(prompt, return_tensors="pt")[0]
AttributeError: 'list' object has no attribute 'mean'
code:
from mario_gpt import MarioLM, SampleOutput
# pretrained_model = shyamsn97/Mario-GPT2-700-context-length
mario_lm = MarioLM()
# use cuda to speed stuff up
# import torch
# device = torch.device('cuda')
# mario_lm = mario_lm.to(device)
prompts = ["many pipes, many enemies, some blocks, high elevation"]
# generate level of size 1400, pump temperature up to ~2.4 for more stochastic but playable levels
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=1400,
temperature=2.0,
use_tqdm=True
)
# show string list
generated_level.level
# show PIL image
generated_level.img
# save image
generated_level.img.save("generated_level.png")
# save text level to file
generated_level.save("generated_level.txt")
# play in interactive
generated_level.play()
# run Astar agent
generated_level.run_astar()
# Continue generation
generated_level_continued = mario_lm.sample(
seed=generated_level,
prompts=prompts,
num_steps=1400,
temperature=2.0,
use_tqdm=True
)
# load from text file
loaded_level = SampleOutput.load("generated_level.txt")
# play from loaded (should be the same level that we generated)
loaded_level.play()
Can you check what's wrong with it, or I'm doing something wrong
shyamsn97 commented
Hey! Do you mind sharing a full stack trace?
shyamsn97 commented
But you might wanna try upgrading your installation, either by downloading the repo from source or doing pip install --upgrade mario-gpt
shyamsn97 commented
Closing this as a version issue, feel free to re open if you still see the issue!