CTUAvastLab/JsonGrinder.jl

Adding a support for word embeddings

Closed this issue · 1 comments

So I needed a support for language models, which was relatively simple to add as follows.

using Embeddings
using WordTokenizers
const embtable = load_embeddings(FastText_Text) # or load_embeddings(FastText_Text) or ...
const get_word_index = Dict(word=>ii for (ii,word) in enumerate(embtable.vocab))

function get_embedding(word)
    ind = get_word_index[word]
    emb = embtable.embeddings[:,ind]
    return emb
end

struct ExtractEmbedded{E} <: AbstractExtractor
	n::Int
	uniontypes::Bool
	embtable::E
end

ExtractEmbedded(embtable = load_embeddings(FastText_Text), uniontypes = true) = ExtractEmbedded(size(embtable.embeddings, 1), uniontypes, embtable)

(e::ExtractEmbedded)(::JsonGrinder.MissingOrNothing; store_input=false) = ArrayNode(fill(missing, e.n, 1), [missing])
(e::ExtractEmbedded)(::JsonGrinder.ExtractEmpty; store_input=false) = BagNode(ArrayNode(Matrix{Float32}(undef, e.n, 0)), AlignedBags([0:-1]), nothing)
function (e::ExtractEmbedded)(s::JsonGrinder.HierarchicType; store_input=false)
	s = filter(x -> haskey(get_word_index, x), tokenize(s))
	isempty(s) && return(e(JsonGrinder.ExtractEmpty(); store_input))
	x = reduce(hcat, get_embedding.(s))
	metadata = store_input ? s : nothing
	BagNode(x, AlignedBags([1:size(x,2)]), metadata)
end

Base.length(e::ExtractEmbedded) = e.n
Base.hash(e::ExtractEmbedded, h::UInt) = hash((e.n, e.uniontypes), h)
Base.:(==)(e1::ExtractEmbedded, e2::ExtractEmbedded) = e1.n === e2.n && e1.uniontypes === e2.uniontypes

function ExplainMill.yarason(ds::BagNode, mk::ExplainMill.BagMask, e::ExtractEmbedded, exportobs=fill(true, nobs(ds)))
    if !any(exportobs) || isempty(ds.bags)
        return(nothing)
    end
    items = findall(prunemask(mk.mask))
    map(ds.bags) do b 
    	ii = intersect(items, b)
    	isempty(ii) ? nothing : ds.metadata[ii]
    end
end

Should we polish it iand add it to the lib with a little bit more love?

Closing in favor of #136