63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
# pylint:disable=invalid-name
|
|
import random
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
SAMPLING_M = 1000
|
|
UNIFORM = False
|
|
|
|
def normal(x, mean, std):
|
|
return (
|
|
(np.exp(-(((x - mean)**2)/(2 * std**2))))/(np.sqrt(2 * np.pi) * std)
|
|
)
|
|
|
|
def p(x):
|
|
N1 = 0.3 * normal(x, 2.0, 1.0)
|
|
N2 = 0.4 * normal(x, 5.0, 2.0)
|
|
N3 = 0.3 * normal(x, 9.0, 1.0)
|
|
return N1 + N2 + N3
|
|
|
|
def gen_weights_uniform(samples):
|
|
return [p(x)/(1/15) for x in samples]
|
|
|
|
def gen_weights_normal(samples):
|
|
return [p(x)/(normal(x,5,4)) for x in samples]
|
|
|
|
def make_histogram(samples, new_samples):
|
|
fig, axs = plt.subplots(1, 2, sharey=True, tight_layout=True)
|
|
axs[0].hist(samples, color=(0,0,1), bins=25, range=(-5,20), density=True)
|
|
axs[1].hist(new_samples, color=(1,0,0), bins=25, range=(-5,20), density=True)
|
|
x_points = [x/10 for x in range(-40,160)]
|
|
p_points = [p(x) for x in x_points]
|
|
axs[1].plot(x_points, p_points, color=(0,0,0))
|
|
fig.savefig(f"histograms/{['normal','uniform'][UNIFORM]}-{SAMPLING_M}.png")
|
|
|
|
def uniform_sampling():
|
|
# Sampling
|
|
samples = np.random.uniform(0, 15, SAMPLING_M)
|
|
|
|
# Importance
|
|
weights = gen_weights_uniform(samples)
|
|
|
|
# Resampling
|
|
new_samples = random.choices(samples, weights, k=SAMPLING_M)
|
|
make_histogram(samples, new_samples)
|
|
|
|
def normal_sampling():
|
|
# Sampling
|
|
samples = np.random.normal(5,4, size=SAMPLING_M)
|
|
|
|
# Importance
|
|
weights = gen_weights_normal(samples)
|
|
|
|
# Resampling
|
|
new_samples = random.choices(samples, weights, k=SAMPLING_M)
|
|
make_histogram(samples, new_samples)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if UNIFORM:
|
|
uniform_sampling()
|
|
else:
|
|
normal_sampling()
|