-
Notifications
You must be signed in to change notification settings - Fork 0
/
FEM.py
85 lines (63 loc) · 2.55 KB
/
FEM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
This module includes functions that compute the finite element method given the number of divisions of the grid and the width and height of the rectangle.
Functions:
- plot_solution(X, Y, u_h): Plot the finite element solution.
- FEM_solver(w, h, nx, ny, source_function): Solves a finite element problem using the FEM method.
"""
## Raw Finite Element implementation to solve Poisson equation in a 2D rectangular domain.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import mesh as mesh
import basis as c_basis
import basis_baryo as b_basis
import stiffness as stiffness
import force as force
import c_boundaries as boundaries
from scipy.sparse.linalg import spsolve
import final_eval as final_eval
def plot_solution(X_plot, Y_plot, xi):
"""
Plot the finite element solution.
Parameters:
X (ndarray): The X coordinates of the nodes in the mesh.
Y (ndarray): The Y coordinates of the nodes in the mesh.
u_h (ndarray): The finite element solution.
"""
plt.figure()
plt.contourf(X_plot, Y_plot, xi.reshape(X_plot.shape), levels = 10, cmap = 'inferno')
plt.colorbar()
plt.axis('equal')
plt.title('FE Solution $u_h(x,y)$')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
return None
def FEM_solver(w, h, nx, ny, source_function):
"""
Solves a finite element problem using the FEM method.
Parameters:
w (float): Width of the domain.
h (float): Height of the domain.
nx (int): Number of elements in the x-direction.
ny (int): Number of elements in the y-direction.
source_function (function): Function that defines the source term.
Returns:
xi (nparray): Solution vector.
"""
# Generate the mesh
nodes, triang_elements, num_nodes , num_elements, X_plot, Y_plot = mesh.generate_mesh(w, h, nx, ny)
# Calculate the global stiffness matrix
A_global = stiffness.calculate_global_stiffness(nodes, triang_elements, num_nodes, num_elements)
# Assemble the global load vector
F = force.assemble_load_vector(nodes, triang_elements, source_function)
# Apply Dirichlet boundary conditions
# Define the boundary of the mesh
mesh_bounds = [0, w, 0, h] # [x_min, x_max, y_min, y_max]
# Identify the boundary nodes
boundary_nodes = boundaries.identify_b_nodes_by_coord(nodes, mesh_bounds)
# Apply Dirichlet boundary conditions
A_final, F_final = boundaries.apply_dirichlet(A_global, F, boundary_nodes)
# Solve the linear system
xi = spsolve(A_final, F_final)
return xi, nodes