import logging
from typing import Any, Dict

from tensorboardX import SummaryWriter

from flambe.logging.datatypes import ScalarT, ScalarsT, HistogramT, TextT, \
    ImageT, EmbeddingT, PRCurveT, GraphT, DataLoggingFilter

[docs]class TensorboardXHandler(logging.Handler): """Implements Tensorboard message logging via TensorboardX Parameters ---------- writer : SummaryWriter Initialized TensorboardX Writer *args : Any Other positional args for `logging.Handler` **kwargs : Any Other kwargs for `logging.Handler` Attributes ---------- writer : SummaryWriter Initialized TensorboardX Writer """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # TODO remove dont_include when we add support for saving graph self.addFilter(DataLoggingFilter(default=False, dont_include=(GraphT,))) self.writers: Dict[str, SummaryWriter] = {}
[docs] def emit(self, record: logging.LogRecord) -> None: """Save to tensorboard logging directory Overrides `logging.Handler.emit` Parameters ---------- record : logging.LogRecord LogRecord with data relevant to Tensorboard Returns ------- None """ # Handler relies on access to raw objects which flambe logging # provides if not hasattr(record, "raw_msg_obj"): return message = record.raw_msg_obj # type: ignore # Check for a log directory from the logging context # This will be prepended to the final tag before saving to # Tensorboard if hasattr(record, "_tf_log_dir"): log_dir = record._tf_log_dir # type: ignore if log_dir in self.writers: writer = self.writers[log_dir] else: writer = SummaryWriter(log_dir=log_dir) hparams = getattr(record, "_tf_hparams", dict()) if len(hparams): writer.add_hparams_start(hparams=hparams) self.writers[log_dir] = writer else: return # Datatypes with a standard `tag` field if isinstance(message, (ScalarT, HistogramT, TextT, EmbeddingT, ImageT, PRCurveT)): kwargs = message._replace(tag=message.tag)._asdict() fn = { ScalarT: writer.add_scalar, HistogramT: writer.add_histogram, TextT: writer.add_text, EmbeddingT: writer.add_embedding, ImageT: writer.add_image, PRCurveT: writer.add_pr_curve } fn[message.__class__](**kwargs) # Datatypes with a special tag field elif isinstance(message, ScalarsT): kwargs = message._replace(main_tag=message.main_tag)._asdict() writer.add_scalars(**kwargs) # Datatypes without a tag field elif isinstance(message, GraphT): kwargs = message._asdict() for k, v in kwargs['kwargs']: kwargs[k] = v del kwargs['kwargs'] writer.add_model(**kwargs) writer.file_writer.flush()
[docs] def close(self) -> None: """Teardown writers and teardown super Returns ------- None """ # Use built-in writer `close` method to flush and close for _, w in self.writers.items(): w.add_hparams_end() w.close() super().close()
[docs] def flush(self) -> None: """Call flush on the Tensorboard writer Returns ------- None """ # No public `flush` method is available on the writer, so # copy the flushing logic from the TensorboardX `SummaryWriter` for _, w in self.writers.items(): if w.file_writer: w.file_writer.flush() for _, w in self.writers.items(): for path, writer in w.all_writers.items(): writer.flush() super().flush()