Skip to content

Commit

Permalink
Merge branch 'sonic-net:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
rameshraghupathy authored Nov 11, 2024
2 parents 093bf00 + ff73070 commit 1dbf35e
Show file tree
Hide file tree
Showing 9 changed files with 753 additions and 18 deletions.
135 changes: 135 additions & 0 deletions host_modules/image_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
This module provides services related to SONiC images, including:
1) Downloading images
2) Installing images
3) Calculating checksums for images
"""

import errno
import hashlib
import logging
import os
import requests
import stat
import subprocess

from host_modules import host_service
import tempfile

MOD_NAME = "image_service"

DEFAULT_IMAGE_SAVE_AS = "/tmp/downloaded-sonic.bin"

logger = logging.getLogger(__name__)


class ImageService(host_service.HostModule):
"""DBus endpoint that handles downloading and installing SONiC images"""

@host_service.method(
host_service.bus_name(MOD_NAME), in_signature="ss", out_signature="is"
)
def download(self, image_url, save_as):
"""
Download a SONiC image.
Args:
image_url: url for remote image.
save_as: local path for the downloaded image. The directory must exist and be *all* writable.
"""
logger.info("Download new sonic image from {} as {}".format(image_url, save_as))
# Check if the directory exists, is absolute and has write permission.
if not os.path.isabs(save_as):
logger.error("The path {} is not an absolute path".format(save_as))
return errno.EINVAL, "Path is not absolute"
dir = os.path.dirname(save_as)
if not os.path.isdir(dir):
logger.error("Directory {} does not exist".format(dir))
return errno.ENOENT, "Directory does not exist"
st_mode = os.stat(dir).st_mode
if (
not (st_mode & stat.S_IWUSR)
or not (st_mode & stat.S_IWGRP)
or not (st_mode & stat.S_IWOTH)
):
logger.error("Directory {} is not all writable {}".format(dir, st_mode))
return errno.EACCES, "Directory is not all writable"
try:
response = requests.get(image_url, stream=True)
if response.status_code != 200:
logger.error(
"Failed to download image: HTTP status code {}".format(
response.status_code
)
)
return errno.EIO, "HTTP error: {}".format(response.status_code)

with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as tmp_file:
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
temp_file_path = tmp_file.name
os.replace(temp_file_path, save_as)
return 0, "Download successful"
except Exception as e:
logger.error("Failed to write downloaded image to disk: {}".format(e))
return errno.EIO, str(e)

@host_service.method(
host_service.bus_name(MOD_NAME), in_signature="s", out_signature="is"
)
def install(self, where):
"""
Install a a sonic image:
Args:
where: either a local path or a remote url pointing to the image.
"""
logger.info("Using sonic-installer to install the image at {}.".format(where))
cmd = ["/usr/local/bin/sonic-installer", "install", "-y", where]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
msg = ""
if result.returncode:
lines = result.stderr.decode().split("\n")
for line in lines:
if "Error" in line:
msg = line
break
return result.returncode, msg

@host_service.method(
host_service.bus_name(MOD_NAME), in_signature="ss", out_signature="is"
)
def checksum(self, file_path, algorithm):
"""
Calculate the checksum of a file.
Args:
file_path: path to the file.
algorithm: checksum algorithm to use (sha256, sha512, md5).
"""

logger.info("Calculating {} checksum for file {}".format(algorithm, file_path))

if not os.path.isfile(file_path):
logger.error("File {} does not exist".format(file_path))
return errno.ENOENT, "File does not exist"

hash_func = None
if algorithm == "sha256":
hash_func = hashlib.sha256()
elif algorithm == "sha512":
hash_func = hashlib.sha512()
elif algorithm == "md5":
hash_func = hashlib.md5()
else:
logger.error("Unsupported algorithm: {}".format(algorithm))
return errno.EINVAL, "Unsupported algorithm"

try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_func.update(chunk)
return 0, hash_func.hexdigest()
except Exception as e:
logger.error("Failed to calculate checksum: {}".format(e))
return errno.EIO, str(e)
22 changes: 22 additions & 0 deletions host_modules/systemd_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Systemd service handler"""

from enum import Enum
from host_modules import host_service
import subprocess

MOD_NAME = 'systemd'
ALLOWED_SERVICES = ['snmp', 'swss', 'dhcp_relay', 'radv', 'restapi', 'lldp', 'sshd', 'pmon', 'rsyslog', 'telemetry']
EXIT_FAILURE = 1

# Define an Enum for Reboot Methods which are defined as in
# https://github.com/openconfig/gnoi/blob/main/system/system.pb.go#L27
class RebootMethod(Enum):
COLD = 1
HALT = 3

