ぜろといち

ポンコツ理系大学院生による雑多なブログ

カルマンフィルタ実装してみた

カルマンフィルタ実装してみた。

import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

# Add Gaussian noise
def add_noise(x, mean=0, deviation=R, seed=None):

    if seed is not None:
        np.random.seed(seed=seed)

    shape = x.shape
    noise = np.random.normal(loc=mean, scale=deviation, size=shape)
    x_ = x + noise

    return x_

# Measurement time interval
dt = 1.

# Actual positions and velocity

v_ = 10. #deviation given to add_noise function in setting true_velocity

true_velocity = add_noise(np.ones((100,1),dtype=float), deviation=v_,seed=0)
true_positions = np.cumsum(true_velocity * dt)
true_positions = np.reshape(true_positions, [100,1])

true_x = np.concatenate((true_positions, true_velocity),axis=1)

#Observation

R = 1. #covariance of measurement noise
observed_x = add_noise(true_positions, deviation=R, seed=0)

#Initialization

z = deepcopy(observed_x)
x_init = true_x[0]
P_init = np.array([[0., 0.], [0., 0.]])
A = np.array([[1., dt], [0., 1.]])
Q = [0.1, 0.1] #process noise covariance
H = np.array([1., 0.]) #Observation Matrix: Extract observed values

#Kalman Filter
def kalman_filter(z, x_init, P_init, Q, R):
    
    #init_step
    x_filtered = np.array([x_init])
    P_filtered = np.array([P_init])   
    I = np.identity((len(x_init)))

    for t in range(1, len(z)):
        #Time Update(Predict)
        x = np.dot(A, x_filtered[t-1])
        P = np.dot(np.dot(A, P_filtered[t-1]), P_filtered[t-1].T) + Q

        #Measurement Update(Correct)
        K = np.dot(P, H.T) / (np.dot(np.dot(H, P), H.T) + R)
        y = z[t] - np.dot(H, x)
        x = x + (K * y)
        P = np.dot((I - np.dot(K, H)), P)

        #append filtered values
        x_filtered = np.append(x_filtered, np.array([x]), axis=0)
        P_filtered = np.append(P_filtered, np.array([P]), axis=0)
    
    return x_filtered, P_filtered

filtered_x = kalman_filter(z, x_init, P_init, Q=Q, R=R)[0]

plt.plot(true_x[:,0], 'r--', label='Actual Positions')
plt.plot(true_x[:,1], 'b--', label='Actual Velocity')
plt.plot(observed_x[:,0], 'y', label='Measured Positions')
plt.plot(filtered_x[:,0], 'gray', label='Estimated Positions')
plt.plot(filtered_x[:,1], 'g', label='Estimated Velocity')
plt.title('Positions of tolloco')
plt.xlabel('time step')
plt.ylabel('position')
plt.legend(loc='best')

描画結果 f:id:ray2480:20181118205016p:plain

なんか思ったほどうまくできてない。説明などが色々必要なのだが、後ほど更新する。