Overriding of nested groups doesn't work without using + operator
sorenmc opened this issue · 3 comments
In #635 I finally managed to make Nested Dataclasses initialize properly and made CLI usable
from dataclasses import dataclass, field
from typing import Literal
from hydra_zen import store, make_custom_builds_fn, zen
builds = make_custom_builds_fn(populate_full_signature=True)
@dataclass
class C:
c_1: int = 1
c_2: str = "c text"
c_3: Literal["a", "b", "c"] = "a"
@dataclass
class B:
b_1: float = 1.0
b_2: str = "b text"
c: C = field(default_factory= lambda: C())
@dataclass
class A:
a_1: float = 1.0
a_2: int = 1
@dataclass
class Config:
a: A = field(default_factory= lambda: A())
b: B = field(default_factory= lambda: B())
@dataclass
class WrapsConfig:
config: Config = field(default_factory= lambda: Config())
class Model:
def __init__(self, config: Config):
self._config = config
@property
def config(self):
return self._config
builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")
c_store = store(group="config/b/c")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")
builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)
b_store = store(group="config/b")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")
builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)
a_store = store(group="config/a")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")
builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)
config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")
builds_wraps_config = builds(WrapsConfig, config=config_1)
store(builds_wraps_config, name="default_config")
def task_function(config: Config):
model = Model(config)
print(model.config)
if __name__ == "__main__":
store.add_to_hydra_store()
zen(task_function).hydra_main(config_name="default_config", version_base="1.3", config_path=None)
I can run this app from the command line overriding c with our c config group as
python config.py +config/b/c=c_text_1
If I run it without the +
as python config.py config/b/c=c_text_1
i get the following error
Could not override 'config/b/c'. No match in the defaults list.
To append to your default list use +config/b/c=c_text_1
If i run
python config.py config.b.c=c_text_1
I get the output
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=1.0, b_2='b text', c='c_text_1'))
So it will override with string "c_text_1" instead of using the config group c_text_1
Which I would guess is related to hydra defaults. How do i override using the config groups i created from the CLI without having to use the + hack?
I'm not sure if the underlying error is hydra or hydra-zen yet, but the following update will work:
@dataclass
class WrapsConfig:
defaults: list[Any] = field(default_factory= lambda: ["_self_", {"config/b/c": None}])
config: Config = field(default_factory= lambda: Config())
You'd have to do that for all the config groups.
This appears to be a Hydra issue, see the statement write above the "Config Inheritance" section here: https://hydra.cc/docs/tutorials/structured_config/config_groups/#config-inheritance
I have now followed what you said, and it works perfectly. Thank you for the quick reply.
Here is an updated example with all groups added:
from dataclasses import dataclass, field
from typing import Literal
from hydra_zen import store, make_custom_builds_fn, zen
from traitlets import Any
builds = make_custom_builds_fn(populate_full_signature=True)
@dataclass
class C:
c_1: int = 1
c_2: str = "c text"
c_3: Literal["a", "b", "c"] = "a"
@dataclass
class B:
b_1: float = 1.0
b_2: str = "b text"
c: C = field(default_factory=lambda: C())
@dataclass
class A:
a_1: float = 1.0
a_2: int = 1
@dataclass
class Config:
a: A = field(default_factory=lambda: A())
b: B = field(default_factory=lambda: B())
@dataclass
class WrapsConfig:
defaults: Any = field(
default_factory=lambda: [
"_self_",
{"config/b": None,},
{"config/b/c": None,},
{"config/a": None,}
]
)
config: Config = field(default_factory=lambda: Config())
class Model:
def __init__(self, config: Config):
self._config = config
@property
def config(self):
return self._config
builds_c = builds(C)
c_text_1 = builds_c(c_2="c text 1")
c_text_2 = builds_c(c_2="c text 2")
c_text_3 = builds_c(c_2="c text 3")
c_store = store(group="config/b/c")
c_store(c_text_1, name="c_text_1")
c_store(c_text_2, name="c_text_2")
c_store(c_text_3, name="c_text_3")
builds_b = builds(B)
b_1_1 = builds_b(b_1=1, c=c_text_1)
b_1_2 = builds_b(b_1=2, c=c_text_2)
b_1_3 = builds_b(b_1=3, c=c_text_3)
b_store = store(group="config/b")
b_store(b_1_1, name="b_1_1")
b_store(b_1_2, name="b_1_2")
b_store(b_1_3, name="b_1_3")
builds_a = builds(A)
a_1_1 = builds_a(a_1=1.2)
a_1_2 = builds_a(a_1=2.2)
a_1_3 = builds_a(a_1=3.2)
a_store = store(group="config/a")
a_store(a_1_1, name="a_1_1")
a_store(a_1_2, name="a_1_2")
a_store(a_1_3, name="a_1_3")
builds_config = builds(Config)
config_1 = builds_config(a=a_1_1, b=b_1_1)
config_2 = builds_config(a=a_1_2, b=b_1_2)
config_3 = builds_config(a=a_1_3, b=b_1_3)
config_store = store(group="config")
config_store(config_1, name="config_1")
config_store(config_2, name="config_2")
config_store(config_3, name="config_3")
builds_wraps_config = builds(WrapsConfig, config=config_1)
store(builds_wraps_config, name="default_config")
def task_function(config: Config):
model = Model(config)
print(model.config)
if __name__ == "__main__":
store.add_to_hydra_store()
zen(task_function).hydra_main(
config_name="default_config", version_base="1.3", config_path=None
)
This will allow us to run
> python config.py config/b/c=c_text_2
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=1.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))
> python config.py config/b=b_1_3
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 3', c_3='a')))
> python config.py config/b=b_1_3 config/b/c=c_text_2
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))
> python config.py config/b/c=c_text_2 config/b=b_1_3
Config(a=A(a_1=1.2, a_2=1), b=B(b_1=3.0, b_2='b text', c=C(c_1=1, c_2='c text 2', c_3='a')))
as you can see b_1_3 is supposed to make c_2="c text 3" but if we add config/b/c=c_text_2 it will make c_2="c text 2" agnostic to the order of the arguments.