diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index dcf24225fe1a..5228b06b10d3 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -414,6 +414,14 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: self.run(inputs, ret_type="native") break plan = pruner.finalize() + elif tool_type == ToolType.QUANTIZER: + quantizer = self.get_tool(ToolType.QUANTIZER) + while not quantizer.calibrated: + assert data_loader, "data_loader should be given to plan prune" + for inputs in data_loader(): + self.run(inputs, ret_type="native") + quantizer.calibrate() + plan = quantizer.finalize() else: plan = self.get_tool(tool_type).finalize() assert plan, "Failed to create plan for {}".format(tool_type) diff --git a/python/tvm/contrib/msc/core/tools/__init__.py b/python/tvm/contrib/msc/core/tools/__init__.py index 0524e4c82362..e97771cf6c06 100644 --- a/python/tvm/contrib/msc/core/tools/__init__.py +++ b/python/tvm/contrib/msc/core/tools/__init__.py @@ -19,4 +19,5 @@ from .tool import * from .execute import * from .prune import * +from .quantize import * from .track import * diff --git a/python/tvm/contrib/msc/core/tools/quantize/__init__.py b/python/tvm/contrib/msc/core/tools/quantize/__init__.py new file mode 100644 index 000000000000..1aad17c0553c --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/quantize/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.quantize""" + +from .quantizer import * +from .method import * diff --git a/python/tvm/contrib/msc/core/tools/quantize/method.py b/python/tvm/contrib/msc/core/tools/quantize/method.py new file mode 100644 index 000000000000..970185826711 --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/quantize/method.py @@ -0,0 +1,472 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.core.tools.quantize.method""" + +from typing import Union, Any +import numpy as np + +from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class QuantizeMethod(object): + """Default quantize method""" + + @classmethod + def amplify_data( + cls, data: np.array, scale: float, min_val: float, max_val: float, rounding: str = "round" + ) -> np.ndarray: + """Amplify the data + + Parameters + ---------- + data: np.ndarray + The source data. + scale: float + The scale factor + min_val: float + The min. + max_val: float + The max. + rounding: str + The round method + + Returns + ------- + data: np.ndarray + The processed data. + """ + + if rounding == "null": + return np.clip(data * scale, min_val, max_val) + if rounding == "floor": + return np.clip(np.floor(data * scale), min_val, max_val) + if rounding == "ceil": + return np.clip(np.ceil(data * scale), min_val, max_val) + if rounding == "round": + return np.clip(np.round(data * scale), min_val, max_val) + if rounding == "trunc": + return np.clip(np.trunc(data * scale), min_val, max_val) + if rounding == "logic_round": + data = np.clip(data * scale, min_val, max_val) + negative_ceil = np.where( + np.logical_and(data < 0, (data - np.floor(data)) == 0.5), np.ceil(data), 0 + ) + data = np.where(np.logical_and(data < 0, (data - np.floor(data)) == 0.5), 0, data) + data = np.where((data - np.floor(data)) >= 0.5, np.ceil(data), data) + data = np.where((data - np.floor(data)) < 0.5, np.floor(data), data) + return data + negative_ceil + raise TypeError("Unexpected rounding " + str(rounding)) + + @classmethod + def get_scale_tensor( + cls, + data: Any, + scale: float, + axis: int = -1, + epsilon: float = 1.0 / (1 << 24), + expand_dims: bool = True, + ) -> Union[float, np.ndarray]: + """Get the scale tensor + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: array_like + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + axis: int + The axis. + epsilon: float + The epsilon for get scale. + expand_dims: bool + Whether to expand dims + + Returns + ------- + scale_tensor: np.ndarray + The processed tensor. + """ + + data = msc_utils.cast_array(data) + if isinstance(scale, list): + scale_tensor = np.array(scale).astype(data.dtype) + if expand_dims: + scale_shape = [s if idx == axis else 1 for idx, s in enumerate(data.shape)] + scale_tensor = scale_tensor.reshape(scale_shape) + if scale_tensor.min() <= epsilon: + scale_mask = scale_tensor <= epsilon + scale_tensor[scale_mask] = 0 + elif scale <= epsilon: + scale_tensor = 0 + else: + scale_tensor = scale + return scale_tensor + + @classmethod + def gather_maxmin( + cls, + quantizer: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + ) -> dict: + """Gather the data by max/min + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + abs_max_list = plan.get("abs_max_list", []) + abs_max_list.append(float(np.abs(data).max())) + max_list = plan.get("max_list", []) + max_list.append(float(data.max())) + min_list = plan.get("min_list", []) + min_list.append(float(data.min())) + return { + "abs_max_list": abs_max_list, + "max_list": max_list, + "min_list": min_list, + "calibrated": False, + } + + @classmethod + def gather_kl_divergence( + cls, + quantizer: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + bins: int = 4096, + ) -> dict: + """Gather the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + bins: int + The number bins. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + if not plan or "abs_max" not in plan: + return cls.gather_maxmin(quantizer, name, data, plan, nbits) + hist, edge = np.histogram(data, bins=bins, range=[-plan["abs_max"], plan["abs_max"]]) + hist_list = plan.get("hist_list", []) + return {"hist_list": hist_list + [hist], "edge": edge, **plan} + + @classmethod + def gather_max_per_channel( + cls, + quantizer: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + channel: str = "O", + auto_unsign: bool = False, + ) -> dict: + """Gather the data by max_per_channel + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + channel: str + The channel reference. + auto_unsign: bool + Whether to use auto unsign. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + weight = quantizer.find_tensor(name) + axis = weight.layout_of(channel) + channel_datas = np.split(data, data.shape[axis], axis) + channel_max = [float(np.abs(d).max()) for d in channel_datas] + sign = data.min() < 0 if auto_unsign else True + valid_range = 2 ** (nbits - int(sign)) - 1 + scale = [valid_range / m for m in channel_max] + return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True} + + @classmethod + def calibrate_maxmin( + cls, + quantizer: BaseTool, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + auto_unsign: bool = False, + ) -> dict: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + auto_unsign: bool + Whether to use auto unsign. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + sign = plan["min"] < 0 if auto_unsign else True + valid_range = 2 ** (nbits - int(sign)) - 1 + abs_max = float(np.array(plan["abs_max_list"]).max()) + return {"scale": valid_range / abs_max, "sign": sign, "calibrated": True} + + @classmethod + def calibrate_kl_divergence( + cls, + quantizer: BaseTool, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + bins: int = 4096, + auto_unsign: bool = False, + ) -> dict: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + bins: int + The number bins. + auto_unsign: bool + Whether to use auto unsign. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + # pylint: disable=import-outside-toplevel + import ctypes + from tvm.relay import quantize as _quantize + + if plan and "abs_max_list" in plan: + return { + "abs_max": float(np.array(plan["abs_max_list"]).max()), + "max": float(np.array(plan["max_list"]).max()), + "min": float(np.array(plan["min_list"]).min()), + "calibrated": False, + } + + def get_pointer(arr, ctypes_type): + ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type)) + return ctypes.cast(ptr, ctypes.c_void_p) + + sign = plan["min"] < 0 if auto_unsign else True + hist = np.array(plan["hist_list"]).sum(axis=0) + hist_ptr = get_pointer(hist.astype(np.int64), ctypes.c_int64) + edge_ptr = get_pointer(plan["edge"].astype(np.float32), ctypes.c_float) + valid_range = 2 ** (nbits - int(sign)) - 1 + scale = _quantize._quantize.FindScaleByKLMinimization(hist_ptr, edge_ptr, bins, valid_range) + return {"scale": valid_range / scale, "sign": sign, "calibrated": True} + + @classmethod + def quantize_normal( + cls, + quantizer: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + scale: float, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> np.ndarray: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + data: array like + The processed tensor. + """ + + valid_range = 2 ** (nbits - int(sign)) - 1 + min_val = -valid_range if sign else 0 + scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") + if scale_tensor is None: + scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon) + quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) + data = cls.amplify_data(data, scale_tensor, min_val, valid_range, rounding) + return data / scale + + @classmethod + def dequantize_normal( + cls, + quantizer: BaseTool, + data: np.ndarray, + name: str, + consumer: str, + scale: float = -1.0, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> np.ndarray: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + data: array like + The processed tensor. + """ + + return data + + @classmethod + def framework(cls): + return MSCFramework.MSC + + @classmethod + def tool_type(cls): + return ToolType.QUANTIZER + + +msc_utils.register_tool_method(QuantizeMethod) diff --git a/python/tvm/contrib/msc/core/tools/quantize/quantizer.py b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py new file mode 100644 index 000000000000..bee8e6fa42eb --- /dev/null +++ b/python/tvm/contrib/msc/core/tools/quantize/quantizer.py @@ -0,0 +1,249 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.tools.quantize.quantizer""" + +from typing import List, Dict, Any + +from tvm.contrib.msc.core.tools.tool import ToolType, BaseTool, ToolStrategy +from tvm.contrib.msc.core import utils as msc_utils + + +class QuantizeStage: + GATHER = "gather" + CALIBRATE = "calibrate" + + +class BaseQuantizer(BaseTool): + """Base quantizer for all""" + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + if self._plan: + self._calibrated = True + self.change_stage(msc_utils.MSCStage.QUANTIZE) + else: + self._calibrated = False + self._calibrate_plan = {} + self.change_stage(QuantizeStage.GATHER) + return super().setup() + + def calibrate(self) -> dict: + """Calibrate the datas + + Returns + ------- + plan: dict + The calibrated plan. + """ + + new_plan = {} + self.change_stage(QuantizeStage.CALIBRATE) + for tensor_id, plan in self._calibrate_plan.items(): + if plan.get("calibrated", False): + new_plan[tensor_id] = plan + continue + name, consumer = self.from_tensor_id(tensor_id) + strategy = self._get_tensor_strategy(name, consumer) + new_plan[tensor_id] = strategy(self, name, consumer, plan) + if any(not plan.get("calibrated", False) for plan in new_plan.values()): + self._calibrate_plan = new_plan + self.change_stage(QuantizeStage.GATHER) + else: + self._calibrated = True + for name, plan in new_plan.items(): + self._plan[name] = {k: v for k, v in plan.items() if k not in ("calibrated")} + self.change_stage(msc_utils.MSCStage.QUANTIZE) + self._forward_cnt = 0 + return new_plan + + def _parse_strategys(self, strategy_list: dict) -> Dict[str, ToolStrategy]: + """Parse the strategy to get valid strategy + + Parameters + ------- + strategy_list: dict + The given strategy + + Returns + ------- + strategys: dict + The parsed strategy. + """ + + def _update_stages(strategy): + if "stages" not in strategy: + strategy["stages"] = [msc_utils.MSCStage.QUANTIZE] + return strategy + + return super()._parse_strategys([_update_stages(s) for s in strategy_list]) + + def _check_tensor(self, name: str, consumer: str) -> bool: + """Check if the tensor should be processed + + Parameters + ------- + name: str + The name of the tensor. + consumer: str + The name of the consumer. + + Returns + ------- + vaild: bool + Whether to process the tensor. + """ + + strategys = self._get_tensor_strategys(name, consumer) + if not strategys: + return False + if any(s.get_config().get("nbits", 8) == -1 for s in strategys): + return False + return True + + def _process_tensor( + self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy] + ) -> Any: + """Process tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + if not self._calibrated: + return self._gather_tensor(tensor, name, consumer, strategys) + return self._quantize_tensor(tensor, name, consumer, strategys) + + def _gather_tensor( + self, tensor: Any, name: str, consumer: str, strategys: List[ToolStrategy] + ) -> Any: + """Gather tensor datas + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + assert len(strategys) == 1, "gather should only has 1 strategy, get " + str(strategys) + tensor_id = self.to_tensor_id(name, consumer) + plan = self._calibrate_plan.get(tensor_id, {}) + if plan.get("calibrated", False): + return tensor + self._calibrate_plan[tensor_id] = strategys[0](self, tensor, name, consumer, plan) + return tensor + + def _quantize_tensor( + self, tensor: Any, name: str, consumer: str, strategys: List[ToolStrategy] + ) -> Any: + """Quantize tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + tensor_id = self.to_tensor_id(name, consumer) + for strategy in strategys: + tensor = strategy(self, tensor, name, consumer, **self._plan[tensor_id]) + return tensor + + def create_tasks(self, **kwargs) -> List[dict]: + """Create tasks for gym + + Parameters + ---------- + kwargs: dict + The kwargs for create tasks. + + Returns + ------- + tasks: list + The tasks. + """ + + tasks, recorded = [], set() + for tensor_id, plan in self._plan.items(): + name, _ = self.from_tensor_id(tensor_id) + if self.is_weight(name) and not kwargs.get("quantize_weights", False): + continue + if name not in recorded: + tasks.append({"name": tensor_id, **plan}) + if self._cache_processed: + recorded.add(name) + return tasks + + @property + def calibrated(self): + return self._calibrated + + @classmethod + def tool_type(cls): + return ToolType.QUANTIZER + + +class DefaultQuantizer(BaseQuantizer): + @classmethod + def tool_style(cls): + return "default" + + +msc_utils.register_tool_cls(DefaultQuantizer) diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index 146cfaf50434..446efd4724f3 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -72,6 +72,8 @@ def __str__(self): return "{}(Cleanup: {}): {} Files".format(self._path, self._cleanup, len(self.listdir())) def __enter__(self): + if not os.path.isdir(self._path): + os.mkdir(self._path) os.chdir(self._path) return self @@ -105,6 +107,9 @@ def add_file(self, name: str, contains: str) -> str: """ file_path = self.relpath(name) + base_dir = os.path.dirname(name) + if base_dir and not os.path.isdir(base_dir): + os.makedirs(base_dir) with open(file_path, "w") as f: f.write(contains) return file_path diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index 440789f8568c..d1b5cd1a2644 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -234,6 +234,10 @@ def load_dict(str_dict: str, flavor: str = "json") -> dict: dict_obj = json.load(f) elif isinstance(str_dict, str): dict_obj = json.loads(str_dict) + elif isinstance(str_dict, dict): + dict_obj = copy_dict(str_dict) + else: + raise Exception("Unexpected str_dict {}({})".format(str_dict, type(str_dict))) assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor) return dict_obj diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py index d25cfd4e674c..e5ebe509564c 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/__init__.py @@ -17,4 +17,5 @@ """tvm.contrib.msc.framework.tensorflow.tools""" from .prune import * +from .quantize import * from .track import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py new file mode 100644 index 000000000000..ed458ef8381d --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorflow.tools.quantize""" + +from .quantizer import * diff --git a/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py new file mode 100644 index 000000000000..dd6f2aac38d2 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorflow/tools/quantize/quantizer.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorflow.tools.quantize.quantizer""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.quantize import BaseQuantizer +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TensorflowQuantizerFactory(object): + """Quantizer factory for tensorflow""" + + def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: + """Create adaptive quantizer + + Parameters + ---------- + base_cls: BaseQuantizer + The base quantizer class + + Returns + ------- + quantizer_cls: BaseQuantizer + The quantizer class. + """ + + class Quantizer(base_cls): + """Adaptive quantizer for tensorflow""" + + @classmethod + def framework(cls): + return MSCFramework.TENSORFLOW + + return Quantizer + + +factory = TensorflowQuantizerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py index b6497e9258b7..a5df42f78b17 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/sources.py @@ -302,6 +302,171 @@ def get_trt_common_cc_code() -> str: """ +def get_trt_quantize_h_code(): + """Create trt_quantize header file codes + + Returns + ------- + source: str + The trt_quantize header source. + """ + + return """#ifndef TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ +#define TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "base.h" +#include "trt_common.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace nvinfer1; + +class CalibrateHelper { + public: + CalibrateHelper(const std::string& range_file, const std::string& folder, int max_size = -1); + + ~CalibrateHelper() { + for (const auto& buffer : cpu_buffers_) { + free(buffer); + } + for (const auto& buffer : gpu_buffers_) { + CHECK(cudaFree(buffer)); + } + } + + bool GetBatch(void* bindings[], const char* names[], int nbBindings); + + const void* ReadCache(size_t& length); + + void WriteCache(const void* cache, size_t length); + + private: + std::unique_ptr reader_; + std::string range_file_; + std::vector cache_; + std::vector cpu_buffers_; + std::vector gpu_buffers_; +}; + +#define CALIBRATE_MEMBERS(Calibrator) \\ + public: \\ + Calibrator(const std::string& range_file, const std::string& folder, int max_size = -1) { \\ + helper_.reset(new CalibrateHelper(range_file, folder, max_size)); \\ + } \\ + \\ + virtual ~Calibrator() {} \\ + \\ + int getBatchSize() const noexcept override { return 1; } \\ + \\ + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { \\ + return helper_->GetBatch(bindings, names, nbBindings); \\ + } \\ + \\ + const void* readCalibrationCache(size_t& length) noexcept override { \\ + return helper_->ReadCache(length); \\ + } \\ + \\ + void writeCalibrationCache(const void* cache, size_t length) noexcept override { \\ + return helper_->WriteCache(cache, length); \\ + } \\ + \\ + private: \\ + std::unique_ptr helper_; + +class MSCInt8EntropyCalibrator : public IInt8EntropyCalibrator { + CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator) +}; + +class MSCInt8EntropyCalibrator2 : public IInt8EntropyCalibrator2 { + CALIBRATE_MEMBERS(MSCInt8EntropyCalibrator2) +}; + +} // namespace msc +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_MSC_UTILS_TRT_QUANTIZE_H_ +""" + + +def get_trt_quantize_cc_code(): + """Create trt_quantize cc file codes + + Returns + ------- + source: str + The trt_quantize cc source. + """ + + return """#include "trt_quantize.h" + +namespace tvm { +namespace contrib { +namespace msc { + +using namespace nvinfer1; + +CalibrateHelper::CalibrateHelper(const std::string& range_file, const std::string& folder, + int max_size) { + range_file_ = range_file; + reader_.reset(new DatasetReader(folder, max_size)); + const auto& tensor_names = reader_->GetTensorNames(); + cpu_buffers_.resize(tensor_names.size()); + gpu_buffers_.resize(tensor_names.size()); + for (size_t i = 0; i < tensor_names.size(); i++) { + size_t tensor_size = reader_->GetTensorSize(tensor_names[i]); + cpu_buffers_[i] = malloc(tensor_size); + CHECK(cudaMalloc(&gpu_buffers_[i], tensor_size)); + } +} + +bool CalibrateHelper::GetBatch(void* bindings[], const char* names[], int nbBindings) { + if (!reader_->ReadNext(cpu_buffers_.data())) { + return false; + } + for (size_t i = 0; i < nbBindings; i++) { + CHECK(cudaMemcpy(gpu_buffers_[i], cpu_buffers_[i], reader_->GetTensorSize(names[i]), + cudaMemcpyHostToDevice)); + bindings[i] = gpu_buffers_[i]; + } + return true; +} + +const void* CalibrateHelper::ReadCache(size_t& length) { + cache_.clear(); + std::ifstream in_file(range_file_, std::ifstream::binary); + if (!in_file.is_open()) { + return nullptr; + } + in_file >> std::noskipws; + std::copy(std::istream_iterator(in_file), std::istream_iterator(), + std::back_inserter(cache_)); + length = cache_.size(); + return length > 0 ? &cache_[0] : nullptr; +} + +void CalibrateHelper::WriteCache(const void* cache, size_t length) { + std::ofstream output(range_file_, std::ios::binary); + output.write(reinterpret_cast(cache), length); +} + +} // namespace msc +} // namespace contrib +} // namespace tvm +""" + + def get_trt_sources() -> Dict[str, str]: """Create trt sources for cpp codegen @@ -313,6 +478,11 @@ def get_trt_sources() -> Dict[str, str]: sources = get_base_sources() sources.update( - {"trt_common.h": get_trt_common_h_code(), "trt_common.cc": get_trt_common_cc_code()} + { + "trt_common.h": get_trt_common_h_code(), + "trt_common.cc": get_trt_common_cc_code(), + "trt_quantize.h": get_trt_quantize_h_code(), + "trt_quantize.cc": get_trt_quantize_cc_code(), + } ) return sources diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index 15a42b2cf967..c66f8d145035 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -17,9 +17,14 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.tensorrt.runtime.runner""" +from typing import Any, List, Dict + import tvm +from tvm.contrib.msc.core.ir import MSCGraph from tvm.contrib.msc.core.runtime import BYOCRunner +from tvm.contrib.msc.core.tools import ToolType from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.framework.tensorrt.frontend import ( partition_for_tensorrt, transform_for_tensorrt, @@ -44,6 +49,54 @@ def setup(self) -> dict: self._device = "cuda" return super().setup() + def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: + """Execute tool and get plan + + Parameters + ------- + tool_type: str + The tool type, should be in ToolType + data_loader: + The data loader + """ + + assert tool_type in self._tools, "Can not find tool " + str(tool_type) + if tool_type == ToolType.QUANTIZER: + quantizer = self.get_tool(ToolType.QUANTIZER) + assert data_loader, "data_loader should be given to plan prune" + for inputs in data_loader(): + self.run(inputs) + self._generate_model() + quantizer.calibrate() + assert quantizer.calibrated, "Failed to calibrate the tenosrrt quantizer" + return super().apply_tool(tool_type, data_loader) + + def _generate_model( + self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None + ) -> Any: + """Codegen the model according to framework + + Parameters + ------- + graphs: list + The msc graphs. + weights: list> + The weights + + Returns + ------- + model: Any + The meta model + """ + + codegen = self._generate_config.get("codegen") + if not isinstance(codegen, (list, tuple)): + self._generate_config["codegen"] = [msc_utils.copy_dict(codegen)] * len(self._graphs) + for tool in self.get_tools(): + self._generate_config = tool.config_generate(self._generate_config) + + return super()._generate_model(graphs, weights) + @classmethod def target_transform(cls, mod: tvm.IRModule): """Transform the mod by target. diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py index ecc82bc40f5e..c010a42004fd 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/__init__.py @@ -17,4 +17,5 @@ """tvm.contrib.msc.framework.tensorrt.tools""" from .prune import * +from .quantize import * from .track import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py new file mode 100644 index 000000000000..e47b3324c9ee --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.tools.quantize""" + +from .quantizer import * +from .method import * diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py new file mode 100644 index 000000000000..0feb836d1350 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/method.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.framework.tensorrt.tools.quantize.method""" + +from typing import Dict + +from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TensorRTQuantizeMethod(QuantizeMethod): + """Default quantize method for tensorrt""" + + @classmethod + def quantize_normal( + cls, + quantizer: BaseQuantizer, + tensor_ctx: Dict[str, str], + name: str, + consumer: str, + scale: float, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> Dict[str, str]: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + tensor_ctx: dict + Tensor describe items. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + tensor_ctx: dict + Tensor describe items. + """ + + if quantizer.is_weight(name): + return tensor_ctx + dtype = quantizer.find_tensor(name).dtype_name + precision = "DataType::k" + if nbits == 8: + precision += "INT8" + elif dtype == "float16": + precision += "HALF" + elif dtype == "float32": + precision += "FLOAT" + else: + raise TypeError("nbits {} is not supported".format(nbits)) + tensor_ctx["processed"].extend( + [ + "{}->setPrecision({})".format(tensor_ctx["producer"], precision), + "{0}->setDynamicRange(-{1}, {1})".format(tensor_ctx["tensor"], scale), + ] + ) + return tensor_ctx + + @classmethod + def dequantize_normal( + cls, + quantizer: BaseQuantizer, + tensor_ctx: Dict[str, str], + name: str, + consumer: str, + scale: float, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> Dict[str, str]: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + tensor_ctx: dict + Tensor describe items. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + tensor_ctx: dict + Tensor describe items. + """ + + return cls.quantize_normal( + quantizer, tensor_ctx, name, consumer, scale, nbits, axis, sign, rounding, epsilon + ) + + @classmethod + def framework(cls): + return MSCFramework.TENSORRT + + +msc_utils.register_tool_method(TensorRTQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py new file mode 100644 index 000000000000..f97118619603 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tensorrt/tools/quantize/quantizer.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tensorrt.tools.quantize.quantizer""" + +import os +import struct +from typing import List, Dict, Any, Tuple + +import tvm +from tvm.contrib.msc.core.ir import MSCGraph +from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy +from tvm.contrib.msc.core.tools.quantize import BaseQuantizer, QuantizeStage +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TensorRTQuantizerFactory(object): + """Quantizer factory for tensorrt""" + + def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: + """Create adaptive quantizer + + Parameters + ---------- + base_cls: BaseQuantizer + The base quantizer class + + Returns + ------- + quantizer_cls: BaseQuantizer + The quantizer class. + """ + + class Quantizer(base_cls): + """Adaptive quantizer for tensorrt""" + + def setup(self) -> dict: + """Setup the tool + + Returns + ------- + info: dict + The setup info. + """ + + if self._plan: + self._use_range = all( + info.get("use_range", False) for info in self._plan.values() + ) + else: + self._use_range = True + return super().setup() + + def _reset( + self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] + ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + """Reset the tool + + Parameters + ---------- + graphs: list + The msc graphs. + weights: list> + The weights + + Returns + ------- + graphs: list + The msc graphs. + weights: list> + The weights + """ + + config_folder = msc_utils.get_config_dir() + self._range_files = [config_folder.relpath(g.name + ".range") for g in graphs] + calibrate_root = msc_utils.get_dataset_dir().create_dir("Calibrate") + self._calibrate_folders = [calibrate_root.relpath(g.name) for g in graphs] + if self._calibrated: + if self._use_range: + for r_file, graph in zip(self._range_files, graphs): + if not os.path.isfile(r_file): + self._plan_to_range(graph, r_file) + self._logger.debug( + "G[%s](%s) use range file: %s", + graph.name, + self._stage, + r_file, + ) + else: + self._quantized_tensors = set() + elif self._stage == QuantizeStage.GATHER: + self._calibrate_savers = [] + for folder, graph in zip(self._calibrate_folders, graphs): + saver_options = {"input_names": [i.name for i in graph.get_inputs()]} + saver = msc_utils.IODataSaver(folder, saver_options) + self._calibrate_savers.append(saver) + self._logger.debug( + "G[%s](%s) create calibrate saver: %s", + graph.name, + self._stage, + saver, + ) + else: + assert all( + msc_utils.is_io_dataset(f) for f in self._calibrate_folders + ), "Some IODataset missing: " + str(self._calibrate_folders) + return super()._reset(graphs, weights) + + def _execute_after_build(self, codegen_context: dict) -> dict: + """Execute after model build + + Parameters + ---------- + codegen_context: dict + The context. + + Returns + ---------- + codegen_context: dict + The processed context. + """ + + if self._stage == QuantizeStage.GATHER and self._forward_cnt == 0: + return codegen_context + if not self._use_range: + return codegen_context + processed = ["// Set int8 calibrator"] + range_file = self.get_graph().name + ".range" + version = [int(v) for v in codegen_context["version"].split(".")] + if msc_utils.compare_version(version, [6, 0, 0]) >= 0: + configer = codegen_context["config"] + else: + configer = codegen_context["builder"] + # check the range file if calibrated + if self._calibrated: + processed.extend( + [ + 'if (!FileUtils::FileExist("{}")) {{'.format(range_file), + ' logger.log(ILogger::Severity::kERROR, "{} not exist!");'.format( + range_file + ), + " return -1;", + "}", + ] + ) + processed.extend( + [ + 'MSCInt8EntropyCalibrator2 calibrator("{}", "{}");'.format( + range_file, self._calibrate_folders[self._graph_id] + ), + "{}->setInt8Calibrator(&calibrator);".format(configer), + ] + ) + codegen_context["processed"].extend(processed) + return codegen_context + + def _execute_before_forward(self, step_context: dict) -> dict: + """Execute before model forward + + Parameters + ---------- + step_context: dict + The context. + + Returns + ---------- + step_context: dict + The processed context. + """ + + if self._stage == QuantizeStage.GATHER: + saver = self._calibrate_savers[self._graph_id] + saver.save_batch( + {name: data.asnumpy() for name, data in step_context["datas"].items()} + ) + for name, data in step_context["datas"].items(): + self.debug_tensor(data, name, "any", "ctx_gathered") + super()._execute_before_forward(step_context) + + def _quantize_tensor( + self, + tensor_ctx: Dict[str, str], + name: str, + consumer: str, + strategys: List[ToolStrategy], + ) -> Dict[str, str]: + """Quantize tensor + + Parameters + ------- + tensor_ctx: dict + Tensor describe items. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor_ctx: dict + Tensor items with processed. + """ + + if not self._use_range and name not in self._quantized_tensors: + self._quantized_tensors.add(name) + return super()._quantize_tensor(tensor_ctx, name, consumer, strategys) + return tensor_ctx + + def calibrate(self) -> dict: + """Calibrate the datas + + Returns + ------- + plan: dict + The calibrated plan. + """ + + for r_file, graph in zip(self._range_files, self._graphs): + self._range_to_plan(graph, r_file) + self._calibrated, self._forward_cnt = True, 0 + self.change_stage("quantize") + return self._plan + + def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]: + """Update the generate configs + + Parameters + ---------- + generate_config: dict + The generate_config. + + Returns + ------- + generate_config: dict + The updated generate_config. + """ + + if self._calibrated: + if self._use_range: + for config, r_file in zip(generate_config["codegen"], self._range_files): + if os.path.isfile(r_file): + config.update({"range_file": r_file, "precision": "int8"}) + elif self._stage == QuantizeStage.GATHER and self._forward_cnt > 0: + for config, saver, r_file in zip( + generate_config["codegen"], self._calibrate_savers, self._range_files + ): + saver.finalize() + self._logger.debug( + "%ssave %d datas to %s", + self.msg_mark(in_forward=False), + self._forward_cnt, + saver.folder, + ) + config.update( + {"dataset": saver.folder, "range_file": r_file, "precision": "int8"} + ) + self.change_stage(QuantizeStage.CALIBRATE) + return generate_config + + def _plan_to_range(self, graph: MSCGraph, range_file: str, title="MSCCalibrate"): + """Extract plan config to range_file + + Parameters + ---------- + plan: dict + The plan. + graph: MSCGraph + The graph. + range_file: str + The output range_file path. + title: str + The title of the range file. + """ + + def _scale_to_hex(scale): + return hex(struct.unpack(" torch.Tensor: + """Amplify the data + + Parameters + ---------- + data: torch.Tensor + The source data. + scale: float + The scale factor + min_val: float + The min. + max_val: float + The max. + rounding: str + The round method + + Returns + ------- + data: torch.Tensor + The processed data. + """ + + if rounding == "null": + return torch.clamp(data * scale, min_val, max_val) + if rounding == "floor": + return torch.clamp(torch.floor(data * scale), min_val, max_val) + if rounding == "ceil": + return torch.clamp(torch.ceil(data * scale), min_val, max_val) + if rounding == "round": + return torch.clamp(torch.round(data * scale), min_val, max_val) + if rounding == "trunc": + return torch.clamp(torch.trunc(data * scale), min_val, max_val) + if rounding == "logic_round": + data = torch.clamp(data * scale, min_val, max_val) + negative_ceil = torch.where( + torch.logical_and(data < 0, (data - torch.floor(data)) == 0.5), torch.ceil(data), 0 + ) + data = torch.where( + torch.logical_and(data < 0, (data - torch.floor(data)) == 0.5), 0, data + ) + data = torch.where((data - torch.floor(data)) >= 0.5, torch.ceil(data), data) + data = torch.where((data - torch.floor(data)) < 0.5, torch.floor(data), data) + return data + negative_ceil + raise TypeError("Unexpected rounding " + str(rounding)) + + @classmethod + def gather_maxmin( + cls, + quantizer: BaseQuantizer, + data: torch.Tensor, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + ) -> dict: + """Gather the data by max/min + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + abs_max_list = plan.get("abs_max_list", []) + abs_max_list.append(float(torch.abs(data).max())) + max_list = plan.get("max_list", []) + max_list.append(float(data.max())) + min_list = plan.get("min_list", []) + min_list.append(float(data.min())) + return { + "abs_max_list": abs_max_list, + "max_list": max_list, + "min_list": min_list, + "calibrated": False, + } + + @classmethod + def gather_max_per_channel( + cls, + quantizer: BaseQuantizer, + data: torch.Tensor, + name: str, + consumer: str, + plan: dict, + nbits: int = 8, + channel: str = "O", + auto_unsign: bool = False, + ) -> dict: + """Gather the data by max_per_channel + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + plan: dict + The pre-calibrated plan. + nbits: int + The number bits for quantize. + channel: str + The channel reference. + auto_unsign: bool + Whether to use auto unsign. + + Returns + ------- + plan: dict + The plan of the tensor. + """ + + weight = quantizer.find_tensor(name) + axis = weight.layout_of(channel) + channel_max = [torch.abs(d).max() for d in torch.chunk(data, data.shape[axis], dim=axis)] + sign = data.min() < 0 if auto_unsign else True + valid_range = 2 ** (nbits - int(sign)) - 1 + scale = [valid_range / float(m) for m in channel_max] + return {"scale": scale, "sign": sign, "axis": axis, "calibrated": True} + + @classmethod + def quantize_normal( + cls, + quantizer: BaseQuantizer, + data: torch.Tensor, + name: str, + consumer: str, + scale: float, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> torch.Tensor: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: torch.Tensor + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + data: torch.Tensor + The processed tensor. + """ + + valid_range = 2 ** (nbits - int(sign)) - 1 + min_val = -valid_range if sign else 0 + scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") + if scale_tensor is None: + scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon) + if isinstance(scale_tensor, np.ndarray): + scale_tensor = torch.from_numpy(scale_tensor).to(data.device) + quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) + data = cls.amplify_data(data, scale_tensor, min_val, valid_range, rounding) + return data / scale_tensor + + @classmethod + def framework(cls): + return MSCFramework.TORCH + + +msc_utils.register_tool_method(TorchQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py new file mode 100644 index 000000000000..0e5c599b877a --- /dev/null +++ b/python/tvm/contrib/msc/framework/torch/tools/quantize/quantizer.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.torch.tools.quantize.quantizer""" + +from tvm.contrib.msc.core.tools.tool import ToolType +from tvm.contrib.msc.core.tools.quantize import BaseQuantizer +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TorchQuantizerFactory(object): + """Quantizer factory for torch""" + + def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: + """Create adaptive quantizer + + Parameters + ---------- + base_cls: BaseQuantizer + The base quantizer class + + Returns + ------- + quantizer_cls: BaseQuantizer + The quantizer class. + """ + + class Quantizer(base_cls): + """Adaptive quantizer for torch""" + + @classmethod + def framework(cls): + return MSCFramework.TORCH + + return Quantizer + + +factory = TorchQuantizerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py index 226ae3102de7..ddfd41f3c8f7 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/__init__.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/__init__.py @@ -17,4 +17,5 @@ """tvm.contrib.msc.framework.tvm.tools""" from .prune import * +from .quantize import * from .track import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py new file mode 100644 index 000000000000..0026724989c9 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.framework.tvm.tools.quantize""" + +from .quantizer import * +from .method import * diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py new file mode 100644 index 000000000000..9966e9c1af5d --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/method.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.framework.tvm.tools.quantize.method""" + +from typing import Tuple +import numpy as np + +import tvm +from tvm.relax import op as relax_op +from tvm.contrib.msc.core.tools.quantize import QuantizeMethod, BaseQuantizer +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core import _ffi_api + + +class TVMQuantizeMethod(QuantizeMethod): + """Default quantize method for tvm""" + + @classmethod + def get_quantize_cache( + cls, + quantizer: BaseQuantizer, + data: tvm.relax.Var, + name: str, + consumer: str, + scale: float, + axis: int = -1, + epsilon: float = 1.0 / (1 << 24), + ) -> Tuple[tvm.relax.Constant, tvm.relax.Constant]: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: tvm.relax.Var + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + axis: int + The axis. + epsilon: float + The epsilon for get scale. + + Returns + ------- + scale_tensor: tvm.relax.Constant + The scale_tensor. + zero_point: tvm.relax.Constant + The zero_point. + """ + + name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) + scale_tensor = quantizer._get_tensor_cache(name, consumer, "scale_tensor") + zero_point = quantizer._get_tensor_cache(name, consumer, "zero_point") + if scale_tensor is None: + scale_tensor = cls.get_scale_tensor(data, scale, axis, epsilon, expand_dims=False) + if isinstance(scale_tensor, float): + scale_tensor = np.array(scale_tensor) + scale_tensor = scale_tensor.astype(quantizer.find_tensor(name).dtype_name) + zero_point = np.zeros_like(scale_tensor).astype("int8") + scale_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_scale") + scale_tensor = tvm.relax.Constant(tvm.nd.array(scale_tensor), span=scale_span) + zp_span = _ffi_api.SpanCreateWithAttr("name", name_prefix + "_zero_point") + zero_point = tvm.relax.Constant(tvm.nd.array(zero_point), span=zp_span) + quantizer._save_tensor_cache(name, consumer, "scale_tensor", scale_tensor) + quantizer._save_tensor_cache(name, consumer, "zero_point", zero_point) + return scale_tensor, zero_point + + @classmethod + def quantize_normal( + cls, + quantizer: BaseQuantizer, + data: tvm.relax.Var, + name: str, + consumer: str, + scale: float, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> tvm.relax.Var: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: tvm.relax.Var + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + data: tvm.relax.Var + The processed tensor. + """ + + if nbits == 8: + dtype = "int8" + else: + raise TypeError("Unexpected nbits " + str(nbits)) + name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) + scale_tensor, zero_point = cls.get_quantize_cache( + quantizer, data, name, consumer, scale, axis, epsilon + ) + expr = relax_op.quantize(data, scale_tensor, zero_point, axis, dtype) + return quantizer._block_builder.emit(expr, name_hint=name_prefix + "_quantize") + + @classmethod + def dequantize_normal( + cls, + quantizer: BaseQuantizer, + data: tvm.relax.Var, + name: str, + consumer: str, + scale: float = -1.0, + nbits: int = 8, + axis: int = -1, + sign: bool = True, + rounding: str = "round", + epsilon: float = 1.0 / (1 << 24), + ) -> tvm.relax.Var: + """Calibrate the data by kl_divergence + + Parameters + ---------- + quantizer: BaseQuantizer + The quantizer + data: np.ndarray + The source data. + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scale: float + The scale factor + nbits: int + The number bits for quantize. + axis: int + The axis. + sign: bool + Whether to use sign. + rounding str + The rounding method. + epsilon: float + The epsilon for get scale. + + Returns + ------- + data: array like + The processed tensor. + """ + + name_prefix = name if quantizer._cache_processed else quantizer.to_tensor_id(name, consumer) + scale_tensor, zero_point = cls.get_quantize_cache( + quantizer, data, name, consumer, scale, axis, epsilon + ) + expr = relax_op.dequantize( + data, scale_tensor, zero_point, axis, quantizer.find_tensor(name).dtype + ) + return quantizer._block_builder.emit(expr, name_hint=name_prefix + "_dequantize") + + @classmethod + def framework(cls): + return MSCFramework.TVM + + +msc_utils.register_tool_method(TVMQuantizeMethod) diff --git a/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py new file mode 100644 index 000000000000..d4680b9088b3 --- /dev/null +++ b/python/tvm/contrib/msc/framework/tvm/tools/quantize/quantizer.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""tvm.contrib.msc.framework.tvm.tools.quantize.quantizer""" + +from typing import List, Union + +import tvm +from tvm.contrib.msc.core.tools.tool import ToolType, ToolStrategy +from tvm.contrib.msc.core.tools.quantize import BaseQuantizer +from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils + + +class TVMQuantizerFactory(object): + """Quantizer factory for tvm""" + + def create(self, base_cls: BaseQuantizer) -> BaseQuantizer: + """Create adaptive quantizer + + Parameters + ---------- + base_cls: BaseQuantizer + The base quantizer class + + Returns + ------- + quantizer_cls: BaseQuantizer + The quantizer class. + """ + + class Quantizer(base_cls): + """Adaptive quantizer for tvm""" + + def _execute_before_build(self, block_builder: tvm.relax.BlockBuilder): + """Execute before model build + + Parameters + ---------- + block_builder: tvm.relax.BlockBuilder + The block builder. + """ + + self._block_builder = block_builder + self._gather_tensors, self._gather_names = {}, [] + super()._execute_before_build(block_builder) + + def _execute_after_build( + self, output: Union[tvm.relax.Var, List[tvm.relax.DataflowVar]] + ) -> List[tvm.relax.Var]: + """Execute after model build + + Parameters + ---------- + output: var or list + The output var of the model. + + Returns + ------- + outputs: list + The modified outputs var. + """ + + if self._calibrated: + return super()._execute_after_build(output) + self._gather_names = list(sorted(self._gather_tensors.keys())) + gather_tensors = [self._gather_tensors[o]["tensor"] for o in self._gather_names] + if isinstance(output, tvm.relax.Var): + return super()._execute_after_build([output] + gather_tensors) + return super()._execute_after_build(output + gather_tensors) + + def _execute_after_forward( + self, outputs: List[tvm.runtime.NDArray] + ) -> Union[tvm.runtime.NDArray, List[tvm.runtime.NDArray]]: + """Execute after model forward + + Parameters + ---------- + outputs: list + The output datas. + + Returns + ------- + output: np.ndarray or list + The modified output ndarray. + """ + + if self._calibrated: + return super()._execute_after_forward(outputs) + output_num = len(outputs) - len(self._gather_names) + for data, name in zip(outputs[output_num:], self._gather_names): + info = self._gather_tensors[name] + for consumer in info["consumers"]: + strategys = self._get_tensor_strategys(name, consumer) + self._gather_tensor(data, name, consumer, strategys) + if output_num == 1: + return super()._execute_after_forward(outputs[0]) + return super()._execute_after_forward(outputs[:output_num]) + + def _process_tensor( + self, + tensor: tvm.relax.DataflowVar, + name: str, + consumer: str, + scope: str, + strategys: List[ToolStrategy], + ) -> tvm.relax.DataflowVar: + """Process tensor + + Parameters + ------- + tensor: Any + Tensor in framework + name: str + The name of the tensor. + consumer: str + The name of the consumer. + scope: str + The scope mark teacher| student| null. + strategys: list + The strategys for the tensor. + + Returns + ------- + tensor: Any + The processed tensor. + """ + + if not self._calibrated: + if self.is_weight(name): + return self._gather_tensor(self.get_data(name), name, consumer, strategys) + if name not in self._gather_tensors: + self._gather_tensors[name] = { + "consumers": [consumer], + "tensor": tensor, + } + self._gather_names.append(name) + else: + self._gather_tensors[name]["consumers"].append(consumer) + return tensor + return self._quantize_tensor(tensor, name, consumer, strategys) + + @classmethod + def framework(cls): + return MSCFramework.TVM + + return Quantizer + + +factory = TVMQuantizerFactory() +tools = msc_utils.get_registered_tool_cls(MSCFramework.MSC, ToolType.QUANTIZER, tool_style="all") +for tool in tools.values(): + msc_utils.register_tool_cls(factory.create(tool)) diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index bbd6d452ad23..8a37ef951f4b 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -378,6 +378,10 @@ def _tool_enabled(tool_type: str) -> bool: if _tool_enabled(ToolType.PRUNER): self._apply_tool(ToolType.PRUNER, stage_config) + # run quantize + if _tool_enabled(ToolType.QUANTIZER): + self._apply_tool(ToolType.QUANTIZER, stage_config) + # optimize and get the runner msc_utils.time_stamp(MSCStage.OPTIMIZE) return self._create_runner( diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index f396b81ea463..7e981d348ba1 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -77,7 +77,42 @@ def get_tool_config(tool_type): "strategys": [{"method": "per_channel", "density": 0.8}], } elif tool_type == ToolType.QUANTIZER: - raise NotImplementedError("Quantizer is not supported") + # pylint: disable=import-outside-toplevel + from tvm.contrib.msc.core.tools.quantize import QuantizeStage + + config = { + "plan_file": "msc_quantizer.json", + "strategys": [ + { + "method": "gather_maxmin", + "op_types": ["nn.conv2d", "msc.linear"], + "tensor_types": ["input", "output"], + "stages": [QuantizeStage.GATHER], + }, + { + "method": "gather_max_per_channel", + "op_types": ["nn.conv2d", "msc.linear"], + "tensor_types": ["weight"], + "stages": [QuantizeStage.GATHER], + }, + { + "method": "calibrate_maxmin", + "op_types": ["nn.conv2d", "msc.linear"], + "tensor_types": ["input", "output"], + "stages": [QuantizeStage.CALIBRATE], + }, + { + "method": "quantize_normal", + "op_types": ["nn.conv2d", "msc.linear"], + "tensor_types": ["input", "weight"], + }, + { + "method": "dequantize_normal", + "op_types": ["nn.conv2d", "msc.linear"], + "tensor_types": ["output"], + }, + ], + } elif tool_type == ToolType.TRACKER: config = { "plan_file": "msc_tracker.json", @@ -183,7 +218,7 @@ def get_model_info(compile_type): raise TypeError("Unexpected compile_type " + str(compile_type)) -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.TRACKER]) +@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER]) def test_tvm_tool(tool_type): """Test tools for tvm""" @@ -194,7 +229,10 @@ def test_tvm_tool(tool_type): @requires_tensorrt -@pytest.mark.parametrize("tool_type", [ToolType.PRUNER, ToolType.TRACKER]) +@pytest.mark.parametrize( + "tool_type", + [ToolType.PRUNER, ToolType.QUANTIZER, ToolType.TRACKER], +) def test_tensorrt_tool(tool_type): """Test tools for tensorrt"""