diff --git a/.gitignore b/.gitignore index b182de52..0868d35c 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ src/aihwkit/simulator/*.so # Temporary folder for the example downloads. data/ +# Folder for cibuildwheel. +wheelhouse/ + ## From https://github.com/github/gitignore/blob/master/Python.gitignore # Byte-compiled / optimized / DLL files diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e8fd4c5..eb0db585 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ The format is based on [Keep a Changelog], and this project adheres to * `WeightModifiers` of the `InferenceRPUConfig` are no longer called in the forward pass, but instead in the `post_update_step` method to avoid issues with repeated forward calls. (\#423) +* Fix training `learn_out_scales` issue after checkpoint load. (\#434) ### Changed @@ -81,7 +82,7 @@ The format is based on [Keep a Changelog], and this project adheres to ### Removed * The `_scaled` versions of the weight getter and setter methods are -removed (\#423) +removed (\#423) ## [0.6.0] - 2022/05/16 diff --git a/CMakeLists.txt b/CMakeLists.txt index bf0c52af..a3fd5da6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,13 +13,13 @@ project(aihwkit C CXX) # Project options. option(BUILD_TEST "Build C++ test binaries" OFF) -option(USE_CUDA "Build with CUDA support" OFF) +option(USE_CUDA "Build with CUDA support" $ENV{USE_CUDA}) option(RPU_DEBUG "Enable debug printing" OFF) option(RPU_USE_FASTMOD "Use fast mod" ON) option(RPU_USE_FASTRAND "Use fastrand" OFF) set(RPU_BLAS "OpenBLAS" CACHE STRING "BLAS backend of choice (OpenBLAS, MKL)") -set(RPU_CUDA_ARCHITECTURES "60" CACHE STRING "Target CUDA architectures") +set(RPU_CUDA_ARCHITECTURES "60;70;75;80" CACHE STRING "Target CUDA architectures") # Internal variables. set(CUDA_TARGET_PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/requirements.txt b/requirements.txt index 3b225b66..b8cd728c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ # Build dependencies. cmake>=3.18 -setuptools<=46.0; python_version < '3.9' scikit-build>=0.11.1 pybind11>=2.6.2 # Runtime dependencies. diff --git a/src/aihwkit/VERSION.txt b/src/aihwkit/VERSION.txt index a918a2aa..b6160487 100644 --- a/src/aihwkit/VERSION.txt +++ b/src/aihwkit/VERSION.txt @@ -1 +1 @@ -0.6.0 +0.6.2 diff --git a/src/aihwkit/cloud/client/entities.py b/src/aihwkit/cloud/client/entities.py index 5b410b83..6ace6f94 100644 --- a/src/aihwkit/cloud/client/entities.py +++ b/src/aihwkit/cloud/client/entities.py @@ -19,9 +19,11 @@ from aihwkit.cloud.client.exceptions import ExperimentStatusError from aihwkit.cloud.converter.definitions.input_file_pb2 import TrainingInput +from aihwkit.cloud.converter.definitions.i_input_file_pb2 import InferenceInput from aihwkit.cloud.converter.definitions.output_file_pb2 import TrainingOutput from aihwkit.cloud.converter.v1.training import BasicTrainingConverter, BasicTrainingResultConverter -from aihwkit.experiments import BasicTraining +from aihwkit.cloud.converter.v1.inferencing import BasicInferencingConverter +# from aihwkit.experiments import BasicTraining, BasicInferencing class CloudJobStatus(Enum): @@ -62,7 +64,7 @@ class CloudExperiment: input_id: Optional[str] = field(repr=False) job: Optional[CloudJob] = field(repr=False) - def get_experiment(self) -> BasicTraining: + def get_experiment(self) -> Any: """Return a data Experiment. Returns: @@ -76,11 +78,16 @@ def get_experiment(self) -> BasicTraining: input_ = self._api_client.input_get(self.input_id) - input_proto = TrainingInput() - input_proto.ParseFromString(input_) + if 'InferenceRPUConfig' in str(input_): + input_proto = InferenceInput() + input_proto.ParseFromString(input_) + proto = BasicInferencingConverter().from_proto(input_proto) + else: + input_proto = TrainingInput() + input_proto.ParseFromString(input_) + proto = BasicTrainingConverter().from_proto(input_proto) - converter = BasicTrainingConverter() - return converter.from_proto(input_proto) + return proto def get_result(self) -> list: """Return the result of an Experiment. diff --git a/src/aihwkit/cloud/converter/v1/i_mappings.py b/src/aihwkit/cloud/converter/v1/i_mappings.py index e1463f34..71426863 100644 --- a/src/aihwkit/cloud/converter/v1/i_mappings.py +++ b/src/aihwkit/cloud/converter/v1/i_mappings.py @@ -24,11 +24,13 @@ from aihwkit.simulator.configs import InferenceRPUConfig from aihwkit.simulator.presets.web import ( - WebComposerInferenceRPUConfig, OldWebComposerInferenceRPUConfig -) + WebComposerInferenceRPUConfig, OldWebComposerInferenceRPUConfig) from aihwkit.cloud.converter.definitions.i_onnx_common_pb2 import AttributeProto from aihwkit.cloud.converter.exceptions import ConversionError -from aihwkit.nn import AnalogConv2d, AnalogLinear +from aihwkit.nn import ( + AnalogConv2d, AnalogConv2dMapped, + AnalogLinear, AnalogLinearMapped +) from aihwkit.optim import AnalogSGD from aihwkit.cloud.converter.v1.rpu_config_info import RPUconfigInfo @@ -139,10 +141,13 @@ def get_field_value_to_proto(self, source: Any, field: str, default: Any = None) if field == 'bias': return getattr(source, 'bias', None) is not None if field == 'rpu_config': - preset_cls = type(source.analog_tile.rpu_config) + # preset_cls = type(source.analog_tile.rpu_config) + analog_tile = next(source.analog_tiles()) + preset_cls = type(analog_tile.rpu_config) if preset_cls not in Mappings.presets: - raise ConversionError('Invalid rpu_config in layer: {} not ' - 'among the presets'.format(preset_cls)) + raise ConversionError('Invalid rpu_config in layer: ' + f'{preset_cls} not ' + 'among the presets') return Mappings.presets[preset_cls] return super().get_field_value_to_proto(source, field, default) @@ -182,12 +187,28 @@ class Mappings: 'bias': bool, 'rpu_config': str, }), + AnalogConv2dMapped: LayerFunction('AnalogConv2dMapped', { + 'in_channels': int, + 'out_channels': int, + 'kernel_size': [int], + 'stride': [int], + 'padding': [int], + 'dilation': [int], + 'bias': bool, + 'rpu_config': str, + }), AnalogLinear: LayerFunction('AnalogLinear', { 'in_features': int, 'out_features': int, 'bias': bool, 'rpu_config': str, }), + AnalogLinearMapped: LayerFunction('AnalogLinearMapped', { + 'in_features': int, + 'out_features': int, + 'bias': bool, + 'rpu_config': str, + }), BatchNorm2d: LayerFunction('BatchNorm2d', { 'num_features': int }), diff --git a/src/aihwkit/cloud/converter/v1/inferencing.py b/src/aihwkit/cloud/converter/v1/inferencing.py index 6062386c..02092531 100644 --- a/src/aihwkit/cloud/converter/v1/inferencing.py +++ b/src/aihwkit/cloud/converter/v1/inferencing.py @@ -68,11 +68,11 @@ def from_proto(self, protobuf: Any) -> BasicInferencing: """Convert a protobuf representation to an `Experiment`.""" dataset = InverseMappings.datasets[protobuf.dataset.dataset_id] - + layers = protobuf.network.layers # build RPUconfig_info to be used when it is instantiated dynamically alog_info = AnalogInfo(protobuf.inferencing.analog_info) nm_info = NoiseModelInfo(protobuf.inferencing.noise_model_info) - rc_info = RPUconfigInfo(nm_info, alog_info) + rc_info = RPUconfigInfo(nm_info, alog_info, layers) model = self._model_from_proto(protobuf.network, rc_info) @@ -104,8 +104,8 @@ def _version_to_proto() -> Any: @staticmethod def _dataset_to_proto(dataset: type, batch_size: int) -> Any: - if dataset not in Mappings.datasets.keys(): - raise ConversionError('Unsupported dataset: {}'.format(dataset)) + if dataset not in Mappings.datasets: + raise ConversionError(f'Unsupported dataset: {dataset}') return Dataset( dataset_id=Mappings.datasets[dataset], @@ -121,7 +121,8 @@ def _model_to_proto(model: Module, weight_template_id: str) -> Any: children_types = {type(layer) for layer in model.children()} valid_types = set(Mappings.layers.keys()) | set(Mappings.activation_functions.keys()) if children_types - valid_types: - raise ConversionError('Unsupported layers: {}'.format(children_types - valid_types)) + raise ConversionError('Unsupported layers: ' + f'{children_types - valid_types}') # Create a new input_file pb Network object with weight_template_id network = Network(weight_template_id=weight_template_id) @@ -189,7 +190,7 @@ def rpu_config_info_from_info(analog_info: Dict, nm_info = NoiseModelInfo(BasicInferencingConverter._noise_model_to_proto( noise_model_info)) # type: ignore[name-defined] a_info = AnalogInfo(AnalogProto(**analog_info)) - return RPUconfigInfo(nm_info, a_info) + return RPUconfigInfo(nm_info, a_info, None) @staticmethod def rpu_config_from_info(analog_info: Dict, @@ -200,7 +201,8 @@ def rpu_config_from_info(analog_info: Dict, nm_info = NoiseModelInfo(BasicInferencingConverter._noise_model_to_proto( noise_model_info)) # type: ignore[name-defined] a_info = AnalogInfo(AnalogProto(**analog_info)) - return RPUconfigInfo(nm_info, a_info).create_inference_rpu_config(func_id) + return RPUconfigInfo(nm_info, + a_info, None).create_inference_rpu_config(func_id) @staticmethod def _inferencing_to_proto( diff --git a/src/aihwkit/cloud/converter/v1/rpu_config_info.py b/src/aihwkit/cloud/converter/v1/rpu_config_info.py index e225ab58..a4a79807 100644 --- a/src/aihwkit/cloud/converter/v1/rpu_config_info.py +++ b/src/aihwkit/cloud/converter/v1/rpu_config_info.py @@ -11,20 +11,28 @@ # that they have been altered from the originals. """Creates InferenceRPUConfig to add to nn model""" - +from typing import Dict, Any from collections import OrderedDict from aihwkit.simulator.configs.configs import InferenceRPUConfig -from aihwkit.simulator.presets.web import OldWebComposerInferenceRPUConfig +from aihwkit.simulator.presets.web import ( + WebComposerInferenceRPUConfig, + OldWebComposerInferenceRPUConfig +) from aihwkit.inference.noise.pcm import PCMLikeNoiseModel from aihwkit.inference.noise.custom import StateIndependentNoiseModel from aihwkit.inference.compensation.drift import GlobalDriftCompensation from aihwkit.cloud.converter.v1.analog_info import AnalogInfo from aihwkit.cloud.converter.v1.noise_model_info import NoiseModelInfo -# pylint: disable=too-few-public-methods +RPU_CLASSES = { + 'InferenceRPUConfig': InferenceRPUConfig, + 'WebComposerInferenceRPUConfig': WebComposerInferenceRPUConfig, + 'OldWebComposerInferenceRPUConfig': OldWebComposerInferenceRPUConfig +} +# pylint: disable=too-few-public-methods class NoiseModelDeviceIDException(Exception): """Exception raised if noise model device id is not correct""" @@ -32,13 +40,52 @@ class NoiseModelDeviceIDException(Exception): class RPUconfigInfo: """Data only class for RPUConfig fields""" - def __init__(self, nm_info: NoiseModelInfo, a_info: AnalogInfo): - """"Constructor for this class""" - + def __init__(self, nm_info: NoiseModelInfo, + a_info: AnalogInfo, + layers: Any = None): + """ + The only constructor for this class + """ self._noise_model_info = nm_info self._analog_info = a_info + self._layers = layers self._device_id = '' + @staticmethod + def _get_common_rpucfg_name(layers: Any) -> Any: + """Set common rpu config name by search all analog layers""" + # Use default RPU config for Composer + if layers is None: + return 'WebComposerInferenceRPUConfig' + # Need to loop through protobuf layers and figure out + # common rpu_config value. + names: Dict[str, int] = {} + # pylint: disable=too-many-nested-blocks + for layer_proto in layers: # type: ignore[attr-defined] + if layer_proto.WhichOneof('item') == 'layer': + layer = layer_proto.layer + if layer.id.startswith('Analog'): + # Loop though all AttributeProto objecs in layer.arguments + for argument in layer.arguments: + if argument.name == 'rpu_config': + # stored as UTF8 byte string in attribute s + arg_value = getattr(argument, 's') + # update count of this rpu_config in all analog layers + if arg_value in names: + names[arg_value] += 1 + else: + names[arg_value] = 1 + # pylint: enable=too-many-nested-blocks + # should have exactly on in dictionary 'names' + if len(names) > 1: + print(f'>>> ERROR: more than one rpu_config: {names}') + return None + if len(names) == 1: + # keys() returns dict_keys object, need a list + return list(names.keys())[0].decode('UTF-8') # type: ignore[attr-defined] + print('>>> INFO: experiment has not analog layers') + return '' + def _print_rpu_config( self, rpu_config: InferenceRPUConfig, @@ -92,9 +139,22 @@ def _print_rpu_config( def create_inference_rpu_config(self, func_id: str, verbose: bool = False) -> InferenceRPUConfig: """Creates a InferenceRPUConfig class using noise and analog info""" + # Need to find name of 'common-rpu-conf-class-name' in protobuf + # This should be the consistent across all layers. + # The Composer Validator should have already caught this but + # it is checked here for testcases and other unknown environments + rpu_class_name = self._get_common_rpucfg_name(self._layers) + print(f'>>> INFO: rpu_class_name={rpu_class_name}') + if rpu_class_name is None or len(rpu_class_name) == 0: + raise Exception('class name error. see previous messages') + rpu_config_class = None + if rpu_class_name in RPU_CLASSES: + rpu_config_class = RPU_CLASSES[rpu_class_name] + else: + raise Exception(f"rpu class name '{rpu_class_name}' not one of '{RPU_CLASSES.keys()}'") - rpu_config = OldWebComposerInferenceRPUConfig() - + # Dynamically create the right InferenceRPUConfig class + rpu_config = rpu_config_class() # Assign values from AnalogProto rpu_config.forward.out_noise = self._analog_info.output_noise_strength diff --git a/src/aihwkit/experiments/experiments/inferencing.py b/src/aihwkit/experiments/experiments/inferencing.py index e511e142..fce265cb 100644 --- a/src/aihwkit/experiments/experiments/inferencing.py +++ b/src/aihwkit/experiments/experiments/inferencing.py @@ -185,7 +185,7 @@ def get_model( if weight_template_id.startswith('http'): template_url = weight_template_id else: - print('weights_template_id: ', weight_template_id) + # print('weights_template_id: ', weight_template_id) template_path = template_dir + "/" + weight_template_id + ".pth" template_url = WEIGHT_TEMPLATE_URL + weight_template_id + ".pth" # check if the file exists @@ -194,7 +194,7 @@ def get_model( if not path.exists(template_path): download(template_url, template_path) - print('template_path: ', template_path) + # print('template_path: ', template_path) if path.exists(template_path): model.load_state_dict(load(template_path, map_location=device), load_rpu_config=False) @@ -203,7 +203,7 @@ def get_model( if self.remap_weights: for module in model.analog_modules(): - module.remap_weights() + module.remap_weights(1.0) return model.to(device) diff --git a/src/aihwkit/nn/modules/base.py b/src/aihwkit/nn/modules/base.py index ac548b42..9b5a8cf3 100644 --- a/src/aihwkit/nn/modules/base.py +++ b/src/aihwkit/nn/modules/base.py @@ -102,7 +102,8 @@ def register_helper(self, name: str) -> None: if name not in self._registered_helper_parameter: self._registered_helper_parameter.append(name) - def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) -> None: + def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None, + update_only: bool = False) -> None: """Register the analog context of the tile. Note: @@ -111,12 +112,16 @@ def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) -> Args: tile: tile to register - name: Optional tile name used as the parameter name + name: Optional tile name used as the parameter name. + update_only: Whether to re-register (does not advance tile counter) """ if name is None: name = str(self._analog_tile_counter) + if not update_only: + self._analog_tile_counter += 1 + ctx_name = self.ANALOG_CTX_PREFIX + name self.register_helper(ctx_name) @@ -125,15 +130,16 @@ def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) -> if tile.shared_weights is not None: if not isinstance(tile.shared_weights, Parameter): tile.shared_weights = Parameter(tile.shared_weights) - par_name = self.ANALOG_SHARED_WEIGHT_PREFIX + str(self._analog_tile_counter) + par_name = self.ANALOG_SHARED_WEIGHT_PREFIX + name self.register_parameter(par_name, tile.shared_weights) self.register_helper(par_name) - if tile.get_learned_out_scales() is not None: - par_name = self.ANALOG_OUT_SCALING_ALPHA_PREFIX + str(self._analog_tile_counter) - self.register_parameter(par_name, tile.get_learned_out_scales()) - - self._analog_tile_counter += 1 + if tile.out_scaling_alpha is not None: + if not isinstance(tile.out_scaling_alpha, Parameter): + tile.out_scaling_alpha = Parameter(tile.out_scaling_alpha) + par_name = self.ANALOG_OUT_SCALING_ALPHA_PREFIX + name + self.register_parameter(par_name, tile.out_scaling_alpha) + self.register_helper(par_name) def unregister_parameter(self, param_name: str) -> None: """Unregister module parameter from parameters. @@ -415,15 +421,30 @@ def load_state_dict(self, # pylint: disable=arguments-differ For instance, changing the device type might change the expected fields in the hidden parameters and result in an error. + Returns: see torch's ``load_state_dict`` Raises: ModuleError: in case the rpu_config class mismatches - for ``load_rpu_config=False``. + or mapping parameter mismatch for + ``load_rpu_config=False`` + """ self._set_load_rpu_config_state(load_rpu_config) return super().load_state_dict(state_dict, strict) + def __setstate__(self, state: Dict) -> None: + """Set the state after unpickling. + + Makes sure that the parameter in the tiles are correctly registered. + + """ + self.__dict__.update(state) + + # update registered parameters + for name, analog_tile in list(self.named_analog_tiles()): + self.register_analog_tile(analog_tile, name, update_only=True) + def _load_from_state_dict( self, state_dict: Dict, @@ -443,6 +464,7 @@ def _load_from_state_dict( Raises: ModuleError: in case the rpu_config class mismatches. """ + # pylint: disable=too-many-locals for name, analog_tile in list(self.named_analog_tiles()): key = prefix + self.ANALOG_STATE_PREFIX + name @@ -459,9 +481,26 @@ def _load_from_state_dict( "Tried to replace " f"{analog_state['rpu_config'].__class__.__name__} " f"with {analog_tile.rpu_config.__class__.__name__}") + + if hasattr(analog_state['rpu_config'], 'mapping'): + old_mapping = analog_state['rpu_config'].mapping + new_mapping = analog_tile.rpu_config.mapping + if (old_mapping.max_input_size != new_mapping.max_input_size + or old_mapping.max_output_size != new_mapping.max_output_size + or old_mapping.digital_bias != new_mapping.digital_bias + or (old_mapping.out_scaling_columnwise + != new_mapping.out_scaling_columnwise)): + raise ModuleError("MappingParameter mismatch during loading: " + "Tried to replace " + f"{old_mapping} " + f"with {new_mapping}") + analog_state['rpu_config'] = analog_tile.rpu_config analog_tile.__setstate__(analog_state) + # update registered parameters + self.register_analog_tile(analog_tile, name, update_only=True) + elif strict: missing_keys.append(key) diff --git a/src/aihwkit/simulator/presets/web.py b/src/aihwkit/simulator/presets/web.py index 472f01d8..65210247 100644 --- a/src/aihwkit/simulator/presets/web.py +++ b/src/aihwkit/simulator/presets/web.py @@ -86,7 +86,7 @@ class WebComposerMappingParameter(MappingParameter): weight_scaling_omega: float = 1.0 weight_scaling_columnwise: bool = True learn_out_scaling: bool = True - out_scaling_columnwise: bool = True + out_scaling_columnwise: bool = False max_input_size: int = 512 max_output_size: int = 512 diff --git a/src/aihwkit/simulator/tiles/base.py b/src/aihwkit/simulator/tiles/base.py index eca24d71..75fd591c 100644 --- a/src/aihwkit/simulator/tiles/base.py +++ b/src/aihwkit/simulator/tiles/base.py @@ -51,6 +51,7 @@ class AnalogTileStateNames: # pylint: disable=too-few-public-methods SHARED_WEIGHTS = 'shared_weights' CONTEXT = 'analog_ctx' OUT_SCALING = 'out_scaling_alpha' + MAPPING_SCALES = 'mapping_scales' RPU_CONFIG = 'rpu_config' @@ -191,7 +192,7 @@ def __setstate__(self, state: Dict) -> None: Raises: TileError: if tile class does not match or hidden parameters do not match """ - # pylint: disable=too-many-locals + # pylint: disable=too-many-locals, too-many-statements, too-many-branches # Note: self here is NOT initialized! So we need to recreate # attributes that were not saved in getstate @@ -199,6 +200,7 @@ def __setstate__(self, state: Dict) -> None: current_dict = state.copy() current_dict.pop('image_sizes', None) # should not be saved weights = current_dict.pop(SN.WEIGHTS) + hidden_parameters = current_dict.pop(SN.HIDDEN_PARAMETERS) hidden_parameters_names = current_dict.pop(SN.HIDDEN_PARAMETER_NAMES, []) alpha_scale = current_dict.pop('analog_alpha_scale', None) # legacy @@ -208,6 +210,9 @@ def __setstate__(self, state: Dict) -> None: shared_weights = current_dict.pop(SN.SHARED_WEIGHTS) shared_weights_if = shared_weights is not None + mapping_scales = current_dict.pop(SN.MAPPING_SCALES, None) + learned_out_scales = current_dict.pop(SN.OUT_SCALING, None) + current_dict.pop('noise_model', None) # legacy current_dict.pop('drift_compensation', None) # legacy @@ -236,10 +241,13 @@ def __setstate__(self, state: Dict) -> None: # Check whether names match raise TileError('Mismatch with loaded analog state: ' 'Hidden parameter structure is unexpected.') + if not isinstance(weights, Tensor): + weights = from_numpy(array(weights)) + self.tile.set_weights(weights) + if not isinstance(hidden_parameters, Tensor): hidden_parameters = from_numpy(array(hidden_parameters)) self.tile.set_hidden_parameters(hidden_parameters) - self.tile.set_weights(weights) self.tile.set_learning_rate(analog_lr) @@ -262,6 +270,32 @@ def __setstate__(self, state: Dict) -> None: self.analog_ctx.reset(self) self.analog_ctx.set_data(analog_ctx.data) + # set scales + self.out_scaling_alpha = None + self.mapping_scales = None + self.init_mapping_scales() + self.init_learned_out_scales() + + if self.out_scaling_alpha is None and learned_out_scales is not None: + if mapping_scales is None: + mapping_scales = 1.0 + x = learned_out_scales.view(learned_out_scales.numel()).clone() + mapping_scales = mapping_scales * x + learned_out_scales = None + + self.set_mapping_scales(mapping_scales) + self.set_learned_out_scales(learned_out_scales) + + if alpha_scale is not None: + # legacy. We apply the alpha scale instaed of the + # out_scaling_alpha when loading. The alpha_scale + # mechansim is now replaced with the out scaling factors + # + # Caution: will overwrite the loaded out_scaling_alphas + # if they would exist also (should not be for old checkpoints) + + self.set_mapping_scales(alpha_scale) + if to_device.type.startswith('cuda'): self.cuda(to_device) @@ -592,7 +626,7 @@ def get_mapping_scales(self) -> Optional[Tensor]: return self.mapping_scales @no_grad() - def set_mapping_scales(self, mapping_scales: Optional[Tensor]) -> Tensor: + def set_mapping_scales(self, mapping_scales: Optional[Union[Tensor, float]]) -> None: """Set the scales used for the weight mapping. Args: @@ -606,11 +640,20 @@ def set_mapping_scales(self, mapping_scales: Optional[Tensor]) -> Tensor: self.mapping_scales = None return - if isinstance(self.mapping_scales, Tensor) and len(mapping_scales) == 1: + if isinstance(mapping_scales, float): + if self.mapping_scales is None: + self.mapping_scales = ones((1, ), + dtype=float32, + device=self.device, + requires_grad=False) self.mapping_scales[:] = mapping_scales return - self.mapping_scales = mapping_scales.flatten() + if isinstance(self.mapping_scales, Tensor) and len(mapping_scales) == 1: + self.mapping_scales[:] = mapping_scales.to(self.device) + return + + self.mapping_scales = mapping_scales.flatten().to(self.device) @no_grad() def init_mapping_scales(self) -> None: @@ -690,13 +733,15 @@ def init_learned_out_scales(self) -> None: mapping = self.rpu_config.mapping # type: ignore if mapping.learn_out_scaling: if mapping.out_scaling_columnwise: - self.out_scaling_alpha = Parameter(ones((self.out_size, ), - dtype=float32, - device=self.device)) + self.out_scaling_alpha = ones((self.out_size, ), + dtype=float32, + device=self.device, + requires_grad=True) else: - self.out_scaling_alpha = Parameter(ones((1, ), - dtype=float32, - device=self.device)) + self.out_scaling_alpha = ones((1, ), + dtype=float32, + device=self.device, + requires_grad=True) @no_grad() def set_learned_out_scales(self, alpha: Union[Tensor, float]) -> None: @@ -739,7 +784,6 @@ def apply_out_scaling(self, values: Tensor, if tensor_view is None: tensor_view = self._get_tensor_view(values.dim(), 0 if self.out_trans else values.dim() - 1) - return values * self.out_scaling_alpha.view(*tensor_view) return values diff --git a/tests/helpers/infer_experiments.py b/tests/helpers/infer_experiments.py deleted file mode 100644 index 154607d8..00000000 --- a/tests/helpers/infer_experiments.py +++ /dev/null @@ -1,321 +0,0 @@ -# -*- coding: utf-8 -*- - -# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved. -# -# This code is licensed under the Apache License, Version 2.0. You may -# obtain a copy of this license in the LICENSE.txt file in the root directory -# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. -# -# Any modifications or derivative works of this code must retain this -# copyright notice, and modified files need to carry a notice indicating -# that they have been altered from the originals. - -# pylint: disable=missing-function-docstring,too-few-public-methods - -"""Models helpers for aihwkit tests.""" - -from typing import Any - -from torch.nn import ( - BatchNorm2d, Conv2d, Flatten, Linear, LogSoftmax, MaxPool2d, Module, - ReLU, Tanh, NLLLoss -) -from torchvision.datasets import FashionMNIST, SVHN - -from aihwkit.nn import AnalogConv2d, AnalogLinear, AnalogSequential -from aihwkit.simulator.presets.utils import PresetIOParameters -from aihwkit.simulator.configs import InferenceRPUConfig -from aihwkit.simulator.configs.utils import WeightNoiseType, WeightClipType -from aihwkit.inference import PCMLikeNoiseModel -from aihwkit.experiments.experiments.inferencing import BasicInferencing - - -class HwTrainedLenet5: - """Hardware-aware LeNet5; with FashionMNIST.""" - - def get_reference_accuracy(self): - """ Reference for 100 samples without noise """ - return 86.0 - - def get_experiment( - self, - real: bool = False, - rpu_config: Any = InferenceRPUConfig() - ): - """Return a BasicInference experiment.""" - - rpu_config = InferenceRPUConfig(forward=PresetIOParameters()) - rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT - rpu_config.clip.type = WeightClipType.FIXED_VALUE - rpu_config.clip.fixed_value = 1.0 - rpu_config.forward.w_noise = 0.0175 - - argv = { - 'dataset': FashionMNIST, - 'model': self.get_model(rpu_config), - 'batch_size': 8, - 'loss_function': NLLLoss, - 'weight_template_id': 'hw-trained-lenet5', - 'inference_repeats': 10, - 'inference_time': 3600 - } - - if not real: - argv['inference_repeats'] = 2 - - return BasicInferencing(**argv) - - def get_model(self, rpu_config: Any = InferenceRPUConfig) -> Module: - # set the InferenceRPUConfig - channel = [16, 32, 512, 128] - return AnalogSequential( - AnalogConv2d(in_channels=1, out_channels=channel[0], kernel_size=5, stride=1, - rpu_config=rpu_config), - Tanh(), - MaxPool2d(kernel_size=2), - AnalogConv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=5, stride=1, - rpu_config=rpu_config), - Tanh(), - MaxPool2d(kernel_size=2), - Tanh(), - Flatten(), - AnalogLinear(in_features=channel[2], out_features=channel[3], rpu_config=rpu_config), - Tanh(), - AnalogLinear(in_features=channel[3], out_features=10, rpu_config=rpu_config), - LogSoftmax(dim=1) - ) - - -class DigitalTrainedLenet5: - """Hardware-aware LeNet5; with FashionMNIST.""" - - def get_reference_accuracy(self): - """ Reference for 100 samples without noise """ - return 88.0 - - def get_experiment( - self, - real: bool = False, - rpu_config: Any = InferenceRPUConfig() - ): - """Return a BasicInference experiment.""" - - rpu_config = InferenceRPUConfig(forward=PresetIOParameters()) - rpu_config.forward.w_noise_type = WeightNoiseType.ADDITIVE_CONSTANT - rpu_config.clip.type = WeightClipType.FIXED_VALUE - rpu_config.clip.fixed_value = 1.0 - rpu_config.forward.w_noise = 0.0175 - - argv = { - 'dataset': FashionMNIST, - 'model': self.get_model(rpu_config), - 'batch_size': 8, - 'loss_function': NLLLoss, - 'weight_template_id': 'digital-trained-lenet5', - 'inference_repeats': 10, - 'inference_time': 3600 - } - - if not real: - argv['inference_repeats'] = 2 - - return BasicInferencing(**argv) - - def get_model(self, rpu_config: Any = InferenceRPUConfig) -> Module: - # set the InferenceRPUConfig - channel = [16, 32, 512, 128] - return AnalogSequential( - AnalogConv2d(in_channels=1, out_channels=channel[0], kernel_size=5, stride=1, - rpu_config=rpu_config), - Tanh(), - MaxPool2d(kernel_size=2), - AnalogConv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=5, stride=1, - rpu_config=rpu_config), - Tanh(), - MaxPool2d(kernel_size=2), - Tanh(), - Flatten(), - AnalogLinear(in_features=channel[2], out_features=channel[3], rpu_config=rpu_config), - Tanh(), - AnalogLinear(in_features=channel[3], out_features=10, rpu_config=rpu_config), - LogSoftmax(dim=1) - ) - - -class HwTrainedVgg8: - """Vgg8; with SVHN.""" - - def get_reference_accuracy(self): - """ Reference for 100 samples without noise """ - return 93.0 - - def get_experiment( - self, - real: bool = False, - rpu_config: Any = InferenceRPUConfig() - ): - """Return a BasicInference experiment.""" - output_noise_strength = 0.03999999910593033 - adc = 9 - dac = 7 - programming_noise_scale = 1.0 - read_noise_scale = 1.0 - drift_scale = 1.0 - poly_first_order_coef = 1.965000033378601 - poly_second_order_coef = -1.1730999946594238 - poly_constant_coef = 0.26350000500679016 - rpu_config.forward.out_noise = output_noise_strength - rpu_config.forward.out_res = 1/(2**adc - 2) - rpu_config.forward.inp_res = 1/(2**dac - 2) - rpu_config.noise_model.prog_noise_scale = programming_noise_scale - rpu_config.noise_model.read_noise_scale = read_noise_scale - rpu_config.noise_model.drift_scale = drift_scale - rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0, - prog_coeff=[poly_constant_coef, - poly_first_order_coef, - poly_second_order_coef]) - - argv = { - 'dataset': SVHN, - 'model': self.get_model(rpu_config), - 'batch_size': 128, - 'loss_function': NLLLoss, - 'weight_template_id': 'hw-trained-vgg8', - 'inference_repeats': 10, - 'inference_time': 3600 - - } - - if not real: - argv['inference_repeats'] = 2 - - return BasicInferencing(**argv) - - def get_model(self, rpu_config: Any = InferenceRPUConfig) -> Module: - return AnalogSequential( - Conv2d(in_channels=3, out_channels=48, - kernel_size=3, stride=1, padding=1), - ReLU(), - AnalogConv2d(in_channels=48, out_channels=48, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(48), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - AnalogConv2d(in_channels=48, out_channels=96, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - ReLU(), - AnalogConv2d(in_channels=96, out_channels=96, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(96), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - AnalogConv2d(in_channels=96, out_channels=144, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - ReLU(), - AnalogConv2d(in_channels=144, out_channels=144, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(144), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - Flatten(), - AnalogLinear(in_features=16 * 144, out_features=384, - rpu_config=rpu_config), - ReLU(), - Linear(in_features=384, out_features=10), - LogSoftmax(dim=1) - ) - - -class DigitalTrainedVgg8: - """Vgg8; with SVHN.""" - - def get_reference_accuracy(self): - """ Reference for 100 samples without noise """ - return 6.0 - - def get_experiment( - self, - real: bool = False, - rpu_config: Any = InferenceRPUConfig() - ): - """Return a BasicInference experiment.""" - - output_noise_strength = 0.03999999910593033 - adc = 9 - dac = 7 - programming_noise_scale = 1.0 - read_noise_scale = 1.0 - drift_scale = 1.0 - poly_first_order_coef = 1.965000033378601 - poly_second_order_coef = -1.1730999946594238 - poly_constant_coef = 0.26350000500679016 - rpu_config.forward.out_noise = output_noise_strength - rpu_config.forward.out_res = 1/(2**adc - 2) - rpu_config.forward.inp_res = 1/(2**dac - 2) - rpu_config.noise_model.prog_noise_scale = programming_noise_scale - rpu_config.noise_model.read_noise_scale = read_noise_scale - rpu_config.noise_model.drift_scale = drift_scale - rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0, - prog_coeff=[poly_constant_coef, - poly_first_order_coef, - poly_second_order_coef]) - - argv = { - 'dataset': SVHN, - 'model': self.get_model(rpu_config), - 'batch_size': 128, - 'loss_function': NLLLoss, - 'weight_template_id': 'digital-trained-vgg8', - 'inference_repeats': 10, - 'inference_time': 3600 - - } - - if not real: - argv['inference_repeats'] = 2 - - return BasicInferencing(**argv) - - def get_model(self, rpu_config: Any = InferenceRPUConfig) -> Module: - return AnalogSequential( - Conv2d(in_channels=3, out_channels=48, - kernel_size=3, stride=1, padding=1), - ReLU(), - AnalogConv2d(in_channels=48, out_channels=48, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(48), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - AnalogConv2d(in_channels=48, out_channels=96, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - ReLU(), - AnalogConv2d(in_channels=96, out_channels=96, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(96), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - AnalogConv2d(in_channels=96, out_channels=144, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - ReLU(), - AnalogConv2d(in_channels=144, out_channels=144, - kernel_size=3, stride=1, padding=1, - rpu_config=rpu_config), - BatchNorm2d(144), - ReLU(), - MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1), - Flatten(), - AnalogLinear(in_features=16 * 144, out_features=384, - rpu_config=rpu_config), - ReLU(), - Linear(in_features=384, out_features=10), - LogSoftmax(dim=1) - ) diff --git a/tests/helpers/tiles.py b/tests/helpers/tiles.py index 28083ae3..b27ed237 100644 --- a/tests/helpers/tiles.py +++ b/tests/helpers/tiles.py @@ -343,6 +343,25 @@ def get_tile(self, out_size, in_size, rpu_config=None, **kwargs): return InferenceTile(out_size, in_size, rpu_config, **kwargs) +class InferenceLearnOutScaling: + """Inference tile (perfect forward).""" + + simulator_tile_class = tiles.AnalogTile + first_hidden_field = None + use_cuda = False + + def get_rpu_config(self): + rpu_config = InferenceRPUConfig() + rpu_config.forward.is_perfect = True + rpu_config.mapping.learn_out_scaling = True + + return rpu_config + + def get_tile(self, out_size, in_size, rpu_config=None, **kwargs): + rpu_config = rpu_config or self.get_rpu_config() + return InferenceTile(out_size, in_size, rpu_config, **kwargs) + + class FloatingPointCuda: """FloatingPointTile.""" diff --git a/tests/test_cloud_runner.py b/tests/test_cloud_runner.py index f733fe23..5d8a4f58 100644 --- a/tests/test_cloud_runner.py +++ b/tests/test_cloud_runner.py @@ -18,6 +18,7 @@ from aihwkit.cloud.client.entities import CloudExperiment, CloudJobStatus from aihwkit.cloud.client.utils import ClientConfiguration from aihwkit.experiments.experiments.training import BasicTraining +from aihwkit.experiments.experiments.inferencing import BasicInferencing from aihwkit.experiments.runners.cloud import CloudRunner from .helpers.decorators import parametrize_over_experiments @@ -66,7 +67,10 @@ def test_get_cloud_experiment_experiment(self): raise SkipTest('No executions found') cloud_experiment = experiments[-1].get_experiment() - self.assertIsInstance(cloud_experiment, BasicTraining) + if 'BasicInferencing' in str(cloud_experiment): + self.assertIsInstance(cloud_experiment, BasicInferencing) + else: + self.assertIsInstance(cloud_experiment, BasicTraining) def test_get_cloud_experiment_result(self): """Test getting the result from a cloud experiment.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index becb45c6..fbfeadc1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -43,13 +43,13 @@ LinearMappedCuda, Conv2dMapped, Conv2dMappedCuda ) from .helpers.testcases import ParametrizedTestCase, SKIP_CUDA_TESTS -from .helpers.tiles import FloatingPoint, ConstantStep, Inference +from .helpers.tiles import FloatingPoint, ConstantStep, Inference, InferenceLearnOutScaling @parametrize_over_layers( layers=[Linear, Conv2d, LinearMapped, LinearCuda, LinearMappedCuda, Conv2dCuda, Conv2dMapped, Conv2dMappedCuda], - tiles=[FloatingPoint, ConstantStep, Inference], + tiles=[FloatingPoint, ConstantStep, Inference, InferenceLearnOutScaling], biases=['analog', 'digital', None], ) class SerializationTest(ParametrizedTestCase): @@ -59,14 +59,13 @@ class SerializationTest(ParametrizedTestCase): def train_model(model, loss_func, x_b, y_b): """Train the model.""" opt = AnalogSGD(model.parameters(), lr=0.5) - opt.regroup_param_groups(model) - + # opt.regroup_param_groups(model) epochs = 3 for _ in range(epochs): opt.zero_grad() pred = model(x_b) - loss = loss_func(pred, y_b) + loss = loss_func(pred, y_b) loss.backward() opt.step() return loss @@ -90,16 +89,12 @@ def get_layer_and_tile_weights(model): else: bias = None - analog_weight, analog_bias = model.analog_tile.get_weights() - analog_weight = analog_weight.detach().cpu().numpy().reshape(weight.shape) - if model.analog_bias: + analog_weight, analog_bias = model.get_weights() + analog_weight = analog_weight.detach().cpu().numpy() + if analog_bias is not None: analog_bias = analog_bias.detach().cpu().numpy() - elif model.digital_bias: - analog_bias = model.bias.detach().cpu().numpy() - else: - analog_bias = None - return weight, bias, analog_weight, analog_bias, model.digital_bias + return weight, bias, analog_weight.reshape(weight.shape), analog_bias, True @staticmethod def get_analog_tile(model): @@ -194,6 +189,7 @@ def test_save_load_state_dict_train_after(self): assert_array_almost_equal(tile_biases, new_tile_biases) new_loss = self.train_model(new_model, loss_func, input_x, input_y) + loss = self.train_model(model, loss_func, input_x, input_y) if self.tile_class != ConstantStep: self.assertTensorAlmostEqual(loss, new_loss)