Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[incremental learning] example:keep all results whether is hardExample or not, fixed the issue of using s3 to save model #107

Merged
merged 4 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from interface import Estimator

max_epochs = 1


def _load_txt_dataset(dataset_url):
# use original dataset url,
Expand All @@ -43,9 +41,9 @@ def main():
input_shape = Context.get_parameters("input_shape")
input_shape = tuple(int(shape) for shape in input_shape.split(','))

model = IncrementalLearning(estimator=Estimator)
return model.evaluate(valid_data, class_names=class_names,
input_shape=input_shape)
incremental_instance = IncrementalLearning(estimator=Estimator)
return incremental_instance.evaluate(valid_data, class_names=class_names,
input_shape=input_shape)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import json
import time
import warnings

Expand All @@ -24,11 +25,11 @@
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator


he_saved_url = Context.get_parameters("HE_SAVED_URL")
he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']

FileOps.clean_folder([he_saved_url], clean=False)
FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)


def draw_boxes(img, labels, scores, bboxes, class_names, colors):
Expand Down Expand Up @@ -59,11 +60,14 @@ def draw_boxes(img, labels, scores, bboxes, class_names, colors):
p2 = (int(bbox[2]), int(bbox[3]))
if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
continue
cv2.rectangle(img, p1[::-1], p2[::-1],
colors_code[labels[i]], box_thickness)
cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
text_thickness, line_type)
try:
cv2.rectangle(img, p1[::-1], p2[::-1],
colors_code[labels[i]], box_thickness)
cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
text_thickness, line_type)
except TypeError as err:
warnings.warn(f"Draw box fail: {err}")
return img


Expand All @@ -72,12 +76,13 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb):
img_rgb = np.array(img_rgb)
img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
colors = 'yellow,blue,green,red'
if not is_hard_example:
return

lables, scores, bbox_list_pred = infer_result
img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names,
colors)
cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
if is_hard_example:
cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img)


def mkdir(path):
Expand All @@ -100,10 +105,29 @@ def deal_infer_rsl(model_output):
def run():
camera_address = Context.get_parameters('video_url')

hard_example_name = Context.get_parameters('HEM_NAME', "IBT")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does a developer find the env 'HEM_NAME'? If this is a sedna system environment varibles, it should provide a function or class for the developer, instead of expose an environment name which the developer might find only in documents.

In sedna, there are two seperated parameter entries for deployer and developer, and the parameter should be subject to CRD or LIB.

for example:

# Select one of the next two lines (ignore the name of function or varibles): 
ibt=IL.get_hem_from_crd()  # subject to CRD
ibt=IL.IBT(0.9,0.9)        # subject to LIB

