Source code for gec_metrics.analysis.attributor

from .base import AttributorBase
from .add import AttributorAdd
from .sub import AttributorSub
from .shapley import AttributorShapley
from .shapley_sampling import AttributorShapleySampling

__all__ = [
    "AttributorBase",
    "AttributorAdd",
    "AttributorSub",
    "AttributorShapley",
    "AttributorShapleySampling",
]

CLS = [
    AttributorAdd,
    AttributorSub,
    AttributorShapley,
    AttributorShapleySampling
]

NAME2CLS = {
    c.__name__.lower().replace('attributor', ''): c for c in CLS
}

[docs] def get_attributor_ids(): return sorted(list(NAME2CLS.keys()))
[docs] def get_attributor(name: str): assert name in get_attributor_ids() return NAME2CLS[name]