mit-ll-responsible-ai/hydra-zen

Overriding of nested groups doesn't work without using + operator

sorenmc opened this issue · 3 comments

In #635 I finally managed to make Nested Dataclasses initialize properly and made CLI usable

from dataclasses import dataclass, field
from typing import Literal

from hydra_zen import store, make_custom_builds_fn, zen

builds = make_custom_builds_fn(populate_full_signature=True)

@dataclass
class C:
    c_1: int = 1
    c_2: str = "c text"
    c_3: Literal["a", "b", "c"] = "a"

@dataclass
class B:
    b_1: float = 1.0
    b_2: str = "b text"
    c: C = field(default_factory= lambda: C())

@dataclass
class A:
    a_1: float = 1.0
    a_2: int = 1

@dataclass
class Config:
    a: A = field(default_factory= lambda: A())
    b: B = field(default_factory= lambda: B())

@dataclass
class WrapsConfig:
    config: Config = field(default_factory= lambda: Config())

class Model:
    def __init__(self, config: Config):
        self._config = config

    @property
    def config(self):
        return self._config
    
builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")

c_store = store(group="config/b/c")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")


builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)

b_store = store(group="config/b")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")


builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)

a_store = store(group="config/a")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")


builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)

config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")



builds_wraps_config = builds(WrapsConfig, config=config_1)
store(builds_wraps_config,  name="default_config")

def task_function(config: Config):
    model = Model(config)
    print(model.config)


if __name__ == "__main__":
    store.add_to_hydra_store()
    zen(task_function).hydra_main(config_name="default_config", version_base="1.3", config_path=None)    

I can run this app from the command line overriding c with our c config group as

python config.py +config/b/c=c_text_1

If I run it without the + as python config.py config/b/c=c_text_1 i get the following error

Could not override 'config/b/c'. No match in the defaults list.
To append to your default list use +config/b/c=c_text_1

If i run

python config.py config.b.c=c_text_1

I get the output

Config(a=A(a_1=1.2, a_2=1), b=B(b_1=1.0, b_2='b text', c='c_text_1'))

So it will override with string "c_text_1" instead of using the config group c_text_1

Which I would guess is related to hydra defaults. How do i override using the config groups i created from the CLI without having to use the + hack?

I'm not sure if the underlying error is hydra or hydra-zen yet, but the following update will work:

@dataclass
class WrapsConfig:
    defaults: list[Any] = field(default_factory= lambda: ["_self_", {"config/b/c": None}])
    config: Config = field(default_factory= lambda: Config())

You'd have to do that for all the config groups.

This appears to be a Hydra issue, see the statement write above the "Config Inheritance" section here: https://hydra.cc/docs/tutorials/structured_config/config_groups/#config-inheritance

I have now followed what you said, and it works perfectly. Thank you for the quick reply.

Here is an updated example with all groups added:

from dataclasses import dataclass, field
from typing import Literal

from hydra_zen import store, make_custom_builds_fn, zen
from traitlets import Any

builds = make_custom_builds_fn(populate_full_signature=True)


@dataclass
class C:
    c_1: int = 1
    c_2: str = "c text"
    c_3: Literal["a", "b", "c"] = "a"


@dataclass
class B:
    b_1: float = 1.0
    b_2: str = "b text"
    c: C = field(default_factory=lambda: C())


@dataclass
class A:
    a_1: float = 1.0
    a_2: int = 1


@dataclass
class Config:
    a: A = field(default_factory=lambda: A())
    b: B = field(default_factory=lambda: B())


@dataclass
class WrapsConfig:
    defaults: Any = field(
        default_factory=lambda: [
            "_self_",
            {"config/b": None,},
            {"config/b/c": None,},
            {"config/a": None,}
        ]
    )
    config: Config = field(default_factory=lambda: Config())


class Model:
    def __init__(self, config: Config):
        self._config = config

    @property
    def config(self):
        return self._config


builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")

c_store = store(group="config/b/c")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")


builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)

b_store = store(group="config/b")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")


builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)

a_store = store(group="config/a")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")


builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)

config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")


builds_wraps_config = builds(WrapsConfig, config=config_1)
store(builds_wraps_config, name="default_config")


def task_function(config: Config):
    model = Model(config)
    print(model.config)


if __name__ == "__main__":
    store.add_to_hydra_store()
    zen(task_function).hydra_main(
        config_name="default_config", version_base="1.3", config_path=None
    )

This will allow us to run

> python config.py config/b/c=c_text_2
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=1.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))

> python config.py config/b=b_1_3            
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 3', c_3='a')))
> python config.py config/b=b_1_3 config/b/c=c_text_2
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))

> python config.py config/b/c=c_text_2 config/b=b_1_3
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))

as you can see b_1_3 is supposed to make c_2="c text 3" but if we add config/b/c=c_text_2 it will make c_2="c text 2" agnostic to the order of the arguments.