Source code for flambe.logging.logging

import sys
import os
import pathlib
import logging
from logging import handlers
from typing import Type, Any, List, AnyStr, Optional, Dict  # noqa: F401
from types import TracebackType

from tqdm import tqdm

from flambe.logging.handler.tensorboard import TensorboardXHandler
from flambe.logging.handler.contextual_file import ContextualFileHandler
from flambe.logging.datatypes import DataLoggingFilter
from flambe.const import FLAMBE_GLOBAL_FOLDER

[docs]MB = 2**20
[docs]def setup_global_logging(console_log_level: int = logging.NOTSET) -> None: """Set up flambe logging with a Stream handler and a Rotating File handler. This method should be set before consuming any logger as it sets the basic configuration for all future logs. After executing this method, all loggers will have the following handlers: * Stream handler: prints to std output all logs that above The console_log_level * Rotating File hanlder: 10MB log file located in Flambe global folder. Configured to store all logs (min level DEBUG) Parameters ---------- console_log_level: int The minimum log level for the Stream handler """ colorize_exceptions() logs_dir = os.path.join(FLAMBE_GLOBAL_FOLDER, 'logs') pathlib.Path(logs_dir).mkdir(parents=True, exist_ok=True) fh = handlers.RotatingFileHandler( os.path.join(FLAMBE_GLOBAL_FOLDER, 'logs', 'log.log'), maxBytes=10 * MB, backupCount=5 ) formatter = logging.Formatter( '%(asctime)s | %(levelname)-8s | %(name)-8s | %(lineno)04d | %(message)s' ) fh.setFormatter(formatter) tqdm_safe_out, tqdm_safe_err = map(TqdmFileWrapper, [sys.stdout, sys.stderr]) ch = logging.StreamHandler(stream=tqdm_safe_out) # type: ignore formatter = logging.Formatter('%(asctime)s | %(message)s', "%H:%M:%S") # Only flambe logs in stdout ch.addFilter(FlambeFilter()) ch.setLevel(console_log_level) ch.setFormatter(formatter) logging.captureWarnings(True) logging.basicConfig(level=logging.DEBUG, handlers=[fh, ch])
[docs]class FlambeFilter(logging.Filter): """Filter all log records that don't come from flambe or main. """
[docs] def filter(self, record: logging.LogRecord) -> bool: n = record.name return n.startswith("flambe") or n.startswith("__main__")
[docs]class TrialLogging: def __init__(self, log_dir: str, verbose: bool = False, root_log_level: Optional[int] = None, capture_warnings: bool = True, console_prefix: Optional[str] = None, hyper_params: Optional[Dict] = None) -> None: self.log_dir = log_dir self.verbose = verbose self.log_level = logging.NOTSET self.capture_warnings = capture_warnings self.listener: handlers.QueueListener self.console_prefix = console_prefix self.handlers: List[logging.Handler] = [] self.queue_handler: handlers.QueueHandler self.old_root_log_level: int = logging.NOTSET self.hyper_params: Dict = hyper_params or {}
[docs] def __enter__(self) -> logging.Logger: colorize_exceptions() logger = logging.root self.old_root_log_level = logger.level if self.log_level is not None: logger.setLevel(self.log_level) console_log_level = logging.NOTSET if self.verbose else logging.ERROR console_data_log_level = logging.NOTSET if self.verbose else logging.ERROR console_file_log_level = logging.NOTSET if self.verbose else logging.INFO tensorboard_log_level = logging.NOTSET if self.verbose else logging.INFO # CONSOLE LOGGING tqdm_safe_out, tqdm_safe_err = map(TqdmFileWrapper, [sys.stdout, sys.stderr]) console = logging.StreamHandler(stream=tqdm_safe_out) # type: ignore console.setLevel(console_log_level) console_formatter = logging.Formatter('%(name)s [block_%(_console_prefix)s] %(message)s') console.setFormatter(console_formatter) console_data_filter = DataLoggingFilter(level=console_data_log_level) console.addFilter(console_data_filter) self.handlers.append(console) # RECORD VERBOSE CONSOLE OUTPUT TO CONTEXT-SPECIFIC FILE console_splitter = ContextualFileHandler(canonical_name="console.out", mode='a') console_splitter.setLevel(console_file_log_level) console_splitter.setFormatter(console_formatter) self.handlers.append(console_splitter) # TENSORBOARDX LOGGING try: tbx = TensorboardXHandler() tbx.setLevel(tensorboard_log_level) self.handlers.append(tbx) except ModuleNotFoundError: print("TensorboardX not found. Disabling logging handler.") for handler in self.handlers: logger.addHandler(handler) # Route built-in Python warnings through our logger, defaulting # to WARN severity This means warnings that come from 3rd party # code e.g. Pandas, custom user code can be properly filtered # and logged logging.captureWarnings(self.capture_warnings) self.old_factory = logging.getLogRecordFactory() def record_factory(name, level, fn, lno, msg, args, exc_info, # type: ignore func=None, sinfo=None, **kwargs): record = self.old_factory(name, level, fn, lno, msg, args, exc_info, func=None, sinfo=None, **kwargs) # Always make raw message available by default so that # handlers can manipulate native objects instead of # string representations record.raw_msg_obj = msg # type: ignore return record logging.setLogRecordFactory(record_factory) logging.root._log_dir = self.log_dir # type: ignore self.context_filter = ContextInjection(_console_prefix=self.console_prefix, _tf_log_dir=self.log_dir, _tf_hparams=self.hyper_params, _console_log_dir=self.log_dir) for handler in logger.handlers: handler.addFilter(self.context_filter) logger.addFilter(self.context_filter) self.logger = logger return logger
[docs] def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: """Close the listener and restore original logging config""" for handler in self.logger.handlers: handler.removeFilter(self.context_filter) self.logger.removeFilter(self.context_filter) for handler in self.handlers: self.logger.removeHandler(handler) self.logger.setLevel(self.old_root_log_level) logging.setLogRecordFactory(self.old_factory) delattr(logging.root, '_log_dir')
[docs]class ContextInjection: """Add specified attributes to all log records Parameters ---------- **attrs : Any Attributes that should be added to all log records, for use in downstream handlers """ def __init__(self, **attrs) -> None: self.attrs = attrs
[docs] def filter(self, record: logging.LogRecord) -> int: for k, v in self.attrs.items(): setattr(record, k, v) return True
[docs] def __call__(self, record: logging.LogRecord) -> int: return self.filter(record)
[docs]class TqdmFileWrapper: """Dummy file-like that will write to tqdm Based on canoncial tqdm example """ def __init__(self, file: Any) -> None: self.file = file
[docs] def write(self, x: AnyStr) -> int: # Avoid print() second call (useless \n) if len(x.rstrip()) > 0: return tqdm.write(x, file=self.file) return 0
[docs] def flush(self) -> Any: return getattr(self.file, "flush", lambda: None)()
[docs]def colorize_exceptions() -> None: """Colorizes the system stderr ouput using pygments if installed""" try: import traceback from pygments import highlight from pygments.lexers import get_lexer_by_name from pygments.formatters import TerminalFormatter def colorized_excepthook(type_: Type[BaseException], value: BaseException, tb: TracebackType) -> None: tbtext = ''.join(traceback.format_exception(type_, value, tb)) lexer = get_lexer_by_name("pytb", stripall=True) formatter = TerminalFormatter() sys.stderr.write(highlight(tbtext, lexer, formatter)) sys.excepthook = colorized_excepthook # type: ignore except ModuleNotFoundError: pass