Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lifelong learning supporting non-structure #352

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ jobs:
verify-and-lint:
runs-on: ubuntu-latest
name: Verify codegen/vendor/licenses, do lint
strategy:
matrix:
python-version: [ "3.6", "3.7", "3.8", "3.9" ]
env:
GOPATH: ${{ github.workspace }}

Expand Down Expand Up @@ -50,6 +53,17 @@ jobs:
run: pycodestyle lib
working-directory: ${{ env.CODE_DIR }}

- name: Set python version
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Run python lint test
run: |
python -m pip install pylint
pylint lib
working-directory: ${{ env.CODE_DIR }}

build:
runs-on: ubuntu-latest
name: build gm and lc
Expand Down
38 changes: 38 additions & 0 deletions examples/lifelong_learning/RFNet/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from basemodel import val_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import of the relative path should be adjusted.

from utils.metrics import Evaluator
from tqdm import tqdm
from dataloaders import make_data_loader
from sedna.common.class_factory import ClassType, ClassFactory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the order of import


__all__ = ('accuracy')

@ClassFactory.register(ClassType.GENERAL)
def accuracy(y_true, y_pred, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common keyword. Use alias while register.

args = val_args()
_, _, test_loader, num_class = make_data_loader(args, test_data=y_true)
evaluator = Evaluator(num_class)

tbar = tqdm(test_loader, desc='\r')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless

for i, (sample, img_path) in enumerate(tbar):
if args.depth:
image, depth, target = sample['image'], sample['depth'], sample['label']
else:
image, target = sample['image'], sample['label']
if args.cuda:
image, target = image.cuda(), target.cuda()
if args.depth:
depth = depth.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check whether the device supports GPU.


target[target > evaluator.num_class-1] = 255
target = target.cpu().numpy()
# Add batch sample into evaluator
evaluator.add_batch(target, y_pred[i])

# Test during the training
# Acc = evaluator.Pixel_Accuracy()
CPA = evaluator.Pixel_Accuracy_Class()
mIoU = evaluator.Mean_Intersection_over_Union()
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()

print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
return CPA
Loading