Skip to content

Commit

Permalink
support export after save model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Sep 9, 2024
1 parent 5b54ac4 commit d8479a4
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 7 deletions.
7 changes: 7 additions & 0 deletions ppocr/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pickle
import six
import json

import paddle

Expand Down Expand Up @@ -248,6 +249,12 @@ def save_model(
if prefix == "best_accuracy":
arch.backbone.model.save_pretrained(best_model_path)

save_model_info = kwargs.pop("save_model_info", False)
if save_model_info:
with open(os.path.join(model_path, f"{prefix}.info.json"), "w") as f:
json.dump(kwargs, f)
logger.info("Already save model info in {}".format(model_path))

# save metric and config
with open(metric_prefix + ".states", "wb") as f:
pickle.dump(kwargs, f, protocol=2)
Expand Down
77 changes: 76 additions & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,88 @@
import yaml
import json
import paddle
import paddle.nn as nn
from paddle.jit import to_static
from collections import OrderedDict
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser


class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument("-o", "--opt", nargs="+", help="set configuration options")
self.add_argument(
"-p",
"--profiler_options",
type=str,
default=None,
help="The option of profiler, which should be in format "
'"key1=value1;key2=value2;key3=value3".',
)

def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, "Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args

def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split("=")
config[k] = yaml.load(v, Loader=yaml.Loader)
return config


def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in [".yml", ".yaml"], "only support yaml files for now"
config = yaml.load(open(file_path, "rb"), Loader=yaml.Loader)
return config


def merge_config(config, opts):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in opts.items():
if "." not in key:
if isinstance(value, dict) and key in config:
config[key].update(value)
else:
config[key] = value
else:
sub_keys = key.split(".")
assert sub_keys[0] in config, (
"the sub_keys can only be one of global_config: {}, but get: "
"{}, please check your running command".format(
config.keys(), sub_keys[0]
)
)
cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
return config


def export_single_model(
Expand Down
82 changes: 76 additions & 6 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tqdm import tqdm
import cv2
import numpy as np
import copy
from argparse import ArgumentParser, RawDescriptionHelpFormatter

from ppocr.utils.stats import TrainingStats
Expand All @@ -36,6 +37,8 @@
from ppocr.utils.loggers import WandbLogger, Loggers
from ppocr.utils import profiler
from ppocr.data import build_dataloader
from .export_model import dump_infer_config
from .export_model import export_single_model


class ArgsParser(ArgumentParser):
Expand Down Expand Up @@ -205,6 +208,7 @@ def train(
eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
profiler_options = config["profiler_options"]
print_mem_info = config["Global"].get("print_mem_info", True)
model_export_enabled = config["Global"].get("model_export_enabled", False)

global_step = 0
if "global_step" in pre_best_model_dict:
Expand Down Expand Up @@ -303,6 +307,7 @@ def train(
)

for idx, batch in enumerate(train_dataloader):
model.train()
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start
if idx >= max_iter:
Expand Down Expand Up @@ -484,18 +489,26 @@ def train(
if cur_metric[main_indicator] >= best_model_dict[main_indicator]:
best_model_dict.update(cur_metric)
best_model_dict["best_epoch"] = epoch
prefix = "best_accuracy"
save_model(
model,
optimizer,
save_model_dir,
(
os.path.join(save_model_dir, prefix)
if model_export_enabled
else save_model_dir
),
logger,
config,
is_best=True,
prefix="best_accuracy",
prefix=prefix,
save_model_info=model_export_enabled,
best_model_dict=best_model_dict,
epoch=epoch,
global_step=global_step,
)
if model_export_enabled:
export(config, model, os.path.join(save_model_dir, prefix))
best_str = "best metric, {}".format(
", ".join(
["{}: {}".format(k, v) for k, v in best_model_dict.items()]
Expand All @@ -520,35 +533,51 @@ def train(

reader_start = time.time()
if dist.get_rank() == 0:
prefix = "latest"
save_model(
model,
optimizer,
save_model_dir,
(
os.path.join(save_model_dir, prefix)
if model_export_enabled
else save_model_dir
),
logger,
config,
is_best=False,
prefix="latest",
prefix=prefix,
save_model_info=model_export_enabled,
best_model_dict=best_model_dict,
epoch=epoch,
global_step=global_step,
)
if model_export_enabled:
export(config, model, os.path.join(save_model_dir, prefix))

if log_writer is not None:
log_writer.log_model(is_best=False, prefix="latest")

if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
prefix = "iter_epoch_{}".format(epoch)
save_model(
model,
optimizer,
save_model_dir,
(
os.path.join(save_model_dir, prefix)
if model_export_enabled
else save_model_dir
),
logger,
config,
is_best=False,
prefix="iter_epoch_{}".format(epoch),
prefix=prefix,
save_model_info=model_export_enabled,
best_model_dict=best_model_dict,
epoch=epoch,
global_step=global_step,
)
if model_export_enabled:
export(config, model, os.path.join(save_model_dir, prefix))
if log_writer is not None:
log_writer.log_model(
is_best=False, prefix="iter_epoch_{}".format(epoch)
Expand Down Expand Up @@ -842,3 +871,44 @@ def preprocess(is_train=False):

logger.info("train with paddle {} and device {}".format(paddle.__version__, device))
return config, device, logger, log_writer


def export(config, base_model, save_path):
model = copy.deepcopy(base_model)
logger = get_logger()
yaml_path = os.path.join(save_path, "inference.yml")
model.eval()
arch_config = config["Architecture"]
if (
arch_config["algorithm"] in ["SVTR", "CPPD"]
and arch_config["Head"]["name"] != "MultiHead"
):
input_shape = config["Eval"]["dataset"]["transforms"][-2]["SVTRRecResizeImg"][
"image_shape"
]
elif arch_config["algorithm"].lower() == "ABINet".lower():
rec_rs = [
c
for c in config["Eval"]["dataset"]["transforms"]
if "ABINetRecResizeImg" in c
]
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
else:
input_shape = None

if arch_config["algorithm"] in [
"Distillation",
]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(
model.model_list[idx], archs[idx], sub_model_save_path, logger
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, input_shape=input_shape
)

dump_infer_config(config, yaml_path, logger)

0 comments on commit d8479a4

Please sign in to comment.