Skip to content

Commit

Permalink
Made ggev functions as non-scalar functions and support for lapack-sr…
Browse files Browse the repository at this point in the history
…c v0.13 (#28)

* Made ggev functions as non-scalar functions and fixed the issue with lsame

Arguments are now in lower case

alpha in larfg is corrected as a scalar

* Unify the rules in is_scalar

* Remove redundant formatting

* Bump the version number

* Exclude lsame

* Remove a redundancy

* Update lapack-sys

* Regenerate the functions

* Refactor the generator

* Unify the rules in is_scalar

* Fix vl and vr

* Fix k+

* Fix dif

Co-authored-by: Ivan Ukhov <ivan.ukhov@gmail.com>
  • Loading branch information
selvavm and IvanUkhov authored Apr 18, 2021
1 parent 4d31631 commit 0276f89
Show file tree
Hide file tree
Showing 6 changed files with 27,257 additions and 21,315 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
[package]
name = "lapack"
version = "0.18.0"
version = "0.19.0"
license = "Apache-2.0/MIT"
authors = [
"Andrew Straw <strawman@astraw.com>",
"Crozet Sébastien <developer@crozet.re>",
"David Greenberg <dsg123456789@gmail.com>",
"Ivan Ukhov <ivan.ukhov@gmail.com>",
"Pavel Potocek <pavelpotocek@gmail.com>",
"Selvavignesh Vedamanickam <selvavm@hotmail.com>",
"Toshiki Teramura <toshiki.teramura@gmail.com>",
]
description = "The package provides wrappers for LAPACK (Fortran)."
Expand All @@ -26,5 +27,5 @@ version = "0.4"
default-features = false

[dependencies.lapack-sys]
version = "0.12"
version = "0.14"
default-features = false
6 changes: 2 additions & 4 deletions bin/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
return_re = re.compile('(?:\s*->\s*([^;]+))?')


class Function(object):
class Function():

def __init__(self, name, args, ret):
self.name = name
Expand All @@ -25,8 +25,6 @@ def parse(line):
arg, aty, line = pull_argument(line)
if arg is None:
break
if arg == 'matrix_layout':
arg = 'layout'
args.append((arg, aty))
line = line.strip()

Expand Down Expand Up @@ -55,7 +53,7 @@ def pull_return(s):
return match.group(1), s[match.end(1):]


def read_functions(path):
def read(path):
lines = []
with open(path) as file:
append = False
Expand Down
164 changes: 90 additions & 74 deletions bin/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import re

from function import Function
from function import read_functions
from function import read

select_re = re.compile('LAPACK_(\w)_SELECT(\d)')


def is_scalar(name, cty, f):
return (
return ( \
'c_char' in cty or
name in [
'abnrm',
Expand All @@ -20,7 +20,6 @@ def is_scalar(name, cty, f):
'anorm',
'bbnrm',
'colcnd',
'dif',
'ihi',
'il',
'ilo',
Expand All @@ -45,29 +44,80 @@ def is_scalar(name, cty, f):
'tryrac',
'vu',
] or
name == 'q' and 'lapack_int' in cty or
not (
name in [
'alpha',
] and (
'larfg' in f.name
) or
name in [
'dif',
] and not (
'tgsen' in f.name or
'tgsna' in f.name
) or
name in [
'p',
] and not (
'tgevc' in f.name
) or
name in [
'q'
] and (
'lapack_int' in cty
) or
name in [
'vl',
'vr',
] and not (
'geev' in f.name or
'ggev' in f.name or
'hsein' in f.name or
'tgevc' in f.name or
'tgsna' in f.name or
'trevc' in f.name or
'trsna' in f.name
) and name in [
'vl',
'vr',
] or
not ('tgevc' in f.name) and name in [
'p',
] or
name.startswith('alpha') or
name.startswith('beta') or
) or
name.startswith('k') and not (
'lapmr' in f.name or
'lapmt' in f.name
) or
name.startswith('inc') or
name.startswith('k') or
name.startswith('ld') or
name.startswith('tol') or
name.startswith('vers')
)


def translate_argument(name, cty, f):
def translate_name(name):
return name.lower()


def translate_base_type(cty):
cty = cty.replace('__BindgenComplex<f32>', 'lapack_complex_float')
cty = cty.replace('__BindgenComplex<f64>', 'lapack_complex_double')
cty = cty.replace('lapack_float_return', 'c_float')
cty = cty.replace('f32', 'c_float')
cty = cty.replace('f64', 'c_double')

if 'c_char' in cty:
return 'u8'
elif 'c_int' in cty:
return 'i32'
elif 'c_float' in cty:
return 'f32'
elif 'c_double' in cty:
return 'f64'
elif 'lapack_complex_float' in cty:
return 'c32'
elif 'lapack_complex_double' in cty:
return 'c64'
elif 'size_t' in cty:
return 'size_t'

assert False, 'cannot translate `{}`'.format(cty)


def translate_signature_type(name, cty, f):
m = select_re.match(cty)
if m is not None:
if m.group(1) == 'S':
Expand All @@ -79,7 +129,7 @@ def translate_argument(name, cty, f):
elif m.group(1) == 'Z':
return 'Select{}C64'.format(m.group(2))

base = translate_type_base(cty)
base = translate_base_type(cty)
if '*const' in cty:
if is_scalar(name, cty, f):
return base
Expand All @@ -94,30 +144,6 @@ def translate_argument(name, cty, f):
return base


def translate_type_base(cty):
cty = cty.replace('__BindgenComplex<f32>', 'lapack_complex_float')
cty = cty.replace('__BindgenComplex<f64>', 'lapack_complex_double')
cty = cty.replace('f32', 'c_float')
cty = cty.replace('f64', 'c_double')

if 'c_char' in cty:
return 'u8'
elif 'c_int' in cty:
return 'i32'
elif 'c_float' in cty:
return 'f32'
elif 'c_double' in cty:
return 'f64'
elif 'lapack_complex_float' in cty:
return 'c32'
elif 'lapack_complex_double' in cty:
return 'c64'
elif 'size_t' in cty:
return 'libc::c_ulong'

assert False, 'cannot translate `{}`'.format(cty)


def translate_body_argument(name, rty):
if rty.startswith('Select'):
return 'transmute({})'.format(name)
Expand Down Expand Up @@ -154,66 +180,56 @@ def translate_body_argument(name, rty):
elif rty.startswith('&mut [c'):
return '{}.as_mut_ptr() as *mut _'.format(name)

elif rty.startswith('libc::'):
return '&{}'.format(name)
elif rty == 'size_t':
return name

assert False, 'cannot translate `{}: {}`'.format(name, rty)


def translate_return_type(cty):
cty = cty.replace('lapack_float_return', 'c_float')
cty = cty.replace('f64', 'c_double')

if cty == 'c_int':
return 'i32'
elif cty == 'c_float':
return 'f32'
elif cty == 'c_double':
return 'f64'

assert False, 'cannot translate `{}`'.format(cty)


def format_header(f):
args = format_header_arguments(f)
def format_signature(f):
args = format_signature_arguments(f)
if f.ret is None:
return 'pub unsafe fn {}({})'.format(f.name, args)
else:
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args,
translate_return_type(f.ret))


def format_body(f):
return 'ffi::{}_({})'.format(f.name, format_body_arguments(f))
translate_base_type(f.ret))