instance = IL(
    estimator=e,
    hem=ibt
)```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

hem_parameters = Context.get_parameters('HEM_PARAMETERS')

try:
hem_parameters = json.loads(hem_parameters)
hem_parameters = {
p["key"]: p.get("value", "")
for p in hem_parameters if "key" in p
}
except:
hem_parameters = {}

hard_example_mining = {
"method": hard_example_name,
"param": hem_parameters
}

input_shape_str = Context.get_parameters("input_shape")
input_shape = tuple(int(v) for v in input_shape_str.split(","))
# create little model object
model = IncrementalLearning(estimator=Estimator)
# create Incremental Learning instance
incremental_instance = IncrementalLearning(
estimator=Estimator, hard_example_mining=hard_example_mining
)
# use video streams for testing
camera = cv2.VideoCapture(camera_address)
fps = 10
Expand All @@ -123,7 +147,7 @@ def run():
img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
nframe += 1
warnings.warn(f"camera is open, current frame index is {nframe}")
results, _, is_hard_example = model.inference(
results, _, is_hard_example = incremental_instance.inference(
img_rgb, post_process=deal_infer_rsl, input_shape=input_shape)
output_deal(is_hard_example, results, nframe, img_rgb)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import six
import logging
from urllib.parse import urlparse

import cv2
import numpy as np
Expand All @@ -26,8 +27,21 @@
from yolo3_multiscale import Yolo3
from yolo3_multiscale import YoloConfig


os.environ['BACKEND_TYPE'] = 'TENSORFLOW'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
s3_url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")
if not (s3_url.startswith("http://") or s3_url.startswith("https://")):
_url = f"https://{s3_url}"
s3_url = urlparse(s3_url)
s3_use_ssl = s3_url.scheme == 'https' if s3_url.scheme else True

os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("ACCESS_KEY_ID", "")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("SECRET_ACCESS_KEY", "")
os.environ["S3_ENDPOINT"] = s3_url.netloc

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not env ACCESS_KEY_ID, os.getenv("ACCESS_KEY_ID") will be None,
os.environ["AWS_ACCESS_KEY_ID"] = None will report error
suggest os.environ.setdefault()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.environ["S3_USE_HTTPS"] = "1" if s3_use_ssl else "0"
LOG = logging.getLogger(__name__)
Copy link

@JimmyYang20 JimmyYang20 Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constant TENSORFLOW in constant.py, e.g., constant.backend_type.TENSORFLOW='TENSORFLOW'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do developers need to know constants in lib?

flags = tf.flags.FLAGS


def preprocess(image, input_shape):
Expand Down Expand Up @@ -89,7 +103,7 @@ def train(self, train_data, valid_data=None, **kwargs):

data_gen = DataGen(yolo_config, train_data.x)

max_epochs = int(kwargs.get("max_epochs", "1"))
max_epochs = int(kwargs.get("epochs", flags.max_epochs))
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True

Expand Down
15 changes: 7 additions & 8 deletions examples/incremental_learning/helmet_detection/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def _load_txt_dataset(dataset_url):

# use original dataset url,
# see https://github.com/kubeedge/sedna/issues/35
original_dataset_url = Context.get_parameters('original_dataset_url')
Expand Down Expand Up @@ -93,13 +92,13 @@ def main():
tf.flags.DEFINE_string('result_url', default=None,
help='result url for training')

model = IncrementalLearning(estimator=Estimator)
return model.train(train_data=train_data, epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)
incremental_instance = IncrementalLearning(estimator=Estimator)
return incremental_instance.train(train_data=train_data, epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)


if __name__ == '__main__':
Expand Down
140 changes: 1 addition & 139 deletions lib/sedna/algorithms/hard_example_mining/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,142 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hard Example Mining Algorithms"""
import abc
import math

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')


class BaseFilter(metaclass=abc.ABCMeta):
"""The base class to define unified interface."""

def __call__(self, infer_result=None):
"""predict function, and it must be implemented by
different methods class.

:param infer_result: prediction result
:return: `True` means hard sample, `False` means not a hard sample.
"""
raise NotImplementedError

@classmethod
def data_check(cls, data):
"""Check the data in [0,1]."""
return 0 <= float(data) <= 1


@ClassFactory.register(ClassType.HEM, alias="Threshold")
class ThresholdFilter(BaseFilter, abc.ABC):
def __init__(self, threshold=0.5, **kwargs):
self.threshold = float(threshold)

def __call__(self, infer_result=None):
"""
:param infer_result: [N, 6], (x0, y0, x1, y1, score, class)
:return: `True` means hard sample, `False` means not a hard sample.
"""
# if invalid input, return False
if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
return False

image_score = 0

for bbox in infer_result:
image_score += bbox[4]

average_score = image_score / (len(infer_result) or 1)
return average_score < self.threshold


@ClassFactory.register(ClassType.HEM, alias="CrossEntropy")
class CrossEntropyFilter(BaseFilter, abc.ABC):
""" Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_cross_entropy: threshold_cross_entropy to filter img,
whose hard coefficient is less than
threshold_cross_entropy. And its default value is
threshold_cross_entropy=0.5
"""

def __init__(self, threshold_cross_entropy=0.5, **kwargs):
self.threshold_cross_entropy = float(threshold_cross_entropy)

def __call__(self, infer_result=None):
"""judge the img is hard sample or not.

:param infer_result:
prediction classes list,
such as [class1-score, class2-score, class2-score,....],
where class-score is the score corresponding to the class,
class-score value is in [0,1], who will be ignored if its value
not in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not infer_result:
# if invalid input, return False
return False

log_sum = 0.0
data_check_list = [class_probability for class_probability
in infer_result
if self.data_check(class_probability)]

if len(data_check_list) != len(infer_result):
return False

for class_data in data_check_list:
log_sum += class_data * math.log(class_data)
confidence_score = 1 + 1.0 * log_sum / math.log(
len(infer_result))
return confidence_score < self.threshold_cross_entropy


@ClassFactory.register(ClassType.HEM, alias="IBT")
class IBTFilter(BaseFilter, abc.ABC):
"""Implement the hard samples discovery methods named IBT
(image-box-thresholds).

:param threshold_img: threshold_img to filter img, whose hard coefficient
is less than threshold_img.
:param threshold_box: threshold_box to calculate hard coefficient, formula
is hard coefficient = number(prediction_boxes less than
threshold_box)/number(prediction_boxes)
"""

def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs):
self.threshold_box = float(threshold_box)
self.threshold_img = float(threshold_img)

def __call__(self, infer_result=None):
"""Judge the img is hard sample or not.

:param infer_result:
prediction boxes list,
such as [bbox1, bbox2, bbox3,....],
where bbox = [xmin, ymin, xmax, ymax, score, label]
score should be in [0,1], who will be ignored if its value not
in [0,1].
:return: `True` means a hard sample, `False` means not a hard sample.
"""

if not (infer_result
and all(map(lambda x: len(x) > 4, infer_result))):
# if invalid input, return False
return False

data_check_list = [bbox[4] for bbox in infer_result
if self.data_check(bbox[4])]
if len(data_check_list) != len(infer_result):
return False

confidence_score_list = [
float(box_score) for box_score in data_check_list
if float(box_score) <= self.threshold_box]
return (len(confidence_score_list) / len(infer_result)
>= (1 - self.threshold_img))
from .hard_example_mining import *
Loading