class SystemdService(host_service.HostModule):
"""
Expand Down Expand Up @@ -48,3 +54,19 @@ def stop_service(self, service):
if result.returncode:
msg = result.stderr.decode()
return result.returncode, msg

@host_service.method(host_service.bus_name(MOD_NAME), in_signature='i', out_signature='is')
def execute_reboot(self, rebootmethod):
if rebootmethod == RebootMethod.COLD:
cmd = ['/usr/local/bin/reboot']
elif rebootmethod == RebootMethod.HALT:
cmd = ['/usr/local/bin/reboot','-p']
else:
return EXIT_FAILURE, "{}: Invalid reboot method: {}".format(MOD_NAME, rebootmethod)

result = subprocess.run(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
msg = ''
if result.returncode:
msg = result.stderr.decode()

return result.returncode, msg
92 changes: 92 additions & 0 deletions scripts/hostcfgd
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,84 @@ class SerialConsoleCfg:
return


class BannerCfg(object):
"""
Banner Config Daemon
Handles changes in BANNER_MESSAGE table.
1) Handle change of feature state
2) Handle change of login message
3) Handle change of MOTD message
4) Handle change of logout message
"""

def __init__(self):
self.cache = {}

def load(self, banner_messages_config: dict):
"""Banner messages configuration
Force load banner configuration. Login messages should be taken at boot-time by
SSH daemon.
Args:
banners_message_config: Configured banner messages.
"""

syslog.syslog(syslog.LOG_INFO, 'BannerCfg: load initial')

if not banner_messages_config:
banner_messages_config = {}

# Force load banner messages.
# Login messages show be taken at boot-time by SSH daemon.
state_data = banner_messages_config.get("state", {})
login_data = banner_messages_config.get("login", {})
motd_data = banner_messages_config.get("motd", {})
logout_data = banner_messages_config.get("logout", {})

self.banner_message("state", state_data)
self.banner_message("login", login_data)
self.banner_message("motd", motd_data)
self.banner_message("logout", logout_data)

def banner_message(self, key, data):
"""
Apply banner message handler.
Args:
cache: Cache to compare/save data.
db: DB instance.
table: DB table that was changed.
key: DB table's key that was triggered change.
data: Read table data.
"""
# Handling state, login/logout and MOTD messages. Data should be a dict
if type(data) != dict:
# Nothing to handle
return

update_required = False
# Check with cache
for k,v in data.items():
if v != self.cache.get(k):
update_required = True
break

if update_required == False:
return

try:
run_cmd(["systemctl", "restart", "banner-config"], True, True)
except Exception:
syslog.syslog(syslog.LOG_ERR, 'BannerCfg: Failed to restart '
'banner-config service')
return

# Update cache
for k,v in data.items():
self.cache[k] = v


class HostConfigDaemon:
def __init__(self):
self.state_db_conn = DBConnector(STATE_DB, 0)
Expand Down Expand Up @@ -1803,6 +1881,9 @@ class HostConfigDaemon:
# Initialize SerialConsoleCfg
self.serialconscfg = SerialConsoleCfg()

# Initialize BannerCfg
self.bannermsgcfg = BannerCfg()

def load(self, init_data):
aaa = init_data['AAA']
tacacs_global = init_data['TACPLUS']
Expand All @@ -1826,6 +1907,7 @@ class HostConfigDaemon:
ntp_servers = init_data.get(swsscommon.CFG_NTP_SERVER_TABLE_NAME)
ntp_keys = init_data.get(swsscommon.CFG_NTP_KEY_TABLE_NAME)
serial_console = init_data.get('SERIAL_CONSOLE', {})
banner_messages = init_data.get(swsscommon.CFG_BANNER_MESSAGE_TABLE_NAME)

self.aaacfg.load(aaa, tacacs_global, tacacs_server, radius_global, radius_server, ldap_global, ldap_server)
self.iptables.load(lpbk_table)
Expand All @@ -1839,6 +1921,8 @@ class HostConfigDaemon:
self.fipscfg.load(fips_cfg)
self.ntpcfg.load(ntp_global, ntp_servers, ntp_keys)
self.serialconscfg.load(serial_console)
self.bannermsgcfg.load(banner_messages)

self.pamLimitsCfg.update_config_file()

# Update AAA with the hostname
Expand Down Expand Up @@ -1992,6 +2076,10 @@ class HostConfigDaemon:
syslog.syslog(syslog.LOG_INFO, 'SERIAL_CONSOLE table handler...')
self.serialconscfg.update_serial_console_cfg(key, data)

def banner_handler(self, key, op, data):
syslog.syslog(syslog.LOG_INFO, 'BANNER_MESSAGE table handler...')
self.bannermsgcfg.banner_message(key, data)

def wait_till_system_init_done(self):
# No need to print the output in the log file so using the "--quiet"
# flag
Expand Down Expand Up @@ -2059,6 +2147,10 @@ class HostConfigDaemon:
self.config_db.subscribe(swsscommon.CFG_NTP_KEY_TABLE_NAME,
make_callback(self.ntp_srv_key_handler))

# Handle BANNER_MESSAGE changes
self.config_db.subscribe(swsscommon.CFG_BANNER_MESSAGE_TABLE_NAME,
make_callback(self.banner_handler))

syslog.syslog(syslog.LOG_INFO,
"Waiting for systemctl to finish initialization")
self.wait_till_system_init_done()
Expand Down
Loading

0 comments on commit 1dbf35e

Please sign in to comment.