Adding a support for word embeddings
Closed this issue · 1 comments
pevnak commented
So I needed a support for language models, which was relatively simple to add as follows.
using Embeddings
using WordTokenizers
const embtable = load_embeddings(FastText_Text) # or load_embeddings(FastText_Text) or ...
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))
function get_embedding(word)
ind = get_word_index[word]
emb = embtable.embeddings[:,ind]
return emb
end
struct ExtractEmbedded{E} <: AbstractExtractor
n::Int
uniontypes::Bool
embtable::E
end
ExtractEmbedded(embtable = load_embeddings(FastText_Text), uniontypes = true) = ExtractEmbedded(size(embtable.embeddings, 1), uniontypes, embtable)
(e::ExtractEmbedded)(::JsonGrinder.MissingOrNothing; store_input=false) = ArrayNode(fill(missing, e.n, 1), [missing])
(e::ExtractEmbedded)(::JsonGrinder.ExtractEmpty; store_input=false) = BagNode(ArrayNode(Matrix{Float32}(undef, e.n, 0)), AlignedBags([0:-1]), nothing)
function (e::ExtractEmbedded)(s::JsonGrinder.HierarchicType; store_input=false)
s = filter(x -> haskey(get_word_index, x), tokenize(s))
isempty(s) && return(e(JsonGrinder.ExtractEmpty(); store_input))
x = reduce(hcat, get_embedding.(s))
metadata = store_input ? s : nothing
BagNode(x, AlignedBags([1:size(x,2)]), metadata)
end
Base.length(e::ExtractEmbedded) = e.n
Base.hash(e::ExtractEmbedded, h::UInt) = hash((e.n, e.uniontypes), h)
Base.:(==)(e1::ExtractEmbedded, e2::ExtractEmbedded) = e1.n === e2.n && e1.uniontypes === e2.uniontypes
function ExplainMill.yarason(ds::BagNode, mk::ExplainMill.BagMask, e::ExtractEmbedded, exportobs=fill(true, nobs(ds)))
if !any(exportobs) || isempty(ds.bags)
return(nothing)
end
items = findall(prunemask(mk.mask))
map(ds.bags) do b
ii = intersect(items, b)
isempty(ii) ? nothing : ds.metadata[ii]
end
end
Should we polish it iand add it to the lib with a little bit more love?
simonmandlik commented
Closing in favor of #136