cpdbench.task.MetricExecutionTask
1import inspect 2from collections.abc import Iterable 3 4from cpdbench.exception.ValidationException import InputValidationException, MetricValidationException 5from cpdbench.task.Task import Task 6from cpdbench.utils.Utils import get_name_of_function 7 8 9class MetricExecutionTask(Task): 10 def __init__(self, function, counter, param_dict=None): 11 super().__init__(function, counter, param_dict) 12 13 def execute(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 14 return self._function(indexes, scores, ground_truths) 15 16 def validate_task(self) -> None: 17 # Check number of args 18 full_arg_spec = inspect.getfullargspec(self._function) 19 if len(full_arg_spec.args) != 3: 20 # Wrong number of arguments 21 function_name = get_name_of_function(self._function) 22 raise InputValidationException("The number of arguments for the metric task '{0}' is {1} but should be " 23 "3: (indexes, scores, ground_truth)" 24 .format(function_name, len(full_arg_spec.args))) 25 26 def validate_input(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 27 try: 28 res = self._function(indexes, scores, ground_truths) 29 except Exception as e: 30 raise MetricValidationException(f"The validation of {get_name_of_function(self._function)} failed.") \ 31 from e 32 else: 33 return res 34 35 def get_task_name(self) -> str: 36 return f"metric:{self._task_name}"
10class MetricExecutionTask(Task): 11 def __init__(self, function, counter, param_dict=None): 12 super().__init__(function, counter, param_dict) 13 14 def execute(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 15 return self._function(indexes, scores, ground_truths) 16 17 def validate_task(self) -> None: 18 # Check number of args 19 full_arg_spec = inspect.getfullargspec(self._function) 20 if len(full_arg_spec.args) != 3: 21 # Wrong number of arguments 22 function_name = get_name_of_function(self._function) 23 raise InputValidationException("The number of arguments for the metric task '{0}' is {1} but should be " 24 "3: (indexes, scores, ground_truth)" 25 .format(function_name, len(full_arg_spec.args))) 26 27 def validate_input(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 28 try: 29 res = self._function(indexes, scores, ground_truths) 30 except Exception as e: 31 raise MetricValidationException(f"The validation of {get_name_of_function(self._function)} failed.") \ 32 from e 33 else: 34 return res 35 36 def get_task_name(self) -> str: 37 return f"metric:{self._task_name}"
Abstract class for a Task object which is a work package to be executed by the framework. A task has a name, can be validated, and executed, and can have some parameters.
MetricExecutionTask(function, counter, param_dict=None)
11 def __init__(self, function, counter, param_dict=None): 12 super().__init__(function, counter, param_dict)
General constructor for all task objects.
Parameters
- function: The function handle to be executed as task content
- counter: A number which is appended to the task name. Useful if multiple tasks with the same name exist.
- param_dict: An optional parameter dictionary for the task
def
execute( self, indexes: collections.abc.Iterable, scores: collections.abc.Iterable, ground_truths: collections.abc.Iterable) -> float:
14 def execute(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 15 return self._function(indexes, scores, ground_truths)
Executes the task. Can take an arbitrary number of arguments and can produce any result.
def
validate_task(self) -> None:
17 def validate_task(self) -> None: 18 # Check number of args 19 full_arg_spec = inspect.getfullargspec(self._function) 20 if len(full_arg_spec.args) != 3: 21 # Wrong number of arguments 22 function_name = get_name_of_function(self._function) 23 raise InputValidationException("The number of arguments for the metric task '{0}' is {1} but should be " 24 "3: (indexes, scores, ground_truth)" 25 .format(function_name, len(full_arg_spec.args)))
Validates the task statically by checking task details before running it. Throws an exception if the validation fails.
def
validate_input( self, indexes: collections.abc.Iterable, scores: collections.abc.Iterable, ground_truths: collections.abc.Iterable) -> float:
27 def validate_input(self, indexes: Iterable, scores: Iterable, ground_truths: Iterable) -> float: 28 try: 29 res = self._function(indexes, scores, ground_truths) 30 except Exception as e: 31 raise MetricValidationException(f"The validation of {get_name_of_function(self._function)} failed.") \ 32 from e 33 else: 34 return res
Validates the task in combination with some input arguments. Throws an exception if the validation fails.
def
get_task_name(self) -> str:
Returns a descriptive name for the task.
Returns
task name as string