Skip to content

Commit d79cf6b

Browse files
committed
lfx
1 parent 645c836 commit d79cf6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+8326
-8323
lines changed

core/testcasecontroller/algorithm/paradigm/lifelong_learning/lifelong_learning.py

Lines changed: 441 additions & 441 deletions
Large diffs are not rendered by default.
Lines changed: 122 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,122 @@
1-
# Copyright 2022 The KubeEdge Authors.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
"""Test Case"""
16-
17-
import os
18-
import uuid
19-
20-
from core.common.constant import SystemMetricType
21-
from core.testcasecontroller.metrics import get_metric_func
22-
23-
24-
class TestCase:
25-
"""
26-
Test Case:
27-
Consists of a test environment and a test algorithm
28-
29-
Parameters
30-
----------
31-
test_env : instance
32-
The test environment of benchmarking,
33-
including dataset, Post-processing algorithms like metric computation.
34-
algorithm : instance
35-
Typical distributed-synergy AI algorithm paradigm.
36-
"""
37-
38-
def __init__(self, test_env, algorithm):
39-
# pylint: disable=C0103
40-
self.id = uuid.uuid1()
41-
self.test_env = test_env
42-
self.algorithm = algorithm
43-
self.output_dir = None
44-
45-
def _get_output_dir(self, workspace):
46-
output_dir = os.path.join(workspace, self.algorithm.name)
47-
flag = True
48-
while flag:
49-
output_dir = os.path.join(workspace, self.algorithm.name, str(self.id))
50-
if not os.path.exists(output_dir):
51-
flag = False
52-
return output_dir
53-
54-
def run(self, workspace):
55-
"""
56-
Run the test case
57-
58-
Returns
59-
-------
60-
test result: dict
61-
e.g.: {"f1_score": 0.89}
62-
"""
63-
64-
try:
65-
dataset = self.test_env.dataset
66-
test_env_config = {}
67-
# pylint: disable=C0103
68-
for k, v in self.test_env.__dict__.items():
69-
test_env_config[k] = v
70-
71-
self.output_dir = self._get_output_dir(workspace)
72-
paradigm = self.algorithm.paradigm(workspace=self.output_dir,
73-
**test_env_config)
74-
res, system_metric_info = paradigm.run()
75-
test_result = self.compute_metrics(res, dataset, **system_metric_info)
76-
77-
except Exception as err:
78-
paradigm_type = self.algorithm.paradigm_type
79-
raise RuntimeError(
80-
f"(paradigm={paradigm_type}) pipeline runs failed, error: {err}") from err
81-
return test_result
82-
83-
def compute_metrics(self, paradigm_result, dataset, **kwargs):
84-
"""
85-
Compute metrics of paradigm result
86-
87-
Parameters
88-
----------
89-
paradigm_result: numpy.ndarray
90-
dataset: instance
91-
kwargs: dict
92-
information needed to compute system metrics.
93-
94-
Returns
95-
-------
96-
dict
97-
e.g.: {"f1_score": 0.89}
98-
"""
99-
100-
metric_funcs = {}
101-
for metric_dict in self.test_env.metrics:
102-
metric_name, metric_func = get_metric_func(metric_dict=metric_dict)
103-
if callable(metric_func):
104-
metric_funcs.update({metric_name: metric_func})
105-
106-
test_dataset_file = dataset.test_url
107-
test_dataset = dataset.load_data(test_dataset_file,
108-
data_type="eval overall",
109-
label=dataset.label)
110-
111-
metric_res = {}
112-
system_metric_types = [e.value for e in SystemMetricType.__members__.values()]
113-
for metric_name, metric_func in metric_funcs.items():
114-
if metric_name in system_metric_types:
115-
metric_res[metric_name] = metric_func(kwargs)
116-
else:
117-
metric_res[metric_name] = metric_func(test_dataset.y, paradigm_result)
118-
119-
return metric_res
1+
# Copyright 2022 The KubeEdge Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test Case"""
16+
17+
import os
18+
import uuid
19+
20+
from core.common.constant import SystemMetricType
21+
from core.testcasecontroller.metrics import get_metric_func
22+
23+
24+
class TestCase:
25+
"""
26+
Test Case:
27+
Consists of a test environment and a test algorithm
28+
29+
Parameters
30+
----------
31+
test_env : instance
32+
The test environment of benchmarking,
33+
including dataset, Post-processing algorithms like metric computation.
34+
algorithm : instance
35+
Typical distributed-synergy AI algorithm paradigm.
36+
"""
37+
38+
def __init__(self, test_env, algorithm):
39+
# pylint: disable=C0103
40+
self.id = uuid.uuid1()
41+
self.test_env = test_env
42+
self.algorithm = algorithm
43+
self.output_dir = None
44+
45+
def _get_output_dir(self, workspace):
46+
output_dir = os.path.join(workspace, self.algorithm.name)
47+
flag = True
48+
while flag:
49+
output_dir = os.path.join(workspace, self.algorithm.name, str(self.id))
50+
if not os.path.exists(output_dir):
51+
flag = False
52+
return output_dir
53+
54+
def run(self, workspace):
55+
"""
56+
Run the test case
57+
58+
Returns
59+
-------
60+
test result: dict
61+
e.g.: {"f1_score": 0.89}
62+
"""
63+
64+
try:
65+
dataset = self.test_env.dataset
66+
test_env_config = {}
67+
# pylint: disable=C0103
68+
for k, v in self.test_env.__dict__.items():
69+
test_env_config[k] = v
70+
71+
self.output_dir = self._get_output_dir(workspace)
72+
paradigm = self.algorithm.paradigm(workspace=self.output_dir,
73+
**test_env_config)
74+
res, system_metric_info = paradigm.run()
75+
test_result = self.compute_metrics(res, dataset, **system_metric_info)
76+
77+
except Exception as err:
78+
paradigm_type = self.algorithm.paradigm_type
79+
raise RuntimeError(
80+
f"(paradigm={paradigm_type}) pipeline runs failed, error: {err}") from err
81+
return test_result
82+
83+
def compute_metrics(self, paradigm_result, dataset, **kwargs):
84+
"""
85+
Compute metrics of paradigm result
86+
87+
Parameters
88+
----------
89+
paradigm_result: numpy.ndarray
90+
dataset: instance
91+
kwargs: dict
92+
information needed to compute system metrics.
93+
94+
Returns
95+
-------
96+
dict
97+
e.g.: {"f1_score": 0.89}
98+
"""
99+
100+
metric_funcs = {}
101+
for metric_dict in self.test_env.metrics:
102+
metric_name, metric_func = get_metric_func(metric_dict=metric_dict)
103+
if callable(metric_func):
104+
metric_funcs.update({metric_name: metric_func})
105+
106+
test_dataset_file = dataset.test_url
107+
test_dataset = dataset.load_data(test_dataset_file,
108+
data_type="eval overall",
109+
label=dataset.label)
110+
111+
metric_res = {}
112+
system_metric_types = [e.value for e in SystemMetricType.__members__.values()]
113+
for metric_name, metric_func in metric_funcs.items():
114+
if metric_name in system_metric_types:
115+
metric_res[metric_name] = metric_func(kwargs)
116+
else:
117+
if paradigm_result is None:
118+
continue
119+
metric_res[metric_name] = metric_func(test_dataset.y, paradigm_result)
120+
if paradigm_result is None:
121+
metric_res["accuracy"] = metric_res["task_avg_acc"]
122+
return metric_res

0 commit comments

Comments
 (0)