-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathformatter_fortran.py
More file actions
116 lines (90 loc) · 3.58 KB
/
formatter_fortran.py
File metadata and controls
116 lines (90 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations
from functools import singledispatchmethod
from formatter_base import Formatter
from chain import (
ONE, Function, Const, ConstName, Var, Sum, Prod, Times, Minus, Frac, Pow,
Sqrt, Inv, Sin, Cos, Arctan, Sgn, Abs, Mod, NamedFunction,
ExplicitFunction, OpaqueFunction, PiecewiseFunction
)
class FortranFormatter(Formatter):
"""Format Functions into simple Fortran-style code/expression.
This is intentionally conservative and aims for readable, valid
Fortran-like expressions (real signatures using real(8)).
"""
@property
def prefix_der(self) -> str: return self._prefix_der
def __init__(self, prefix_der: str = "_d") -> None:
self._prefix_der = prefix_der
# Repeating needed because @singledispatchmethod doesn't work well with inheritance
@singledispatchmethod
def format(self, f: Function) -> str:
if isinstance(f, PiecewiseFunction):
# Fortran doesn't have a conditional operator (ternary operator),
# so a piecewise function cannot be defined as an expression.
raise TypeError('PiecewiseFunction is not supported in regular Fortran expressions.')
return super().format(f)
@format.register
def _const(self, f: Const) -> str:
return f'{f.number}D0'
@format.register
def _const_name(self, f: ConstName) -> str:
return str(f.name)
@format.register
def _var(self, f: Var) -> str:
return f.name
@format.register
def _sum(self, f: Sum) -> str:
parts = [self.format(x) for x in f.args]
s = ' + '.join(parts)
s = s.replace('+ -', '- ')
return f'({s})'
@format.register
def _prod(self, f: Prod) -> str:
return '*'.join([self.format(x) for x in f.args])
@format.register
def _times(self, f: Times) -> str:
return f"{self.format(f.n)}*{self.format(f.f)}"
@format.register
def _minus(self, f: Minus) -> str:
return f"-{self.format(f.f)}"
@format.register
def _frac(self, f: Frac) -> str:
return f"({self.format(f.num)})/({self.format(f.den)})"
@format.register
def _pow(self, f: Pow) -> str:
# Fortran uses ** for power
if f.exp.number.is_integer():
return f"({self.format(f.base)})**{int(f.exp.number)}"
return f"({self.format(f.base)})**{self.format(f.exp)}"
@format.register
def _sqrt(self, f: Sqrt) -> str:
return f"sqrt({self.format(f.f)})"
@format.register
def _inv(self, f: Inv) -> str:
return f"1/({self.format(f.f)})"
@format.register
def _sin(self, f: Sin) -> str:
return f"sin({self.format(f.f)})"
@format.register
def _cos(self, f: Cos) -> str:
return f"cos({self.format(f.f)})"
@format.register
def _atan(self, f: Arctan) -> str:
return f"atan({self.format(f.f)})"
@format.register
def _sgn(self, f: Sgn) -> str:
return f"sign({self.format(ONE)}, {self.format(f.f)})"
@format.register
def _abs(self, f: Abs) -> str:
return f"abs({self.format(f.f)})"
@format.register
def _mod(self, f: Mod) -> str:
return f"modulo({self.format(f.f)}, {self.format(f.n)})"
def _get_function_str_with_ders(self, f: NamedFunction) -> str:
return self.prefix_der.join([f.name, *[str(v) for v in f.ders]])
@format.register
def _explicit_function(self, f: ExplicitFunction) -> str:
return self._get_function_str_with_ders(f)
@format.register
def _opaque_function(self, f: OpaqueFunction) -> str:
return self._get_function_str_with_ders(f)