Skip to content

Commit

Permalink
make the hard_example_mining alg to be a common interface
Browse files Browse the repository at this point in the history
Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
  • Loading branch information
JoeyHwong-gk committed Jul 31, 2021
1 parent e833a23 commit 99b6cb3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import json
import time
import warnings

Expand All @@ -25,6 +24,7 @@
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator


he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']
Expand Down Expand Up @@ -105,22 +105,11 @@ def deal_infer_rsl(model_output):
def run():
camera_address = Context.get_parameters('video_url')

hard_example_name = Context.get_parameters('HEM_NAME', "IBT")
hem_parameters = Context.get_parameters('HEM_PARAMETERS')

try:
hem_parameters = json.loads(hem_parameters)
hem_parameters = {
p["key"]: p.get("value", "")
for p in hem_parameters if "key" in p
}
except:
hem_parameters = {}

hard_example_mining = {
"method": hard_example_name,
"param": hem_parameters
}
# get hard exmaple mining algorithm from config e.g.:
# {"method": "IBT", "param": {"threshold_img": 0.9}}
hard_example_mining = IncrementalLearning.get_hem_algorithm(
threshold_img=0.9
)

input_shape_str = Context.get_parameters("input_shape")
input_shape = tuple(int(v) for v in input_shape_str.split(","))
Expand Down
24 changes: 24 additions & 0 deletions lib/sedna/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,27 @@ def get_parameters(cls, param, default=None):
value = cls.parameters.get(
param) or cls.parameters.get(str(param).upper())
return value if value else default

@classmethod
def get_crd_algorithm(cls, algorithm, **param) -> dict:
"""get the algorithm and parameter which define in crd"""
hard_example_name = cls.get_parameters(f'{algorithm}_NAME')
hem_parameters = cls.get_parameters(f'{algorithm}_PARAMETERS')

try:
hem_parameters = json.loads(hem_parameters)
hem_parameters = {
p["key"]: p.get("value", "")
for p in hem_parameters if "key" in p
}
except:
hem_parameters = {}

hem_parameters.update(**param)

hard_example_mining = {
"method": hard_example_name,
"param": hem_parameters
}

return hard_example_mining
13 changes: 12 additions & 1 deletion lib/sedna/core/incremental_learning/incremental_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from copy import deepcopy

from sedna.common.file_ops import FileOps
Expand Down Expand Up @@ -50,6 +49,18 @@ def __init__(self, estimator, hard_example_mining: dict = None):
ClassType.HEM, hem
)(**hem_parameters)

@classmethod
def get_hem_algorithm(cls, **param):
"""
get the `algorithm` name and `param` of hard_example_mining from crd
:param param: update value in parameters of hard_example_mining
:return: dict, e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}}
"""
return cls.parameters.get_crd_algorithm(
algorithm="HEM",
**param
)

def train(self, train_data,
valid_data=None,
post_process=None,
Expand Down

0 comments on commit 99b6cb3

Please sign in to comment.