Source code for revscoring.scoring.statistics.classification.micro_macro_stats
import logging
from tabulate import tabulate
from ... import util
from ...model_info import ModelInfo
logger = logging.getLogger(__name__)
MAX_COLUMNS_WIDTH_CHARS = 80
[docs]class MicroMacroStats(ModelInfo):
def __init__(self, stats, field):
"""
Constructs a micro-average and macro-average for a specific statistic
based on the name. Works like a dictionary with fields
* micro : the micro-average
* macro : the macro-average
* labels : a mapping of labels to their individual statistics
:Parameters:
""" # noqa
super().__init__()
self.field = field
try:
self['micro'] = (
sum(lstats[field] * lstats.trues
for lstats in stats.values()) /
sum(lstats.trues for lstats in stats.values()))
except Exception as e:
logger.warn("Could not generate micro-average of {0}: {1}"
.format(field, str(e)))
self['micro'] = None
try:
self['macro'] = (
sum(lstats[field] for lstats in stats.values()) /
len(stats))
except Exception as e:
logger.warn("Could not generate macro-average of {0}: {1}"
.format(field, str(e)))
self['macro'] = None
self['labels'] = {label: lstats[field]
for label, lstats in stats.items()}
def format_str(self, path_tree, ndigits=3, **kwargs):
if len(path_tree) > 0:
logger.warn("Ignoring path_tree at {0!r}".format(path_tree))
formatted = "{0} (micro={1}, macro={2}):\n" \
.format(self.field,
util.round(self['micro'], ndigits=ndigits),
util.round(self['macro'], ndigits=ndigits))
table_str = self.format_label_table(ndigits)
formatted += util.tab_it_in(table_str)
return formatted
def format_json(self, path_tree, ndigits=3):
if len(path_tree) > 0:
logger.warn("Ignoring path_tree at {0!r}".format(path_tree))
return {
'micro': util.round(self['micro'], ndigits),
'macro': util.round(self['macro'], ndigits),
'labels': {l: util.round(self['labels'][l], ndigits)
for l in self['labels']}
}
def format_label_table(self, ndigits):
column_header_width = sum(max(len(str(l)) + 2, ndigits + 4)
for l in self['labels'])
if column_header_width < MAX_COLUMNS_WIDTH_CHARS:
return self.format_column_major_table(ndigits)
else:
return self.format_row_major_table(ndigits)
def format_row_major_table(self, ndigits):
return tabulate(
[[l, util.round(stat, ndigits)]
for l, stat in self['labels'].items()])
def format_column_major_table(self, ndigits):
return tabulate(
[[util.round(stat, ndigits)
for l, stat in self['labels'].items()]],
headers=self['labels'].keys())