Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 59 additions & 26 deletions gladier/state_models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from __future__ import annotations

import copy
import logging
import typing as t
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum

from pydantic import BaseModel, Extra

from .helpers import (
JSONObject,
eliminate_none_values,
ensure_json_path,
ensure_parameter_values,
)
from .helpers import JSONObject, ensure_json_path, ensure_parameter_values


logger = logging.getLogger(__name__)


class BaseState(ABC, BaseModel):
Expand Down Expand Up @@ -103,19 +101,6 @@ def get_flow_transition_states(self) -> t.List[str]:
return []


class BaseCompositeState(BaseState):
"""
A class which allows combining two or more base states
"""

state_type: str = "CompositeVirtualState"
state_name_prefix: str = ""

@abstractmethod
def get_flow_definition(self) -> JSONObject:
return super().get_flow_definition()


class StateWithNextOrEnd(BaseState):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -157,11 +142,13 @@ def next(
new_next_state = old_next
if self.next_state is None:
self.next_state = new_next_state
elif (
isinstance(self.next_state, StateWithNextOrEnd)
and new_next_state is not None
):
self.next_state.next(new_next_state)
elif new_next_state is not None:
try:
self.next_state.next(new_next_state)
except AttributeError:
logger.warn(
f"Unable to set next for state {self.next_state.valid_state_name}"
)
return self

def get_child_states(self) -> t.List[BaseState]:
Expand Down Expand Up @@ -268,3 +255,49 @@ def result_path_for_step(self) -> str:
)
result_path = ensure_json_path(result_path)
return result_path


class BaseCompositeState(BaseState):
state_type: str = "CompositeVirtualState"
state_name_prefix: str = ""

@abstractmethod
def construct_flow(self) -> BaseState:
raise ValueError(
f"construct_flow method not implemented on class {type(self).__name__}"
)

def get_flow_definition(self) -> JSONObject:
start_state = self.construct_flow()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If construct_flow is called here, does it need its own method?

flow_definition = start_state.get_flow_definition()
if self.state_name_prefix:
new_states: JSONObject = {}
flow_states: JSONObject = flow_definition["States"]
for state_name, state_def in flow_states.items():
new_states[self.state_name_prefix + state_name] = copy.deepcopy(
state_def
)
for new_state_name, new_state_def in new_states.items():
for state_def_key, state_def_val in new_state_def.items():
# Replace all references to old state names with prefixed value
if state_def_val in flow_states.keys():
new_states[new_state_name][state_def_key] = (
self.state_name_prefix + state_def_val
)
flow_definition["States"] = new_states

return flow_definition
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state renaming code might be nice to have in its own separate method for clarity if people are running super().get_flow_definition() a lot. I'm not sure if we'd ever want to call only the state naming functionality out-of-band, but it could possibly be useful. Maybe call it generate_state_names()?


@abstractmethod
def next(
self,
next_state: BaseState,
for_state: t.Optional[t.Union[str, BaseState]] = None,
) -> BaseState:
"""Set the next state for the composite state/sub-flow. The must be implemented
by each composite state. if for_state is None (the default) the next_state
should be set on all out-going links of the composite state. If for_state is
provided, the next should only be set on it for the state spcified by state_name
or by instance.
"""
...
21 changes: 18 additions & 3 deletions gladier/tests/test_builtin_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
from dataclasses import dataclass

import pytest
from pydantic import ValidationError

from gladier.tools import (
ActionState,
AndRule,
ChoiceOption,
ChoiceState,
ComparisonRule,
ExpressionEvalState,
PassState,
WaitState,
ExpressionEvalState,
)
from gladier.tools.builtins import ChoiceSkipState
from pydantic import ValidationError


@dataclass
Expand Down Expand Up @@ -137,3 +137,18 @@ def test_expression_eval():
state_param_keys = state_def["Parameters"].keys()
expected_keys = {k if k.endswith(".=") else k + ".=" for k in parameters.keys()}
assert state_param_keys == expected_keys


def test_choice_skip_state():
choice_state_name = "ChoiceSkip"
skip_state = ChoiceSkipState(
state_name=choice_state_name,
rule=ComparisonRule(Variable="$.input.should_i", BooleanEquals=True),
state_for_rule=PassState(state_name="DoThisIfIShould"),
)
skip_state.next(PassState(state_name="ChoiceSkipTarget"))
flow_def = skip_state.get_flow_definition()
states = flow_def["States"]
assert flow_def["StartAt"] == choice_state_name
for state_name in {choice_state_name, "DoThisIfIShould", "ChoiceSkipTarget"}:
assert state_name in states, states.keys()
2 changes: 2 additions & 0 deletions gladier/tools/builtins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ComparisonRule,
NotRule,
OrRule,
ChoiceSkipState,
)
from .expression_eval import ExpressionEvalState
from .fail import FailState
Expand All @@ -29,6 +30,7 @@
ComparisonRule,
NotRule,
OrRule,
ChoiceSkipState,
ExpressionEvalState,
PassState,
WaitState,
Expand Down
26 changes: 25 additions & 1 deletion gladier/tools/builtins/choice_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel, validator
from pydantic.fields import ModelField

from gladier import BaseState, JSONObject
from gladier import BaseCompositeState, BaseState, JSONObject, StateWithNextOrEnd
from gladier.tools.helpers import exclusive_validator_generator, validate_path_property


Expand Down Expand Up @@ -217,3 +217,27 @@ def get_child_states(self) -> t.List[BaseState]:
if self.default is not None
else []
)


class ChoiceSkipState(BaseCompositeState):
rule: ChoiceRule
state_for_rule: StateWithNextOrEnd

def construct_flow(self) -> ChoiceState:
if not hasattr(self, "_next_state"):
raise ValueError(
f"For state {self.state_name} next() must be set prior to "
"generating the flow definition"
)
choice_state = ChoiceState(state_name=self.state_name, default=self._next_state)
choice_state.choice(ChoiceOption(rule=self.rule, next=self.state_for_rule))
self.state_for_rule.next(self._next_state, replace_next=True)
return choice_state

def next(
self,
next_state: BaseState,
for_state: t.Optional[t.Union[str, BaseState]] = None,
) -> BaseState:
self._next_state = next_state
return self