Custom learning tasks tutorial gives error
usiam opened this issue · 0 comments
usiam commented
using Pkg;
Pkg.activate(".")
using FastAI, FastVision, Random, Images
import CairoMakie;
CairoMakie.activate!(type="png");
path = FastAI.load(datasets()["oxford-iiit-pet"])
im_path = joinpath(path, "images")
files = loadfolderdata(im_path; filterfn=FastVision.isimagefile)
function transform_image(image, sz=224)
image_resized = imresize(convert.(RGB{N0f8}, image), (sz, sz))
a = permuteddimsview(channelview(image_resized), (2, 3, 1))
end
p = getobs(files, 1)
image = loadfile(p)
label_func(path) = match(r"^(.*)_\d+\.jpg$", pathname(path))[1]
label_func(p)
labels = map(label_func, files)
length(unique(labels))
data = mapobs(files) do file
return (loadfile(file), label_func(file))
end
idxs = shuffle(1:length(files))
cut = round(Int, 0.8 * length(idxs))
trainidxs, valididxs = idxs[1:cut], idxs[cut+1:end]
trainfiles, validfiles = files[trainidxs], files[valididxs]
summary.((trainfiles, validfiles))
import FastAI.MLUtils
struct SiamesePairs
labels
same
other
valid
end
function SiamesePairs(labels; valid=false)
ulabels = unique(labels)
same = Dict(
label => [i for (i, l) in enumerate(labels) if l == label]
for label in ulabels)
other = Dict(
label => [i for (i, l) in enumerate(labels) if l != label]
for label in ulabels)
return SiamesePairs(labels, same, other, valid)
end
function MLUtils.getobs(si::SiamesePairs, idx::Int)
rng = si.valid ? MersenneTwister(idx) : Random.GLOBAL_RNG
if rand(rng) > 0.5
return ((idx, rand(rng, si.same[si.labels[idx]])), true)
else
return ((idx, rand(rng, si.other[si.labels[idx]])), false)
end
end
MLUtils.numobs(si::SiamesePairs) = length(si.labels)
function siamesedata(files; valid = false, transformfn = identity)
labels = map(label_func, files)
si = SiamesePairs(labels; valid = valid)
return mapobs(si) do obs
(i, j), same = obs
image1 = transformfn(loadfile(getobs(files, i)))
image2 = transformfn(loadfile(getobs(files, j)))
return ((image1, image2), same)
end
end
traindata = siamesedata(trainfiles; transformfn=transform_image)
validdata = siamesedata(validfiles; transformfn=transform_image, valid=true);
traindl = FastAI.MLUtils.DataLoader(traindata, 16)
ERROR: MethodError: no method matching MLUtils.DataLoader(::MLUtils.MappedData{:auto, var"#75#76"{typeof(transform_image), ObsView{MLDatasets.FileDataset{typeof(identity), String}, Vector{Int64}}}, SiamesePairs}, ::Int64)
I was trying to recreate the Siamese example in the docs and could not figure out why I am getting this error? And how do I fix this?