From 2cca7cbf07d20fb8ab5677c3e9e85e9420989fa9 Mon Sep 17 00:00:00 2001 From: Tim Date: Wed, 16 Oct 2024 13:48:12 +0200 Subject: [PATCH] added optional to set custom residual function in base filter, mainly for angular states/measurements --- filterpy2/kalman/kalman_filter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/filterpy2/kalman/kalman_filter.py b/filterpy2/kalman/kalman_filter.py index 48a291ec..2820a78d 100644 --- a/filterpy2/kalman/kalman_filter.py +++ b/filterpy2/kalman/kalman_filter.py @@ -123,8 +123,9 @@ from copy import deepcopy from math import log, exp, sqrt import sys +from typing import Callable import numpy as np -from numpy import dot, zeros, eye, isscalar, shape +from numpy import dot, zeros, eye, isscalar, shape, subtract, typing as npt import numpy.linalg as linalg from filterpy2.stats import logpdf from filterpy2.common import pretty_str, reshape_z @@ -384,7 +385,7 @@ class KalmanFilter(object): """ - def __init__(self, dim_x, dim_z, dim_u=0): + def __init__(self, dim_x: int, dim_z: int, dim_u: int=0, residual_z_fn: Callable[[npt.NDArray], npt.NDArray]=subtract): if dim_x < 1: raise ValueError("dim_x must be 1 or greater") if dim_z < 1: @@ -406,6 +407,7 @@ def __init__(self, dim_x, dim_z, dim_u=0): self._alpha_sq = 1.0 # fading memory control self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation self.z = np.array([[None] * self.dim_z]).T + self.residual_z_fn = residual_z_fn # gain and residual are computed during the innovation step. We # save them so that in case you want to inspect them for various @@ -527,7 +529,7 @@ def update(self, z, R=None, H=None): # y = z - Hx # error (residual) between measurement and prediction - self.y = z - dot(H, self.x) + self.y = self.residual_z_fn(z, dot(H, self.x)) # common subexpression for speed PHT = dot(self.P, H.T)