def format_header_arguments(f):
def format_signature_arguments(f):
s = []
for arg in f.args:
s.append('{}: {}'.format(arg[0], translate_argument(*arg, f=f)))
for name, cty in f.args:
name = translate_name(name)
s.append('{}: {}'.format(name, translate_signature_type(name, cty, f)))
return ', '.join(s)


def format_body(f):
return 'ffi::{}_({})'.format(f.name, format_body_arguments(f))


def format_body_arguments(f):
s = []
for arg in f.args:
rty = translate_argument(*arg, f=f)
s.append(translate_body_argument(arg[0], rty))
for name, cty in f.args:
name = translate_name(name)
rty = translate_signature_type(name, cty, f)
s.append(translate_body_argument(name, rty))
return ', '.join(s)


def prepare(code):
def process(code):
lines = filter(lambda line: not re.match(r'^\s*//.*', line),
code.split('\n'))
lines = re.sub(r'\s+', ' ', ''.join(lines)).strip().split(';')
lines = filter(lambda line: not re.match(r'^\s*$', line), lines)
return [Function.parse(line) for line in lines]


def do(functions):
def write(functions):
for f in functions:
if f.name in ['lsame']:
continue
print('\n#[inline]')
print(format_header(f) + ' {')
print(format_signature(f) + ' {')
print(' ' + format_body(f) + '\n}')


Expand All @@ -222,4 +238,4 @@ def do(functions):
parser.add_argument('--sys', default='lapack-sys')
arguments = parser.parse_args()
path = os.path.join(arguments.sys, 'src', 'lapack.rs')
do(prepare(read_functions(path)))
write(process(read(path)))
2 changes: 1 addition & 1 deletion lapack-sys
Loading

0 comments on commit 0276f89

Please sign in to comment.