"""Solve the 2D heat equation using a finite difference explicit scheme"""

# Inspired by https://levelup.gitconnected.com/solving-2d-heat-equation-numerically-using-python-3334004aa01a

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation


class Plot:

    def __init__(self, heat, grid=False):
        self.heat = heat
        self.grid = grid
        self.fig = plt.figure(figsize=(10, 8))
        self.ax = plt.subplot()
        self.h = self.ax.imshow(self.heat.u, cmap=plt.cm.jet,
                                vmin=0., vmax=1.)
        cb = self.fig.colorbar(self.h, ax=self.ax)

        self.set_title(0)
        self.ax.set_xlabel("x")
        self.ax.set_ylabel("y")
        self.ax.axis('scaled')
        if self.grid:
            # Minor ticks
            nx, ny = self.heat.u.shape
            self.ax.set_xticks(np.arange(-.5, nx, 1), minor=True)
            self.ax.set_yticks(np.arange(-.5, ny, 1), minor=True)
            self.ax.grid(which='minor', color='w', linestyle='-', linewidth=2)

    def set_title(self, n):
        self.ax.set_title(f"Température à t = {n}")

    def update(self, n, draw=False):
        while self.heat.n < n:
            self.heat.one_step()
        self.h.set_data(self.heat.u)
        self.set_title(n)

        if draw:
            self.fig.canvas.draw()
            plt.draw()
            plt.pause(0.01)


class Heat:

    def __init__(self, nx, nmax):
        self.nx = nx
        self.ny = self.nx   # Square
        self.nmax = nmax
        self.initialize()

    def initialize(self):

        # Initial condition: u(t=0) = 0
        self.u = np.zeros((self.nx + 2, self.ny + 2))

        # Boundary conditions u_BC = 1
        self.u[-1, :] = 1.  # top
        self.u[:, 0] = 1.  # left
        self.u[0, :] = 1.  # bottom
        self.u[:, -1] = 1.  # right

        self.n = 0  # time iteration

    def one_step(self):
        u = self.u
        self.n += 1
        u[1:-1, 1:-1] = 0.25 * (
            u[2:, 1:-1] + u[:-2, 1:-1] +
            u[1:-1, 2:] + u[1:-1, :-2]
        )
        return u

    def get_animation(self, grid=False):
        print("Generating animation...", end='')
        self.initialize()
        plot_anim = Plot(self, grid=grid)
        anim = animation.FuncAnimation(
            plot_anim.fig,
            plot_anim.update,
            interval=50,
            blit=False,
            frames=range(0, self.nmax + 1, max(int(self.nmax / 100), 1)),
            repeat=False)
        plt.close(anim._fig)
        return anim

    def get_center(self):
        nx, ny = self.u.shape
        ix = nx // 2
        iy = ny // 2

        self.initialize()
        center = np.empty((self.nmax + 1, ), dtype=np.float64)
        for i in range(self.nmax + 1):
            center[i] = self.u[ix, iy]
            self.one_step()
        return center

    def plot_center(self):
        """Plot time evolution of temperature on center point"""

        plt.figure()
        ax = plt.subplot()
        ax.set_title("Température du centre en fonction du temps")
        ax.set_xlabel("t")
        ax.set_ylabel("T")
        self.initialize()
        ax.plot(list(range(self.nmax + 1)), self.get_center())
        plt.show()


if __name__ == '__main__':
    heat = Heat(nx=7, nmax=50)
    anim = heat.get_animation(grid=True)
    anim.save("heat2d.gif", dpi=80, writer='imagemagick')
    heat.plot_center()
