JuliaPOMDP/POMCPOW.jl

Using POMCPOW with belief states

Closed this issue · 3 comments

odow commented

Hey! So I'm trying to solve some POMDPs, and I successfully have with different solvers such as QMDP and BasicPOMCP, but I'm stuck trying to use this one.

In all other solvers, I've simulated the policy using a belief updater constructed as updater(policy). With this solver, I get

ERROR: MethodError: no method matching updater(::POMCPOWPlanner{InventoryManagement,POMCPOW.POWNodeFilter,MaxUCB,MCTS.RandomActionGenerator{Random.MersenneTwister},BasicPOMCP.SolvedPORollout{POMDPPolicies.RandomPolicy{Random.MersenneTwister,InventoryManagement,BeliefUpdaters.NothingUpdater},BeliefUpdaters.NothingUpdater,Random.MersenneTwister},Int64,Float64,POMCPOWSolver})
Closest candidates are:
updater(::POMDPPolicies.AlphaVectorPolicy) at C:\Users\Oscar\.julia\packages\POMDPPolicies\nVH68\src\alpha_vector.jl:45
updater(::POMDPPolicies.FunctionPolicy) at C:\Users\Oscar\.julia\packages\POMDPPolicies\nVH68\src\function.jl:27
updater(::POMDPPolicies.RandomPolicy) at C:\Users\Oscar\.julia\packages\POMDPPolicies\nVH68\src\random.jl:38
...
Stacktrace:
[1] top-level scope at none:0

In the readme, it looks like it users updater(pomdp), but then I have to overload the method myself and I couldn't find any documentation that describes what the method should do...

Here is a reproducible example that I'm trying to get working:

using POMDPs, POMDPModelTools, POMDPSimulators, POMCPOW

struct InventoryManagement <: POMDP{Tuple{Symbol, Float64}, Float64, Float64}
    nodes::Vector{Symbol}
    discount::Float64
    inventory_max::Float64
    inventory_step::Float64
    inventory_offset::Float64
    inventory::Vector{Float64}
    states::Vector{Tuple{Symbol, Float64}}
    actions::Vector{Float64}
    demand::Dict{Symbol, Vector{Float64}}
    destroy_cost::Float64
    lost_demand_cost::Float64
    unit_cost::Float64
    function InventoryManagement()
        nodes = [:A, :B]
        inventory_step = 0.1
        inventory_max = 3.0
        inventory_offset = 5.0
        inventory = collect(
            -inventory_offset:inventory_step:inventory_max + inventory_offset
        )
        states = [(node, i) for node in nodes for i in inventory]
        return new(
            nodes,
            0.9,
            inventory_max,
            inventory_step,
            inventory_offset,
            inventory,
            states,
            collect(0.0:inventory_step:inventory_max),
            Dict(:A => [0.5, 0.3, 0.2], :B => [0.2, 0.3, 0.5]),
            2.0,
            10.0,
            1.0
        )
    end
end

POMDPs.discount(m::InventoryManagement) = m.discount

POMDPs.states(m::InventoryManagement) = m.states
POMDPs.n_states(m::InventoryManagement) = length(m.states)
function POMDPs.stateindex(m::InventoryManagement, s)
    (node, inventory) = s
    idx = round(Int, (inventory + m.inventory_offset) / m.inventory_step + 1)
    if node == :A
        return idx
    else
        return idx + length(m.inventory)
    end
end

POMDPs.actions(m::InventoryManagement) = m.actions
POMDPs.n_actions(m::InventoryManagement) = length(m.actions)
function POMDPs.actionindex(m::InventoryManagement, a)
    return round(Int, a / m.inventory_step + 1)
end

function POMDPs.transition(m::InventoryManagement, s, a)
    (node, inventory) = s
    values = Tuple{Symbol, Float64}[]
    for demand in 1:length(m.demand[node])
        push!(
            values,
            (node, clamp(inventory, 0.0, m.inventory_max) + a - demand)
        )
    end
    return POMDPModelTools.SparseCat(values, m.demand[node])
end

function POMDPs.reward(m::InventoryManagement, s, a, s′)
    lost_demand = max(-s′[2], 0.0)
    destroyed_units = max(s′[2] - m.inventory_max, 0.0)
    # Note the -ve here because POMDPs.jl talks in terms of maximizing value.
    return -(m.unit_cost * a + m.destroy_cost * destroyed_units +
        m.lost_demand_cost * lost_demand)
end

function POMDPs.initialstate_distribution(m::InventoryManagement)
    return POMDPModelTools.SparseCat([(:A, 0.0), (:B, 0.0)], [0.5, 0.5])
end

function POMDPs.observation(m::InventoryManagement, s)
    return POMDPModelTools.SparseCat(
        collect(1:length(m.demand[s[1]])),
        m.demand[s[1]]
    )
end

pomdp = InventoryManagement()

# solver = QMDPSolver(; max_iterations = 500, tolerance = 1e-9, verbose = true)
# solver = POMCPSolver()
solver = POMCPOWSolver()

policy = solve(solver, pomdp)

hist = simulate(
    HistoryRecorder(max_steps = 10),
    pomdp,
    policy,
    updater(policy),  # updater(pomdp),
    initialstate_distribution(pomdp)
)

Hi @odow , thanks for reporting the issue! I added updater(::POMCPOWPlanner), so the standard way should work on the master branch now. I'll also tag a new version soon. updater(::POMDP) was left over from an earlier era and should have been removed from the readme.

In general, you can get more control over the updater by creating your own updater (e.g. a DiscreteUpdater or SIRParticleFilter) rather than using updater(policy).

Let me know if you run into any other issues!

odow commented

Thanks! I'll keep playing. (Btw, I sent an email to your stanford email, is that an okay one to use?)

odow commented

Can confirm working. Thanks!