# Copyright (c) 2020, The InferLO authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 - see LICENSE file.
from __future__ import annotations
import math
from typing import List, Callable, Any, TYPE_CHECKING
from inferlo.base.factors.factor import Factor
if TYPE_CHECKING:
from inferlo.base.graph_model import GraphModel
[docs]
class FunctionFactor(Factor):
"""A factor given explicitly by function."""
[docs]
def __init__(self, model: GraphModel,
var_idx: List[int],
func: Callable[[List[float]], float]):
"""Create function factor.
:param model: Graphical model this factor belongs to.
:param var_idx: Indices of variables in the model, on which this
factor depends.
:param func: Function of this factor (as Python callable).
"""
super().__init__(model, var_idx)
self.func = func
[docs]
def value(self, values: List[float]):
return self.func(values)
# Numeric operations on factors.
[docs]
@staticmethod
def combine_factors(factor1: FunctionFactor, factor2: FunctionFactor,
func: Callable[
[float, float], float]) -> FunctionFactor:
"""Returns a factor which is a function of other 2 factors."""
assert factor1.model == factor2.model
# List of variable indices in the new factor.
new_idx = list(set(factor1.var_idx + factor2.var_idx))
# Maps position in new factor to variable index.
new_idx_rev = {new_idx[i]: i for i in range(len(new_idx))}
# Maps position in args list of first factor to position of the same
# variable in the new factor.
idx1 = [new_idx_rev[i] for i in factor1.var_idx]
# Maps position in args list of second factor to position of the same
# variable in the new factor.
idx2 = [new_idx_rev[i] for i in factor2.var_idx]
def new_func(all_args: List[float]) -> float:
first_factor_args = [all_args[i] for i in idx1]
second_factor_args = [all_args[i] for i in idx2]
return func(factor1.func(first_factor_args),
factor2.func(second_factor_args))
return FunctionFactor(factor1.model, new_idx, new_func)
[docs]
def apply_function(self, func: Callable[[float], float]):
"""Returns factor func(g(x)), where g(x) is given factor."""
return FunctionFactor(self.model, self.var_idx,
lambda x: func(self.func(x)))
[docs]
def combine_with(self, other: Any, func: Callable[[float, float], float]):
"""Returns factor func(g(x), other), where g(x) is given factor.
`other` may be a number, variable or another factor.
"""
if isinstance(other, (int, float)):
return self.apply_function(lambda x: func(x, other))
elif other.__class__.__name__ == 'FunctionFactor':
return FunctionFactor.combine_factors(self, other, func)
else:
raise TypeError(
'Cannot combine FunctionFactor with %s' % type(other))
def __add__(self, other: Any):
return self.combine_with(other, lambda x, y: x + y)
def __radd__(self, other: Any):
return self.combine_with(other, lambda x, y: x + y)
def __sub__(self, other: Any):
return self.combine_with(other, lambda x, y: x - y)
def __rsub__(self, other: Any):
return self.combine_with(other, lambda x, y: y - x)
def __mul__(self, other: Any):
return self.combine_with(other, lambda x, y: x * y)
def __rmul__(self, other: Any):
return self.combine_with(other, lambda x, y: x * y)
def __truediv__(self, other: Any):
return self.combine_with(other, lambda x, y: x / y)
def __rtruediv__(self, other: Any):
return self.combine_with(other, lambda x, y: y / x)
def __pow__(self, other):
return self.combine_with(other, lambda x, y: x ** y)
def __rpow__(self, other):
return self.combine_with(other, lambda x, y: y ** x)
def __neg__(self):
return self.apply_function(lambda x: -x)
def __abs__(self):
return self.apply_function(lambda x: abs(x))
[docs]
def exp(self):
"""Exponent of this factor."""
return self.apply_function(math.exp)
[docs]
def sin(self):
"""Sine of this factor."""
return self.apply_function(math.sin)
[docs]
def cos(self):
"""Sine of this factor."""
return self.apply_function(math.cos)