# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition for TasksGroup."""
import enum
from typing import Optional, Union
from kfp.components import for_loop
from kfp.components import pipeline_context
from kfp.components import pipeline_channel
from kfp.components import pipeline_task
class TasksGroupType(str, enum.Enum):
"""Types of TasksGroup."""
PIPELINE = 'pipeline'
CONDITION = 'condition'
FOR_LOOP = 'for-loop'
EXIT_HANDLER = 'exit-handler'
class TasksGroup:
"""Represents a logical group of tasks and groups of TasksGroups.
This class is the base class for groups of tasks, such as tasks
sharing an exit handler, a condition branch, or a loop. This class
is not supposed to be used by pipeline authors. It is useful for
implementing a compiler.
Attributes:
group_type: The type of the TasksGroup.
tasks: A list of all PipelineTasks in this group.
groups: A list of TasksGroups in this group.
display_name: The optional user given name of the group.
dependencies: A list of tasks or groups this group depends on.
"""
def __init__(
self,
group_type: TasksGroupType,
name: Optional[str] = None,
):
"""Create a new instance of TasksGroup.
Args:
group_type: The type of the group.
name: Optional; the name of the group. Used as display name in UI.
"""
self.group_type = group_type
self.tasks = list()
self.groups = list()
self.display_name = name
self.dependencies = []
def __enter__(self):
if not pipeline_context.Pipeline.get_default_pipeline():
raise ValueError('Default pipeline not defined.')
self._make_name_unique()
pipeline_context.Pipeline.get_default_pipeline().push_tasks_group(self)
return self
def __exit__(self, *unused_args):
pipeline_context.Pipeline.get_default_pipeline().pop_tasks_group()
def _make_name_unique(self):
"""Generates a unique TasksGroup name in the pipeline."""
if not pipeline_context.Pipeline.get_default_pipeline():
raise ValueError('Default pipeline not defined.')
group_id = pipeline_context.Pipeline.get_default_pipeline(
).get_next_group_id()
self.name = f'{self.group_type}-{group_id}'
self.name = self.name.replace('_', '-')
def remove_task_recursive(self, task: pipeline_task.PipelineTask):
"""Removes a task from the group recursively."""
if self.tasks and task in self.tasks:
self.tasks.remove(task)
for group in self.groups or []:
group.remove_task_recursive(task)
[docs]class ExitHandler(TasksGroup):
"""Represents an exit handler that is invoked upon exiting a group of
tasks.
Example:
::
exit_task = ExitComponent(...)
with ExitHandler(exit_task):
task1 = MyComponent1(...)
task2 = MyComponent2(...)
Attributes:
exit_task: The exit handler task.
"""
def __init__(
self,
exit_task: pipeline_task.PipelineTask,
name: Optional[str] = None,
):
"""Initializes a Condition task group.
Args:
exit_task: An operator invoked at exiting a group of ops.
name: Optional; the name of the exit handler group.
Raises:
ValueError: Raised if the exit_task is invalid.
"""
super().__init__(group_type=TasksGroupType.EXIT_HANDLER, name=name)
if exit_task.dependent_tasks:
raise ValueError('exit_task cannot depend on any other tasks.')
# Removing exit_task form any group
pipeline_context.Pipeline.get_default_pipeline(
).remove_task_from_groups(exit_task)
# Set is_exit_handler since the compiler might be using this attribute.
exit_task.is_exit_handler = True
self.exit_task = exit_task
[docs]class Condition(TasksGroup):
"""Represents an condition group with a condition.
Example:
::
with Condition(param1=='pizza', '[param1 is pizza]'):
task1 = MyComponent1(...)
task2 = MyComponent2(...)
Attributes:
condition: The condition expression.
"""
def __init__(
self,
condition: pipeline_channel.ConditionOperator,
name: Optional[str] = None,
):
"""Initializes a conditional task group.
Args:
condition: The condition expression.
name: Optional; the name of the condition group.
"""
super().__init__(group_type=TasksGroupType.CONDITION, name=name)
self.condition = condition
[docs]class ParallelFor(TasksGroup):
"""Represents a parallel for loop over a static set of items.
Example:
::
with dsl.ParallelFor([{'a': 1, 'b': 10}, {'a': 2, 'b': 20}]) as item:
task1 = MyComponent(..., item.a)
task2 = MyComponent(..., item.b)
In this case :code:`task1` would be executed twice, once with case
:code:`args=['echo 1']` and once with case :code:`args=['echo 2']`::
Attributes:
loop_argument: The argument for each loop iteration.
items_is_pipeline_channel: Whether the loop items is PipelineChannel
instead of raw items.
"""
def __init__(
self,
items: Union[for_loop.ItemList, pipeline_channel.PipelineChannel],
name: Optional[str] = None,
):
"""Initializes a for loop task group.
Args:
items: The argument to loop over. It can be either a raw list or a
pipeline channel.
name: Optional; the name of the for loop group.
"""
super().__init__(group_type=TasksGroupType.FOR_LOOP, name=name)
if isinstance(items, pipeline_channel.PipelineChannel):
self.loop_argument = for_loop.LoopArgument.from_pipeline_channel(
items)
self.items_is_pipeline_channel = True
else:
self.loop_argument = for_loop.LoopArgument.from_raw_items(
raw_items=items,
name_code=pipeline_context.Pipeline.get_default_pipeline()
.get_next_group_id(),
)
self.items_is_pipeline_channel = False
def __enter__(self) -> for_loop.LoopArgument:
super().__enter__()
return self.loop_argument