Source code for openpnm.solvers._base
from numpy.linalg import norm
__all__ = ['BaseSolver', 'DirectSolver', 'IterativeSolver']
[docs]
class BaseSolver:
"""Base class for all solvers."""
def __init__(self):
...
[docs]
def solve(self):
"""Solves the given linear system of equations Ax=b."""
raise NotImplementedError
[docs]
class DirectSolver(BaseSolver):
"""Base class for all direct solvers."""
...
[docs]
class IterativeSolver(BaseSolver):
"""Base class for iterative solvers."""
def __init__(self, tol=1e-8, maxiter=1000):
self.tol = tol
self.maxiter = maxiter
self.atol = None # needs to be evaluated later
self.rtol = None # needs to be evaluated later
def _get_atol(self, b):
r"""
Returns the absolute tolerance ``atol`` that corresponds to the
the given tolerance ``tol``.
Notes
-----
``atol`` is calculated to satisfy the following stopping criterion:
``norm(A*x-b)`` <= ``atol``
"""
return norm(b) * self.tol
def _get_rtol(self, A, b, x0):
r"""
Returns the relative tolerance ``rtol`` that corresponds to the
the given tolerance ``tol``.
Notes
-----
``rtol`` is defined based on the following formula:
``rtol = residual(@x_final) / residual(@x0)``
"""
res0 = self._get_residual(A, b, x0)
atol = self._get_atol(b)
rtol = atol / res0
return rtol
def _get_residual(self, A, b, x):
r"""
Calculates the residual based on the given ``x`` using:
``res = norm(A*x - b)``
"""
return norm(A * x - b)