# Better use a conda env for this one:
# conda create -n pymc_env python=3.14 pymc arviz matplotlib
# # possibly conda init or source /usr/etc/profile.d/conda.sh
# conda activate pymc_env
# 


import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
from multiprocessing import freeze_support


def main():
    rng = np.random.default_rng(1234)

    # 1. Create synthetic data
    N = 30
    x = np.linspace(0, 10, N)

    a_true = 2.0
    b_true = -1.0
    sigma_true = 1.0

    y = a_true * x + b_true + rng.normal(0, sigma_true, size=N)

    # 2. Define Bayesian model
    with pm.Model() as model:
        a = pm.Normal("a", mu=0.0, sigma=5.0)
        b = pm.Normal("b", mu=0.0, sigma=5.0)
        sigma = pm.HalfCauchy("sigma", beta=2.0)

        mu = a * x + b
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

        # 3. Sample posterior
        trace = pm.sample(
            2000,
            tune=1000,
            chains=3,
            cores=3,
            target_accept=0.9,
            random_seed=1234,
            return_inferencedata=True,
        )

    print(az.summary(trace, var_names=["a", "b", "sigma"], round_to=3))

    a_samps = trace.posterior["a"].values.flatten()
    b_samps = trace.posterior["b"].values.flatten()
    sigma_samps = trace.posterior["sigma"].values.flatten()

    print("\nPosterior means:")
    print(f"a     ≈ {a_samps.mean():.3f}")
    print(f"b     ≈ {b_samps.mean():.3f}")
    print(f"sigma ≈ {sigma_samps.mean():.3f}")

    az.plot_trace(trace, var_names=["a", "b", "sigma"])
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8, 5))
    plt.scatter(x, y, label="data")
    plt.plot(x, a_true * x + b_true, linewidth=3, label="true line")
    plt.plot(x, a_samps.mean() * x + b_samps.mean(), linewidth=3, label="posterior mean")
    inds = rng.choice(len(a_samps), size=100, replace=False)
    for j in inds:
        plt.plot(x, a_samps[j] * x + b_samps[j], alpha=0.12)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title("Bayesian linear regression (PyMC)")
    plt.legend()
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    freeze_support()
    main()
    
