Fix wandb logging

This commit is contained in:
Srini Iyer 2025-02-06 00:07:59 +00:00
parent c79b1fdbd0
commit a27ab3de8e

View file

@ -4,7 +4,6 @@
import json import json
import logging import logging
from collections import namedtuple from collections import namedtuple
from dataclasses import asdict
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, Union
@ -68,8 +67,8 @@ class MetricLogger:
and get_is_master() and get_is_master()
): ):
run = wandb.init( run = wandb.init(
config=asdict(self.args), config=self.args.model_dump(),
**asdict(self.args.logging.wandb), **self.args.logging.wandb.model_dump(),
) )
def log(self, metrics: dict[str, Any]): def log(self, metrics: dict[str, Any]):