forked from mk-fg/fgtk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pam-run
executable file
·312 lines (263 loc) · 10.7 KB
/
pam-run
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#!/usr/bin/env python3
import itertools as it, operator as op, functools as ft
import os, sys, errno, ctypes as c, enum, logging, signal, contextlib, pwd, grp
class LogMessage:
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), **log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
class pam_handle_t(c.Structure):
_fields_ = [('handle', c.c_void_p)]
class pam_message_t(c.Structure):
_fields_ = [('msg_style', c.c_int), ('msg', c.c_char_p)]
class pam_response_t(c.Structure):
_fields_ = [('resp', c.c_void_p), ('resp_retcode', c.c_int)]
pam_conv_func_t = c.CFUNCTYPE(
c.c_int, c.c_int, c.POINTER(c.POINTER(pam_message_t)),
c.POINTER(c.POINTER(pam_response_t)), c.c_void_p )
class pam_conv_t(c.Structure):
_fields_ = [
('conv', pam_conv_func_t),
('appdata_ptr', c.c_void_p) ]
class pam_t(enum.IntEnum):
service = 1
user = 2
tty = 3
rhost = 4
conv = 5
authtok = 6
oldauthtok = 7
ruser = 8
user_prompt = 9
fail_delay = 10
xdisplay = 11
xauthdata = 12
authtok_type = 13
success = 0
conv_err = 19
def force_str(s): return s.decode() if isinstance(s, bytes) else s
def force_bytes(s): return s.encode() if isinstance(s, str) else s
class c_str_p_type(object):
c_type = c.c_char_p
def __call__(self, val): return force_str(val)
def from_param(self, val): return force_bytes(val)
c_str_p = c_str_p_type()
class PAMError(Exception): pass
class PAMSession:
# func_def ::= arg_types_list | (arg_types_list, res_spec) | (res_spec, arg_types_list)
# res_spec ::= ctypes_restype
# | res_proc_func | (ctypes_restype, res_proc_func)
# | res_spec_name_str | (ctypes_restype, res_spec_name_str)
# res_spec_name_str ::= 'int_check_ge0' | ...
func_defs = dict(
strerror=([pam_handle_t, c.c_int], c_str_p),
start=([ c_str_p, c_str_p,
c.POINTER(pam_conv_t), c.POINTER(pam_handle_t) ]),
end=([pam_handle_t, c.c_int]),
set_item=([pam_handle_t, c.c_int, c_str_p]),
acct_mgmt=([pam_handle_t, c.c_int]),
open_session=([pam_handle_t, c.c_int]),
close_session=([pam_handle_t, c.c_int]),
get_item=([pam_handle_t, c.c_int, c.POINTER(c.c_char_p)]),
getenvlist=([pam_handle_t], c.POINTER(c.c_char_p)) )
service = user = tty = session = None
_lib = _pam_h = _pam_err = None
@classmethod
def _func_wrapper(cls, func_name, func, args=list(), res_proc=None):
func.restype, func.argtypes = None, args
if isinstance(res_proc, tuple): func.restype, res_proc = res_proc
if isinstance(res_proc, str): raise NotImplementedError
elif not func.restype:
if res_proc:
if hasattr(res_proc, 'c_type'): func.restype = res_proc.c_type
else: func.restype, res_proc = res_proc, None
else: func.restype = c.c_int
def _wrapper(*args):
res = func(*args)
if res_proc: res = res_proc(res)
return res
_wrapper.__name__ = 'libpam.{}'.format(func_name)
return _wrapper
_lib_pam = _libc = None
@classmethod
def _get_lib(cls):
if cls._lib_pam is None:
lib_pam = cls._lib_pam = c.CDLL('libpam.so.0')
for k, spec in cls.func_defs.items():
k = 'pam_{}'.format(k)
func, args, res_proc = getattr(lib_pam, k), None, None
if spec:
if not isinstance(spec, tuple): spec = (spec,)
for v in spec:
assert v, [k, spec, v]
if isinstance(v, list): args = v
else: res_proc = v
setattr(lib_pam, k, cls._func_wrapper(k, func, args, res_proc))
if cls._libc is None:
cls._libc = c.CDLL('libc.so.6')
cls._libc.malloc.restype = c.c_void_p
return cls._lib_pam
def __init__(self, service, user, tty=None, password=None, skip=None):
self.service, self.user, self.tty, self.pw = service, user, tty, password
self.skip = set(skip or list())
if self.pw is None: self.skip |= {'authenticate'}
self._lib = self._get_lib()
def _pam(self, func, *args, no_handle=False, raw=False):
if func in self.skip: return
if not no_handle: args = (self._pam_h,) + args
if not func.startswith('pam_'): func = 'pam_{}'.format(func)
log.debug('libpam call: {} {}', func, args)
res = getattr(self._lib, func)(*args)
if raw: return res
if res != pam_t.success:
self._pam_err = res
raise PAMError(res, self._lib.pam_strerror(self._pam_h, res))
def open(self):
self._pam_h = pam_handle_t()
self._conv = pam_conv_t(pam_conv_func_t(self.conv_func), None)
self._pam( 'start', self.service,
self.user, c.pointer(self._conv), c.pointer(self._pam_h), no_handle=True )
if self.tty: self._pam('set_item', pam_t.tty, self.tty)
self._pam('authenticate', 0)
self._pam('acct_mgmt', 0)
self._pam('open_session', 0)
self.session = True
# Get mapped user name - can be updated by one of the modules
user = c.c_char_p()
self._pam('get_item', pam_t.user, c.pointer(user))
self.user = user.value.decode()
envlist, n = self._pam('getenvlist', raw=True), 0
while envlist[n]:
k, v = envlist[n].decode().split('=', 1)
log.debug('Exporting env: {} = {!r}'.format(k, v))
os.environ[k] = v
n += 1
self._libc.free(envlist)
def close(self):
if self.session:
self._pam('close_session', 0)
self.session = None
if self._pam_h:
self._pam('end', self._pam_err or pam_t.success)
self._pam_h = self._pam_err = None
if self._lib: self._lib = None
def conv_func(self, num_msg, msg, resp, appdata_ptr):
if self.pw is not None:
# Pointers allocated in python doesn't work with free() in pam, hence libc.malloc
pw = self.pw.encode('utf-8')
pw_buff_len = len(pw) + 1
pw_buff = self._libc.malloc(pw_buff_len)
(c.c_char * pw_buff_len).from_address(pw_buff)[:] = pw + b'\0'
resp[0] = resp_struct_p = c.cast(
self._libc.malloc(c.sizeof(pam_response_t)), c.POINTER(pam_response_t) )
resp_struct = resp_struct_p.contents
resp_struct.resp = pw_buff
resp_struct.resp_retcode = 0
return pam_t.success
return pam_t.conv_err
def __enter__(self):
self.open()
return self
def __exit__(self, *err): self.close()
def __del__(self): self.close()
def parse_user_spec(spec):
try: uname, gid = spec.split(':', 1)
except ValueError: uname, gid = spec, None
try: uid = int(uname)
except ValueError: uent = None
else: uent = pwd.getpwuid(uid)
if uent is None: uent = pwd.getpwnam(uname)
uname, uid, home, shell = (
getattr(uent, 'pw_{}'.format(k)) for k in ['name', 'uid', 'dir', 'shell'] )
if not gid: gid = uent.pw_gid
else:
try: gid = int(spec)
except ValueError: gid = None
if gid is None: gid = grp.getgrnam(spec).gr_gid
return uname, uid, gid, home, shell
def main(args=None):
import argparse
parser = argparse.ArgumentParser(
description='Wrap specified command into PAM session.')
parser.add_argument('command', nargs='?', help='Binary (command) to run.')
parser.add_argument('args', nargs='*', help='Command arguments.')
parser.add_argument('-s', '--service',
default='system-login', metavar='pam-service-name',
help='PAM service name (matching'
' configuration file in /etc/pam.d) to use. Default: %(default)s')
parser.add_argument('-u', '--user', metavar='{uname|uid}[:{gname|gid}]',
help='Username to pass to PAM and switch to before running command.'
' Can be specified as uid. Group/gid to switch to can also be specified.'
' Default is to use current user/uid.')
parser.add_argument('-t', '--tty', metavar='dev-path-or-X-display',
help='PAM_TTY value to set for session. Should be something like'
' /dev/ttyX or :0 for console or X session, respectively. Not set by default.')
parser.add_argument('-n', '--no-env', action='store_true',
help='Do not set/update USER, HOME, PATH and such login env vars for --user.')
parser.add_argument('-x', '--no-session', action='store_true',
help='Disable open_session() call, which would usually start "user@..." and such.')
parser.add_argument('-p', '--password-file', metavar='path',
help='Read space-stripped first line of'
' a specified file as a password for PAM authentication.')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose operation mode.')
parser.add_argument('--debug',
action='store_true', help='Same as --verbose, but even more info.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
global log
if opts.debug: log = logging.DEBUG
elif opts.verbose: log = logging.INFO
else: log = logging.WARNING
logging.basicConfig(level=log)
log = get_logger('main')
pw, tty = None, opts.tty
if opts.password_file:
with open(opts.password_file) as src: pw = src.readline().strip()
sname = opts.service
if not opts.user: opts.user = os.getuid()
uname, uid, gid, home, shell = parse_user_spec(opts.user)
child_pid = None
sig_trap = dict(
(getattr(signal, 'SIG{}'.format(sig)), sig)
for sig in ['TERM', 'INT', 'HUP', 'USR1', 'USR2'] )
def sig_pass(sig_name, sig, frm):
if not child_pid: return
for sig in sig_trap: signal.signal(sig, lambda sig,frm: sys.exit(31)) # propagate once
with contextlib.suppress(OSError):
log.debug('Propagating SIG{} ({}) to subprocess (pid: {})', sig_name, sig, child_pid)
os.kill(child_pid, sig)
for sig, sig_name in sig_trap.items(): signal.signal(sig, ft.partial(sig_pass, sig_name))
log.info(
'Starting PAM session {!r} for user {!r} (uid={}, gid={}, tty={}, pw={})...',
sname, uname, uid, gid, tty, pw is not None )
skip = ['open_session'] if opts.no_session else []
with PAMSession(sname, uname, tty=tty, password=pw, skip=skip) as s:
if s.user != uname:
log.info('Updated uname from PAM: {!r} -> {!r}', uname, s.user)
uname, uid, gid, home, shell = parse_user_spec(s.user)
if not opts.no_env: # same basic env stuff that login(1) sets
os.environ.update(USER=uname, LOGNAME=uname, HOME=home, SHELL=shell)
os.environ['PATH'] = '/usr/local/bin:/opt/bin:/usr/bin'
cmd, args = os.path.expanduser(opts.command), list(opts.args or list())
child_pid = os.fork()
if not child_pid:
os.setresgid(gid, gid, gid)
os.setresuid(uid, uid, uid)
os.execlp(cmd, opts.command, *args)
log.debug( 'Started session subprocess'
' (pid: {}): {}', child_pid, ' '.join([opts.command] + args) )
pid, code = os.waitpid(child_pid, 0)
log.info( 'PAM session {!r} for user {!r} (uid={}, gid={})'
' was closed cleanly, command exit code: {}', sname, uname, uid, gid, code )
sig = code & 0xff
if sig: # emulate exit-on-signal, same as child_pid
if sig in sig_trap: signal.signal(sig, signal.SIG_DFL)
os.kill(os.getpid(), sig)
return code >> 8
if __name__ == '__main__': sys.exit(main())