Bläddra i källkod

feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
takatost 1 år sedan
förälder
incheckning
dabfd74622
100 ändrade filer med 8782 tillägg och 4858 borttagningar
  1. 1 1
      api/configs/packaging/__init__.py
  2. 52 117
      api/core/app/apps/advanced_chat/app_generator.py
  3. 164 126
      api/core/app/apps/advanced_chat/app_runner.py
  4. 229 433
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  5. 0 203
      api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
  6. 1 1
      api/core/app/apps/base_app_generate_response_converter.py
  7. 3 3
      api/core/app/apps/base_app_runner.py
  8. 34 31
      api/core/app/apps/workflow/app_generator.py
  9. 72 73
      api/core/app/apps/workflow/app_runner.py
  10. 169 264
      api/core/app/apps/workflow/generate_task_pipeline.py
  11. 0 200
      api/core/app/apps/workflow/workflow_event_trigger_callback.py
  12. 379 0
      api/core/app/apps/workflow_app_runner.py
  13. 187 97
      api/core/app/apps/workflow_logging_callback.py
  14. 165 17
      api/core/app/entities/queue_entities.py
  15. 78 78
      api/core/app/entities/task_entities.py
  16. 8 6
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  17. 10 26
      api/core/app/task_pipeline/message_cycle_manage.py
  18. 352 313
      api/core/app/task_pipeline/workflow_cycle_manage.py
  19. 0 16
      api/core/app/task_pipeline/workflow_cycle_state_manager.py
  20. 0 290
      api/core/app/task_pipeline/workflow_iteration_cycle_manage.py
  21. 33 0
      api/core/model_runtime/entities/llm_entities.py
  22. 4 4
      api/core/moderation/output_moderation.py
  23. 3 1
      api/core/tools/tool/workflow_tool.py
  24. 2 0
      api/core/tools/tool_engine.py
  25. 1 2
      api/core/tools/tool_manager.py
  26. 1 0
      api/core/tools/utils/message_transformer.py
  27. 6 107
      api/core/workflow/callbacks/base_workflow_callback.py
  28. 1 1
      api/core/workflow/entities/base_node_data_entities.py
  29. 30 4
      api/core/workflow/entities/node_entities.py
  30. 54 32
      api/core/workflow/entities/variable_pool.py
  31. 2 3
      api/core/workflow/entities/workflow_entities.py
  32. 4 6
      api/core/workflow/errors.py
  33. 0 0
      api/core/workflow/graph_engine/__init__.py
  34. 0 0
      api/core/workflow/graph_engine/condition_handlers/__init__.py
  35. 31 0
      api/core/workflow/graph_engine/condition_handlers/base_handler.py
  36. 28 0
      api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py
  37. 32 0
      api/core/workflow/graph_engine/condition_handlers/condition_handler.py
  38. 35 0
      api/core/workflow/graph_engine/condition_handlers/condition_manager.py
  39. 0 0
      api/core/workflow/graph_engine/entities/__init__.py
  40. 163 0
      api/core/workflow/graph_engine/entities/event.py
  41. 692 0
      api/core/workflow/graph_engine/entities/graph.py
  42. 21 0
      api/core/workflow/graph_engine/entities/graph_init_params.py
  43. 27 0
      api/core/workflow/graph_engine/entities/graph_runtime_state.py
  44. 13 0
      api/core/workflow/graph_engine/entities/next_graph_node.py
  45. 21 0
      api/core/workflow/graph_engine/entities/run_condition.py
  46. 111 0
      api/core/workflow/graph_engine/entities/runtime_route_state.py
  47. 716 0
      api/core/workflow/graph_engine/graph_engine.py
  48. 19 72
      api/core/workflow/nodes/answer/answer_node.py
  49. 169 0
      api/core/workflow/nodes/answer/answer_stream_generate_router.py
  50. 221 0
      api/core/workflow/nodes/answer/answer_stream_processor.py
  51. 71 0
      api/core/workflow/nodes/answer/base_stream_processor.py
  52. 35 7
      api/core/workflow/nodes/answer/entities.py
  53. 61 135
      api/core/workflow/nodes/base_node.py
  54. 15 9
      api/core/workflow/nodes/code/code_node.py
  55. 12 50
      api/core/workflow/nodes/end/end_node.py
  56. 148 0
      api/core/workflow/nodes/end/end_stream_generate_router.py
  57. 191 0
      api/core/workflow/nodes/end/end_stream_processor.py
  58. 16 0
      api/core/workflow/nodes/end/entities.py
  59. 20 0
      api/core/workflow/nodes/event.py
  60. 19 9
      api/core/workflow/nodes/http_request/http_request_node.py
  61. 1 14
      api/core/workflow/nodes/if_else/entities.py
  62. 25 380
      api/core/workflow/nodes/if_else/if_else_node.py
  63. 8 1
      api/core/workflow/nodes/iteration/entities.py
  64. 338 91
      api/core/workflow/nodes/iteration/iteration_node.py
  65. 39 0
      api/core/workflow/nodes/iteration/iteration_start_node.py
  66. 22 10
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  67. 112 54
      api/core/workflow/nodes/llm/llm_node.py
  68. 22 8
      api/core/workflow/nodes/loop/loop_node.py
  69. 37 0
      api/core/workflow/nodes/node_mapping.py
  70. 24 11
      api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
  71. 40 10
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  72. 15 7
      api/core/workflow/nodes/start/start_node.py
  73. 17 9
      api/core/workflow/nodes/template_transform/template_transform_node.py
  74. 18 5
      api/core/workflow/nodes/tool/tool_node.py
  75. 18 7
      api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
  76. 6 7
      api/core/workflow/nodes/variable_assigner/node.py
  77. 0 0
      api/core/workflow/utils/condition/__init__.py
  78. 17 0
      api/core/workflow/utils/condition/entities.py
  79. 383 0
      api/core/workflow/utils/condition/processor.py
  80. 0 1005
      api/core/workflow/workflow_engine_manager.py
  81. 314 0
      api/core/workflow/workflow_entry.py
  82. 35 0
      api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py
  83. 3 0
      api/models/workflow.py
  84. 2 1
      api/services/app_dsl_service.py
  85. 3 4
      api/services/app_generate_service.py
  86. 69 91
      api/services/workflow_service.py
  87. 175 144
      api/tests/integration_tests/workflow/nodes/test_code.py
  88. 96 60
      api/tests/integration_tests/workflow/nodes/test_http.py
  89. 82 51
      api/tests/integration_tests/workflow/nodes/test_llm.py
  90. 87 105
      api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
  91. 61 23
      api/tests/integration_tests/workflow/nodes/test_template_transform.py
  92. 61 23
      api/tests/integration_tests/workflow/nodes/test_tool.py
  93. 17 0
      api/tests/unit_tests/conftest.py
  94. 0 0
      api/tests/unit_tests/core/workflow/graph_engine/__init__.py
  95. 791 0
      api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
  96. 505 0
      api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
  97. 0 0
      api/tests/unit_tests/core/workflow/nodes/answer/__init__.py
  98. 82 0
      api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
  99. 109 0
      api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py
  100. 216 0
      api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py

+ 1 - 1
api/configs/packaging/__init__.py

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
 
     CURRENT_VERSION: str = Field(
         description="Dify version",
-        default="0.7.3",
+        default="0.8.0",
     )
 
     COMMIT_SHA: str = Field(

+ 52 - 117
api/core/app/apps/advanced_chat/app_generator.py

@@ -4,12 +4,10 @@ import os
 import threading
 import uuid
 from collections.abc import Generator
-from typing import Literal, Union, overload
+from typing import Any, Literal, Optional, Union, overload
 
 from flask import Flask, current_app
 from pydantic import ValidationError
-from sqlalchemy import select
-from sqlalchemy.orm import Session
 
 import contexts
 from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
 from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
 from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
 from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
-from core.app.entities.app_invoke_entities import (
-    AdvancedChatAppGenerateEntity,
-    InvokeFrom,
-)
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
 from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
 from core.file.message_file_parser import MessageFileParser
 from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
 from core.ops.ops_trace_manager import TraceQueueManager
-from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
 from extensions.ext_database import db
 from models.account import Account
 from models.model import App, Conversation, EndUser, Message
-from models.workflow import ConversationVariable, Workflow
+from models.workflow import Workflow
 
 logger = logging.getLogger(__name__)
 
@@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     ) -> dict: ...
 
     def generate(
-        self, app_model: App,
-        workflow: Workflow,
-        user: Union[Account, EndUser],
-        args: dict,
-        invoke_from: InvokeFrom,
-        stream: bool = True,
-    ):
+            self,
+            app_model: App,
+            workflow: Workflow,
+            user: Union[Account, EndUser],
+            args: dict,
+            invoke_from: InvokeFrom,
+            stream: bool = True,
+    )  -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
                                   node_id: str,
                                   user: Account,
                                   args: dict,
-                                  stream: bool = True):
+                                  stream: bool = True) \
+            -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         if args.get('inputs') is None:
             raise ValueError('inputs is required')
 
-        extras = {
-            "auto_generate_conversation_name": False
-        }
-
-        # get conversation
-        conversation = None
-        conversation_id = args.get('conversation_id')
-        if conversation_id:
-            conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
-
         # convert to app config
         app_config = AdvancedChatAppConfigManager.get_app_config(
             app_model=app_model,
@@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
         application_generate_entity = AdvancedChatAppGenerateEntity(
             task_id=str(uuid.uuid4()),
             app_config=app_config,
-            conversation_id=conversation.id if conversation else None,
+            conversation_id=None,
             inputs={},
             query='',
             files=[],
             user_id=user.id,
             stream=stream,
             invoke_from=InvokeFrom.DEBUGGER,
-            extras=extras,
+            extras={
+                "auto_generate_conversation_name": False
+            },
             single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
                 node_id=node_id,
                 inputs=args['inputs']
@@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             user=user,
             invoke_from=InvokeFrom.DEBUGGER,
             application_generate_entity=application_generate_entity,
-            conversation=conversation,
+            conversation=None,
             stream=stream
         )
 
     def _generate(self, *,
-                 workflow: Workflow,
-                 user: Union[Account, EndUser],
-                 invoke_from: InvokeFrom,
-                 application_generate_entity: AdvancedChatAppGenerateEntity,
-                 conversation: Conversation | None = None,
-                 stream: bool = True):
+                  workflow: Workflow,
+                  user: Union[Account, EndUser],
+                  invoke_from: InvokeFrom,
+                  application_generate_entity: AdvancedChatAppGenerateEntity,
+                  conversation: Optional[Conversation] = None,
+                  stream: bool = True) \
+            -> dict[str, Any] | Generator[str, Any, None]:
+        """
+        Generate App response.
+
+        :param workflow: Workflow
+        :param user: account or end user
+        :param invoke_from: invoke from source
+        :param application_generate_entity: application generate entity
+        :param conversation: conversation
+        :param stream: is stream
+        """
         is_first_conversation = False
         if not conversation:
             is_first_conversation = True
@@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             # update conversation features
             conversation.override_model_configs = workflow.features
             db.session.commit()
-            # db.session.refresh(conversation)
+            db.session.refresh(conversation)
 
         # init queue manager
         queue_manager = MessageBasedAppQueueManager(
@@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             message_id=message.id
         )
 
-        # Init conversation variables
-        stmt = select(ConversationVariable).where(
-            ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
-        )
-        with Session(db.engine) as session:
-            conversation_variables = session.scalars(stmt).all()
-            if not conversation_variables:
-                # Create conversation variables if they don't exist.
-                conversation_variables = [
-                    ConversationVariable.from_variable(
-                        app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
-                    )
-                    for variable in workflow.conversation_variables
-                ]
-                session.add_all(conversation_variables)
-            # Convert database entities to variables.
-            conversation_variables = [item.to_variable() for item in conversation_variables]
-
-            session.commit()
-
-            # Increment dialogue count.
-            conversation.dialogue_count += 1
-
-            conversation_id = conversation.id
-            conversation_dialogue_count = conversation.dialogue_count
-            db.session.commit()
-            db.session.refresh(conversation)
-
-        inputs = application_generate_entity.inputs
-        query = application_generate_entity.query
-        files = application_generate_entity.files
-
-        user_id = None
-        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
-            end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
-            if end_user:
-                user_id = end_user.session_id
-        else:
-            user_id = application_generate_entity.user_id
-
-        # Create a variable pool.
-        system_inputs = {
-            SystemVariableKey.QUERY: query,
-            SystemVariableKey.FILES: files,
-            SystemVariableKey.CONVERSATION_ID: conversation_id,
-            SystemVariableKey.USER_ID: user_id,
-            SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
-        }
-        variable_pool = VariablePool(
-            system_variables=system_inputs,
-            user_inputs=inputs,
-            environment_variables=workflow.environment_variables,
-            conversation_variables=conversation_variables,
-        )
-        contexts.workflow_variable_pool.set(variable_pool)
-
         # new thread
         worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
+            'flask_app': current_app._get_current_object(), # type: ignore
             'application_generate_entity': application_generate_entity,
             'queue_manager': queue_manager,
+            'conversation_id': conversation.id,
             'message_id': message.id,
             'context': contextvars.copy_context(),
         })
@@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
     def _generate_worker(self, flask_app: Flask,
                          application_generate_entity: AdvancedChatAppGenerateEntity,
                          queue_manager: AppQueueManager,
+                         conversation_id: str,
                          message_id: str,
                          context: contextvars.Context) -> None:
         """
@@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
             var.set(val)
         with flask_app.app_context():
             try:
-                runner = AdvancedChatAppRunner()
-                if application_generate_entity.single_iteration_run:
-                    single_iteration_run = application_generate_entity.single_iteration_run
-                    runner.single_iteration_run(
-                        app_id=application_generate_entity.app_config.app_id,
-                        workflow_id=application_generate_entity.app_config.workflow_id,
-                        queue_manager=queue_manager,
-                        inputs=single_iteration_run.inputs,
-                        node_id=single_iteration_run.node_id,
-                        user_id=application_generate_entity.user_id
-                    )
-                else:
-                    # get message
-                    message = self._get_message(message_id)
-
-                    # chatbot app
-                    runner = AdvancedChatAppRunner()
-                    runner.run(
-                        application_generate_entity=application_generate_entity,
-                        queue_manager=queue_manager,
-                        message=message
-                    )
+                # get conversation and message
+                conversation = self._get_conversation(conversation_id)
+                message = self._get_message(message_id)
+
+                # chatbot app
+                runner = AdvancedChatAppRunner(
+                    application_generate_entity=application_generate_entity,
+                    queue_manager=queue_manager,
+                    conversation=conversation,
+                    message=message
+                )
+
+                runner.run()
             except GenerateTaskStoppedException:
                 pass
             except InvokeAuthorizationError:

+ 164 - 126
api/core/app/apps/advanced_chat/app_runner.py

@@ -1,49 +1,67 @@
 import logging
 import os
-import time
 from collections.abc import Mapping
-from typing import Any, Optional, cast
+from typing import Any, cast
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
-from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
-from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
-from core.app.apps.base_app_runner import AppRunner
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
 from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
 from core.app.entities.app_invoke_entities import (
     AdvancedChatAppGenerateEntity,
     InvokeFrom,
 )
-from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
+from core.app.entities.queue_entities import (
+    QueueAnnotationReplyEvent,
+    QueueStopEvent,
+    QueueTextChunkEvent,
+)
 from core.moderation.base import ModerationException
 from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.nodes.base_node import UserFrom
-from core.workflow.workflow_engine_manager import WorkflowEngineManager
+from core.workflow.entities.node_entities import UserFrom
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariableKey
+from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
-from models import App, Message, Workflow
+from models.model import App, Conversation, EndUser, Message
+from models.workflow import ConversationVariable, WorkflowType
 
 logger = logging.getLogger(__name__)
 
 
-class AdvancedChatAppRunner(AppRunner):
+class AdvancedChatAppRunner(WorkflowBasedAppRunner):
     """
     AdvancedChat Application Runner
     """
 
-    def run(
-        self,
-        application_generate_entity: AdvancedChatAppGenerateEntity,
-        queue_manager: AppQueueManager,
-        message: Message,
+    def __init__(
+            self,
+            application_generate_entity: AdvancedChatAppGenerateEntity,
+            queue_manager: AppQueueManager,
+            conversation: Conversation,
+            message: Message
     ) -> None:
         """
-        Run application
         :param application_generate_entity: application generate entity
         :param queue_manager: application queue manager
         :param conversation: conversation
         :param message: message
+        """
+        super().__init__(queue_manager)
+
+        self.application_generate_entity = application_generate_entity
+        self.conversation = conversation
+        self.message = message
+
+    def run(self) -> None:
+        """
+        Run application
         :return:
         """
-        app_config = application_generate_entity.app_config
+        app_config = self.application_generate_entity.app_config
         app_config = cast(AdvancedChatAppConfig, app_config)
 
         app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
@@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
         if not workflow:
             raise ValueError('Workflow not initialized')
 
-        inputs = application_generate_entity.inputs
-        query = application_generate_entity.query
+        user_id = None
+        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+            end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
+            if end_user:
+                user_id = end_user.session_id
+        else:
+            user_id = self.application_generate_entity.user_id
 
-        # moderation
-        if self.handle_input_moderation(
-            queue_manager=queue_manager,
-            app_record=app_record,
-            app_generate_entity=application_generate_entity,
-            inputs=inputs,
-            query=query,
-            message_id=message.id,
-        ):
-            return
+        workflow_callbacks: list[WorkflowCallback] = []
+        if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
+            workflow_callbacks.append(WorkflowLoggingCallback())
 
-        # annotation reply
-        if self.handle_annotation_reply(
-            app_record=app_record,
-            message=message,
-            query=query,
-            queue_manager=queue_manager,
-            app_generate_entity=application_generate_entity,
-        ):
-            return
+        if self.application_generate_entity.single_iteration_run:
+            # if only single iteration run is requested
+            graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
+                workflow=workflow,
+                node_id=self.application_generate_entity.single_iteration_run.node_id,
+                user_inputs=self.application_generate_entity.single_iteration_run.inputs
+            )
+        else:
+            inputs = self.application_generate_entity.inputs
+            query = self.application_generate_entity.query
+            files = self.application_generate_entity.files
 
-        db.session.close()
+            # moderation
+            if self.handle_input_moderation(
+                    app_record=app_record,
+                    app_generate_entity=self.application_generate_entity,
+                    inputs=inputs,
+                    query=query,
+                    message_id=self.message.id
+            ):
+                return
 
-        workflow_callbacks: list[WorkflowCallback] = [
-            WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
-        ]
+            # annotation reply
+            if self.handle_annotation_reply(
+                    app_record=app_record,
+                    message=self.message,
+                    query=query,
+                    app_generate_entity=self.application_generate_entity
+            ):
+                return
 
-        if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
-            workflow_callbacks.append(WorkflowLoggingCallback())
+            # Init conversation variables
+            stmt = select(ConversationVariable).where(
+                ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
+            )
+            with Session(db.engine) as session:
+                conversation_variables = session.scalars(stmt).all()
+                if not conversation_variables:
+                    # Create conversation variables if they don't exist.
+                    conversation_variables = [
+                        ConversationVariable.from_variable(
+                            app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
+                        )
+                        for variable in workflow.conversation_variables
+                    ]
+                    session.add_all(conversation_variables)
+                # Convert database entities to variables.
+                conversation_variables = [item.to_variable() for item in conversation_variables]
 
-        # RUN WORKFLOW
-        workflow_engine_manager = WorkflowEngineManager()
-        workflow_engine_manager.run_workflow(
-            workflow=workflow,
-            user_id=application_generate_entity.user_id,
-            user_from=UserFrom.ACCOUNT
-            if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
-            else UserFrom.END_USER,
-            invoke_from=application_generate_entity.invoke_from,
-            callbacks=workflow_callbacks,
-            call_depth=application_generate_entity.call_depth,
-        )
+                session.commit()
 
-    def single_iteration_run(
-        self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
-    ) -> None:
-        """
-        Single iteration run
-        """
-        app_record = db.session.query(App).filter(App.id == app_id).first()
-        if not app_record:
-            raise ValueError('App not found')
+            # Increment dialogue count.
+            self.conversation.dialogue_count += 1
 
-        workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
-        if not workflow:
-            raise ValueError('Workflow not initialized')
+            conversation_dialogue_count = self.conversation.dialogue_count
+            db.session.commit()
 
-        workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
+            # Create a variable pool.
+            system_inputs = {
+                SystemVariableKey.QUERY: query,
+                SystemVariableKey.FILES: files,
+                SystemVariableKey.CONVERSATION_ID: self.conversation.id,
+                SystemVariableKey.USER_ID: user_id,
+                SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
+            }
 
-        workflow_engine_manager = WorkflowEngineManager()
-        workflow_engine_manager.single_step_run_iteration_workflow_node(
-            workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
+            # init variable pool
+            variable_pool = VariablePool(
+                system_variables=system_inputs,
+                user_inputs=inputs,
+                environment_variables=workflow.environment_variables,
+                conversation_variables=conversation_variables,
+            )
+
+            # init graph
+            graph = self._init_graph(graph_config=workflow.graph_dict)
+
+        db.session.close()
+
+        # RUN WORKFLOW
+        workflow_entry = WorkflowEntry(
+            tenant_id=workflow.tenant_id,
+            app_id=workflow.app_id,
+            workflow_id=workflow.id,
+            workflow_type=WorkflowType.value_of(workflow.type),
+            graph=graph,
+            graph_config=workflow.graph_dict,
+            user_id=self.application_generate_entity.user_id,
+            user_from=(
+                UserFrom.ACCOUNT
+                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                else UserFrom.END_USER
+            ),
+            invoke_from=self.application_generate_entity.invoke_from,
+            call_depth=self.application_generate_entity.call_depth,
+            variable_pool=variable_pool,
         )
 
-    def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
-        """
-        Get workflow
-        """
-        # fetch workflow by workflow_id
-        workflow = (
-            db.session.query(Workflow)
-            .filter(
-                Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
-            )
-            .first()
+        generator = workflow_entry.run(
+            callbacks=workflow_callbacks,
         )
 
-        # return workflow
-        return workflow
+        for event in generator:
+            self._handle_event(workflow_entry, event)
 
     def handle_input_moderation(
-        self,
-        queue_manager: AppQueueManager,
-        app_record: App,
-        app_generate_entity: AdvancedChatAppGenerateEntity,
-        inputs: Mapping[str, Any],
-        query: str,
-        message_id: str,
+            self,
+            app_record: App,
+            app_generate_entity: AdvancedChatAppGenerateEntity,
+            inputs: Mapping[str, Any],
+            query: str,
+            message_id: str
     ) -> bool:
         """
         Handle input moderation
-        :param queue_manager: application queue manager
         :param app_record: app record
         :param app_generate_entity: application generate entity
         :param inputs: inputs
@@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
                 message_id=message_id,
             )
         except ModerationException as e:
-            self._stream_output(
-                queue_manager=queue_manager,
+            self._complete_with_stream_output(
                 text=str(e),
-                stream=app_generate_entity.stream,
-                stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
+                stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
             )
             return True
 
         return False
 
-    def handle_annotation_reply(
-        self,
-        app_record: App,
-        message: Message,
-        query: str,
-        queue_manager: AppQueueManager,
-        app_generate_entity: AdvancedChatAppGenerateEntity,
-    ) -> bool:
+    def handle_annotation_reply(self, app_record: App,
+                                message: Message,
+                                query: str,
+                                app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
         """
         Handle annotation reply
         :param app_record: app record
         :param message: message
         :param query: query
-        :param queue_manager: application queue manager
         :param app_generate_entity: application generate entity
         """
         # annotation reply
@@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
         )
 
         if annotation_reply:
-            queue_manager.publish(
-                QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
+            self._publish_event(
+                QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
             )
 
-            self._stream_output(
-                queue_manager=queue_manager,
+            self._complete_with_stream_output(
                 text=annotation_reply.content,
-                stream=app_generate_entity.stream,
-                stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
+                stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
             )
             return True
 
         return False
 
-    def _stream_output(
-        self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
-    ) -> None:
+    def _complete_with_stream_output(self,
+                                     text: str,
+                                     stopped_by: QueueStopEvent.StopBy) -> None:
         """
         Direct output
-        :param queue_manager: application queue manager
         :param text: text
-        :param stream: stream
         :return:
         """
-        if stream:
-            index = 0
-            for token in text:
-                queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
-                index += 1
-                time.sleep(0.01)
-        else:
-            queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
+        self._publish_event(
+            QueueTextChunkEvent(
+                text=text
+            )
+        )
 
-        queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
+        self._publish_event(
+            QueueStopEvent(stopped_by=stopped_by)
+        )

+ 229 - 433
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -2,9 +2,8 @@ import json
 import logging
 import time
 from collections.abc import Generator
-from typing import Any, Optional, Union, cast
+from typing import Any, Optional, Union
 
-import contexts
 from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
 from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
     QueueNodeFailedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
+    QueueParallelBranchRunFailedEvent,
+    QueueParallelBranchRunStartedEvent,
+    QueueParallelBranchRunSucceededEvent,
     QueuePingEvent,
     QueueRetrieverResourcesEvent,
     QueueStopEvent,
@@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
     QueueWorkflowSucceededEvent,
 )
 from core.app.entities.task_entities import (
-    AdvancedChatTaskState,
     ChatbotAppBlockingResponse,
     ChatbotAppStreamResponse,
-    ChatflowStreamGenerateRoute,
     ErrorStreamResponse,
     MessageAudioEndStreamResponse,
     MessageAudioStreamResponse,
     MessageEndStreamResponse,
     StreamResponse,
+    WorkflowTaskState,
 )
 from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
 from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
 from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
-from core.file.file_obj import FileVar
-from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.ops.ops_trace_manager import TraceQueueManager
-from core.workflow.entities.node_entities import NodeType
 from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.answer.answer_node import AnswerNode
-from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from events.message_event import message_was_created
 from extensions.ext_database import db
 from models.account import Account
 from models.model import Conversation, EndUser, Message
 from models.workflow import (
     Workflow,
-    WorkflowNodeExecution,
     WorkflowRunStatus,
 )
 
@@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     """
     AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
     """
-    _task_state: AdvancedChatTaskState
+    _task_state: WorkflowTaskState
     _application_generate_entity: AdvancedChatAppGenerateEntity
     _workflow: Workflow
     _user: Union[Account, EndUser]
-    # Deprecated
     _workflow_system_variables: dict[SystemVariableKey, Any]
-    _iteration_nested_relations: dict[str, list[str]]
 
     def __init__(
-            self, application_generate_entity: AdvancedChatAppGenerateEntity,
+            self,
+            application_generate_entity: AdvancedChatAppGenerateEntity,
             workflow: Workflow,
             queue_manager: AppQueueManager,
             conversation: Conversation,
@@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         self._workflow = workflow
         self._conversation = conversation
         self._message = message
-        # Deprecated
         self._workflow_system_variables = {
             SystemVariableKey.QUERY: message.query,
             SystemVariableKey.FILES: application_generate_entity.files,
@@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             SystemVariableKey.USER_ID: user_id,
         }
 
-        self._task_state = AdvancedChatTaskState(
-            usage=LLMUsage.empty_usage()
-        )
+        self._task_state = WorkflowTaskState()
 
-        self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
-        self._stream_generate_routes = self._get_stream_generate_routes()
         self._conversation_name_generate_thread = None
 
     def process(self):
@@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         generator = self._wrapper_process_stream_response(
             trace_manager=self._application_generate_entity.trace_manager
         )
+
         if self._stream:
             return self._to_stream_response(generator)
         else:
@@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
     def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
             Generator[StreamResponse, None, None]:
 
-        publisher = None
+        tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
         features_dict = self._workflow.features_dict
 
         if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
                 'text_to_speech'].get('autoPlay') == 'enabled':
-            publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
-        for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
+            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
+
+        for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
-                audio_response = self._listenAudioMsg(publisher, task_id=task_id)
+                audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
                 if audio_response:
                     yield audio_response
                 else:
@@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         # timeout
         while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
             try:
-                if not publisher:
+                if not tts_publisher:
                     break
-                audio_trunk = publisher.checkAndGetAudio()
+                audio_trunk = tts_publisher.checkAndGetAudio()
                 if audio_trunk is None:
                     # release cpu
                     # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
     def _process_stream_response(
             self,
-            publisher: AppGeneratorTTSPublisher,
+            tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
             trace_manager: Optional[TraceQueueManager] = None
     ) -> Generator[StreamResponse, None, None]:
         """
         Process stream response.
         :return:
         """
-        for message in self._queue_manager.listen():
-            if (message.event
-                    and getattr(message.event, 'metadata', None)
-                    and message.event.metadata.get('is_answer_previous_node', False)
-                    and publisher):
-                publisher.publish(message=message)
-            elif (hasattr(message.event, 'execution_metadata')
-                  and message.event.execution_metadata
-                  and message.event.execution_metadata.get('is_answer_previous_node', False)
-                  and publisher):
-                publisher.publish(message=message)
-            event = message.event
-
-            if isinstance(event, QueueErrorEvent):
+        # init fake graph runtime state
+        graph_runtime_state = None
+        workflow_run = None
+
+        for queue_message in self._queue_manager.listen():
+            event = queue_message.event
+
+            if isinstance(event, QueuePingEvent):
+                yield self._ping_stream_response()
+            elif isinstance(event, QueueErrorEvent):
                 err = self._handle_error(event, self._message)
                 yield self._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
-                workflow_run = self._handle_workflow_start()
+                # override graph runtime state
+                graph_runtime_state = event.graph_runtime_state
+
+                # init workflow run
+                workflow_run = self._handle_workflow_run_start()
 
-                self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
+                self._refetch_message()
                 self._message.workflow_run_id = workflow_run.id
 
                 db.session.commit()
@@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                     workflow_run=workflow_run
                 )
             elif isinstance(event, QueueNodeStartedEvent):
-                workflow_node_execution = self._handle_node_start(event)
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
 
-                # search stream_generate_routes if node id is answer start at node
-                if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
-                    self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
-                    # reset current route position to 0
-                    self._task_state.current_stream_generate_state.current_route_position = 0
+                workflow_node_execution = self._handle_node_execution_start(
+                    workflow_run=workflow_run,
+                    event=event
+                )
 
-                    # generate stream outputs when node started
-                    yield from self._generate_stream_outputs_when_node_started()
+                response = self._workflow_node_start_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution
+                )
 
-                yield self._workflow_node_start_to_stream_response(
+                if response:
+                    yield response
+            elif isinstance(event, QueueNodeSucceededEvent):
+                workflow_node_execution = self._handle_workflow_node_execution_success(event)
+
+                response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution
                 )
-            elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
-                workflow_node_execution = self._handle_node_finished(event)
 
-                # stream outputs when node finished
-                generator = self._generate_stream_outputs_when_node_finished()
-                if generator:
-                    yield from generator
+                if response:
+                    yield response
+            elif isinstance(event, QueueNodeFailedEvent):
+                workflow_node_execution = self._handle_workflow_node_execution_failed(event)
 
-                yield self._workflow_node_finish_to_stream_response(
+                response = self._workflow_node_finish_to_stream_response(
+                    event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution
                 )
 
-                if isinstance(event, QueueNodeFailedEvent):
-                    yield from self._handle_iteration_exception(
-                        task_id=self._application_generate_entity.task_id,
-                        error=f'Child node failed: {event.error}'
-                    )
-            elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
-                if isinstance(event, QueueIterationNextEvent):
-                    # clear ran node execution infos of current iteration
-                    iteration_relations = self._iteration_nested_relations.get(event.node_id)
-                    if iteration_relations:
-                        for node_id in iteration_relations:
-                            self._task_state.ran_node_execution_infos.pop(node_id, None)
-
-                yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
-                self._handle_iteration_operation(event)
-            elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
-                workflow_run = self._handle_workflow_finished(
-                    event, conversation_id=self._conversation.id, trace_manager=trace_manager
+                if response:
+                    yield response
+            elif isinstance(event, QueueParallelBranchRunStartedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_parallel_branch_start_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
                 )
-                if workflow_run:
-                    yield self._workflow_finish_to_stream_response(
-                        task_id=self._application_generate_entity.task_id,
-                        workflow_run=workflow_run
-                    )
+            elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
 
-                    if workflow_run.status == WorkflowRunStatus.FAILED.value:
-                        err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
-                        yield self._error_to_stream_response(self._handle_error(err_event, self._message))
-                        break
+                yield self._workflow_parallel_branch_finished_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationStartEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
 
-                if isinstance(event, QueueStopEvent):
-                    # Save message
-                    self._save_message()
+                yield self._workflow_iteration_start_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationNextEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
 
-                    yield self._message_end_to_stream_response()
-                    break
-                else:
-                    self._queue_manager.publish(
-                        QueueAdvancedChatMessageEndEvent(),
-                        PublishFrom.TASK_PIPELINE
+                yield self._workflow_iteration_next_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationCompletedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_iteration_completed_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueWorkflowSucceededEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                if not graph_runtime_state:
+                    raise Exception('Graph runtime state not initialized.')
+
+                workflow_run = self._handle_workflow_run_success(
+                    workflow_run=workflow_run,
+                    start_at=graph_runtime_state.start_at,
+                    total_tokens=graph_runtime_state.total_tokens,
+                    total_steps=graph_runtime_state.node_run_steps,
+                    outputs=json.dumps(event.outputs) if event.outputs else None,
+                    conversation_id=self._conversation.id,
+                    trace_manager=trace_manager,
+                )
+
+                yield self._workflow_finish_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run
+                )
+
+                self._queue_manager.publish(
+                    QueueAdvancedChatMessageEndEvent(),
+                    PublishFrom.TASK_PIPELINE
+                )
+            elif isinstance(event, QueueWorkflowFailedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                if not graph_runtime_state:
+                    raise Exception('Graph runtime state not initialized.')
+
+                workflow_run = self._handle_workflow_run_failed(
+                    workflow_run=workflow_run,
+                    start_at=graph_runtime_state.start_at,
+                    total_tokens=graph_runtime_state.total_tokens,
+                    total_steps=graph_runtime_state.node_run_steps,
+                    status=WorkflowRunStatus.FAILED,
+                    error=event.error,
+                    conversation_id=self._conversation.id,
+                    trace_manager=trace_manager,
+                )
+
+                yield self._workflow_finish_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run
+                )
+
+                err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
+                yield self._error_to_stream_response(self._handle_error(err_event, self._message))
+                break
+            elif isinstance(event, QueueStopEvent):
+                if workflow_run and graph_runtime_state:
+                    workflow_run = self._handle_workflow_run_failed(
+                        workflow_run=workflow_run,
+                        start_at=graph_runtime_state.start_at,
+                        total_tokens=graph_runtime_state.total_tokens,
+                        total_steps=graph_runtime_state.node_run_steps,
+                        status=WorkflowRunStatus.STOPPED,
+                        error=event.get_stop_reason(),
+                        conversation_id=self._conversation.id,
+                        trace_manager=trace_manager,
+                    )
+
+                    yield self._workflow_finish_to_stream_response(
+                        task_id=self._application_generate_entity.task_id,
+                        workflow_run=workflow_run
                     )
-            elif isinstance(event, QueueAdvancedChatMessageEndEvent):
-                output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
-                if output_moderation_answer:
-                    self._task_state.answer = output_moderation_answer
-                    yield self._message_replace_to_stream_response(answer=output_moderation_answer)
 
                 # Save message
-                self._save_message()
+                self._save_message(graph_runtime_state=graph_runtime_state)
 
                 yield self._message_end_to_stream_response()
+                break
             elif isinstance(event, QueueRetrieverResourcesEvent):
                 self._handle_retriever_resources(event)
+
+                self._refetch_message()
+
+                self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
+                    if self._task_state.metadata else None
+
+                db.session.commit()
+                db.session.refresh(self._message)
+                db.session.close()
             elif isinstance(event, QueueAnnotationReplyEvent):
                 self._handle_annotation_reply(event)
+
+                self._refetch_message()
+
+                self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
+                    if self._task_state.metadata else None
+
+                db.session.commit()
+                db.session.refresh(self._message)
+                db.session.close()
             elif isinstance(event, QueueTextChunkEvent):
                 delta_text = event.text
                 if delta_text is None:
                     continue
 
-                if not self._is_stream_out_support(
-                        event=event
-                ):
-                    continue
-
                 # handle output moderation chunk
                 should_direct_answer = self._handle_output_moderation_chunk(delta_text)
                 if should_direct_answer:
                     continue
 
+                # only publish tts message at text chunk streaming
+                if tts_publisher:
+                    tts_publisher.publish(message=queue_message)
+
                 self._task_state.answer += delta_text
                 yield self._message_to_stream_response(delta_text, self._message.id)
             elif isinstance(event, QueueMessageReplaceEvent):
+                # published by moderation
                 yield self._message_replace_to_stream_response(answer=event.text)
-            elif isinstance(event, QueuePingEvent):
-                yield self._ping_stream_response()
+            elif isinstance(event, QueueAdvancedChatMessageEndEvent):
+                if not graph_runtime_state:
+                    raise Exception('Graph runtime state not initialized.')
+
+                output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
+                if output_moderation_answer:
+                    self._task_state.answer = output_moderation_answer
+                    yield self._message_replace_to_stream_response(answer=output_moderation_answer)
+
+                # Save message
+                self._save_message(graph_runtime_state=graph_runtime_state)
+
+                yield self._message_end_to_stream_response()
             else:
                 continue
-        if publisher:
-            publisher.publish(None)
+
+        # publish None when task finished
+        if tts_publisher:
+            tts_publisher.publish(None)
+
         if self._conversation_name_generate_thread:
             self._conversation_name_generate_thread.join()
 
-    def _save_message(self) -> None:
+    def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
         """
         Save message.
         :return:
         """
-        self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
+        self._refetch_message()
 
         self._message.answer = self._task_state.answer
         self._message.provider_response_latency = time.perf_counter() - self._start_at
         self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
             if self._task_state.metadata else None
 
-        if self._task_state.metadata and self._task_state.metadata.get('usage'):
-            usage = LLMUsage(**self._task_state.metadata['usage'])
-
+        if graph_runtime_state and graph_runtime_state.llm_usage:
+            usage = graph_runtime_state.llm_usage
             self._message.message_tokens = usage.prompt_tokens
             self._message.message_unit_price = usage.prompt_unit_price
             self._message.message_price_unit = usage.prompt_price_unit
@@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
         """
         extras = {}
         if self._task_state.metadata:
-            extras['metadata'] = self._task_state.metadata
+            extras['metadata'] = self._task_state.metadata.copy()
+
+            if 'annotation_reply' in extras['metadata']:
+                del extras['metadata']['annotation_reply']
 
         return MessageEndStreamResponse(
             task_id=self._application_generate_entity.task_id,
@@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
             **extras
         )
 
-    def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
-        """
-        Get stream generate routes.
-        :return:
-        """
-        # find all answer nodes
-        graph = self._workflow.graph_dict
-        answer_node_configs = [
-            node for node in graph['nodes']
-            if node.get('data', {}).get('type') == NodeType.ANSWER.value
-        ]
-
-        # parse stream output node value selectors of answer nodes
-        stream_generate_routes = {}
-        for node_config in answer_node_configs:
-            # get generate route for stream output
-            answer_node_id = node_config['id']
-            generate_route = AnswerNode.extract_generate_route_selectors(node_config)
-            start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
-            if not start_node_ids:
-                continue
-
-            for start_node_id in start_node_ids:
-                stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
-                    answer_node_id=answer_node_id,
-                    generate_route=generate_route
-                )
-
-        return stream_generate_routes
-
-    def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-            -> list[str]:
-        """
-        Get answer start at node id.
-        :param graph: graph
-        :param target_node_id: target node ID
-        :return:
-        """
-        nodes = graph.get('nodes')
-        edges = graph.get('edges')
-
-        # fetch all ingoing edges from source node
-        ingoing_edges = []
-        for edge in edges:
-            if edge.get('target') == target_node_id:
-                ingoing_edges.append(edge)
-
-        if not ingoing_edges:
-            # check if it's the first node in the iteration
-            target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
-            if not target_node:
-                return []
-
-            node_iteration_id = target_node.get('data', {}).get('iteration_id')
-            # get iteration start node id
-            for node in nodes:
-                if node.get('id') == node_iteration_id:
-                    if node.get('data', {}).get('start_node_id') == target_node_id:
-                        return [target_node_id]
-
-            return []
-
-        start_node_ids = []
-        for ingoing_edge in ingoing_edges:
-            source_node_id = ingoing_edge.get('source')
-            source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
-            if not source_node:
-                continue
-
-            node_type = source_node.get('data', {}).get('type')
-            node_iteration_id = source_node.get('data', {}).get('iteration_id')
-            iteration_start_node_id = None
-            if node_iteration_id:
-                iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
-                iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
-
-            if node_type in [
-                NodeType.ANSWER.value,
-                NodeType.IF_ELSE.value,
-                NodeType.QUESTION_CLASSIFIER.value,
-                NodeType.ITERATION.value,
-                NodeType.LOOP.value
-            ]:
-                start_node_id = target_node_id
-                start_node_ids.append(start_node_id)
-            elif node_type == NodeType.START.value or \
-                    node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
-                start_node_id = source_node_id
-                start_node_ids.append(start_node_id)
-            else:
-                sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
-                if sub_start_node_ids:
-                    start_node_ids.extend(sub_start_node_ids)
-
-        return start_node_ids
-
-    def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
-        """
-        Get iteration nested relations.
-        :param graph: graph
-        :return:
-        """
-        nodes = graph.get('nodes')
-
-        iteration_ids = [node.get('id') for node in nodes
-                         if node.get('data', {}).get('type') in [
-                             NodeType.ITERATION.value,
-                             NodeType.LOOP.value,
-                         ]]
-
-        return {
-            iteration_id: [
-                node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
-            ] for iteration_id in iteration_ids
-        }
-
-    def _generate_stream_outputs_when_node_started(self) -> Generator:
-        """
-        Generate stream outputs.
-        :return:
-        """
-        if self._task_state.current_stream_generate_state:
-            route_chunks = self._task_state.current_stream_generate_state.generate_route[
-                           self._task_state.current_stream_generate_state.current_route_position:
-                           ]
-
-            for route_chunk in route_chunks:
-                if route_chunk.type == 'text':
-                    route_chunk = cast(TextGenerateRouteChunk, route_chunk)
-
-                    # handle output moderation chunk
-                    should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
-                    if should_direct_answer:
-                        continue
-
-                    self._task_state.answer += route_chunk.text
-                    yield self._message_to_stream_response(route_chunk.text, self._message.id)
-                else:
-                    break
-
-                self._task_state.current_stream_generate_state.current_route_position += 1
-
-            # all route chunks are generated
-            if self._task_state.current_stream_generate_state.current_route_position == len(
-                    self._task_state.current_stream_generate_state.generate_route
-            ):
-                self._task_state.current_stream_generate_state = None
-
-    def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
-        """
-        Generate stream outputs.
-        :return:
-        """
-        if not self._task_state.current_stream_generate_state:
-            return
-
-        route_chunks = self._task_state.current_stream_generate_state.generate_route[
-                       self._task_state.current_stream_generate_state.current_route_position:]
-
-        for route_chunk in route_chunks:
-            if route_chunk.type == 'text':
-                route_chunk = cast(TextGenerateRouteChunk, route_chunk)
-                self._task_state.answer += route_chunk.text
-                yield self._message_to_stream_response(route_chunk.text, self._message.id)
-            else:
-                value = None
-                route_chunk = cast(VarGenerateRouteChunk, route_chunk)
-                value_selector = route_chunk.value_selector
-                if not value_selector:
-                    self._task_state.current_stream_generate_state.current_route_position += 1
-                    continue
-
-                route_chunk_node_id = value_selector[0]
-
-                if route_chunk_node_id == 'sys':
-                    # system variable
-                    value = contexts.workflow_variable_pool.get().get(value_selector)
-                    if value:
-                        value = value.text
-                elif route_chunk_node_id in self._iteration_nested_relations:
-                    # it's a iteration variable
-                    if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
-                        continue
-                    iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
-                    iterator = iteration_state.inputs
-                    if not iterator:
-                        continue
-                    iterator_selector = iterator.get('iterator_selector', [])
-                    if value_selector[1] == 'index':
-                        value = iteration_state.current_index
-                    elif value_selector[1] == 'item':
-                        value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
-                            iterator_selector
-                        ) else None
-                else:
-                    # check chunk node id is before current node id or equal to current node id
-                    if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
-                        break
-
-                    latest_node_execution_info = self._task_state.latest_node_execution_info
-
-                    # get route chunk node execution info
-                    route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
-                    if (route_chunk_node_execution_info.node_type == NodeType.LLM
-                            and latest_node_execution_info.node_type == NodeType.LLM):
-                        # only LLM support chunk stream output
-                        self._task_state.current_stream_generate_state.current_route_position += 1
-                        continue
-
-                    # get route chunk node execution
-                    route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
-                        WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
-                    ).first()
-
-                    outputs = route_chunk_node_execution.outputs_dict
-
-                    # get value from outputs
-                    value = None
-                    for key in value_selector[1:]:
-                        if not value:
-                            value = outputs.get(key) if outputs else None
-                        else:
-                            value = value.get(key)
-
-                if value is not None:
-                    text = ''
-                    if isinstance(value, str | int | float):
-                        text = str(value)
-                    elif isinstance(value, FileVar):
-                        # convert file to markdown
-                        text = value.to_markdown()
-                    elif isinstance(value, dict):
-                        # handle files
-                        file_vars = self._fetch_files_from_variable_value(value)
-                        if file_vars:
-                            file_var = file_vars[0]
-                            try:
-                                file_var_obj = FileVar(**file_var)
-
-                                # convert file to markdown
-                                text = file_var_obj.to_markdown()
-                            except Exception as e:
-                                logger.error(f'Error creating file var: {e}')
-
-                        if not text:
-                            # other types
-                            text = json.dumps(value, ensure_ascii=False)
-                    elif isinstance(value, list):
-                        # handle files
-                        file_vars = self._fetch_files_from_variable_value(value)
-                        for file_var in file_vars:
-                            try:
-                                file_var_obj = FileVar(**file_var)
-                            except Exception as e:
-                                logger.error(f'Error creating file var: {e}')
-                                continue
-
-                            # convert file to markdown
-                            text = file_var_obj.to_markdown() + ' '
-
-                        text = text.strip()
-
-                        if not text and value:
-                            # other types
-                            text = json.dumps(value, ensure_ascii=False)
-
-                    if text:
-                        self._task_state.answer += text
-                        yield self._message_to_stream_response(text, self._message.id)
-
-            self._task_state.current_stream_generate_state.current_route_position += 1
-
-        # all route chunks are generated
-        if self._task_state.current_stream_generate_state.current_route_position == len(
-                self._task_state.current_stream_generate_state.generate_route
-        ):
-            self._task_state.current_stream_generate_state = None
-
-    def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
-        """
-        Is stream out support
-        :param event: queue text chunk event
-        :return:
-        """
-        if not event.metadata:
-            return True
-
-        if 'node_id' not in event.metadata:
-            return True
-
-        node_type = event.metadata.get('node_type')
-        stream_output_value_selector = event.metadata.get('value_selector')
-        if not stream_output_value_selector:
-            return False
-
-        if not self._task_state.current_stream_generate_state:
-            return False
-
-        route_chunk = self._task_state.current_stream_generate_state.generate_route[
-            self._task_state.current_stream_generate_state.current_route_position]
-
-        if route_chunk.type != 'var':
-            return False
-
-        if node_type != NodeType.LLM:
-            # only LLM support chunk stream output
-            return False
-
-        route_chunk = cast(VarGenerateRouteChunk, route_chunk)
-        value_selector = route_chunk.value_selector
-
-        # check chunk node id is before current node id or equal to current node id
-        if value_selector != stream_output_value_selector:
-            return False
-
-        return True
-
     def _handle_output_moderation_chunk(self, text: str) -> bool:
         """
         Handle output moderation chunk.
@@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 self._output_moderation_handler.append_new_token(text)
 
         return False
+
+    def _refetch_message(self) -> None:
+        """
+        Refetch message.
+        :return:
+        """
+        message = db.session.query(Message).filter(Message.id == self._message.id).first()
+        if message:
+            self._message = message

+ 0 - 203
api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py

@@ -1,203 +0,0 @@
-from typing import Any, Optional
-
-from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
-from core.app.entities.queue_entities import (
-    AppQueueEvent,
-    QueueIterationCompletedEvent,
-    QueueIterationNextEvent,
-    QueueIterationStartEvent,
-    QueueNodeFailedEvent,
-    QueueNodeStartedEvent,
-    QueueNodeSucceededEvent,
-    QueueTextChunkEvent,
-    QueueWorkflowFailedEvent,
-    QueueWorkflowStartedEvent,
-    QueueWorkflowSucceededEvent,
-)
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
-from models.workflow import Workflow
-
-
-class WorkflowEventTriggerCallback(WorkflowCallback):
-
-    def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
-        self._queue_manager = queue_manager
-
-    def on_workflow_run_started(self) -> None:
-        """
-        Workflow run started
-        """
-        self._queue_manager.publish(
-            QueueWorkflowStartedEvent(),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_run_succeeded(self) -> None:
-        """
-        Workflow run succeeded
-        """
-        self._queue_manager.publish(
-            QueueWorkflowSucceededEvent(),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_run_failed(self, error: str) -> None:
-        """
-        Workflow run failed
-        """
-        self._queue_manager.publish(
-            QueueWorkflowFailedEvent(
-                error=error
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_started(self, node_id: str,
-                                         node_type: NodeType,
-                                         node_data: BaseNodeData,
-                                         node_run_index: int = 1,
-                                         predecessor_node_id: Optional[str] = None) -> None:
-        """
-        Workflow node execute started
-        """
-        self._queue_manager.publish(
-            QueueNodeStartedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                node_run_index=node_run_index,
-                predecessor_node_id=predecessor_node_id
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_succeeded(self, node_id: str,
-                                           node_type: NodeType,
-                                           node_data: BaseNodeData,
-                                           inputs: Optional[dict] = None,
-                                           process_data: Optional[dict] = None,
-                                           outputs: Optional[dict] = None,
-                                           execution_metadata: Optional[dict] = None) -> None:
-        """
-        Workflow node execute succeeded
-        """
-        self._queue_manager.publish(
-            QueueNodeSucceededEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                inputs=inputs,
-                process_data=process_data,
-                outputs=outputs,
-                execution_metadata=execution_metadata
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_failed(self, node_id: str,
-                                        node_type: NodeType,
-                                        node_data: BaseNodeData,
-                                        error: str,
-                                        inputs: Optional[dict] = None,
-                                        outputs: Optional[dict] = None,
-                                        process_data: Optional[dict] = None) -> None:
-        """
-        Workflow node execute failed
-        """
-        self._queue_manager.publish(
-            QueueNodeFailedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                inputs=inputs,
-                outputs=outputs,
-                process_data=process_data,
-                error=error
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
-        """
-        Publish text chunk
-        """
-        self._queue_manager.publish(
-            QueueTextChunkEvent(
-                text=text,
-                metadata={
-                    "node_id": node_id,
-                    **metadata
-                }
-            ), PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_started(self, 
-                                      node_id: str,
-                                      node_type: NodeType,
-                                      node_run_index: int = 1,
-                                      node_data: Optional[BaseNodeData] = None,
-                                      inputs: dict = None,
-                                      predecessor_node_id: Optional[str] = None,
-                                      metadata: Optional[dict] = None) -> None:
-        """
-        Publish iteration started
-        """
-        self._queue_manager.publish(
-            QueueIterationStartEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_run_index=node_run_index,
-                node_data=node_data,
-                inputs=inputs,
-                predecessor_node_id=predecessor_node_id,
-                metadata=metadata
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_next(self, node_id: str, 
-                                   node_type: NodeType,
-                                   index: int, 
-                                   node_run_index: int,
-                                   output: Optional[Any]) -> None:
-        """
-        Publish iteration next
-        """
-        self._queue_manager._publish(
-            QueueIterationNextEvent(
-                node_id=node_id,
-                node_type=node_type,
-                index=index,
-                node_run_index=node_run_index,
-                output=output
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_completed(self, node_id: str, 
-                                        node_type: NodeType,
-                                        node_run_index: int,
-                                        outputs: dict) -> None:
-        """
-        Publish iteration completed
-        """
-        self._queue_manager._publish(
-            QueueIterationCompletedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_run_index=node_run_index,
-                outputs=outputs
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_event(self, event: AppQueueEvent) -> None:
-        """
-        Publish event
-        """
-        self._queue_manager.publish(
-            event,
-            PublishFrom.APPLICATION_MANAGER
-        )

+ 1 - 1
api/core/app/apps/base_app_generate_response_converter.py

@@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
     def convert(cls, response: Union[
         AppBlockingResponse,
         Generator[AppStreamResponse, Any, None]
-    ], invoke_from: InvokeFrom):
+    ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
         if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
             if isinstance(response, AppBlockingResponse):
                 return cls.convert_blocking_full_response(response)

+ 3 - 3
api/core/app/apps/base_app_runner.py

@@ -1,6 +1,6 @@
 import time
-from collections.abc import Generator
-from typing import TYPE_CHECKING, Optional, Union
+from collections.abc import Generator, Mapping
+from typing import TYPE_CHECKING, Any, Optional, Union
 
 from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
 from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -347,7 +347,7 @@ class AppRunner:
             self, app_id: str,
             tenant_id: str,
             app_generate_entity: AppGenerateEntity,
-            inputs: dict,
+            inputs: Mapping[str, Any],
             query: str,
             message_id: str,
     ) -> tuple[bool, dict, str]:

+ 34 - 31
api/core/app/apps/workflow/app_generator.py

@@ -4,7 +4,7 @@ import os
 import threading
 import uuid
 from collections.abc import Generator
-from typing import Literal, Union, overload
+from typing import Any, Literal, Optional, Union, overload
 
 from flask import Flask, current_app
 from pydantic import ValidationError
@@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[True] = True,
+        call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None
     ) -> Generator[str, None, None]: ...
 
     @overload
@@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
         args: dict,
         invoke_from: InvokeFrom,
         stream: Literal[False] = False,
+        call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None
     ) -> dict: ...
 
     def generate(
-        self, app_model: App,
+        self,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         args: dict,
         invoke_from: InvokeFrom,
         stream: bool = True,
         call_depth: int = 0,
+        workflow_thread_pool_id: Optional[str] = None
     ):
         """
         Generate App response.
@@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
         :param invoke_from: invoke from source
         :param stream: is stream
         :param call_depth: call depth
+        :param workflow_thread_pool_id: workflow thread pool id
         """
         inputs = args['inputs']
 
@@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
             application_generate_entity=application_generate_entity,
             invoke_from=invoke_from,
             stream=stream,
+            workflow_thread_pool_id=workflow_thread_pool_id
         )
 
     def _generate(
-        self, app_model: App,
+        self, *,
+        app_model: App,
         workflow: Workflow,
         user: Union[Account, EndUser],
         application_generate_entity: WorkflowAppGenerateEntity,
         invoke_from: InvokeFrom,
         stream: bool = True,
-    ) -> Union[dict, Generator[str, None, None]]:
+        workflow_thread_pool_id: Optional[str] = None
+    ) -> dict[str, Any] | Generator[str, None, None]:
         """
         Generate App response.
 
@@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
         :param application_generate_entity: application generate entity
         :param invoke_from: invoke from source
         :param stream: is stream
+        :param workflow_thread_pool_id: workflow thread pool id
         """
         # init queue manager
         queue_manager = WorkflowAppQueueManager(
@@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
 
         # new thread
         worker_thread = threading.Thread(target=self._generate_worker, kwargs={
-            'flask_app': current_app._get_current_object(),
+            'flask_app': current_app._get_current_object(), # type: ignore
             'application_generate_entity': application_generate_entity,
             'queue_manager': queue_manager,
-            'context': contextvars.copy_context()
+            'context': contextvars.copy_context(),
+            'workflow_thread_pool_id': workflow_thread_pool_id
         })
 
         worker_thread.start()
@@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
                                   node_id: str,
                                   user: Account,
                                   args: dict,
-                                  stream: bool = True):
+                                  stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
         """
         Generate App response.
 
@@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
         if args.get('inputs') is None:
             raise ValueError('inputs is required')
 
-        extras = {
-            "auto_generate_conversation_name": False
-        }
-
         # convert to app config
         app_config = WorkflowAppConfigManager.get_app_config(
             app_model=app_model,
@@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
             user_id=user.id,
             stream=stream,
             invoke_from=InvokeFrom.DEBUGGER,
-            extras=extras,
+            extras={
+                "auto_generate_conversation_name": False
+            },
             single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
                 node_id=node_id,
                 inputs=args['inputs']
@@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
     def _generate_worker(self, flask_app: Flask,
                          application_generate_entity: WorkflowAppGenerateEntity,
                          queue_manager: AppQueueManager,
-                         context: contextvars.Context) -> None:
+                         context: contextvars.Context,
+                         workflow_thread_pool_id: Optional[str] = None) -> None:
         """
         Generate worker in a new thread.
         :param flask_app: Flask app
         :param application_generate_entity: application generate entity
         :param queue_manager: queue manager
+        :param workflow_thread_pool_id: workflow thread pool id
         :return:
         """
         for var, val in context.items():
@@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
         with flask_app.app_context():
             try:
                 # workflow app
-                runner = WorkflowAppRunner()
-                if application_generate_entity.single_iteration_run:
-                    single_iteration_run = application_generate_entity.single_iteration_run
-                    runner.single_iteration_run(
-                        app_id=application_generate_entity.app_config.app_id,
-                        workflow_id=application_generate_entity.app_config.workflow_id,
-                        queue_manager=queue_manager,
-                        inputs=single_iteration_run.inputs,
-                        node_id=single_iteration_run.node_id,
-                        user_id=application_generate_entity.user_id
-                    )
-                else:
-                    runner.run(
-                        application_generate_entity=application_generate_entity,
-                        queue_manager=queue_manager
-                    )
+                runner = WorkflowAppRunner(
+                    application_generate_entity=application_generate_entity,
+                    queue_manager=queue_manager,
+                    workflow_thread_pool_id=workflow_thread_pool_id
+                )
+
+                runner.run()
             except GenerateTaskStoppedException:
                 pass
             except InvokeAuthorizationError:
@@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
                 logger.exception("Validation Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except (ValueError, InvokeError) as e:
-                if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
+                if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
                     logger.exception("Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             except Exception as e:
                 logger.exception("Unknown Error when generating")
                 queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
             finally:
-                db.session.remove()
+                db.session.close()
 
     def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
                          workflow: Workflow,

+ 72 - 73
api/core/app/apps/workflow/app_runner.py

@@ -4,46 +4,61 @@ from typing import Optional, cast
 
 from core.app.apps.base_app_queue_manager import AppQueueManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
-from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
+from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
 from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
 from core.app.entities.app_invoke_entities import (
     InvokeFrom,
     WorkflowAppGenerateEntity,
 )
 from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
+from core.workflow.entities.node_entities import UserFrom
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.base_node import UserFrom
-from core.workflow.workflow_engine_manager import WorkflowEngineManager
+from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
 from models.model import App, EndUser
-from models.workflow import Workflow
+from models.workflow import WorkflowType
 
 logger = logging.getLogger(__name__)
 
 
-class WorkflowAppRunner:
+class WorkflowAppRunner(WorkflowBasedAppRunner):
     """
     Workflow Application Runner
     """
 
-    def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
+    def __init__(
+            self,
+            application_generate_entity: WorkflowAppGenerateEntity,
+            queue_manager: AppQueueManager,
+            workflow_thread_pool_id: Optional[str] = None
+    ) -> None:
+        """
+        :param application_generate_entity: application generate entity
+        :param queue_manager: application queue manager
+        :param workflow_thread_pool_id: workflow thread pool id
+        """
+        self.application_generate_entity = application_generate_entity
+        self.queue_manager = queue_manager
+        self.workflow_thread_pool_id = workflow_thread_pool_id
+
+    def run(self) -> None:
         """
         Run application
         :param application_generate_entity: application generate entity
         :param queue_manager: application queue manager
         :return:
         """
-        app_config = application_generate_entity.app_config
+        app_config = self.application_generate_entity.app_config
         app_config = cast(WorkflowAppConfig, app_config)
 
         user_id = None
-        if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
-            end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
+        if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
+            end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
             if end_user:
                 user_id = end_user.session_id
         else:
-            user_id = application_generate_entity.user_id
+            user_id = self.application_generate_entity.user_id
 
         app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
         if not app_record:
@@ -53,80 +68,64 @@ class WorkflowAppRunner:
         if not workflow:
             raise ValueError('Workflow not initialized')
 
-        inputs = application_generate_entity.inputs
-        files = application_generate_entity.files
-
         db.session.close()
 
-        workflow_callbacks: list[WorkflowCallback] = [
-            WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
-        ]
-
+        workflow_callbacks: list[WorkflowCallback] = []
         if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
             workflow_callbacks.append(WorkflowLoggingCallback())
 
-        # Create a variable pool.
-        system_inputs = {
-            SystemVariableKey.FILES: files,
-            SystemVariableKey.USER_ID: user_id,
-        }
-        variable_pool = VariablePool(
-            system_variables=system_inputs,
-            user_inputs=inputs,
-            environment_variables=workflow.environment_variables,
-            conversation_variables=[],
-        )
+        # if only single iteration run is requested
+        if self.application_generate_entity.single_iteration_run:
+            # if only single iteration run is requested
+            graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
+                workflow=workflow,
+                node_id=self.application_generate_entity.single_iteration_run.node_id,
+                user_inputs=self.application_generate_entity.single_iteration_run.inputs
+            )
+        else:
 
-        # RUN WORKFLOW
-        workflow_engine_manager = WorkflowEngineManager()
-        workflow_engine_manager.run_workflow(
-            workflow=workflow,
-            user_id=application_generate_entity.user_id,
-            user_from=UserFrom.ACCOUNT
-            if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
-            else UserFrom.END_USER,
-            invoke_from=application_generate_entity.invoke_from,
-            callbacks=workflow_callbacks,
-            call_depth=application_generate_entity.call_depth,
-            variable_pool=variable_pool,
-        )
+            inputs = self.application_generate_entity.inputs
+            files = self.application_generate_entity.files
 
-    def single_iteration_run(
-        self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
-    ) -> None:
-        """
-        Single iteration run
-        """
-        app_record = db.session.query(App).filter(App.id == app_id).first()
-        if not app_record:
-            raise ValueError('App not found')
+            # Create a variable pool.
+            system_inputs = {
+                SystemVariableKey.FILES: files,
+                SystemVariableKey.USER_ID: user_id,
+            }
 
-        if not app_record.workflow_id:
-            raise ValueError('Workflow not initialized')
-
-        workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
-        if not workflow:
-            raise ValueError('Workflow not initialized')
+            variable_pool = VariablePool(
+                system_variables=system_inputs,
+                user_inputs=inputs,
+                environment_variables=workflow.environment_variables,
+                conversation_variables=[],
+            )
 
-        workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
+            # init graph
+            graph = self._init_graph(graph_config=workflow.graph_dict)
 
-        workflow_engine_manager = WorkflowEngineManager()
-        workflow_engine_manager.single_step_run_iteration_workflow_node(
-            workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
+        # RUN WORKFLOW
+        workflow_entry = WorkflowEntry(
+            tenant_id=workflow.tenant_id,
+            app_id=workflow.app_id,
+            workflow_id=workflow.id,
+            workflow_type=WorkflowType.value_of(workflow.type),
+            graph=graph,
+            graph_config=workflow.graph_dict,
+            user_id=self.application_generate_entity.user_id,
+            user_from=(
+                UserFrom.ACCOUNT
+                if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
+                else UserFrom.END_USER
+            ),
+            invoke_from=self.application_generate_entity.invoke_from,
+            call_depth=self.application_generate_entity.call_depth,
+            variable_pool=variable_pool,
+            thread_pool_id=self.workflow_thread_pool_id
         )
 
-    def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
-        """
-        Get workflow
-        """
-        # fetch workflow by workflow_id
-        workflow = (
-            db.session.query(Workflow)
-            .filter(
-                Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
-            )
-            .first()
+        generator = workflow_entry.run(
+            callbacks=workflow_callbacks
         )
 
-        # return workflow
-        return workflow
+        for event in generator:
+            self._handle_event(workflow_entry, event)

+ 169 - 264
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import time
 from collections.abc import Generator
@@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
     QueueIterationCompletedEvent,
     QueueIterationNextEvent,
     QueueIterationStartEvent,
-    QueueMessageReplaceEvent,
     QueueNodeFailedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
+    QueueParallelBranchRunFailedEvent,
+    QueueParallelBranchRunStartedEvent,
+    QueueParallelBranchRunSucceededEvent,
     QueuePingEvent,
     QueueStopEvent,
     QueueTextChunkEvent,
@@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
     MessageAudioStreamResponse,
     StreamResponse,
     TextChunkStreamResponse,
-    TextReplaceStreamResponse,
     WorkflowAppBlockingResponse,
     WorkflowAppStreamResponse,
     WorkflowFinishStreamResponse,
-    WorkflowStreamGenerateNodes,
+    WorkflowStartStreamResponse,
     WorkflowTaskState,
 )
 from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
 from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
 from core.ops.ops_trace_manager import TraceQueueManager
-from core.workflow.entities.node_entities import NodeType
 from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.end.end_node import EndNode
 from extensions.ext_database import db
 from models.account import Account
 from models.model import EndUser
@@ -52,8 +52,8 @@ from models.workflow import (
     Workflow,
     WorkflowAppLog,
     WorkflowAppLogCreatedFrom,
-    WorkflowNodeExecution,
     WorkflowRun,
+    WorkflowRunStatus,
 )
 
 logger = logging.getLogger(__name__)
@@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
     _task_state: WorkflowTaskState
     _application_generate_entity: WorkflowAppGenerateEntity
     _workflow_system_variables: dict[SystemVariableKey, Any]
-    _iteration_nested_relations: dict[str, list[str]]
 
     def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
                  workflow: Workflow,
@@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             SystemVariableKey.USER_ID: user_id
         }
 
-        self._task_state = WorkflowTaskState(
-            iteration_nested_node_ids=[]
-        )
-        self._stream_generate_nodes = self._get_stream_generate_nodes()
-        self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
+        self._task_state = WorkflowTaskState()
 
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
         """
@@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             if isinstance(stream_response, ErrorStreamResponse):
                 raise stream_response.err
             elif isinstance(stream_response, WorkflowFinishStreamResponse):
-                workflow_run = db.session.query(WorkflowRun).filter(
-                    WorkflowRun.id == self._task_state.workflow_run_id).first()
-
                 response = WorkflowAppBlockingResponse(
                     task_id=self._application_generate_entity.task_id,
-                    workflow_run_id=workflow_run.id,
+                    workflow_run_id=stream_response.data.id,
                     data=WorkflowAppBlockingResponse.Data(
-                        id=workflow_run.id,
-                        workflow_id=workflow_run.workflow_id,
-                        status=workflow_run.status,
-                        outputs=workflow_run.outputs_dict,
-                        error=workflow_run.error,
-                        elapsed_time=workflow_run.elapsed_time,
-                        total_tokens=workflow_run.total_tokens,
-                        total_steps=workflow_run.total_steps,
-                        created_at=int(workflow_run.created_at.timestamp()),
-                        finished_at=int(workflow_run.finished_at.timestamp())
+                        id=stream_response.data.id,
+                        workflow_id=stream_response.data.workflow_id,
+                        status=stream_response.data.status,
+                        outputs=stream_response.data.outputs,
+                        error=stream_response.data.error,
+                        elapsed_time=stream_response.data.elapsed_time,
+                        total_tokens=stream_response.data.total_tokens,
+                        total_steps=stream_response.data.total_steps,
+                        created_at=int(stream_response.data.created_at),
+                        finished_at=int(stream_response.data.finished_at)
                     )
                 )
 
@@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         To stream response.
         :return:
         """
+        workflow_run_id = None
         for stream_response in generator:
+            if isinstance(stream_response, WorkflowStartStreamResponse):
+                workflow_run_id = stream_response.workflow_run_id
+
             yield WorkflowAppStreamResponse(
-                workflow_run_id=self._task_state.workflow_run_id,
+                workflow_run_id=workflow_run_id,
                 stream_response=stream_response
             )
 
@@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
     def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
             Generator[StreamResponse, None, None]:
 
-        publisher = None
+        tts_publisher = None
         task_id = self._application_generate_entity.task_id
         tenant_id = self._application_generate_entity.app_config.tenant_id
         features_dict = self._workflow.features_dict
 
         if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
                 'text_to_speech'].get('autoPlay') == 'enabled':
-            publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
-        for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
+            tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
+
+        for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
             while True:
-                audio_response = self._listenAudioMsg(publisher, task_id=task_id)
+                audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
                 if audio_response:
                     yield audio_response
                 else:
@@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         start_listener_time = time.time()
         while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
             try:
-                if not publisher:
+                if not tts_publisher:
                     break
-                audio_trunk = publisher.checkAndGetAudio()
+                audio_trunk = tts_publisher.checkAndGetAudio()
                 if audio_trunk is None:
                     # release cpu
                     # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
     def _process_stream_response(
         self,
-        publisher: AppGeneratorTTSPublisher,
+        tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
         trace_manager: Optional[TraceQueueManager] = None
     ) -> Generator[StreamResponse, None, None]:
         """
         Process stream response.
         :return:
         """
-        for message in self._queue_manager.listen():
-            if publisher:
-                publisher.publish(message=message)
-            event = message.event
+        graph_runtime_state = None
+        workflow_run = None
 
-            if isinstance(event, QueueErrorEvent):
+        for queue_message in self._queue_manager.listen():
+            event = queue_message.event
+
+            if isinstance(event, QueuePingEvent):
+                yield self._ping_stream_response()
+            elif isinstance(event, QueueErrorEvent):
                 err = self._handle_error(event)
                 yield self._error_to_stream_response(err)
                 break
             elif isinstance(event, QueueWorkflowStartedEvent):
-                workflow_run = self._handle_workflow_start()
+                # override graph runtime state
+                graph_runtime_state = event.graph_runtime_state
+
+                # init workflow run
+                workflow_run = self._handle_workflow_run_start()
                 yield self._workflow_start_to_stream_response(
                     task_id=self._application_generate_entity.task_id,
                     workflow_run=workflow_run
                 )
             elif isinstance(event, QueueNodeStartedEvent):
-                workflow_node_execution = self._handle_node_start(event)
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
 
-                # search stream_generate_routes if node id is answer start at node
-                if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
-                    self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
+                workflow_node_execution = self._handle_node_execution_start(
+                    workflow_run=workflow_run,
+                    event=event
+                )
+
+                response = self._workflow_node_start_to_stream_response(
+                    event=event,
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_node_execution=workflow_node_execution
+                )
 
-                    # generate stream outputs when node started
-                    yield from self._generate_stream_outputs_when_node_started()
+                if response:
+                    yield response
+            elif isinstance(event, QueueNodeSucceededEvent):
+                workflow_node_execution = self._handle_workflow_node_execution_success(event)
 
-                yield self._workflow_node_start_to_stream_response(
+                response = self._workflow_node_finish_to_stream_response(
                     event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution
                 )
-            elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
-                workflow_node_execution = self._handle_node_finished(event)
 
-                yield self._workflow_node_finish_to_stream_response(
+                if response:
+                    yield response
+            elif isinstance(event, QueueNodeFailedEvent):
+                workflow_node_execution = self._handle_workflow_node_execution_failed(event)
+
+                response = self._workflow_node_finish_to_stream_response(
+                    event=event,
                     task_id=self._application_generate_entity.task_id,
                     workflow_node_execution=workflow_node_execution
                 )
 
-                if isinstance(event, QueueNodeFailedEvent):
-                    yield from self._handle_iteration_exception(
-                        task_id=self._application_generate_entity.task_id,
-                        error=f'Child node failed: {event.error}'
-                    )
-            elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
-                if isinstance(event, QueueIterationNextEvent):
-                    # clear ran node execution infos of current iteration
-                    iteration_relations = self._iteration_nested_relations.get(event.node_id)
-                    if iteration_relations:
-                        for node_id in iteration_relations:
-                            self._task_state.ran_node_execution_infos.pop(node_id, None)
-
-                yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
-                self._handle_iteration_operation(event)
-            elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
-                workflow_run = self._handle_workflow_finished(
-                    event, trace_manager=trace_manager
+                if response:
+                    yield response
+            elif isinstance(event, QueueParallelBranchRunStartedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_parallel_branch_start_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_parallel_branch_finished_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationStartEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_iteration_start_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationNextEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_iteration_next_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueIterationCompletedEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                yield self._workflow_iteration_completed_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run,
+                    event=event
+                )
+            elif isinstance(event, QueueWorkflowSucceededEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                if not graph_runtime_state:
+                    raise Exception('Graph runtime state not initialized.')
+
+                workflow_run = self._handle_workflow_run_success(
+                    workflow_run=workflow_run,
+                    start_at=graph_runtime_state.start_at,
+                    total_tokens=graph_runtime_state.total_tokens,
+                    total_steps=graph_runtime_state.node_run_steps,
+                    outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
+                    conversation_id=None,
+                    trace_manager=trace_manager,
+                )
+
+                # save workflow app log
+                self._save_workflow_app_log(workflow_run)
+
+                yield self._workflow_finish_to_stream_response(
+                    task_id=self._application_generate_entity.task_id,
+                    workflow_run=workflow_run
+                )
+            elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
+                if not workflow_run:
+                    raise Exception('Workflow run not initialized.')
+
+                if not graph_runtime_state:
+                    raise Exception('Graph runtime state not initialized.')
+
+                workflow_run = self._handle_workflow_run_failed(
+                    workflow_run=workflow_run,
+                    start_at=graph_runtime_state.start_at,
+                    total_tokens=graph_runtime_state.total_tokens,
+                    total_steps=graph_runtime_state.node_run_steps,
+                    status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
+                    error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
+                    conversation_id=None,
+                    trace_manager=trace_manager,
                 )
 
                 # save workflow app log
@@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 if delta_text is None:
                     continue
 
-                if not self._is_stream_out_support(
-                        event=event
-                ):
-                    continue
+                # only publish tts message at text chunk streaming
+                if tts_publisher:
+                    tts_publisher.publish(message=queue_message)
 
                 self._task_state.answer += delta_text
                 yield self._text_chunk_to_stream_response(delta_text)
-            elif isinstance(event, QueueMessageReplaceEvent):
-                yield self._text_replace_to_stream_response(event.text)
-            elif isinstance(event, QueuePingEvent):
-                yield self._ping_stream_response()
             else:
                 continue
 
-        if publisher:
-            publisher.publish(None)
+        if tts_publisher:
+            tts_publisher.publish(None)
 
 
     def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
@@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
             # not save log for debugging
             return
 
-        workflow_app_log = WorkflowAppLog(
-            tenant_id=workflow_run.tenant_id,
-            app_id=workflow_run.app_id,
-            workflow_id=workflow_run.workflow_id,
-            workflow_run_id=workflow_run.id,
-            created_from=created_from.value,
-            created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
-            created_by=self._user.id,
-        )
+        workflow_app_log = WorkflowAppLog()
+        workflow_app_log.tenant_id = workflow_run.tenant_id
+        workflow_app_log.app_id = workflow_run.app_id
+        workflow_app_log.workflow_id = workflow_run.workflow_id
+        workflow_app_log.workflow_run_id = workflow_run.id
+        workflow_app_log.created_from = created_from.value
+        workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
+        workflow_app_log.created_by = self._user.id
+
         db.session.add(workflow_app_log)
         db.session.commit()
         db.session.close()
@@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
         )
 
         return response
-
-    def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
-        """
-        Text replace to stream response.
-        :param text: text
-        :return:
-        """
-        return TextReplaceStreamResponse(
-            task_id=self._application_generate_entity.task_id,
-            text=TextReplaceStreamResponse.Data(text=text)
-        )
-
-    def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
-        """
-        Get stream generate nodes.
-        :return:
-        """
-        # find all answer nodes
-        graph = self._workflow.graph_dict
-        end_node_configs = [
-            node for node in graph['nodes']
-            if node.get('data', {}).get('type') == NodeType.END.value
-        ]
-
-        # parse stream output node value selectors of end nodes
-        stream_generate_routes = {}
-        for node_config in end_node_configs:
-            # get generate route for stream output
-            end_node_id = node_config['id']
-            generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
-            start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
-            if not start_node_ids:
-                continue
-
-            for start_node_id in start_node_ids:
-                stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
-                    end_node_id=end_node_id,
-                    stream_node_ids=generate_nodes
-                )
-
-        return stream_generate_routes
-
-    def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-            -> list[str]:
-        """
-        Get end start at node id.
-        :param graph: graph
-        :param target_node_id: target node ID
-        :return:
-        """
-        nodes = graph.get('nodes')
-        edges = graph.get('edges')
-
-        # fetch all ingoing edges from source node
-        ingoing_edges = []
-        for edge in edges:
-            if edge.get('target') == target_node_id:
-                ingoing_edges.append(edge)
-
-        if not ingoing_edges:
-            return []
-
-        start_node_ids = []
-        for ingoing_edge in ingoing_edges:
-            source_node_id = ingoing_edge.get('source')
-            source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
-            if not source_node:
-                continue
-
-            node_type = source_node.get('data', {}).get('type')
-            node_iteration_id = source_node.get('data', {}).get('iteration_id')
-            iteration_start_node_id = None
-            if node_iteration_id:
-                iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
-                iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
-
-            if node_type in [
-                NodeType.IF_ELSE.value,
-                NodeType.QUESTION_CLASSIFIER.value
-            ]:
-                start_node_id = target_node_id
-                start_node_ids.append(start_node_id)
-            elif node_type == NodeType.START.value or \
-                node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
-                start_node_id = source_node_id
-                start_node_ids.append(start_node_id)
-            else:
-                sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
-                if sub_start_node_ids:
-                    start_node_ids.extend(sub_start_node_ids)
-
-        return start_node_ids
-
-    def _generate_stream_outputs_when_node_started(self) -> Generator:
-        """
-        Generate stream outputs.
-        :return:
-        """
-        if self._task_state.current_stream_generate_state:
-            stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
-
-            for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
-                if node_id not in stream_node_ids:
-                    continue
-
-                node_execution_info = self._task_state.ran_node_execution_infos[node_id]
-
-                # get chunk node execution
-                route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
-                    WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
-
-                if not route_chunk_node_execution:
-                    continue
-
-                outputs = route_chunk_node_execution.outputs_dict
-
-                if not outputs:
-                    continue
-
-                # get value from outputs
-                text = outputs.get('text')
-
-                if text:
-                    self._task_state.answer += text
-                    yield self._text_chunk_to_stream_response(text)
-
-            db.session.close()
-
-    def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
-        """
-        Is stream out support
-        :param event: queue text chunk event
-        :return:
-        """
-        if not event.metadata:
-            return False
-
-        if 'node_id' not in event.metadata:
-            return False
-
-        node_id = event.metadata.get('node_id')
-        node_type = event.metadata.get('node_type')
-        stream_output_value_selector = event.metadata.get('value_selector')
-        if not stream_output_value_selector:
-            return False
-
-        if not self._task_state.current_stream_generate_state:
-            return False
-
-        if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
-            return False
-
-        if node_type != NodeType.LLM:
-            # only LLM support chunk stream output
-            return False
-
-        return True
-
-    def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
-        """
-        Get iteration nested relations.
-        :param graph: graph
-        :return:
-        """
-        nodes = graph.get('nodes')
-
-        iteration_ids = [node.get('id') for node in nodes
-                         if node.get('data', {}).get('type') in [
-                             NodeType.ITERATION.value,
-                             NodeType.LOOP.value,
-                        ]]
-
-        return {
-            iteration_id: [
-                node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
-            ] for iteration_id in iteration_ids
-        }

+ 0 - 200
api/core/app/apps/workflow/workflow_event_trigger_callback.py

@@ -1,200 +0,0 @@
-from typing import Any, Optional
-
-from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
-from core.app.entities.queue_entities import (
-    AppQueueEvent,
-    QueueIterationCompletedEvent,
-    QueueIterationNextEvent,
-    QueueIterationStartEvent,
-    QueueNodeFailedEvent,
-    QueueNodeStartedEvent,
-    QueueNodeSucceededEvent,
-    QueueTextChunkEvent,
-    QueueWorkflowFailedEvent,
-    QueueWorkflowStartedEvent,
-    QueueWorkflowSucceededEvent,
-)
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
-from models.workflow import Workflow
-
-
-class WorkflowEventTriggerCallback(WorkflowCallback):
-
-    def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
-        self._queue_manager = queue_manager
-
-    def on_workflow_run_started(self) -> None:
-        """
-        Workflow run started
-        """
-        self._queue_manager.publish(
-            QueueWorkflowStartedEvent(),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_run_succeeded(self) -> None:
-        """
-        Workflow run succeeded
-        """
-        self._queue_manager.publish(
-            QueueWorkflowSucceededEvent(),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_run_failed(self, error: str) -> None:
-        """
-        Workflow run failed
-        """
-        self._queue_manager.publish(
-            QueueWorkflowFailedEvent(
-                error=error
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_started(self, node_id: str,
-                                         node_type: NodeType,
-                                         node_data: BaseNodeData,
-                                         node_run_index: int = 1,
-                                         predecessor_node_id: Optional[str] = None) -> None:
-        """
-        Workflow node execute started
-        """
-        self._queue_manager.publish(
-            QueueNodeStartedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                node_run_index=node_run_index,
-                predecessor_node_id=predecessor_node_id
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_succeeded(self, node_id: str,
-                                           node_type: NodeType,
-                                           node_data: BaseNodeData,
-                                           inputs: Optional[dict] = None,
-                                           process_data: Optional[dict] = None,
-                                           outputs: Optional[dict] = None,
-                                           execution_metadata: Optional[dict] = None) -> None:
-        """
-        Workflow node execute succeeded
-        """
-        self._queue_manager.publish(
-            QueueNodeSucceededEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                inputs=inputs,
-                process_data=process_data,
-                outputs=outputs,
-                execution_metadata=execution_metadata
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_node_execute_failed(self, node_id: str,
-                                        node_type: NodeType,
-                                        node_data: BaseNodeData,
-                                        error: str,
-                                        inputs: Optional[dict] = None,
-                                        outputs: Optional[dict] = None,
-                                        process_data: Optional[dict] = None) -> None:
-        """
-        Workflow node execute failed
-        """
-        self._queue_manager.publish(
-            QueueNodeFailedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_data=node_data,
-                inputs=inputs,
-                outputs=outputs,
-                process_data=process_data,
-                error=error
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
-        """
-        Publish text chunk
-        """
-        self._queue_manager.publish(
-            QueueTextChunkEvent(
-                text=text,
-                metadata={
-                    "node_id": node_id,
-                    **metadata
-                }
-            ), PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_started(self, 
-                                      node_id: str,
-                                      node_type: NodeType,
-                                      node_run_index: int = 1,
-                                      node_data: Optional[BaseNodeData] = None,
-                                      inputs: dict = None,
-                                      predecessor_node_id: Optional[str] = None,
-                                      metadata: Optional[dict] = None) -> None:
-        """
-        Publish iteration started
-        """
-        self._queue_manager.publish(
-            QueueIterationStartEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_run_index=node_run_index,
-                node_data=node_data,
-                inputs=inputs,
-                predecessor_node_id=predecessor_node_id,
-                metadata=metadata
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_next(self, node_id: str, 
-                                   node_type: NodeType,
-                                   index: int, 
-                                   node_run_index: int,
-                                   output: Optional[Any]) -> None:
-        """
-        Publish iteration next
-        """
-        self._queue_manager.publish(
-            QueueIterationNextEvent(
-                node_id=node_id,
-                node_type=node_type,
-                index=index,
-                node_run_index=node_run_index,
-                output=output
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-
-    def on_workflow_iteration_completed(self, node_id: str, 
-                                        node_type: NodeType,
-                                        node_run_index: int,
-                                        outputs: dict) -> None:
-        """
-        Publish iteration completed
-        """
-        self._queue_manager.publish(
-            QueueIterationCompletedEvent(
-                node_id=node_id,
-                node_type=node_type,
-                node_run_index=node_run_index,
-                outputs=outputs
-            ),
-            PublishFrom.APPLICATION_MANAGER
-        )
-        
-    def on_event(self, event: AppQueueEvent) -> None:
-        """
-        Publish event
-        """
-        pass

+ 379 - 0
api/core/app/apps/workflow_app_runner.py

@@ -0,0 +1,379 @@
+from collections.abc import Mapping
+from typing import Any, Optional, cast
+
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.base_app_runner import AppRunner
+from core.app.entities.queue_entities import (
+    AppQueueEvent,
+    QueueIterationCompletedEvent,
+    QueueIterationNextEvent,
+    QueueIterationStartEvent,
+    QueueNodeFailedEvent,
+    QueueNodeStartedEvent,
+    QueueNodeSucceededEvent,
+    QueueParallelBranchRunFailedEvent,
+    QueueParallelBranchRunStartedEvent,
+    QueueParallelBranchRunSucceededEvent,
+    QueueRetrieverResourcesEvent,
+    QueueTextChunkEvent,
+    QueueWorkflowFailedEvent,
+    QueueWorkflowStartedEvent,
+    QueueWorkflowSucceededEvent,
+)
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.event import (
+    GraphEngineEvent,
+    GraphRunFailedEvent,
+    GraphRunStartedEvent,
+    GraphRunSucceededEvent,
+    IterationRunFailedEvent,
+    IterationRunNextEvent,
+    IterationRunStartedEvent,
+    IterationRunSucceededEvent,
+    NodeRunFailedEvent,
+    NodeRunRetrieverResourceEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+    ParallelBranchRunFailedEvent,
+    ParallelBranchRunStartedEvent,
+    ParallelBranchRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.iteration.entities import IterationNodeData
+from core.workflow.nodes.node_mapping import node_classes
+from core.workflow.workflow_entry import WorkflowEntry
+from extensions.ext_database import db
+from models.model import App
+from models.workflow import Workflow
+
+
+class WorkflowBasedAppRunner(AppRunner):
+    def __init__(self, queue_manager: AppQueueManager):
+        self.queue_manager = queue_manager
+
+    def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
+        """
+        Init graph
+        """
+        if 'nodes' not in graph_config or 'edges' not in graph_config:
+            raise ValueError('nodes or edges not found in workflow graph')
+
+        if not isinstance(graph_config.get('nodes'), list):
+            raise ValueError('nodes in workflow graph must be a list')
+
+        if not isinstance(graph_config.get('edges'), list):
+            raise ValueError('edges in workflow graph must be a list')
+        # init graph
+        graph = Graph.init(
+            graph_config=graph_config
+        )
+
+        if not graph:
+            raise ValueError('graph not found in workflow')
+        
+        return graph
+
+    def _get_graph_and_variable_pool_of_single_iteration(
+            self, 
+            workflow: Workflow,
+            node_id: str,
+            user_inputs: dict,
+        ) -> tuple[Graph, VariablePool]:
+        """
+        Get variable pool of single iteration
+        """
+        # fetch workflow graph
+        graph_config = workflow.graph_dict
+        if not graph_config:
+            raise ValueError('workflow graph not found')
+        
+        graph_config = cast(dict[str, Any], graph_config)
+
+        if 'nodes' not in graph_config or 'edges' not in graph_config:
+            raise ValueError('nodes or edges not found in workflow graph')
+
+        if not isinstance(graph_config.get('nodes'), list):
+            raise ValueError('nodes in workflow graph must be a list')
+
+        if not isinstance(graph_config.get('edges'), list):
+            raise ValueError('edges in workflow graph must be a list')
+
+        # filter nodes only in iteration
+        node_configs = [
+            node for node in graph_config.get('nodes', []) 
+            if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
+        ]
+
+        graph_config['nodes'] = node_configs
+
+        node_ids = [node.get('id') for node in node_configs]
+
+        # filter edges only in iteration
+        edge_configs = [
+            edge for edge in graph_config.get('edges', []) 
+            if (edge.get('source') is None or edge.get('source') in node_ids) 
+            and (edge.get('target') is None or edge.get('target') in node_ids) 
+        ]
+
+        graph_config['edges'] = edge_configs
+
+        # init graph
+        graph = Graph.init(
+            graph_config=graph_config,
+            root_node_id=node_id
+        )
+
+        if not graph:
+            raise ValueError('graph not found in workflow')
+        
+        # fetch node config from node id
+        iteration_node_config = None
+        for node in node_configs:
+            if node.get('id') == node_id:
+                iteration_node_config = node
+                break
+
+        if not iteration_node_config:
+            raise ValueError('iteration node id not found in workflow graph')
+        
+        # Get node class
+        node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
+        node_cls = node_classes.get(node_type)
+        node_cls = cast(type[BaseNode], node_cls)
+
+        # init variable pool
+        variable_pool = VariablePool(
+            system_variables={},
+            user_inputs={},
+            environment_variables=workflow.environment_variables,
+        )
+
+        try:
+            variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+                graph_config=workflow.graph_dict, 
+                config=iteration_node_config
+            )
+        except NotImplementedError:
+            variable_mapping = {}
+
+        WorkflowEntry.mapping_user_inputs_to_variable_pool(
+            variable_mapping=variable_mapping,
+            user_inputs=user_inputs,
+            variable_pool=variable_pool,
+            tenant_id=workflow.tenant_id,
+            node_type=node_type,
+            node_data=IterationNodeData(**iteration_node_config.get('data', {}))
+        )
+
+        return graph, variable_pool
+
+    def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
+        """
+        Handle event
+        :param workflow_entry: workflow entry
+        :param event: event
+        """
+        if isinstance(event, GraphRunStartedEvent):
+            self._publish_event(
+                QueueWorkflowStartedEvent(
+                    graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
+                )
+            )
+        elif isinstance(event, GraphRunSucceededEvent):
+            self._publish_event(
+                QueueWorkflowSucceededEvent(outputs=event.outputs)
+            )
+        elif isinstance(event, GraphRunFailedEvent):
+            self._publish_event(
+                QueueWorkflowFailedEvent(error=event.error)
+            )
+        elif isinstance(event, NodeRunStartedEvent):
+            self._publish_event(
+                QueueNodeStartedEvent(
+                    node_execution_id=event.id,
+                    node_id=event.node_id,
+                    node_type=event.node_type,
+                    node_data=event.node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.route_node_state.start_at,
+                    node_run_index=event.route_node_state.index,
+                    predecessor_node_id=event.predecessor_node_id,
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, NodeRunSucceededEvent):
+            self._publish_event(
+                QueueNodeSucceededEvent(
+                    node_execution_id=event.id,
+                    node_id=event.node_id,
+                    node_type=event.node_type,
+                    node_data=event.node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.route_node_state.start_at,
+                    inputs=event.route_node_state.node_run_result.inputs
+                    if event.route_node_state.node_run_result else {},
+                    process_data=event.route_node_state.node_run_result.process_data
+                    if event.route_node_state.node_run_result else {},
+                    outputs=event.route_node_state.node_run_result.outputs
+                    if event.route_node_state.node_run_result else {},
+                    execution_metadata=event.route_node_state.node_run_result.metadata
+                    if event.route_node_state.node_run_result else {},
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, NodeRunFailedEvent):
+            self._publish_event(
+                QueueNodeFailedEvent(
+                    node_execution_id=event.id,
+                    node_id=event.node_id,
+                    node_type=event.node_type,
+                    node_data=event.node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.route_node_state.start_at,
+                    inputs=event.route_node_state.node_run_result.inputs
+                    if event.route_node_state.node_run_result else {},
+                    process_data=event.route_node_state.node_run_result.process_data
+                    if event.route_node_state.node_run_result else {},
+                    outputs=event.route_node_state.node_run_result.outputs
+                    if event.route_node_state.node_run_result else {},
+                    error=event.route_node_state.node_run_result.error
+                    if event.route_node_state.node_run_result
+                       and event.route_node_state.node_run_result.error
+                    else "Unknown error",
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, NodeRunStreamChunkEvent):
+            self._publish_event(
+                QueueTextChunkEvent(
+                    text=event.chunk_content,
+                    from_variable_selector=event.from_variable_selector,
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, NodeRunRetrieverResourceEvent):
+            self._publish_event(
+                QueueRetrieverResourcesEvent(
+                    retriever_resources=event.retriever_resources,
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, ParallelBranchRunStartedEvent):
+            self._publish_event(
+                QueueParallelBranchRunStartedEvent(
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, ParallelBranchRunSucceededEvent):
+            self._publish_event(
+                QueueParallelBranchRunSucceededEvent(
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    in_iteration_id=event.in_iteration_id
+                )
+            )
+        elif isinstance(event, ParallelBranchRunFailedEvent):
+            self._publish_event(
+                QueueParallelBranchRunFailedEvent(
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    in_iteration_id=event.in_iteration_id,
+                    error=event.error
+                )
+            )
+        elif isinstance(event, IterationRunStartedEvent):
+            self._publish_event(
+                QueueIterationStartEvent(
+                    node_execution_id=event.iteration_id,
+                    node_id=event.iteration_node_id,
+                    node_type=event.iteration_node_type,
+                    node_data=event.iteration_node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.start_at,
+                    node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
+                    inputs=event.inputs,
+                    predecessor_node_id=event.predecessor_node_id,
+                    metadata=event.metadata
+                )
+            )
+        elif isinstance(event, IterationRunNextEvent):
+            self._publish_event(
+                QueueIterationNextEvent(
+                    node_execution_id=event.iteration_id,
+                    node_id=event.iteration_node_id,
+                    node_type=event.iteration_node_type,
+                    node_data=event.iteration_node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    index=event.index,
+                    node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
+                    output=event.pre_iteration_output,
+                )
+            )
+        elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
+            self._publish_event(
+                QueueIterationCompletedEvent(
+                    node_execution_id=event.iteration_id,
+                    node_id=event.iteration_node_id,
+                    node_type=event.iteration_node_type,
+                    node_data=event.iteration_node_data,
+                    parallel_id=event.parallel_id,
+                    parallel_start_node_id=event.parallel_start_node_id,
+                    parent_parallel_id=event.parent_parallel_id,
+                    parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                    start_at=event.start_at,
+                    node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
+                    inputs=event.inputs,
+                    outputs=event.outputs,
+                    metadata=event.metadata,
+                    steps=event.steps,
+                    error=event.error if isinstance(event, IterationRunFailedEvent) else None
+                )
+            )
+
+    def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
+        """
+        Get workflow
+        """
+        # fetch workflow by workflow_id
+        workflow = (
+            db.session.query(Workflow)
+            .filter(
+                Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
+            )
+            .first()
+        )
+
+        # return workflow
+        return workflow
+    
+    def _publish_event(self, event: AppQueueEvent) -> None:
+        self.queue_manager.publish(
+            event,
+            PublishFrom.APPLICATION_MANAGER
+        )

+ 187 - 97
api/core/app/apps/workflow_logging_callback.py

@@ -1,10 +1,24 @@
 from typing import Optional
 
-from core.app.entities.queue_entities import AppQueueEvent
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
+from core.workflow.graph_engine.entities.event import (
+    GraphEngineEvent,
+    GraphRunFailedEvent,
+    GraphRunStartedEvent,
+    GraphRunSucceededEvent,
+    IterationRunFailedEvent,
+    IterationRunNextEvent,
+    IterationRunStartedEvent,
+    IterationRunSucceededEvent,
+    NodeRunFailedEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+    ParallelBranchRunFailedEvent,
+    ParallelBranchRunStartedEvent,
+    ParallelBranchRunSucceededEvent,
+)
 
 _TEXT_COLOR_MAPPING = {
     "blue": "36;1",
@@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
     def __init__(self) -> None:
         self.current_node_id = None
 
-    def on_workflow_run_started(self) -> None:
-        """
-        Workflow run started
-        """
-        self.print_text("\n[on_workflow_run_started]", color='pink')
-
-    def on_workflow_run_succeeded(self) -> None:
+    def on_event(
+            self,
+            event: GraphEngineEvent
+    ) -> None:
+        if isinstance(event, GraphRunStartedEvent):
+            self.print_text("\n[GraphRunStartedEvent]", color='pink')
+        elif isinstance(event, GraphRunSucceededEvent):
+            self.print_text("\n[GraphRunSucceededEvent]", color='green')
+        elif isinstance(event, GraphRunFailedEvent):
+            self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
+        elif isinstance(event, NodeRunStartedEvent):
+            self.on_workflow_node_execute_started(
+                event=event
+            )
+        elif isinstance(event, NodeRunSucceededEvent):
+            self.on_workflow_node_execute_succeeded(
+                event=event
+            )
+        elif isinstance(event, NodeRunFailedEvent):
+            self.on_workflow_node_execute_failed(
+                event=event
+            )
+        elif isinstance(event, NodeRunStreamChunkEvent):
+            self.on_node_text_chunk(
+                event=event
+            )
+        elif isinstance(event, ParallelBranchRunStartedEvent):
+            self.on_workflow_parallel_started(
+                event=event
+            )
+        elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
+            self.on_workflow_parallel_completed(
+                event=event
+            )
+        elif isinstance(event, IterationRunStartedEvent):
+            self.on_workflow_iteration_started(
+                event=event
+            )
+        elif isinstance(event, IterationRunNextEvent):
+            self.on_workflow_iteration_next(
+                event=event
+            )
+        elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
+            self.on_workflow_iteration_completed(
+                event=event
+            )
+        else:
+            self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
+
+    def on_workflow_node_execute_started(
+            self,
+            event: NodeRunStartedEvent
+    ) -> None:
         """
-        Workflow run succeeded
+        Workflow node execute started
         """
-        self.print_text("\n[on_workflow_run_succeeded]", color='green')
+        self.print_text("\n[NodeRunStartedEvent]", color='yellow')
+        self.print_text(f"Node ID: {event.node_id}", color='yellow')
+        self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
+        self.print_text(f"Type: {event.node_type.value}", color='yellow')
 
-    def on_workflow_run_failed(self, error: str) -> None:
+    def on_workflow_node_execute_succeeded(
+            self,
+            event: NodeRunSucceededEvent
+    ) -> None:
         """
-        Workflow run failed
+        Workflow node execute succeeded
         """
-        self.print_text("\n[on_workflow_run_failed]", color='red')
-
-    def on_workflow_node_execute_started(self, node_id: str,
-                                         node_type: NodeType,
-                                         node_data: BaseNodeData,
-                                         node_run_index: int = 1,
-                                         predecessor_node_id: Optional[str] = None) -> None:
+        route_node_state = event.route_node_state
+
+        self.print_text("\n[NodeRunSucceededEvent]", color='green')
+        self.print_text(f"Node ID: {event.node_id}", color='green')
+        self.print_text(f"Node Title: {event.node_data.title}", color='green')
+        self.print_text(f"Type: {event.node_type.value}", color='green')
+
+        if route_node_state.node_run_result:
+            node_run_result = route_node_state.node_run_result
+            self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
+                            color='green')
+            self.print_text(
+                f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
+                color='green')
+            self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
+                            color='green')
+            self.print_text(
+                f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
+                color='green')
+
+    def on_workflow_node_execute_failed(
+            self,
+            event: NodeRunFailedEvent
+    ) -> None:
         """
-        Workflow node execute started
+        Workflow node execute failed
         """
-        self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
-        self.print_text(f"Node ID: {node_id}", color='yellow')
-        self.print_text(f"Type: {node_type.value}", color='yellow')
-        self.print_text(f"Index: {node_run_index}", color='yellow')
-        if predecessor_node_id:
-            self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
-
-    def on_workflow_node_execute_succeeded(self, node_id: str,
-                                           node_type: NodeType,
-                                           node_data: BaseNodeData,
-                                           inputs: Optional[dict] = None,
-                                           process_data: Optional[dict] = None,
-                                           outputs: Optional[dict] = None,
-                                           execution_metadata: Optional[dict] = None) -> None:
+        route_node_state = event.route_node_state
+
+        self.print_text("\n[NodeRunFailedEvent]", color='red')
+        self.print_text(f"Node ID: {event.node_id}", color='red')
+        self.print_text(f"Node Title: {event.node_data.title}", color='red')
+        self.print_text(f"Type: {event.node_type.value}", color='red')
+
+        if route_node_state.node_run_result:
+            node_run_result = route_node_state.node_run_result
+            self.print_text(f"Error: {node_run_result.error}", color='red')
+            self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
+                            color='red')
+            self.print_text(
+                f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
+                color='red')
+            self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
+                            color='red')
+
+    def on_node_text_chunk(
+            self,
+            event: NodeRunStreamChunkEvent
+    ) -> None:
         """
-        Workflow node execute succeeded
+        Publish text chunk
         """
-        self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
-        self.print_text(f"Node ID: {node_id}", color='green')
-        self.print_text(f"Type: {node_type.value}", color='green')
-        self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
-        self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
-        self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
-        self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
-                        color='green')
-
-    def on_workflow_node_execute_failed(self, node_id: str,
-                                        node_type: NodeType,
-                                        node_data: BaseNodeData,
-                                        error: str,
-                                        inputs: Optional[dict] = None,
-                                        outputs: Optional[dict] = None,
-                                        process_data: Optional[dict] = None) -> None:
+        route_node_state = event.route_node_state
+        if not self.current_node_id or self.current_node_id != route_node_state.node_id:
+            self.current_node_id = route_node_state.node_id
+            self.print_text('\n[NodeRunStreamChunkEvent]')
+            self.print_text(f"Node ID: {route_node_state.node_id}")
+
+            node_run_result = route_node_state.node_run_result
+            if node_run_result:
+                self.print_text(
+                    f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
+
+        self.print_text(event.chunk_content, color="pink", end="")
+
+    def on_workflow_parallel_started(
+            self,
+            event: ParallelBranchRunStartedEvent
+    ) -> None:
         """
-        Workflow node execute failed
+        Publish parallel started
         """
-        self.print_text("\n[on_workflow_node_execute_failed]", color='red')
-        self.print_text(f"Node ID: {node_id}", color='red')
-        self.print_text(f"Type: {node_type.value}", color='red')
-        self.print_text(f"Error: {error}", color='red')
-        self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
-        self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
-        self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
+        self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
+        self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
+        self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
+        if event.in_iteration_id:
+            self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
 
-    def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
+    def on_workflow_parallel_completed(
+            self,
+            event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
+    ) -> None:
         """
-        Publish text chunk
+        Publish parallel completed
         """
-        if not self.current_node_id or self.current_node_id != node_id:
-            self.current_node_id = node_id
-            self.print_text('\n[on_node_text_chunk]')
-            self.print_text(f"Node ID: {node_id}")
-            self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
+        if isinstance(event, ParallelBranchRunSucceededEvent):
+            color = 'blue'
+        elif isinstance(event, ParallelBranchRunFailedEvent):
+            color = 'red'
 
-        self.print_text(text, color="pink", end="")
+        self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
+        self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
+        self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
+        if event.in_iteration_id:
+            self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
 
-    def on_workflow_iteration_started(self, 
-                                      node_id: str,
-                                      node_type: NodeType,
-                                      node_run_index: int = 1,
-                                      node_data: Optional[BaseNodeData] = None,
-                                      inputs: dict = None,
-                                      predecessor_node_id: Optional[str] = None,
-                                      metadata: Optional[dict] = None) -> None:
+        if isinstance(event, ParallelBranchRunFailedEvent):
+            self.print_text(f"Error: {event.error}", color=color)
+
+    def on_workflow_iteration_started(
+            self,
+            event: IterationRunStartedEvent
+    ) -> None:
         """
         Publish iteration started
         """
-        self.print_text("\n[on_workflow_iteration_started]", color='blue')
-        self.print_text(f"Node ID: {node_id}", color='blue')
+        self.print_text("\n[IterationRunStartedEvent]", color='blue')
+        self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
 
-    def on_workflow_iteration_next(self, node_id: str, 
-                                   node_type: NodeType,
-                                   index: int, 
-                                   node_run_index: int,
-                                   output: Optional[dict]) -> None:
+    def on_workflow_iteration_next(
+            self,
+            event: IterationRunNextEvent
+    ) -> None:
         """
         Publish iteration next
         """
-        self.print_text("\n[on_workflow_iteration_next]", color='blue')
+        self.print_text("\n[IterationRunNextEvent]", color='blue')
+        self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
+        self.print_text(f"Iteration Index: {event.index}", color='blue')
 
-    def on_workflow_iteration_completed(self, node_id: str, 
-                                        node_type: NodeType,
-                                        node_run_index: int,
-                                        outputs: dict) -> None:
+    def on_workflow_iteration_completed(
+            self,
+            event: IterationRunSucceededEvent | IterationRunFailedEvent
+    ) -> None:
         """
         Publish iteration completed
         """
-        self.print_text("\n[on_workflow_iteration_completed]", color='blue')
-
-    def on_event(self, event: AppQueueEvent) -> None:
-        """
-        Publish event
-        """
-        self.print_text("\n[on_workflow_event]", color='blue')
-        self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
+        self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
+        self.print_text(f"Node ID: {event.iteration_id}", color='blue')
 
     def print_text(
             self, text: str, color: Optional[str] = None, end: str = "\n"

+ 165 - 17
api/core/app/entities/queue_entities.py

@@ -1,3 +1,4 @@
+from datetime import datetime
 from enum import Enum
 from typing import Any, Optional
 
@@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator
 
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
 from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 
 
 class QueueEvent(str, Enum):
@@ -31,6 +33,9 @@ class QueueEvent(str, Enum):
     ANNOTATION_REPLY = "annotation_reply"
     AGENT_THOUGHT = "agent_thought"
     MESSAGE_FILE = "message_file"
+    PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
+    PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
+    PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
     ERROR = "error"
     PING = "ping"
     STOP = "stop"
@@ -38,7 +43,7 @@ class QueueEvent(str, Enum):
 
 class AppQueueEvent(BaseModel):
     """
-    QueueEvent entity
+    QueueEvent abstract entity
     """
     event: QueueEvent
 
@@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel):
 class QueueLLMChunkEvent(AppQueueEvent):
     """
     QueueLLMChunkEvent entity
+    Only for basic mode apps
     """
     event: QueueEvent = QueueEvent.LLM_CHUNK
     chunk: LLMResultChunk
@@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent):
     QueueIterationStartEvent entity
     """
     event: QueueEvent = QueueEvent.ITERATION_START
+    node_execution_id: str
     node_id: str
     node_type: NodeType
     node_data: BaseNodeData
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    start_at: datetime
 
     node_run_index: int
-    inputs: dict = None
+    inputs: Optional[dict[str, Any]] = None
     predecessor_node_id: Optional[str] = None
-    metadata: Optional[dict] = None
+    metadata: Optional[dict[str, Any]] = None
 
 class QueueIterationNextEvent(AppQueueEvent):
     """
@@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent):
     event: QueueEvent = QueueEvent.ITERATION_NEXT
 
     index: int
+    node_execution_id: str
     node_id: str
     node_type: NodeType
+    node_data: BaseNodeData
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
 
     node_run_index: int
     output: Optional[Any] = None # output for the current iteration
@@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent):
     """
     QueueIterationCompletedEvent entity
     """
-    event:QueueEvent = QueueEvent.ITERATION_COMPLETED
+    event: QueueEvent = QueueEvent.ITERATION_COMPLETED
 
+    node_execution_id: str
     node_id: str
     node_type: NodeType
+    node_data: BaseNodeData
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    start_at: datetime
     
     node_run_index: int
-    outputs: dict
+    inputs: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    metadata: Optional[dict[str, Any]] = None
+    steps: int = 0
+
+    error: Optional[str] = None
+
 
 class QueueTextChunkEvent(AppQueueEvent):
     """
@@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent):
     """
     event: QueueEvent = QueueEvent.TEXT_CHUNK
     text: str
-    metadata: Optional[dict] = None
+    from_variable_selector: Optional[list[str]] = None
+    """from variable selector"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
 
 
 class QueueAgentMessageEvent(AppQueueEvent):
@@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
     """
     event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
     retriever_resources: list[dict]
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
 
 
 class QueueAnnotationReplyEvent(AppQueueEvent):
@@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
     QueueWorkflowStartedEvent entity
     """
     event: QueueEvent = QueueEvent.WORKFLOW_STARTED
+    graph_runtime_state: GraphRuntimeState
 
 
 class QueueWorkflowSucceededEvent(AppQueueEvent):
@@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
     QueueWorkflowSucceededEvent entity
     """
     event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
+    outputs: Optional[dict[str, Any]] = None
 
 
 class QueueWorkflowFailedEvent(AppQueueEvent):
@@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
     """
     event: QueueEvent = QueueEvent.NODE_STARTED
 
+    node_execution_id: str
     node_id: str
     node_type: NodeType
     node_data: BaseNodeData
     node_run_index: int = 1
     predecessor_node_id: Optional[str] = None
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    start_at: datetime
 
 
 class QueueNodeSucceededEvent(AppQueueEvent):
@@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
     """
     event: QueueEvent = QueueEvent.NODE_SUCCEEDED
 
+    node_execution_id: str
     node_id: str
     node_type: NodeType
     node_data: BaseNodeData
-
-    inputs: Optional[dict] = None
-    process_data: Optional[dict] = None
-    outputs: Optional[dict] = None
-    execution_metadata: Optional[dict] = None
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    start_at: datetime
+
+    inputs: Optional[dict[str, Any]] = None
+    process_data: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
 
     error: Optional[str] = None
 
@@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent):
     """
     event: QueueEvent = QueueEvent.NODE_FAILED
 
+    node_execution_id: str
     node_id: str
     node_type: NodeType
     node_data: BaseNodeData
-
-    inputs: Optional[dict] = None
-    outputs: Optional[dict] = None
-    process_data: Optional[dict] = None
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    start_at: datetime
+
+    inputs: Optional[dict[str, Any]] = None
+    process_data: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
 
     error: str
 
@@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent):
     event: QueueEvent = QueueEvent.STOP
     stopped_by: StopBy
 
+    def get_stop_reason(self) -> str:
+        """
+        To stop reason
+        """
+        reason_mapping = {
+            QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
+            QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
+            QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
+            QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
+        }
+
+        return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
+
 
 class QueueMessage(BaseModel):
     """
-    QueueMessage entity
+    QueueMessage abstract entity
     """
     task_id: str
     app_mode: str
@@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage):
     WorkflowQueueMessage entity
     """
     pass
+
+
+class QueueParallelBranchRunStartedEvent(AppQueueEvent):
+    """
+    QueueParallelBranchRunStartedEvent entity
+    """
+    event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
+
+    parallel_id: str
+    parallel_start_node_id: str
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+
+
+class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
+    """
+    QueueParallelBranchRunSucceededEvent entity
+    """
+    event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
+
+    parallel_id: str
+    parallel_start_node_id: str
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+
+
+class QueueParallelBranchRunFailedEvent(AppQueueEvent):
+    """
+    QueueParallelBranchRunFailedEvent entity
+    """
+    event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
+
+    parallel_id: str
+    parallel_start_node_id: str
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+    error: str

+ 78 - 78
api/core/app/entities/task_entities.py

@@ -3,40 +3,11 @@ from typing import Any, Optional
 
 from pydantic import BaseModel, ConfigDict
 
-from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResult
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
-from core.workflow.nodes.answer.entities import GenerateRouteChunk
 from models.workflow import WorkflowNodeExecutionStatus
 
 
-class WorkflowStreamGenerateNodes(BaseModel):
-    """
-    WorkflowStreamGenerateNodes entity
-    """
-    end_node_id: str
-    stream_node_ids: list[str]
-
-
-class ChatflowStreamGenerateRoute(BaseModel):
-    """
-    ChatflowStreamGenerateRoute entity
-    """
-    answer_node_id: str
-    generate_route: list[GenerateRouteChunk]
-    current_route_position: int = 0
-
-
-class NodeExecutionInfo(BaseModel):
-    """
-    NodeExecutionInfo entity
-    """
-    workflow_node_execution_id: str
-    node_type: NodeType
-    start_at: float
-
-
 class TaskState(BaseModel):
     """
     TaskState entity
@@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState):
     """
     answer: str = ""
 
-    workflow_run_id: Optional[str] = None
-    start_at: Optional[float] = None
-    total_tokens: int = 0
-    total_steps: int = 0
-
-    ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
-    latest_node_execution_info: Optional[NodeExecutionInfo] = None
-
-    current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
-
-    iteration_nested_node_ids: list[str] = None
-
-
-class AdvancedChatTaskState(WorkflowTaskState):
-    """
-    AdvancedChatTaskState entity
-    """
-    usage: LLMUsage
-
-    current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
-
 
 class StreamEvent(Enum):
     """
@@ -97,6 +47,8 @@ class StreamEvent(Enum):
     WORKFLOW_FINISHED = "workflow_finished"
     NODE_STARTED = "node_started"
     NODE_FINISHED = "node_finished"
+    PARALLEL_BRANCH_STARTED = "parallel_branch_started"
+    PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
     ITERATION_STARTED = "iteration_started"
     ITERATION_NEXT = "iteration_next"
     ITERATION_COMPLETED = "iteration_completed"
@@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse):
         inputs: Optional[dict] = None
         created_at: int
         extras: dict = {}
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
 
     event: StreamEvent = StreamEvent.NODE_STARTED
     workflow_run_id: str
@@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse):
                 "predecessor_node_id": self.data.predecessor_node_id,
                 "inputs": None,
                 "created_at": self.data.created_at,
-                "extras": {}
+                "extras": {},
+                "parallel_id": self.data.parallel_id,
+                "parallel_start_node_id": self.data.parallel_start_node_id,
+                "parent_parallel_id": self.data.parent_parallel_id,
+                "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
+                "iteration_id": self.data.iteration_id,
             }
         }
 
@@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse):
         created_at: int
         finished_at: int
         files: Optional[list[dict]] = []
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
 
     event: StreamEvent = StreamEvent.NODE_FINISHED
     workflow_run_id: str
@@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse):
                 "execution_metadata": None,
                 "created_at": self.data.created_at,
                 "finished_at": self.data.finished_at,
-                "files": []
+                "files": [],
+                "parallel_id": self.data.parallel_id,
+                "parallel_start_node_id": self.data.parallel_start_node_id,
+                "parent_parallel_id": self.data.parent_parallel_id,
+                "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
+                "iteration_id": self.data.iteration_id,
             }
         }
+    
+
+class ParallelBranchStartStreamResponse(StreamResponse):
+    """
+    ParallelBranchStartStreamResponse entity
+    """
+
+    class Data(BaseModel):
+        """
+        Data entity
+        """
+        parallel_id: str
+        parallel_branch_id: str
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
+        created_at: int
+
+    event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
+    workflow_run_id: str
+    data: Data
+
+
+class ParallelBranchFinishedStreamResponse(StreamResponse):
+    """
+    ParallelBranchFinishedStreamResponse entity
+    """
+
+    class Data(BaseModel):
+        """
+        Data entity
+        """
+        parallel_id: str
+        parallel_branch_id: str
+        parent_parallel_id: Optional[str] = None
+        parent_parallel_start_node_id: Optional[str] = None
+        iteration_id: Optional[str] = None
+        status: str
+        error: Optional[str] = None
+        created_at: int
+
+    event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
+    workflow_run_id: str
+    data: Data
 
 
 class IterationNodeStartStreamResponse(StreamResponse):
@@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
         extras: dict = {}
         metadata: dict = {}
         inputs: dict = {}
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
 
     event: StreamEvent = StreamEvent.ITERATION_STARTED
     workflow_run_id: str
@@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
         created_at: int
         pre_iteration_output: Optional[Any] = None
         extras: dict = {}
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
 
     event: StreamEvent = StreamEvent.ITERATION_NEXT
     workflow_run_id: str
@@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
         title: str
         outputs: Optional[dict] = None
         created_at: int
-        extras: dict = None
-        inputs: dict = None
+        extras: Optional[dict] = None
+        inputs: Optional[dict] = None
         status: WorkflowNodeExecutionStatus
         error: Optional[str] = None
         elapsed_time: float
@@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
         execution_metadata: Optional[dict] = None
         finished_at: int
         steps: int
+        parallel_id: Optional[str] = None
+        parallel_start_node_id: Optional[str] = None
 
     event: StreamEvent = StreamEvent.ITERATION_COMPLETED
     workflow_run_id: str
@@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
     """
     WorkflowAppStreamResponse entity
     """
-    workflow_run_id: str
+    workflow_run_id: Optional[str] = None
 
 
 class AppBlockingResponse(BaseModel):
@@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
 
     workflow_run_id: str
     data: Data
-
-
-class WorkflowIterationState(BaseModel):
-    """
-    WorkflowIterationState entity
-    """
-
-    class Data(BaseModel):
-        """
-        Data entity
-        """
-        parent_iteration_id: Optional[str] = None
-        iteration_id: str
-        current_index: int
-        iteration_steps_boundary: list[int] = None
-        node_execution_id: str
-        started_at: float
-        inputs: dict = None
-        total_tokens: int = 0
-        node_data: BaseNodeData
-
-    current_iterations: dict[str, Data] = None

+ 8 - 6
api/core/app/task_pipeline/based_generate_task_pipeline.py

@@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline:
             err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
 
         if message:
-            message = db.session.query(Message).filter(Message.id == message.id).first()
-            err_desc = self._error_to_desc(err)
-            message.status = 'error'
-            message.error = err_desc
+            refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
 
-            db.session.commit()
+            if refetch_message:
+                err_desc = self._error_to_desc(err)
+                refetch_message.status = 'error'
+                refetch_message.error = err_desc
+
+                db.session.commit()
 
         return err
 
-    def _error_to_desc(cls, e: Exception) -> str:
+    def _error_to_desc(self, e: Exception) -> str:
         """
         Error to desc.
         :param e: exception

+ 10 - 26
api/core/app/task_pipeline/message_cycle_manage.py

@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
     AgentChatAppGenerateEntity,
     ChatAppGenerateEntity,
     CompletionAppGenerateEntity,
-    InvokeFrom,
 )
 from core.app.entities.queue_entities import (
     QueueAnnotationReplyEvent,
@@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
     QueueRetrieverResourcesEvent,
 )
 from core.app.entities.task_entities import (
-    AdvancedChatTaskState,
     EasyUITaskState,
     MessageFileStreamResponse,
     MessageReplaceStreamResponse,
     MessageStreamResponse,
+    WorkflowTaskState,
 )
 from core.llm_generator.llm_generator import LLMGenerator
 from core.tools.tool_file_manager import ToolFileManager
@@ -36,7 +35,7 @@ class MessageCycleManage:
         AgentChatAppGenerateEntity,
         AdvancedChatAppGenerateEntity
     ]
-    _task_state: Union[EasyUITaskState, AdvancedChatTaskState]
+    _task_state: Union[EasyUITaskState, WorkflowTaskState]
 
     def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
         """
@@ -45,6 +44,9 @@ class MessageCycleManage:
         :param query: query
         :return: thread
         """
+        if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
+            return None
+
         is_first_message = self._application_generate_entity.conversation_id is None
         extras = self._application_generate_entity.extras
         auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
@@ -52,7 +54,7 @@ class MessageCycleManage:
         if auto_generate_conversation_name and is_first_message:
             # start generate thread
             thread = Thread(target=self._generate_conversation_name_worker, kwargs={
-                'flask_app': current_app._get_current_object(),
+                'flask_app': current_app._get_current_object(), # type: ignore
                 'conversation_id': conversation.id,
                 'query': query
             })
@@ -75,6 +77,9 @@ class MessageCycleManage:
                 .first()
             )
 
+            if not conversation:
+                return
+
             if conversation.mode != AppMode.COMPLETION.value:
                 app_model = conversation.app
                 if not app_model:
@@ -121,34 +126,13 @@ class MessageCycleManage:
         if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
             self._task_state.metadata['retriever_resources'] = event.retriever_resources
 
-    def _get_response_metadata(self) -> dict:
-        """
-        Get response metadata by invoke from.
-        :return:
-        """
-        metadata = {}
-
-        # show_retrieve_source
-        if 'retriever_resources' in self._task_state.metadata:
-            metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
-
-        # show annotation reply
-        if 'annotation_reply' in self._task_state.metadata:
-            metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
-
-        # show usage
-        if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
-            metadata['usage'] = self._task_state.metadata['usage']
-
-        return metadata
-
     def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
         """
         Message file to stream response.
         :param event: event
         :return:
         """
-        message_file: MessageFile = (
+        message_file = (
             db.session.query(MessageFile)
             .filter(MessageFile.id == event.message_file_id)
             .first()

+ 352 - 313
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -1,33 +1,41 @@
 import json
 import time
 from datetime import datetime, timezone
-from typing import Optional, Union, cast
+from typing import Any, Optional, Union, cast
 
-from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
 from core.app.entities.queue_entities import (
+    QueueIterationCompletedEvent,
+    QueueIterationNextEvent,
+    QueueIterationStartEvent,
     QueueNodeFailedEvent,
     QueueNodeStartedEvent,
     QueueNodeSucceededEvent,
-    QueueStopEvent,
-    QueueWorkflowFailedEvent,
-    QueueWorkflowSucceededEvent,
+    QueueParallelBranchRunFailedEvent,
+    QueueParallelBranchRunStartedEvent,
+    QueueParallelBranchRunSucceededEvent,
 )
 from core.app.entities.task_entities import (
-    NodeExecutionInfo,
+    IterationNodeCompletedStreamResponse,
+    IterationNodeNextStreamResponse,
+    IterationNodeStartStreamResponse,
     NodeFinishStreamResponse,
     NodeStartStreamResponse,
+    ParallelBranchFinishedStreamResponse,
+    ParallelBranchStartStreamResponse,
     WorkflowFinishStreamResponse,
     WorkflowStartStreamResponse,
+    WorkflowTaskState,
 )
-from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
 from core.file.file_obj import FileVar
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.ops.entities.trace_entity import TraceTaskName
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
 from core.tools.tool_manager import ToolManager
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.enums import SystemVariableKey
 from core.workflow.nodes.tool.entities import ToolNodeData
-from core.workflow.workflow_engine_manager import WorkflowEngineManager
+from core.workflow.workflow_entry import WorkflowEntry
 from extensions.ext_database import db
 from models.account import Account
 from models.model import EndUser
@@ -41,54 +49,56 @@ from models.workflow import (
     WorkflowRunStatus,
     WorkflowRunTriggeredFrom,
 )
-from services.workflow_service import WorkflowService
-
-
-class WorkflowCycleManage(WorkflowIterationCycleManage):
-    def _init_workflow_run(self, workflow: Workflow,
-                           triggered_from: WorkflowRunTriggeredFrom,
-                           user: Union[Account, EndUser],
-                           user_inputs: dict,
-                           system_inputs: Optional[dict] = None) -> WorkflowRun:
-        """
-        Init workflow run
-        :param workflow: Workflow instance
-        :param triggered_from: triggered from
-        :param user: account or end user
-        :param user_inputs: user variables inputs
-        :param system_inputs: system inputs, like: query, files
-        :return:
-        """
-        max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
-                           .filter(WorkflowRun.tenant_id == workflow.tenant_id) \
-                           .filter(WorkflowRun.app_id == workflow.app_id) \
-                           .scalar() or 0
+
+
+class WorkflowCycleManage:
+    _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
+    _workflow: Workflow
+    _user: Union[Account, EndUser]
+    _task_state: WorkflowTaskState
+    _workflow_system_variables: dict[SystemVariableKey, Any]
+
+    def _handle_workflow_run_start(self) -> WorkflowRun:
+        max_sequence = (
+            db.session.query(db.func.max(WorkflowRun.sequence_number))
+            .filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
+            .filter(WorkflowRun.app_id == self._workflow.app_id)
+            .scalar()
+            or 0
+        )
         new_sequence_number = max_sequence + 1
 
-        inputs = {**user_inputs}
-        for key, value in (system_inputs or {}).items():
+        inputs = {**self._application_generate_entity.inputs}
+        for key, value in (self._workflow_system_variables or {}).items():
             if key.value == 'conversation':
                 continue
 
             inputs[f'sys.{key.value}'] = value
-        inputs = WorkflowEngineManager.handle_special_values(inputs)
+
+        inputs = WorkflowEntry.handle_special_values(inputs)
+
+        triggered_from= (
+            WorkflowRunTriggeredFrom.DEBUGGING
+            if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
+            else WorkflowRunTriggeredFrom.APP_RUN
+        )
 
         # init workflow run
-        workflow_run = WorkflowRun(
-            tenant_id=workflow.tenant_id,
-            app_id=workflow.app_id,
-            sequence_number=new_sequence_number,
-            workflow_id=workflow.id,
-            type=workflow.type,
-            triggered_from=triggered_from.value,
-            version=workflow.version,
-            graph=workflow.graph,
-            inputs=json.dumps(inputs),
-            status=WorkflowRunStatus.RUNNING.value,
-            created_by_role=(CreatedByRole.ACCOUNT.value
-                             if isinstance(user, Account) else CreatedByRole.END_USER.value),
-            created_by=user.id
+        workflow_run = WorkflowRun()
+        workflow_run.tenant_id = self._workflow.tenant_id
+        workflow_run.app_id = self._workflow.app_id
+        workflow_run.sequence_number = new_sequence_number
+        workflow_run.workflow_id = self._workflow.id
+        workflow_run.type = self._workflow.type
+        workflow_run.triggered_from = triggered_from.value
+        workflow_run.version = self._workflow.version
+        workflow_run.graph = self._workflow.graph
+        workflow_run.inputs = json.dumps(inputs)
+        workflow_run.status = WorkflowRunStatus.RUNNING.value
+        workflow_run.created_by_role = (
+            CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
         )
+        workflow_run.created_by = self._user.id
 
         db.session.add(workflow_run)
         db.session.commit()
@@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
 
         return workflow_run
 
-    def _workflow_run_success(
-        self, workflow_run: WorkflowRun,
+    def _handle_workflow_run_success(
+        self,
+        workflow_run: WorkflowRun,
+        start_at: float,
         total_tokens: int,
         total_steps: int,
         outputs: Optional[str] = None,
         conversation_id: Optional[str] = None,
-        trace_manager: Optional[TraceQueueManager] = None
+        trace_manager: Optional[TraceQueueManager] = None,
     ) -> WorkflowRun:
         """
         Workflow run success
         :param workflow_run: workflow run
+        :param start_at: start time
         :param total_tokens: total tokens
         :param total_steps: total steps
         :param outputs: outputs
         :param conversation_id: conversation id
         :return:
         """
+        workflow_run = self._refetch_workflow_run(workflow_run.id)
+
         workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
         workflow_run.outputs = outputs
-        workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
+        workflow_run.elapsed_time = time.perf_counter() - start_at
         workflow_run.total_tokens = total_tokens
         workflow_run.total_steps = total_steps
         workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
 
         db.session.commit()
         db.session.refresh(workflow_run)
-        db.session.close()
 
         if trace_manager:
             trace_manager.add_trace_task(
@@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                 )
             )
 
+        db.session.close()
+
         return workflow_run
 
-    def _workflow_run_failed(
-        self, workflow_run: WorkflowRun,
+    def _handle_workflow_run_failed(
+        self,
+        workflow_run: WorkflowRun,
+        start_at: float,
         total_tokens: int,
         total_steps: int,
         status: WorkflowRunStatus,
         error: str,
         conversation_id: Optional[str] = None,
-        trace_manager: Optional[TraceQueueManager] = None
+        trace_manager: Optional[TraceQueueManager] = None,
     ) -> WorkflowRun:
         """
         Workflow run failed
         :param workflow_run: workflow run
+        :param start_at: start time
         :param total_tokens: total tokens
         :param total_steps: total steps
         :param status: status
         :param error: error message
         :return:
         """
+        workflow_run = self._refetch_workflow_run(workflow_run.id)
+
         workflow_run.status = status.value
         workflow_run.error = error
-        workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
+        workflow_run.elapsed_time = time.perf_counter() - start_at
         workflow_run.total_tokens = total_tokens
         workflow_run.total_steps = total_steps
         workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
 
         db.session.commit()
+
+        running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
+            WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
+            WorkflowNodeExecution.app_id == workflow_run.app_id,
+            WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
+            WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+            WorkflowNodeExecution.workflow_run_id == workflow_run.id,
+            WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
+        ).all()
+
+        for workflow_node_execution in running_workflow_node_executions:
+            workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
+            workflow_node_execution.error = error
+            workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+            workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
+            db.session.commit()
+
         db.session.refresh(workflow_run)
         db.session.close()
 
@@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
 
         return workflow_run
 
-    def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
-                                               node_id: str,
-                                               node_type: NodeType,
-                                               node_title: str,
-                                               node_run_index: int = 1,
-                                               predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
-        """
-        Init workflow node execution from workflow run
-        :param workflow_run: workflow run
-        :param node_id: node id
-        :param node_type: node type
-        :param node_title: node title
-        :param node_run_index: run index
-        :param predecessor_node_id: predecessor node id if exists
-        :return:
-        """
+    def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
         # init workflow node execution
-        workflow_node_execution = WorkflowNodeExecution(
-            tenant_id=workflow_run.tenant_id,
-            app_id=workflow_run.app_id,
-            workflow_id=workflow_run.workflow_id,
-            triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-            workflow_run_id=workflow_run.id,
-            predecessor_node_id=predecessor_node_id,
-            index=node_run_index,
-            node_id=node_id,
-            node_type=node_type.value,
-            title=node_title,
-            status=WorkflowNodeExecutionStatus.RUNNING.value,
-            created_by_role=workflow_run.created_by_role,
-            created_by=workflow_run.created_by,
-            created_at=datetime.now(timezone.utc).replace(tzinfo=None)
-        )
+        workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.tenant_id = workflow_run.tenant_id
+        workflow_node_execution.app_id = workflow_run.app_id
+        workflow_node_execution.workflow_id = workflow_run.workflow_id
+        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
+        workflow_node_execution.workflow_run_id = workflow_run.id
+        workflow_node_execution.predecessor_node_id = event.predecessor_node_id
+        workflow_node_execution.index = event.node_run_index
+        workflow_node_execution.node_execution_id = event.node_execution_id
+        workflow_node_execution.node_id = event.node_id
+        workflow_node_execution.node_type = event.node_type.value
+        workflow_node_execution.title = event.node_data.title
+        workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
+        workflow_node_execution.created_by_role = workflow_run.created_by_role
+        workflow_node_execution.created_by = workflow_run.created_by
+        workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
 
         db.session.add(workflow_node_execution)
         db.session.commit()
@@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
 
         return workflow_node_execution
 
-    def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
-                                         start_at: float,
-                                         inputs: Optional[dict] = None,
-                                         process_data: Optional[dict] = None,
-                                         outputs: Optional[dict] = None,
-                                         execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
+    def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
         """
         Workflow node execution success
-        :param workflow_node_execution: workflow node execution
-        :param start_at: start time
-        :param inputs: inputs
-        :param process_data: process data
-        :param outputs: outputs
-        :param execution_metadata: execution metadata
+        :param event: queue node succeeded event
         :return:
         """
-        inputs = WorkflowEngineManager.handle_special_values(inputs)
-        outputs = WorkflowEngineManager.handle_special_values(outputs)
+        workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
+
+        inputs = WorkflowEntry.handle_special_values(event.inputs)
+        outputs = WorkflowEntry.handle_special_values(event.outputs)
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
-        workflow_node_execution.elapsed_time = time.perf_counter() - start_at
         workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
-        workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
+        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
         workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
-        workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
-            if execution_metadata else None
+        workflow_node_execution.execution_metadata = (
+            json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
+        )
         workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+        workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
 
         db.session.commit()
         db.session.refresh(workflow_node_execution)
@@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
 
         return workflow_node_execution
 
-    def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
-                                        start_at: float,
-                                        error: str,
-                                        inputs: Optional[dict] = None,
-                                        process_data: Optional[dict] = None,
-                                        outputs: Optional[dict] = None,
-                                        execution_metadata: Optional[dict] = None
-                                        ) -> WorkflowNodeExecution:
+    def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
         """
         Workflow node execution failed
-        :param workflow_node_execution: workflow node execution
-        :param start_at: start time
-        :param error: error message
+        :param event: queue node failed event
         :return:
         """
-        inputs = WorkflowEngineManager.handle_special_values(inputs)
-        outputs = WorkflowEngineManager.handle_special_values(outputs)
+        workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
+
+        inputs = WorkflowEntry.handle_special_values(event.inputs)
+        outputs = WorkflowEntry.handle_special_values(event.outputs)
 
         workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
-        workflow_node_execution.error = error
-        workflow_node_execution.elapsed_time = time.perf_counter() - start_at
+        workflow_node_execution.error = event.error
         workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
         workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
-        workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
+        workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
         workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
-        workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
-            if execution_metadata else None
+        workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
 
         db.session.commit()
         db.session.refresh(workflow_node_execution)
@@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
 
         return workflow_node_execution
 
-    def _workflow_start_to_stream_response(self, task_id: str,
-                                           workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
+    #################################################
+    #             to stream responses               #
+    #################################################
+
+    def _workflow_start_to_stream_response(
+        self, task_id: str, workflow_run: WorkflowRun
+    ) -> WorkflowStartStreamResponse:
         """
         Workflow start to stream response.
         :param task_id: task id
@@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                 id=workflow_run.id,
                 workflow_id=workflow_run.workflow_id,
                 sequence_number=workflow_run.sequence_number,
-                inputs=workflow_run.inputs_dict,
-                created_at=int(workflow_run.created_at.timestamp())
-            )
+                inputs=workflow_run.inputs_dict or {},
+                created_at=int(workflow_run.created_at.timestamp()),
+            ),
         )
 
-    def _workflow_finish_to_stream_response(self, task_id: str,
-                                            workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
+    def _workflow_finish_to_stream_response(
+        self, task_id: str, workflow_run: WorkflowRun
+    ) -> WorkflowFinishStreamResponse:
         """
         Workflow finish to stream response.
         :param task_id: task id
@@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
             created_by_account = workflow_run.created_by_account
             if created_by_account:
                 created_by = {
-                    "id": created_by_account.id,
-                    "name": created_by_account.name,
-                    "email": created_by_account.email,
+                    'id': created_by_account.id,
+                    'name': created_by_account.name,
+                    'email': created_by_account.email,
                 }
         else:
             created_by_end_user = workflow_run.created_by_end_user
             if created_by_end_user:
                 created_by = {
-                    "id": created_by_end_user.id,
-                    "user": created_by_end_user.session_id,
+                    'id': created_by_end_user.id,
+                    'user': created_by_end_user.session_id,
                 }
 
         return WorkflowFinishStreamResponse(
@@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                 created_by=created_by,
                 created_at=int(workflow_run.created_at.timestamp()),
                 finished_at=int(workflow_run.finished_at.timestamp()),
-                files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
-            )
+                files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
+            ),
         )
 
-    def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
-                                                task_id: str,
-                                                workflow_node_execution: WorkflowNodeExecution) \
-            -> NodeStartStreamResponse:
+    def _workflow_node_start_to_stream_response(
+        self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
+    ) -> Optional[NodeStartStreamResponse]:
         """
         Workflow node start to stream response.
         :param event: queue node started event
@@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
         :param workflow_node_execution: workflow node execution
         :return:
         """
+        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+            return None
+
         response = NodeStartStreamResponse(
             task_id=task_id,
             workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                 index=workflow_node_execution.index,
                 predecessor_node_id=workflow_node_execution.predecessor_node_id,
                 inputs=workflow_node_execution.inputs_dict,
-                created_at=int(workflow_node_execution.created_at.timestamp())
-            )
+                created_at=int(workflow_node_execution.created_at.timestamp()),
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+            ),
         )
 
         # extras logic
@@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
             response.data.extras['icon'] = ToolManager.get_tool_icon(
                 tenant_id=self._application_generate_entity.app_config.tenant_id,
                 provider_type=node_data.provider_type,
-                provider_id=node_data.provider_id
+                provider_id=node_data.provider_id,
             )
 
         return response
 
-    def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-            -> NodeFinishStreamResponse:
+    def _workflow_node_finish_to_stream_response(
+        self, 
+        event: QueueNodeSucceededEvent | QueueNodeFailedEvent, 
+        task_id: str, 
+        workflow_node_execution: WorkflowNodeExecution
+    ) -> Optional[NodeFinishStreamResponse]:
         """
         Workflow node finish to stream response.
+        :param event: queue node succeeded or failed event
         :param task_id: task id
         :param workflow_node_execution: workflow node execution
         :return:
         """
+        if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
+            return None
+        
         return NodeFinishStreamResponse(
             task_id=task_id,
             workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
                 execution_metadata=workflow_node_execution.execution_metadata_dict,
                 created_at=int(workflow_node_execution.created_at.timestamp()),
                 finished_at=int(workflow_node_execution.finished_at.timestamp()),
-                files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
-            )
+                files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+            ),
         )
-
-    def _handle_workflow_start(self) -> WorkflowRun:
-        self._task_state.start_at = time.perf_counter()
-
-        workflow_run = self._init_workflow_run(
-            workflow=self._workflow,
-            triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
-            if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
-            else WorkflowRunTriggeredFrom.APP_RUN,
-            user=self._user,
-            user_inputs=self._application_generate_entity.inputs,
-            system_inputs=self._workflow_system_variables
+    
+    def _workflow_parallel_branch_start_to_stream_response(
+            self,
+            task_id: str,
+            workflow_run: WorkflowRun,
+            event: QueueParallelBranchRunStartedEvent
+        ) -> ParallelBranchStartStreamResponse:
+        """
+        Workflow parallel branch start to stream response
+        :param task_id: task id
+        :param workflow_run: workflow run
+        :param event: parallel branch run started event
+        :return:
+        """
+        return ParallelBranchStartStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_run.id,
+            data=ParallelBranchStartStreamResponse.Data(
+                parallel_id=event.parallel_id,
+                parallel_branch_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+                created_at=int(time.time()),
+            )
         )
-
-        self._task_state.workflow_run_id = workflow_run.id
-
-        db.session.close()
-
-        return workflow_run
-
-    def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
-        workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
-        workflow_node_execution = self._init_node_execution_from_workflow_run(
-            workflow_run=workflow_run,
-            node_id=event.node_id,
-            node_type=event.node_type,
-            node_title=event.node_data.title,
-            node_run_index=event.node_run_index,
-            predecessor_node_id=event.predecessor_node_id
+    
+    def _workflow_parallel_branch_finished_to_stream_response(
+            self,
+            task_id: str,
+            workflow_run: WorkflowRun,
+            event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
+        ) -> ParallelBranchFinishedStreamResponse:
+        """
+        Workflow parallel branch finished to stream response
+        :param task_id: task id
+        :param workflow_run: workflow run
+        :param event: parallel branch run succeeded or failed event
+        :return:
+        """
+        return ParallelBranchFinishedStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_run.id,
+            data=ParallelBranchFinishedStreamResponse.Data(
+                parallel_id=event.parallel_id,
+                parallel_branch_id=event.parallel_start_node_id,
+                parent_parallel_id=event.parent_parallel_id,
+                parent_parallel_start_node_id=event.parent_parallel_start_node_id,
+                iteration_id=event.in_iteration_id,
+                status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
+                error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
+                created_at=int(time.time()),
+            )
         )
 
-        latest_node_execution_info = NodeExecutionInfo(
-            workflow_node_execution_id=workflow_node_execution.id,
-            node_type=event.node_type,
-            start_at=time.perf_counter()
+    def _workflow_iteration_start_to_stream_response(
+            self,
+            task_id: str,
+            workflow_run: WorkflowRun,
+            event: QueueIterationStartEvent
+        ) -> IterationNodeStartStreamResponse:
+        """
+        Workflow iteration start to stream response
+        :param task_id: task id
+        :param workflow_run: workflow run
+        :param event: iteration start event
+        :return:
+        """
+        return IterationNodeStartStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_run.id,
+            data=IterationNodeStartStreamResponse.Data(
+                id=event.node_id,
+                node_id=event.node_id,
+                node_type=event.node_type.value,
+                title=event.node_data.title,
+                created_at=int(time.time()),
+                extras={},
+                inputs=event.inputs or {},
+                metadata=event.metadata or {},
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
+            )
         )
 
-        self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
-        self._task_state.latest_node_execution_info = latest_node_execution_info
-
-        self._task_state.total_steps += 1
-
-        db.session.close()
-
-        return workflow_node_execution
-
-    def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
-        current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
-        workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
-            WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
-
-        execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
-
-        if self._iteration_state and self._iteration_state.current_iterations:
-            if not execution_metadata:
-                execution_metadata = {}
-            current_iteration_data = None
-            for iteration_node_id in self._iteration_state.current_iterations:
-                data = self._iteration_state.current_iterations[iteration_node_id]
-                if data.parent_iteration_id == None:
-                    current_iteration_data = data
-                    break
-
-            if current_iteration_data:
-                execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
-                execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
-
-        if isinstance(event, QueueNodeSucceededEvent):
-            workflow_node_execution = self._workflow_node_execution_success(
-                workflow_node_execution=workflow_node_execution,
-                start_at=current_node_execution.start_at,
-                inputs=event.inputs,
-                process_data=event.process_data,
-                outputs=event.outputs,
-                execution_metadata=execution_metadata
+    def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
+        """
+        Workflow iteration next to stream response
+        :param task_id: task id
+        :param workflow_run: workflow run
+        :param event: iteration next event
+        :return:
+        """
+        return IterationNodeNextStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_run.id,
+            data=IterationNodeNextStreamResponse.Data(
+                id=event.node_id,
+                node_id=event.node_id,
+                node_type=event.node_type.value,
+                title=event.node_data.title,
+                index=event.index,
+                pre_iteration_output=event.output,
+                created_at=int(time.time()),
+                extras={},
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
             )
+        )
 
-            if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
-                self._task_state.total_tokens += (
-                    int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
-
-                if self._iteration_state:
-                    for iteration_node_id in self._iteration_state.current_iterations:
-                        data = self._iteration_state.current_iterations[iteration_node_id]
-                        if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
-                            data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
-
-            if workflow_node_execution.node_type == NodeType.LLM.value:
-                outputs = workflow_node_execution.outputs_dict
-                usage_dict = outputs.get('usage', {})
-                self._task_state.metadata['usage'] = usage_dict
-        else:
-            workflow_node_execution = self._workflow_node_execution_failed(
-                workflow_node_execution=workflow_node_execution,
-                start_at=current_node_execution.start_at,
-                error=event.error,
-                inputs=event.inputs,
-                process_data=event.process_data,
+    def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
+        """
+        Workflow iteration completed to stream response
+        :param task_id: task id
+        :param workflow_run: workflow run
+        :param event: iteration completed event
+        :return:
+        """
+        return IterationNodeCompletedStreamResponse(
+            task_id=task_id,
+            workflow_run_id=workflow_run.id,
+            data=IterationNodeCompletedStreamResponse.Data(
+                id=event.node_id,
+                node_id=event.node_id,
+                node_type=event.node_type.value,
+                title=event.node_data.title,
                 outputs=event.outputs,
-                execution_metadata=execution_metadata
+                created_at=int(time.time()),
+                extras={},
+                inputs=event.inputs or {},
+                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                error=None,
+                elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
+                total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
+                execution_metadata=event.metadata,
+                finished_at=int(time.time()),
+                steps=event.steps,
+                parallel_id=event.parallel_id,
+                parallel_start_node_id=event.parallel_start_node_id,
             )
-
-        db.session.close()
-
-        return workflow_node_execution
-
-    def _handle_workflow_finished(
-        self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
-        conversation_id: Optional[str] = None,
-        trace_manager: Optional[TraceQueueManager] = None
-    ) -> Optional[WorkflowRun]:
-        workflow_run = db.session.query(WorkflowRun).filter(
-            WorkflowRun.id == self._task_state.workflow_run_id).first()
-        if not workflow_run:
-            return None
-
-        if conversation_id is None:
-            conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
-        if isinstance(event, QueueStopEvent):
-            workflow_run = self._workflow_run_failed(
-                workflow_run=workflow_run,
-                total_tokens=self._task_state.total_tokens,
-                total_steps=self._task_state.total_steps,
-                status=WorkflowRunStatus.STOPPED,
-                error='Workflow stopped.',
-                conversation_id=conversation_id,
-                trace_manager=trace_manager
-            )
-
-            latest_node_execution_info = self._task_state.latest_node_execution_info
-            if latest_node_execution_info:
-                workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
-                    WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
-                if (workflow_node_execution
-                        and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
-                    self._workflow_node_execution_failed(
-                        workflow_node_execution=workflow_node_execution,
-                        start_at=latest_node_execution_info.start_at,
-                        error='Workflow stopped.'
-                    )
-        elif isinstance(event, QueueWorkflowFailedEvent):
-            workflow_run = self._workflow_run_failed(
-                workflow_run=workflow_run,
-                total_tokens=self._task_state.total_tokens,
-                total_steps=self._task_state.total_steps,
-                status=WorkflowRunStatus.FAILED,
-                error=event.error,
-                conversation_id=conversation_id,
-                trace_manager=trace_manager
-            )
-        else:
-            if self._task_state.latest_node_execution_info:
-                workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
-                    WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
-                outputs = workflow_node_execution.outputs
-            else:
-                outputs = None
-
-            workflow_run = self._workflow_run_success(
-                workflow_run=workflow_run,
-                total_tokens=self._task_state.total_tokens,
-                total_steps=self._task_state.total_steps,
-                outputs=outputs,
-                conversation_id=conversation_id,
-                trace_manager=trace_manager
-            )
-
-        self._task_state.workflow_run_id = workflow_run.id
-
-        db.session.close()
-
-        return workflow_run
+        )
 
     def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
         """
@@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
             return value.to_dict()
 
         return None
+
+    def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
+        """
+        Refetch workflow run
+        :param workflow_run_id: workflow run id
+        :return:
+        """
+        workflow_run = db.session.query(WorkflowRun).filter(
+            WorkflowRun.id == workflow_run_id).first()
+
+        if not workflow_run:
+            raise Exception(f'Workflow run not found: {workflow_run_id}')
+
+        return workflow_run
+
+    def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
+        """
+        Refetch workflow node execution
+        :param node_execution_id: workflow node execution id
+        :return:
+        """
+        workflow_node_execution = (
+            db.session.query(WorkflowNodeExecution)
+            .filter(
+                WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
+                WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
+                WorkflowNodeExecution.workflow_id == self._workflow.id,
+                WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
+                WorkflowNodeExecution.node_execution_id == node_execution_id,
+            )
+            .first()
+        )
+
+        if not workflow_node_execution:
+            raise Exception(f'Workflow node execution not found: {node_execution_id}')
+
+        return workflow_node_execution

+ 0 - 16
api/core/app/task_pipeline/workflow_cycle_state_manager.py

@@ -1,16 +0,0 @@
-from typing import Any, Union
-
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
-from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
-from core.workflow.enums import SystemVariableKey
-from models.account import Account
-from models.model import EndUser
-from models.workflow import Workflow
-
-
-class WorkflowCycleStateManager:
-    _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
-    _workflow: Workflow
-    _user: Union[Account, EndUser]
-    _task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
-    _workflow_system_variables: dict[SystemVariableKey, Any]

+ 0 - 290
api/core/app/task_pipeline/workflow_iteration_cycle_manage.py

@@ -1,290 +0,0 @@
-import json
-import time
-from collections.abc import Generator
-from datetime import datetime, timezone
-from typing import Optional, Union
-
-from core.app.entities.queue_entities import (
-    QueueIterationCompletedEvent,
-    QueueIterationNextEvent,
-    QueueIterationStartEvent,
-)
-from core.app.entities.task_entities import (
-    IterationNodeCompletedStreamResponse,
-    IterationNodeNextStreamResponse,
-    IterationNodeStartStreamResponse,
-    NodeExecutionInfo,
-    WorkflowIterationState,
-)
-from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
-from core.workflow.entities.node_entities import NodeType
-from core.workflow.workflow_engine_manager import WorkflowEngineManager
-from extensions.ext_database import db
-from models.workflow import (
-    WorkflowNodeExecution,
-    WorkflowNodeExecutionStatus,
-    WorkflowNodeExecutionTriggeredFrom,
-    WorkflowRun,
-)
-
-
-class WorkflowIterationCycleManage(WorkflowCycleStateManager):
-    _iteration_state: WorkflowIterationState = None
-
-    def _init_iteration_state(self) -> WorkflowIterationState:
-        if not self._iteration_state:
-            self._iteration_state = WorkflowIterationState(
-                current_iterations={}
-            )
-
-    def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-    -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
-        """
-        Handle iteration to stream response
-        :param task_id: task id
-        :param event: iteration event
-        :return:
-        """
-        if isinstance(event, QueueIterationStartEvent):
-            return IterationNodeStartStreamResponse(
-                task_id=task_id,
-                workflow_run_id=self._task_state.workflow_run_id,
-                data=IterationNodeStartStreamResponse.Data(
-                    id=event.node_id,
-                    node_id=event.node_id,
-                    node_type=event.node_type.value,
-                    title=event.node_data.title,
-                    created_at=int(time.time()),
-                    extras={},
-                    inputs=event.inputs,
-                    metadata=event.metadata
-                )
-            )
-        elif isinstance(event, QueueIterationNextEvent):
-            current_iteration = self._iteration_state.current_iterations[event.node_id]
-
-            return IterationNodeNextStreamResponse(
-                task_id=task_id,
-                workflow_run_id=self._task_state.workflow_run_id,
-                data=IterationNodeNextStreamResponse.Data(
-                    id=event.node_id,
-                    node_id=event.node_id,
-                    node_type=event.node_type.value,
-                    title=current_iteration.node_data.title,
-                    index=event.index,
-                    pre_iteration_output=event.output,
-                    created_at=int(time.time()),
-                    extras={}
-                )
-            )
-        elif isinstance(event, QueueIterationCompletedEvent):
-            current_iteration = self._iteration_state.current_iterations[event.node_id]
-
-            return IterationNodeCompletedStreamResponse(
-                task_id=task_id,
-                workflow_run_id=self._task_state.workflow_run_id,
-                data=IterationNodeCompletedStreamResponse.Data(
-                    id=event.node_id,
-                    node_id=event.node_id,
-                    node_type=event.node_type.value,
-                    title=current_iteration.node_data.title,
-                    outputs=event.outputs,
-                    created_at=int(time.time()),
-                    extras={},
-                    inputs=current_iteration.inputs,
-                    status=WorkflowNodeExecutionStatus.SUCCEEDED,
-                    error=None,
-                    elapsed_time=time.perf_counter() - current_iteration.started_at,
-                    total_tokens=current_iteration.total_tokens,
-                    execution_metadata={
-                        'total_tokens': current_iteration.total_tokens,
-                    },
-                    finished_at=int(time.time()),
-                    steps=current_iteration.current_index
-                )
-            )
-        
-    def _init_iteration_execution_from_workflow_run(self, 
-        workflow_run: WorkflowRun,
-        node_id: str,
-        node_type: NodeType,
-        node_title: str,
-        node_run_index: int = 1,
-        inputs: Optional[dict] = None,
-        predecessor_node_id: Optional[str] = None
-    ) -> WorkflowNodeExecution:
-        workflow_node_execution = WorkflowNodeExecution(
-            tenant_id=workflow_run.tenant_id,
-            app_id=workflow_run.app_id,
-            workflow_id=workflow_run.workflow_id,
-            triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
-            workflow_run_id=workflow_run.id,
-            predecessor_node_id=predecessor_node_id,
-            index=node_run_index,
-            node_id=node_id,
-            node_type=node_type.value,
-            inputs=json.dumps(inputs) if inputs else None,
-            title=node_title,
-            status=WorkflowNodeExecutionStatus.RUNNING.value,
-            created_by_role=workflow_run.created_by_role,
-            created_by=workflow_run.created_by,
-            execution_metadata=json.dumps({
-                'started_run_index': node_run_index + 1,
-                'current_index': 0,
-                'steps_boundary': [],
-            }),
-            created_at=datetime.now(timezone.utc).replace(tzinfo=None)
-        )
-
-        db.session.add(workflow_node_execution)
-        db.session.commit()
-        db.session.refresh(workflow_node_execution)
-        db.session.close()
-
-        return workflow_node_execution
-    
-    def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
-        if isinstance(event, QueueIterationStartEvent):
-            return self._handle_iteration_started(event)
-        elif isinstance(event, QueueIterationNextEvent):
-            return self._handle_iteration_next(event)
-        elif isinstance(event, QueueIterationCompletedEvent):
-            return self._handle_iteration_completed(event)
-    
-    def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
-        self._init_iteration_state()
-
-        workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
-        workflow_node_execution = self._init_iteration_execution_from_workflow_run(
-            workflow_run=workflow_run,
-            node_id=event.node_id,
-            node_type=NodeType.ITERATION,
-            node_title=event.node_data.title,
-            node_run_index=event.node_run_index,
-            inputs=event.inputs,
-            predecessor_node_id=event.predecessor_node_id
-        )
-
-        latest_node_execution_info = NodeExecutionInfo(
-            workflow_node_execution_id=workflow_node_execution.id,
-            node_type=NodeType.ITERATION,
-            start_at=time.perf_counter()
-        )
-
-        self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
-        self._task_state.latest_node_execution_info = latest_node_execution_info
-
-        self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
-            parent_iteration_id=None,
-            iteration_id=event.node_id,
-            current_index=0,
-            iteration_steps_boundary=[],
-            node_execution_id=workflow_node_execution.id,
-            started_at=time.perf_counter(),
-            inputs=event.inputs,
-            total_tokens=0,
-            node_data=event.node_data
-        )
-
-        db.session.close()
-
-        return workflow_node_execution
-    
-    def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
-        if event.node_id not in self._iteration_state.current_iterations:
-            return
-        current_iteration = self._iteration_state.current_iterations[event.node_id]
-        current_iteration.current_index = event.index
-        current_iteration.iteration_steps_boundary.append(event.node_run_index)
-        workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
-            WorkflowNodeExecution.id == current_iteration.node_execution_id
-        ).first()
-
-        original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
-        if original_node_execution_metadata:
-            original_node_execution_metadata['current_index'] = event.index
-            original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
-            original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
-            workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
-
-            db.session.commit()
-
-        db.session.close()
-
-    def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
-        if event.node_id not in self._iteration_state.current_iterations:
-            return
-        
-        current_iteration = self._iteration_state.current_iterations[event.node_id]
-        workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
-            WorkflowNodeExecution.id == current_iteration.node_execution_id
-        ).first()
-
-        workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
-        workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
-        workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
-
-        original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
-        if original_node_execution_metadata:
-            original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
-            original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
-            workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
-
-        db.session.commit()
-
-        # remove current iteration
-        self._iteration_state.current_iterations.pop(event.node_id, None)
-
-        # set latest node execution info
-        latest_node_execution_info = NodeExecutionInfo(
-            workflow_node_execution_id=workflow_node_execution.id,
-            node_type=NodeType.ITERATION,
-            start_at=time.perf_counter()
-        )
-
-        self._task_state.latest_node_execution_info = latest_node_execution_info
-        
-        db.session.close()
-
-    def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
-        """
-        Handle iteration exception
-        """
-        if not self._iteration_state or not self._iteration_state.current_iterations:
-            return
-        
-        for node_id, current_iteration in self._iteration_state.current_iterations.items():
-            workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
-                WorkflowNodeExecution.id == current_iteration.node_execution_id
-            ).first()
-
-            workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
-            workflow_node_execution.error = error
-            workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
-
-            db.session.commit()
-            db.session.close()
-
-            yield IterationNodeCompletedStreamResponse(
-                task_id=task_id,
-                workflow_run_id=self._task_state.workflow_run_id,
-                data=IterationNodeCompletedStreamResponse.Data(
-                    id=node_id,
-                    node_id=node_id,
-                    node_type=NodeType.ITERATION.value,
-                    title=current_iteration.node_data.title,
-                    outputs={},
-                    created_at=int(time.time()),
-                    extras={},
-                    inputs=current_iteration.inputs,
-                    status=WorkflowNodeExecutionStatus.FAILED,
-                    error=error,
-                    elapsed_time=time.perf_counter() - current_iteration.started_at,
-                    total_tokens=current_iteration.total_tokens,
-                    execution_metadata={
-                        'total_tokens': current_iteration.total_tokens,
-                    },
-                    finished_at=int(time.time()),
-                    steps=current_iteration.current_index
-                )
-            )

+ 33 - 0
api/core/model_runtime/entities/llm_entities.py

@@ -63,6 +63,39 @@ class LLMUsage(ModelUsage):
             latency=0.0
         )
 
+    def plus(self, other: 'LLMUsage') -> 'LLMUsage':
+        """
+        Add two LLMUsage instances together.
+
+        :param other: Another LLMUsage instance to add
+        :return: A new LLMUsage instance with summed values
+        """
+        if self.total_tokens == 0:
+            return other
+        else:
+            return LLMUsage(
+                prompt_tokens=self.prompt_tokens + other.prompt_tokens,
+                prompt_unit_price=other.prompt_unit_price,
+                prompt_price_unit=other.prompt_price_unit,
+                prompt_price=self.prompt_price + other.prompt_price,
+                completion_tokens=self.completion_tokens + other.completion_tokens,
+                completion_unit_price=other.completion_unit_price,
+                completion_price_unit=other.completion_price_unit,
+                completion_price=self.completion_price + other.completion_price,
+                total_tokens=self.total_tokens + other.total_tokens,
+                total_price=self.total_price + other.total_price,
+                currency=other.currency,
+                latency=self.latency + other.latency
+            )
+
+    def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
+        """
+        Overload the + operator to add two LLMUsage instances.
+
+        :param other: Another LLMUsage instance to add
+        :return: A new LLMUsage instance with summed values
+        """
+        return self.plus(other)
 
 class LLMResult(BaseModel):
     """

+ 4 - 4
api/core/moderation/output_moderation.py

@@ -34,13 +34,13 @@ class OutputModeration(BaseModel):
     final_output: Optional[str] = None
     model_config = ConfigDict(arbitrary_types_allowed=True)
 
-    def should_direct_output(self):
+    def should_direct_output(self) -> bool:
         return self.final_output is not None
 
-    def get_final_output(self):
-        return self.final_output
+    def get_final_output(self) -> str:
+        return self.final_output or ""
 
-    def append_new_token(self, token: str):
+    def append_new_token(self, token: str) -> None:
         self.buffer += token
 
         if not self.thread:

+ 3 - 1
api/core/tools/tool/workflow_tool.py

@@ -1,7 +1,7 @@
 import json
 import logging
 from copy import deepcopy
-from typing import Any, Union
+from typing import Any, Optional, Union
 
 from core.file.file_obj import FileTransferMethod, FileVar
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
@@ -18,6 +18,7 @@ class WorkflowTool(Tool):
     version: str
     workflow_entities: dict[str, Any]
     workflow_call_depth: int
+    thread_pool_id: Optional[str] = None
 
     label: str
 
@@ -57,6 +58,7 @@ class WorkflowTool(Tool):
             invoke_from=self.runtime.invoke_from,
             stream=False,
             call_depth=self.workflow_call_depth + 1,
+            workflow_thread_pool_id=self.thread_pool_id
         )
 
         data = result.get('data', {})

+ 2 - 0
api/core/tools/tool_engine.py

@@ -128,6 +128,7 @@ class ToolEngine:
                         user_id: str,
                         workflow_tool_callback: DifyWorkflowCallbackHandler,
                         workflow_call_depth: int,
+                        thread_pool_id: Optional[str] = None
                         ) -> list[ToolInvokeMessage]:
         """
         Workflow invokes the tool with the given arguments.
@@ -141,6 +142,7 @@ class ToolEngine:
 
             if isinstance(tool, WorkflowTool):
                 tool.workflow_call_depth = workflow_call_depth + 1
+                tool.thread_pool_id = thread_pool_id
 
             if tool.runtime and tool.runtime.runtime_parameters:
                 tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

+ 1 - 2
api/core/tools/tool_manager.py

@@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
-from core.workflow.nodes.tool.entities import ToolEntity
 from extensions.ext_database import db
 from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
 from services.tools.tools_transform_service import ToolTransformService
@@ -249,7 +248,7 @@ class ToolManager:
         return tool_entity
 
     @classmethod
-    def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
+    def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
         """
             get the workflow tool runtime
         """

+ 1 - 0
api/core/tools/utils/message_transformer.py

@@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager
 
 logger = logging.getLogger(__name__)
 
+
 class ToolFileMessageTransformer:
     @classmethod
     def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],

+ 6 - 107
api/core/workflow/callbacks/base_workflow_callback.py

@@ -1,116 +1,15 @@
 from abc import ABC, abstractmethod
-from typing import Any, Optional
 
-from core.app.entities.queue_entities import AppQueueEvent
-from core.workflow.entities.base_node_data_entities import BaseNodeData
-from core.workflow.entities.node_entities import NodeType
+from core.workflow.graph_engine.entities.event import GraphEngineEvent
 
 
 class WorkflowCallback(ABC):
     @abstractmethod
-    def on_workflow_run_started(self) -> None:
+    def on_event(
+            self,
+            event: GraphEngineEvent
+    ) -> None:
         """
-        Workflow run started
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_run_succeeded(self) -> None:
-        """
-        Workflow run succeeded
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_run_failed(self, error: str) -> None:
-        """
-        Workflow run failed
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_node_execute_started(self, node_id: str,
-                                         node_type: NodeType,
-                                         node_data: BaseNodeData,
-                                         node_run_index: int = 1,
-                                         predecessor_node_id: Optional[str] = None) -> None:
-        """
-        Workflow node execute started
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_node_execute_succeeded(self, node_id: str,
-                                           node_type: NodeType,
-                                           node_data: BaseNodeData,
-                                           inputs: Optional[dict] = None,
-                                           process_data: Optional[dict] = None,
-                                           outputs: Optional[dict] = None,
-                                           execution_metadata: Optional[dict] = None) -> None:
-        """
-        Workflow node execute succeeded
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_node_execute_failed(self, node_id: str,
-                                        node_type: NodeType,
-                                        node_data: BaseNodeData,
-                                        error: str,
-                                        inputs: Optional[dict] = None,
-                                        outputs: Optional[dict] = None,
-                                        process_data: Optional[dict] = None) -> None:
-        """
-        Workflow node execute failed
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
-        """
-        Publish text chunk
-        """
-        raise NotImplementedError
-    
-    @abstractmethod
-    def on_workflow_iteration_started(self, 
-                                      node_id: str,
-                                      node_type: NodeType,
-                                      node_run_index: int = 1,
-                                      node_data: Optional[BaseNodeData] = None,
-                                      inputs: Optional[dict] = None,
-                                      predecessor_node_id: Optional[str] = None,
-                                      metadata: Optional[dict] = None) -> None:
-        """
-        Publish iteration started
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_iteration_next(self, node_id: str, 
-                                   node_type: NodeType,
-                                   index: int, 
-                                   node_run_index: int,
-                                   output: Optional[Any],
-                                   ) -> None:
-        """
-        Publish iteration next
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_workflow_iteration_completed(self, node_id: str, 
-                                        node_type: NodeType,
-                                        node_run_index: int,
-                                        outputs: dict) -> None:
-        """
-        Publish iteration completed
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def on_event(self, event: AppQueueEvent) -> None:
-        """
-        Publish event
+        Published event
         """
         raise NotImplementedError

+ 1 - 1
api/core/workflow/entities/base_node_data_entities.py

@@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
     desc: Optional[str] = None
 
 class BaseIterationNodeData(BaseNodeData):
-    start_node_id: str
+    start_node_id: Optional[str] = None
 
 class BaseIterationState(BaseModel):
     iteration_node_id: str

+ 30 - 4
api/core/workflow/entities/node_entities.py

@@ -1,9 +1,9 @@
-from collections.abc import Mapping
 from enum import Enum
 from typing import Any, Optional
 
 from pydantic import BaseModel
 
+from core.model_runtime.entities.llm_entities import LLMUsage
 from models import WorkflowNodeExecutionStatus
 
 
@@ -28,6 +28,7 @@ class NodeType(Enum):
     VARIABLE_ASSIGNER = 'variable-assigner'
     LOOP = 'loop'
     ITERATION = 'iteration'
+    ITERATION_START = 'iteration-start'  # fake start node for iteration
     PARAMETER_EXTRACTOR = 'parameter-extractor'
     CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
 
@@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum):
     TOOL_INFO = 'tool_info'
     ITERATION_ID = 'iteration_id'
     ITERATION_INDEX = 'iteration_index'
+    PARALLEL_ID = 'parallel_id'
+    PARALLEL_START_NODE_ID = 'parallel_start_node_id'
+    PARENT_PARALLEL_ID = 'parent_parallel_id'
+    PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
 
 
 class NodeRunResult(BaseModel):
@@ -65,11 +70,32 @@ class NodeRunResult(BaseModel):
 
     status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
 
-    inputs: Optional[Mapping[str, Any]] = None  # node inputs
-    process_data: Optional[dict] = None  # process data
-    outputs: Optional[Mapping[str, Any]] = None  # node outputs
+    inputs: Optional[dict[str, Any]] = None  # node inputs
+    process_data: Optional[dict[str, Any]] = None  # process data
+    outputs: Optional[dict[str, Any]] = None  # node outputs
     metadata: Optional[dict[NodeRunMetadataKey, Any]] = None  # node metadata
+    llm_usage: Optional[LLMUsage] = None  # llm usage
 
     edge_source_handle: Optional[str] = None  # source handle id of node with multiple branches
 
     error: Optional[str] = None  # error message if status is failed
+
+
+class UserFrom(Enum):
+    """
+    User from
+    """
+    ACCOUNT = "account"
+    END_USER = "end-user"
+
+    @classmethod
+    def value_of(cls, value: str) -> "UserFrom":
+        """
+        Value of
+        :param value: value
+        :return:
+        """
+        for item in cls:
+            if item.value == value:
+                return item
+        raise ValueError(f"Invalid value: {value}")

+ 54 - 32
api/core/workflow/entities/variable_pool.py

@@ -2,6 +2,7 @@ from collections import defaultdict
 from collections.abc import Mapping, Sequence
 from typing import Any, Union
 
+from pydantic import BaseModel, Field, model_validator
 from typing_extensions import deprecated
 
 from core.app.segments import Segment, Variable, factory
@@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
 CONVERSATION_VARIABLE_NODE_ID = "conversation"
 
 
-class VariablePool:
-    def __init__(
-        self,
-        system_variables: Mapping[SystemVariableKey, Any],
-        user_inputs: Mapping[str, Any],
-        environment_variables: Sequence[Variable],
-        conversation_variables: Sequence[Variable] | None = None,
-    ) -> None:
-        # system variables
-        # for example:
-        # {
-        #     'query': 'abc',
-        #     'files': []
-        # }
-
-        # Variable dictionary is a dictionary for looking up variables by their selector.
-        # The first element of the selector is the node id, it's the first-level key in the dictionary.
-        # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
-        # elements of the selector except the first one.
-        self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
-
-        # TODO: This user inputs is not used for pool.
-        self.user_inputs = user_inputs
+class VariablePool(BaseModel):
+    # Variable dictionary is a dictionary for looking up variables by their selector.
+    # The first element of the selector is the node id, it's the first-level key in the dictionary.
+    # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
+    # elements of the selector except the first one.
+    variable_dictionary: dict[str, dict[int, Segment]] = Field(
+        description='Variables mapping',
+        default=defaultdict(dict)
+    )
 
+    # TODO: This user inputs is not used for pool.
+    user_inputs: Mapping[str, Any] = Field(
+        description='User inputs',
+    )
+
+    system_variables: Mapping[SystemVariableKey, Any] = Field(
+        description='System variables',
+    )
+
+    environment_variables: Sequence[Variable] = Field(
+        description="Environment variables.",
+        default_factory=list
+    )
+
+    conversation_variables: Sequence[Variable] | None = None
+
+    @model_validator(mode="after")
+    def val_model_after(self):
+        """
+        Append system variables
+        :return:
+        """
         # Add system variables to the variable pool
-        self.system_variables = system_variables
-        for key, value in system_variables.items():
+        for key, value in self.system_variables.items():
             self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
 
         # Add environment variables to the variable pool
-        for var in environment_variables:
+        for var in self.environment_variables or []:
             self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
 
         # Add conversation variables to the variable pool
-        for var in conversation_variables or []:
+        for var in self.conversation_variables or []:
             self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
 
+        return self
+
     def add(self, selector: Sequence[str], value: Any, /) -> None:
         """
         Adds a variable to the variable pool.
@@ -79,7 +89,7 @@ class VariablePool:
             v = factory.build_segment(value)
 
         hash_key = hash(tuple(selector[1:]))
-        self._variable_dictionary[selector[0]][hash_key] = v
+        self.variable_dictionary[selector[0]][hash_key] = v
 
     def get(self, selector: Sequence[str], /) -> Segment | None:
         """
@@ -97,7 +107,7 @@ class VariablePool:
         if len(selector) < 2:
             raise ValueError("Invalid selector")
         hash_key = hash(tuple(selector[1:]))
-        value = self._variable_dictionary[selector[0]].get(hash_key)
+        value = self.variable_dictionary[selector[0]].get(hash_key)
 
         return value
 
@@ -118,7 +128,7 @@ class VariablePool:
         if len(selector) < 2:
             raise ValueError("Invalid selector")
         hash_key = hash(tuple(selector[1:]))
-        value = self._variable_dictionary[selector[0]].get(hash_key)
+        value = self.variable_dictionary[selector[0]].get(hash_key)
         return value.to_object() if value else None
 
     def remove(self, selector: Sequence[str], /):
@@ -134,7 +144,19 @@ class VariablePool:
         if not selector:
             return
         if len(selector) == 1:
-            self._variable_dictionary[selector[0]] = {}
+            self.variable_dictionary[selector[0]] = {}
             return
         hash_key = hash(tuple(selector[1:]))
-        self._variable_dictionary[selector[0]].pop(hash_key, None)
+        self.variable_dictionary[selector[0]].pop(hash_key, None)
+
+    def remove_node(self, node_id: str, /):
+        """
+        Remove all variables associated with a given node id.
+
+        Args:
+            node_id (str): The node id to remove.
+
+        Returns:
+            None
+        """
+        self.variable_dictionary.pop(node_id, None)

+ 2 - 3
api/core/workflow/entities/workflow_entities.py

@@ -66,8 +66,7 @@ class WorkflowRunState:
         self.variable_pool = variable_pool
 
         self.total_tokens = 0
-        self.workflow_nodes_and_results = []
 
-        self.current_iteration_state = None
         self.workflow_node_steps = 1
-        self.workflow_node_runs = []
+        self.workflow_node_runs = []
+        self.current_iteration_state = None

+ 4 - 6
api/core/workflow/errors.py

@@ -1,10 +1,8 @@
-from core.workflow.entities.node_entities import NodeType
+from core.workflow.nodes.base_node import BaseNode
 
 
 class WorkflowNodeRunFailedError(Exception):
-    def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
-        self.node_id = node_id
-        self.node_type = node_type
-        self.node_title = node_title
+    def __init__(self, node_instance: BaseNode, error: str):
+        self.node_instance = node_instance
         self.error = error
-        super().__init__(f"Node {node_title} run failed: {error}")
+        super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")

+ 0 - 0
api/core/workflow/graph_engine/__init__.py


+ 0 - 0
api/core/workflow/graph_engine/condition_handlers/__init__.py


+ 31 - 0
api/core/workflow/graph_engine/condition_handlers/base_handler.py

@@ -0,0 +1,31 @@
+from abc import ABC, abstractmethod
+
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.run_condition import RunCondition
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+
+
+class RunConditionHandler(ABC):
+    def __init__(self,
+                 init_params: GraphInitParams,
+                 graph: Graph,
+                 condition: RunCondition):
+        self.init_params = init_params
+        self.graph = graph
+        self.condition = condition
+
+    @abstractmethod
+    def check(self,
+              graph_runtime_state: GraphRuntimeState,
+              previous_route_node_state: RouteNodeState
+        ) -> bool:
+        """
+        Check if the condition can be executed
+
+        :param graph_runtime_state: graph runtime state
+        :param previous_route_node_state: previous route node state
+        :return: bool
+        """
+        raise NotImplementedError

+ 28 - 0
api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py

@@ -0,0 +1,28 @@
+from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+
+
+class BranchIdentifyRunConditionHandler(RunConditionHandler):
+
+    def check(self,
+              graph_runtime_state: GraphRuntimeState,
+              previous_route_node_state: RouteNodeState) -> bool:
+        """
+        Check if the condition can be executed
+
+        :param graph_runtime_state: graph runtime state
+        :param previous_route_node_state: previous route node state
+        :return: bool
+        """
+        if not self.condition.branch_identify:
+            raise Exception("Branch identify is required")
+
+        run_result = previous_route_node_state.node_run_result
+        if not run_result:
+            return False
+
+        if not run_result.edge_source_handle:
+            return False
+
+        return self.condition.branch_identify == run_result.edge_source_handle

+ 32 - 0
api/core/workflow/graph_engine/condition_handlers/condition_handler.py

@@ -0,0 +1,32 @@
+from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+from core.workflow.utils.condition.processor import ConditionProcessor
+
+
+class ConditionRunConditionHandlerHandler(RunConditionHandler):
+    def check(self,
+              graph_runtime_state: GraphRuntimeState,
+              previous_route_node_state: RouteNodeState
+        ) -> bool:
+        """
+        Check if the condition can be executed
+
+        :param graph_runtime_state: graph runtime state
+        :param previous_route_node_state: previous route node state
+        :return: bool
+        """
+        if not self.condition.conditions:
+            return True
+
+        # process condition
+        condition_processor = ConditionProcessor()
+        input_conditions, group_result = condition_processor.process_conditions(
+            variable_pool=graph_runtime_state.variable_pool,
+            conditions=self.condition.conditions
+        )
+
+        # Apply the logical operator for the current case
+        compare_result = all(group_result)
+
+        return compare_result

+ 35 - 0
api/core/workflow/graph_engine/condition_handlers/condition_manager.py

@@ -0,0 +1,35 @@
+from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
+from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
+from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.run_condition import RunCondition
+
+
+class ConditionManager:
+    @staticmethod
+    def get_condition_handler(
+            init_params: GraphInitParams,
+            graph: Graph,
+            run_condition: RunCondition
+    ) -> RunConditionHandler:
+        """
+        Get condition handler
+
+        :param init_params: init params
+        :param graph: graph
+        :param run_condition: run condition
+        :return: condition handler
+        """
+        if run_condition.type == "branch_identify":
+            return BranchIdentifyRunConditionHandler(
+                init_params=init_params,
+                graph=graph,
+                condition=run_condition
+            )
+        else:
+            return ConditionRunConditionHandlerHandler(
+                init_params=init_params,
+                graph=graph,
+                condition=run_condition
+            )

+ 0 - 0
api/core/workflow/graph_engine/entities/__init__.py


+ 163 - 0
api/core/workflow/graph_engine/entities/event.py

@@ -0,0 +1,163 @@
+from datetime import datetime
+from typing import Any, Optional
+
+from pydantic import BaseModel, Field
+
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+
+
+class GraphEngineEvent(BaseModel):
+    pass
+
+
+###########################################
+# Graph Events
+###########################################
+
+
+class BaseGraphEvent(GraphEngineEvent):
+    pass
+
+
+class GraphRunStartedEvent(BaseGraphEvent):
+    pass
+
+
+class GraphRunSucceededEvent(BaseGraphEvent):
+    outputs: Optional[dict[str, Any]] = None
+    """outputs"""
+
+
+class GraphRunFailedEvent(BaseGraphEvent):
+    error: str = Field(..., description="failed reason")
+
+
+###########################################
+# Node Events
+###########################################
+
+
+class BaseNodeEvent(GraphEngineEvent):
+    id: str = Field(..., description="node execution id")
+    node_id: str = Field(..., description="node id")
+    node_type: NodeType = Field(..., description="node type")
+    node_data: BaseNodeData = Field(..., description="node data")
+    route_node_state: RouteNodeState = Field(..., description="route node state")
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+
+
+class NodeRunStartedEvent(BaseNodeEvent):
+    predecessor_node_id: Optional[str] = None
+    """predecessor node id"""
+
+
+class NodeRunStreamChunkEvent(BaseNodeEvent):
+    chunk_content: str = Field(..., description="chunk content")
+    from_variable_selector: Optional[list[str]] = None
+    """from variable selector"""
+
+
+class NodeRunRetrieverResourceEvent(BaseNodeEvent):
+    retriever_resources: list[dict] = Field(..., description="retriever resources")
+    context: str = Field(..., description="context")
+
+
+class NodeRunSucceededEvent(BaseNodeEvent):
+    pass
+
+
+class NodeRunFailedEvent(BaseNodeEvent):
+    error: str = Field(..., description="error")
+
+
+###########################################
+# Parallel Branch Events
+###########################################
+
+
+class BaseParallelBranchEvent(GraphEngineEvent):
+    parallel_id: str = Field(..., description="parallel id")
+    """parallel id"""
+    parallel_start_node_id: str = Field(..., description="parallel start node id")
+    """parallel start node id"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+    in_iteration_id: Optional[str] = None
+    """iteration id if node is in iteration"""
+
+
+class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
+    pass
+
+
+class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
+    pass
+
+
+class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
+    error: str = Field(..., description="failed reason")
+
+
+###########################################
+# Iteration Events
+###########################################
+
+
+class BaseIterationEvent(GraphEngineEvent):
+    iteration_id: str = Field(..., description="iteration node execution id")
+    iteration_node_id: str = Field(..., description="iteration node id")
+    iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
+    iteration_node_data: BaseNodeData = Field(..., description="node data")
+    parallel_id: Optional[str] = None
+    """parallel id if node is in parallel"""
+    parallel_start_node_id: Optional[str] = None
+    """parallel start node id if node is in parallel"""
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id if node is in parallel"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id if node is in parallel"""
+
+
+class IterationRunStartedEvent(BaseIterationEvent):
+    start_at: datetime = Field(..., description="start at")
+    inputs: Optional[dict[str, Any]] = None
+    metadata: Optional[dict[str, Any]] = None
+    predecessor_node_id: Optional[str] = None
+
+
+class IterationRunNextEvent(BaseIterationEvent):
+    index: int = Field(..., description="index")
+    pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
+
+
+class IterationRunSucceededEvent(BaseIterationEvent):
+    start_at: datetime = Field(..., description="start at")
+    inputs: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    metadata: Optional[dict[str, Any]] = None
+    steps: int = 0
+
+
+class IterationRunFailedEvent(BaseIterationEvent):
+    start_at: datetime = Field(..., description="start at")
+    inputs: Optional[dict[str, Any]] = None
+    outputs: Optional[dict[str, Any]] = None
+    metadata: Optional[dict[str, Any]] = None
+    steps: int = 0
+    error: str = Field(..., description="failed reason")
+
+
+InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

+ 692 - 0
api/core/workflow/graph_engine/entities/graph.py

@@ -0,0 +1,692 @@
+import uuid
+from collections.abc import Mapping
+from typing import Any, Optional, cast
+
+from pydantic import BaseModel, Field
+
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.graph_engine.entities.run_condition import RunCondition
+from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
+from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
+from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
+from core.workflow.nodes.end.entities import EndStreamParam
+
+
+class GraphEdge(BaseModel):
+    source_node_id: str = Field(..., description="source node id")
+    target_node_id: str = Field(..., description="target node id")
+    run_condition: Optional[RunCondition] = None
+    """run condition"""
+
+
+class GraphParallel(BaseModel):
+    id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
+    start_from_node_id: str = Field(..., description="start from node id")
+    parent_parallel_id: Optional[str] = None
+    """parent parallel id"""
+    parent_parallel_start_node_id: Optional[str] = None
+    """parent parallel start node id"""
+    end_to_node_id: Optional[str] = None
+    """end to node id"""
+
+
+class Graph(BaseModel):
+    root_node_id: str = Field(..., description="root node id of the graph")
+    node_ids: list[str] = Field(default_factory=list, description="graph node ids")
+    node_id_config_mapping: dict[str, dict] = Field(
+        default_factory=list,
+        description="node configs mapping (node id: node config)"
+    )
+    edge_mapping: dict[str, list[GraphEdge]] = Field(
+        default_factory=dict,
+        description="graph edge mapping (source node id: edges)"
+    )
+    reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
+        default_factory=dict,
+        description="reverse graph edge mapping (target node id: edges)"
+    )
+    parallel_mapping: dict[str, GraphParallel] = Field(
+        default_factory=dict,
+        description="graph parallel mapping (parallel id: parallel)"
+    )
+    node_parallel_mapping: dict[str, str] = Field(
+        default_factory=dict,
+        description="graph node parallel mapping (node id: parallel id)"
+    )
+    answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
+        ...,
+        description="answer stream generate routes"
+    )
+    end_stream_param: EndStreamParam = Field(
+        ...,
+        description="end stream param"
+    )
+
+    @classmethod
+    def init(cls,
+             graph_config: Mapping[str, Any],
+             root_node_id: Optional[str] = None) -> "Graph":
+        """
+        Init graph
+
+        :param graph_config: graph config
+        :param root_node_id: root node id
+        :return: graph
+        """
+        # edge configs
+        edge_configs = graph_config.get('edges')
+        if edge_configs is None:
+            edge_configs = []
+
+        edge_configs = cast(list, edge_configs)
+
+        # reorganize edges mapping
+        edge_mapping: dict[str, list[GraphEdge]] = {}
+        reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
+        target_edge_ids = set()
+        for edge_config in edge_configs:
+            source_node_id = edge_config.get('source')
+            if not source_node_id:
+                continue
+
+            if source_node_id not in edge_mapping:
+                edge_mapping[source_node_id] = []
+
+            target_node_id = edge_config.get('target')
+            if not target_node_id:
+                continue
+
+            if target_node_id not in reverse_edge_mapping:
+                reverse_edge_mapping[target_node_id] = []
+
+            # is target node id in source node id edge mapping
+            if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
+                continue
+
+            target_edge_ids.add(target_node_id)
+
+            # parse run condition
+            run_condition = None
+            if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
+                run_condition = RunCondition(
+                    type='branch_identify',
+                    branch_identify=edge_config.get('sourceHandle')
+                )
+
+            graph_edge = GraphEdge(
+                source_node_id=source_node_id,
+                target_node_id=target_node_id,
+                run_condition=run_condition
+            )
+
+            edge_mapping[source_node_id].append(graph_edge)
+            reverse_edge_mapping[target_node_id].append(graph_edge)
+
+        # node configs
+        node_configs = graph_config.get('nodes')
+        if not node_configs:
+            raise ValueError("Graph must have at least one node")
+
+        node_configs = cast(list, node_configs)
+
+        # fetch nodes that have no predecessor node
+        root_node_configs = []
+        all_node_id_config_mapping: dict[str, dict] = {}
+        for node_config in node_configs:
+            node_id = node_config.get('id')
+            if not node_id:
+                continue
+
+            if node_id not in target_edge_ids:
+                root_node_configs.append(node_config)
+
+            all_node_id_config_mapping[node_id] = node_config
+
+        root_node_ids = [node_config.get('id') for node_config in root_node_configs]
+
+        # fetch root node
+        if not root_node_id:
+            # if no root node id, use the START type node as root node
+            root_node_id = next((node_config.get("id") for node_config in root_node_configs
+                                 if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
+
+        if not root_node_id or root_node_id not in root_node_ids:
+            raise ValueError(f"Root node id {root_node_id} not found in the graph")
+        
+        # Check whether it is connected to the previous node
+        cls._check_connected_to_previous_node(
+            route=[root_node_id],
+            edge_mapping=edge_mapping
+        )
+
+        # fetch all node ids from root node
+        node_ids = [root_node_id]
+        cls._recursively_add_node_ids(
+            node_ids=node_ids,
+            edge_mapping=edge_mapping,
+            node_id=root_node_id
+        )
+
+        node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
+
+        # init parallel mapping
+        parallel_mapping: dict[str, GraphParallel] = {}
+        node_parallel_mapping: dict[str, str] = {}
+        cls._recursively_add_parallels(
+            edge_mapping=edge_mapping,
+            reverse_edge_mapping=reverse_edge_mapping,
+            start_node_id=root_node_id,
+            parallel_mapping=parallel_mapping,
+            node_parallel_mapping=node_parallel_mapping
+        )
+
+        # Check if it exceeds N layers of parallel
+        for parallel in parallel_mapping.values():
+            if parallel.parent_parallel_id:
+                cls._check_exceed_parallel_limit(
+                    parallel_mapping=parallel_mapping,
+                    level_limit=3,
+                    parent_parallel_id=parallel.parent_parallel_id
+                )
+
+        # init answer stream generate routes
+        answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
+            node_id_config_mapping=node_id_config_mapping,
+            reverse_edge_mapping=reverse_edge_mapping
+        )
+
+        # init end stream param
+        end_stream_param = EndStreamGeneratorRouter.init(
+            node_id_config_mapping=node_id_config_mapping,
+            reverse_edge_mapping=reverse_edge_mapping,
+            node_parallel_mapping=node_parallel_mapping
+        )
+
+        # init graph
+        graph = cls(
+            root_node_id=root_node_id,
+            node_ids=node_ids,
+            node_id_config_mapping=node_id_config_mapping,
+            edge_mapping=edge_mapping,
+            reverse_edge_mapping=reverse_edge_mapping,
+            parallel_mapping=parallel_mapping,
+            node_parallel_mapping=node_parallel_mapping,
+            answer_stream_generate_routes=answer_stream_generate_routes,
+            end_stream_param=end_stream_param
+        )
+
+        return graph
+
+    def add_extra_edge(self, source_node_id: str,
+                       target_node_id: str,
+                       run_condition: Optional[RunCondition] = None) -> None:
+        """
+        Add extra edge to the graph
+
+        :param source_node_id: source node id
+        :param target_node_id: target node id
+        :param run_condition: run condition
+        """
+        if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
+            return
+
+        if source_node_id not in self.edge_mapping:
+            self.edge_mapping[source_node_id] = []
+
+        if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
+            return
+
+        graph_edge = GraphEdge(
+            source_node_id=source_node_id,
+            target_node_id=target_node_id,
+            run_condition=run_condition
+        )
+
+        self.edge_mapping[source_node_id].append(graph_edge)
+
+    def get_leaf_node_ids(self) -> list[str]:
+        """
+        Get leaf node ids of the graph
+
+        :return: leaf node ids
+        """
+        leaf_node_ids = []
+        for node_id in self.node_ids:
+            if node_id not in self.edge_mapping:
+                leaf_node_ids.append(node_id)
+            elif (len(self.edge_mapping[node_id]) == 1
+                  and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
+                leaf_node_ids.append(node_id)
+
+        return leaf_node_ids
+
+    @classmethod
+    def _recursively_add_node_ids(cls,
+                                  node_ids: list[str],
+                                  edge_mapping: dict[str, list[GraphEdge]],
+                                  node_id: str) -> None:
+        """
+        Recursively add node ids
+
+        :param node_ids: node ids
+        :param edge_mapping: edge mapping
+        :param node_id: node id
+        """
+        for graph_edge in edge_mapping.get(node_id, []):
+            if graph_edge.target_node_id in node_ids:
+                continue
+
+            node_ids.append(graph_edge.target_node_id)
+            cls._recursively_add_node_ids(
+                node_ids=node_ids,
+                edge_mapping=edge_mapping,
+                node_id=graph_edge.target_node_id
+            )
+
+    @classmethod
+    def _check_connected_to_previous_node(
+        cls, 
+        route: list[str],
+        edge_mapping: dict[str, list[GraphEdge]]
+    ) -> None:
+        """
+        Check whether it is connected to the previous node
+        """
+        last_node_id = route[-1]
+
+        for graph_edge in edge_mapping.get(last_node_id, []):
+            if not graph_edge.target_node_id:
+                continue
+
+            if graph_edge.target_node_id in route:
+                raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
+
+            new_route = route[:]
+            new_route.append(graph_edge.target_node_id)
+            cls._check_connected_to_previous_node(
+                route=new_route,
+                edge_mapping=edge_mapping,
+            )
+
+    @classmethod
+    def _recursively_add_parallels(
+        cls,
+        edge_mapping: dict[str, list[GraphEdge]],
+        reverse_edge_mapping: dict[str, list[GraphEdge]],
+        start_node_id: str,
+        parallel_mapping: dict[str, GraphParallel],
+        node_parallel_mapping: dict[str, str],
+        parent_parallel: Optional[GraphParallel] = None
+    ) -> None:
+        """
+        Recursively add parallel ids
+
+        :param edge_mapping: edge mapping
+        :param start_node_id: start from node id
+        :param parallel_mapping: parallel mapping
+        :param node_parallel_mapping: node parallel mapping
+        :param parent_parallel: parent parallel
+        """
+        target_node_edges = edge_mapping.get(start_node_id, [])
+        parallel = None
+        if len(target_node_edges) > 1:
+            # fetch all node ids in current parallels
+            parallel_branch_node_ids = []
+            condition_edge_mappings = {}
+            for graph_edge in target_node_edges:
+                if graph_edge.run_condition is None:
+                    parallel_branch_node_ids.append(graph_edge.target_node_id)
+                else:
+                    condition_hash = graph_edge.run_condition.hash
+                    if not condition_hash in condition_edge_mappings:
+                        condition_edge_mappings[condition_hash] = []
+
+                    condition_edge_mappings[condition_hash].append(graph_edge)
+
+            for _, graph_edges in condition_edge_mappings.items():
+                if len(graph_edges) > 1:
+                    for graph_edge in graph_edges:
+                        parallel_branch_node_ids.append(graph_edge.target_node_id)
+
+            # any target node id in node_parallel_mapping
+            if parallel_branch_node_ids:
+                parent_parallel_id = parent_parallel.id if parent_parallel else None
+
+                parallel = GraphParallel(
+                    start_from_node_id=start_node_id,
+                    parent_parallel_id=parent_parallel.id if parent_parallel else None,
+                    parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
+                )
+                parallel_mapping[parallel.id] = parallel
+
+                in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
+                    edge_mapping=edge_mapping,
+                    reverse_edge_mapping=reverse_edge_mapping,
+                    parallel_branch_node_ids=parallel_branch_node_ids
+                )
+
+                # collect all branches node ids
+                parallel_node_ids = []
+                for _, node_ids in in_branch_node_ids.items():
+                    for node_id in node_ids:
+                        in_parent_parallel = True
+                        if parent_parallel_id:
+                            in_parent_parallel = False
+                            for parallel_node_id, parallel_id in node_parallel_mapping.items():
+                                if parallel_id == parent_parallel_id and parallel_node_id == node_id:
+                                    in_parent_parallel = True
+                                    break
+
+                        if in_parent_parallel:
+                            parallel_node_ids.append(node_id)
+                            node_parallel_mapping[node_id] = parallel.id
+
+                outside_parallel_target_node_ids = set()
+                for node_id in parallel_node_ids:
+                    if node_id == parallel.start_from_node_id:
+                        continue
+
+                    node_edges = edge_mapping.get(node_id)
+                    if not node_edges:
+                        continue
+
+                    if len(node_edges) > 1:
+                        continue
+
+                    target_node_id = node_edges[0].target_node_id
+                    if target_node_id in parallel_node_ids:
+                        continue
+
+                    if parent_parallel_id:
+                        parent_parallel = parallel_mapping.get(parent_parallel_id)
+                        if not parent_parallel:
+                            continue
+
+                    if (
+                        (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
+                        or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
+                        or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
+                    ):
+                        outside_parallel_target_node_ids.add(target_node_id)
+
+                if len(outside_parallel_target_node_ids) == 1:
+                    if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
+                        parallel.end_to_node_id = None
+                    else:
+                        parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
+
+        for graph_edge in target_node_edges:
+            current_parallel = None
+            if parallel:
+                current_parallel = parallel
+            elif parent_parallel:
+                if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
+                    current_parallel = parent_parallel
+                else:
+                    # fetch parent parallel's parent parallel
+                    parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
+                    if parent_parallel_parent_parallel_id:
+                        parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
+                        if (
+                            parent_parallel_parent_parallel 
+                            and (
+                                not parent_parallel_parent_parallel.end_to_node_id
+                                 or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
+                            )
+                        ):
+                            current_parallel = parent_parallel_parent_parallel
+
+            cls._recursively_add_parallels(
+                edge_mapping=edge_mapping,
+                reverse_edge_mapping=reverse_edge_mapping,
+                start_node_id=graph_edge.target_node_id,
+                parallel_mapping=parallel_mapping,
+                node_parallel_mapping=node_parallel_mapping,
+                parent_parallel=current_parallel
+            )
+
+    @classmethod
+    def _check_exceed_parallel_limit(
+        cls,
+        parallel_mapping: dict[str, GraphParallel],
+        level_limit: int,
+        parent_parallel_id: str,
+        current_level: int = 1
+    ) -> None:
+        """
+        Check if it exceeds N layers of parallel
+        """
+        parent_parallel = parallel_mapping.get(parent_parallel_id)
+        if not parent_parallel:
+            return
+        
+        current_level += 1
+        if current_level > level_limit:
+            raise ValueError(f"Exceeds {level_limit} layers of parallel")
+        
+        if parent_parallel.parent_parallel_id:
+            cls._check_exceed_parallel_limit(
+                parallel_mapping=parallel_mapping,
+                level_limit=level_limit,
+                parent_parallel_id=parent_parallel.parent_parallel_id,
+                current_level=current_level
+            )
+
+    @classmethod
+    def _recursively_add_parallel_node_ids(cls,
+                                           branch_node_ids: list[str],
+                                           edge_mapping: dict[str, list[GraphEdge]],
+                                           merge_node_id: str,
+                                           start_node_id: str) -> None:
+        """
+        Recursively add node ids
+
+        :param branch_node_ids: in branch node ids
+        :param edge_mapping: edge mapping
+        :param merge_node_id: merge node id
+        :param start_node_id: start node id
+        """
+        for graph_edge in edge_mapping.get(start_node_id, []):
+            if (graph_edge.target_node_id != merge_node_id
+                    and graph_edge.target_node_id not in branch_node_ids):
+                branch_node_ids.append(graph_edge.target_node_id)
+                cls._recursively_add_parallel_node_ids(
+                    branch_node_ids=branch_node_ids,
+                    edge_mapping=edge_mapping,
+                    merge_node_id=merge_node_id,
+                    start_node_id=graph_edge.target_node_id
+                )
+
+    @classmethod
+    def _fetch_all_node_ids_in_parallels(cls,
+                                         edge_mapping: dict[str, list[GraphEdge]],
+                                         reverse_edge_mapping: dict[str, list[GraphEdge]],
+                                         parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
+        """
+        Fetch all node ids in parallels
+        """
+        routes_node_ids: dict[str, list[str]] = {}
+        for parallel_branch_node_id in parallel_branch_node_ids:
+            routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
+
+            # fetch routes node ids
+            cls._recursively_fetch_routes(
+                edge_mapping=edge_mapping,
+                start_node_id=parallel_branch_node_id,
+                routes_node_ids=routes_node_ids[parallel_branch_node_id]
+            )
+
+        # fetch leaf node ids from routes node ids
+        leaf_node_ids: dict[str, list[str]] = {}
+        merge_branch_node_ids: dict[str, list[str]] = {}
+        for branch_node_id, node_ids in routes_node_ids.items():
+            for node_id in node_ids:
+                if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
+                    if branch_node_id not in leaf_node_ids:
+                        leaf_node_ids[branch_node_id] = []
+
+                    leaf_node_ids[branch_node_id].append(node_id)
+
+                for branch_node_id2, inner_route2 in routes_node_ids.items():
+                    if (
+                        branch_node_id != branch_node_id2 
+                        and node_id in inner_route2
+                        and len(reverse_edge_mapping.get(node_id, [])) > 1
+                        and cls._is_node_in_routes(
+                            reverse_edge_mapping=reverse_edge_mapping,
+                            start_node_id=node_id,
+                            routes_node_ids=routes_node_ids
+                        )
+                    ):
+                        if node_id not in merge_branch_node_ids:
+                            merge_branch_node_ids[node_id] = []
+
+                        if branch_node_id2 not in merge_branch_node_ids[node_id]:
+                            merge_branch_node_ids[node_id].append(branch_node_id2)
+
+        # sorted merge_branch_node_ids by branch_node_ids length desc
+        merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
+
+        duplicate_end_node_ids = {}
+        for node_id, branch_node_ids in merge_branch_node_ids.items():
+            for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
+                if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
+                    if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
+                        duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
+                
+        for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
+            # check which node is after
+            if cls._is_node2_after_node1(
+                node1_id=node_id,
+                node2_id=node_id2,
+                edge_mapping=edge_mapping
+            ):
+                if node_id in merge_branch_node_ids:
+                    del merge_branch_node_ids[node_id2]
+            elif cls._is_node2_after_node1(
+                node1_id=node_id2,
+                node2_id=node_id,
+                edge_mapping=edge_mapping
+            ):
+                if node_id2 in merge_branch_node_ids:
+                    del merge_branch_node_ids[node_id]
+
+        branches_merge_node_ids: dict[str, str] = {}
+        for node_id, branch_node_ids in merge_branch_node_ids.items():
+            if len(branch_node_ids) <= 1:
+                continue
+
+            for branch_node_id in branch_node_ids:
+                if branch_node_id in branches_merge_node_ids:
+                    continue
+
+                branches_merge_node_ids[branch_node_id] = node_id
+
+        in_branch_node_ids: dict[str, list[str]] = {}
+        for branch_node_id, node_ids in routes_node_ids.items():
+            in_branch_node_ids[branch_node_id] = []
+            if branch_node_id not in branches_merge_node_ids:
+                # all node ids in current branch is in this thread
+                in_branch_node_ids[branch_node_id].append(branch_node_id)
+                in_branch_node_ids[branch_node_id].extend(node_ids)
+            else:
+                merge_node_id = branches_merge_node_ids[branch_node_id]
+                if merge_node_id != branch_node_id:
+                    in_branch_node_ids[branch_node_id].append(branch_node_id)
+
+                # fetch all node ids from branch_node_id and merge_node_id
+                cls._recursively_add_parallel_node_ids(
+                    branch_node_ids=in_branch_node_ids[branch_node_id],
+                    edge_mapping=edge_mapping,
+                    merge_node_id=merge_node_id,
+                    start_node_id=branch_node_id
+                )
+
+        return in_branch_node_ids
+
+    @classmethod
+    def _recursively_fetch_routes(cls,
+                                  edge_mapping: dict[str, list[GraphEdge]],
+                                  start_node_id: str,
+                                  routes_node_ids: list[str]) -> None:
+        """
+        Recursively fetch route
+        """
+        if start_node_id not in edge_mapping:
+            return
+
+        for graph_edge in edge_mapping[start_node_id]:
+            # find next node ids
+            if graph_edge.target_node_id not in routes_node_ids:
+                routes_node_ids.append(graph_edge.target_node_id)
+
+                cls._recursively_fetch_routes(
+                    edge_mapping=edge_mapping,
+                    start_node_id=graph_edge.target_node_id,
+                    routes_node_ids=routes_node_ids
+                )
+
+    @classmethod
+    def _is_node_in_routes(cls,
+                           reverse_edge_mapping: dict[str, list[GraphEdge]],
+                           start_node_id: str,
+                           routes_node_ids: dict[str, list[str]]) -> bool:
+        """
+        Recursively check if the node is in the routes
+        """
+        if start_node_id not in reverse_edge_mapping:
+            return False
+        
+        all_routes_node_ids = set()
+        parallel_start_node_ids: dict[str, list[str]] = {}
+        for branch_node_id, node_ids in routes_node_ids.items():
+            for node_id in node_ids:
+                all_routes_node_ids.add(node_id)
+            
+            if branch_node_id in reverse_edge_mapping:
+                for graph_edge in reverse_edge_mapping[branch_node_id]:
+                    if graph_edge.source_node_id not in parallel_start_node_ids:
+                        parallel_start_node_ids[graph_edge.source_node_id] = []
+
+                    parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
+
+        parallel_start_node_id = None
+        for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
+            if set(branch_node_ids) == set(routes_node_ids.keys()):
+                parallel_start_node_id = p_start_node_id
+                return True
+            
+        if not parallel_start_node_id:
+            raise Exception("Parallel start node id not found")
+        
+        for graph_edge in reverse_edge_mapping[start_node_id]:
+            if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
+                return False
+            
+        return True
+
+    @classmethod
+    def _is_node2_after_node1(
+        cls, 
+        node1_id: str, 
+        node2_id: str,
+        edge_mapping: dict[str, list[GraphEdge]]
+    ) -> bool:
+        """
+        is node2 after node1
+        """
+        if node1_id not in edge_mapping:
+            return False
+        
+        for graph_edge in edge_mapping[node1_id]:
+            if graph_edge.target_node_id == node2_id:
+                return True
+            
+            if cls._is_node2_after_node1(
+                node1_id=graph_edge.target_node_id,
+                node2_id=node2_id,
+                edge_mapping=edge_mapping
+            ):
+                return True
+            
+        return False

+ 21 - 0
api/core/workflow/graph_engine/entities/graph_init_params.py

@@ -0,0 +1,21 @@
+from collections.abc import Mapping
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import UserFrom
+from models.workflow import WorkflowType
+
+
+class GraphInitParams(BaseModel):
+    # init params
+    tenant_id: str = Field(..., description="tenant / workspace id")
+    app_id: str = Field(..., description="app id")
+    workflow_type: WorkflowType = Field(..., description="workflow type")
+    workflow_id: str = Field(..., description="workflow id")
+    graph_config: Mapping[str, Any] = Field(..., description="graph config")
+    user_id: str = Field(..., description="user id")
+    user_from: UserFrom = Field(..., description="user from, account or end-user")
+    invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
+    call_depth: int = Field(..., description="call depth")

+ 27 - 0
api/core/workflow/graph_engine/entities/graph_runtime_state.py

@@ -0,0 +1,27 @@
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
+
+
+class GraphRuntimeState(BaseModel):
+    variable_pool: VariablePool = Field(..., description="variable pool")
+    """variable pool"""
+    
+    start_at: float = Field(..., description="start time")
+    """start time"""
+    total_tokens: int = 0
+    """total tokens"""
+    llm_usage: LLMUsage = LLMUsage.empty_usage()
+    """llm usage info"""
+    outputs: dict[str, Any] = {}
+    """outputs"""
+
+    node_run_steps: int = 0
+    """node run steps"""
+
+    node_run_state: RuntimeRouteState = RuntimeRouteState()
+    """node run state"""

+ 13 - 0
api/core/workflow/graph_engine/entities/next_graph_node.py

@@ -0,0 +1,13 @@
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.workflow.graph_engine.entities.graph import GraphParallel
+
+
+class NextGraphNode(BaseModel):
+    node_id: str
+    """next node id"""
+
+    parallel: Optional[GraphParallel] = None
+    """parallel"""

+ 21 - 0
api/core/workflow/graph_engine/entities/run_condition.py

@@ -0,0 +1,21 @@
+import hashlib
+from typing import Literal, Optional
+
+from pydantic import BaseModel
+
+from core.workflow.utils.condition.entities import Condition
+
+
+class RunCondition(BaseModel):
+    type: Literal["branch_identify", "condition"]
+    """condition type"""
+
+    branch_identify: Optional[str] = None
+    """branch identify like: sourceHandle, required when type is branch_identify"""
+
+    conditions: Optional[list[Condition]] = None
+    """conditions to run the node, required when type is condition"""
+
+    @property
+    def hash(self) -> str:
+        return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

+ 111 - 0
api/core/workflow/graph_engine/entities/runtime_route_state.py

@@ -0,0 +1,111 @@
+import uuid
+from datetime import datetime, timezone
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+from core.workflow.entities.node_entities import NodeRunResult
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class RouteNodeState(BaseModel):
+    class Status(Enum):
+        RUNNING = "running"
+        SUCCESS = "success"
+        FAILED = "failed"
+        PAUSED = "paused"
+
+    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+    """node state id"""
+
+    node_id: str
+    """node id"""
+
+    node_run_result: Optional[NodeRunResult] = None
+    """node run result"""
+
+    status: Status = Status.RUNNING
+    """node status"""
+
+    start_at: datetime
+    """start time"""
+
+    paused_at: Optional[datetime] = None
+    """paused time"""
+
+    finished_at: Optional[datetime] = None
+    """finished time"""
+
+    failed_reason: Optional[str] = None
+    """failed reason"""
+
+    paused_by: Optional[str] = None
+    """paused by"""
+
+    index: int = 1
+
+    def set_finished(self, run_result: NodeRunResult) -> None:
+        """
+        Node finished
+
+        :param run_result: run result
+        """
+        if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
+            raise Exception(f"Route state {self.id} already finished")
+
+        if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+            self.status = RouteNodeState.Status.SUCCESS
+        elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
+            self.status = RouteNodeState.Status.FAILED
+            self.failed_reason = run_result.error
+        else:
+            raise Exception(f"Invalid route status {run_result.status}")
+
+        self.node_run_result = run_result
+        self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+
+
+class RuntimeRouteState(BaseModel):
+    routes: dict[str, list[str]] = Field(
+        default_factory=dict,
+        description="graph state routes (source_node_state_id: target_node_state_id)"
+    )
+
+    node_state_mapping: dict[str, RouteNodeState] = Field(
+        default_factory=dict,
+        description="node state mapping (route_node_state_id: route_node_state)"
+    )
+
+    def create_node_state(self, node_id: str) -> RouteNodeState:
+        """
+        Create node state
+
+        :param node_id: node id
+        """
+        state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
+        self.node_state_mapping[state.id] = state
+        return state
+
+    def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
+        """
+        Add route to the graph state
+
+        :param source_node_state_id: source node state id
+        :param target_node_state_id: target node state id
+        """
+        if source_node_state_id not in self.routes:
+            self.routes[source_node_state_id] = []
+
+        self.routes[source_node_state_id].append(target_node_state_id)
+
+    def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
+            -> list[RouteNodeState]:
+        """
+        Get routes with node state by source node id
+
+        :param source_node_state_id: source node state id
+        :return: routes with node state
+        """
+        return [self.node_state_mapping[target_state_id]
+                for target_state_id in self.routes.get(source_node_state_id, [])]

+ 716 - 0
api/core/workflow/graph_engine/graph_engine.py

@@ -0,0 +1,716 @@
+import logging
+import queue
+import time
+import uuid
+from collections.abc import Generator, Mapping
+from concurrent.futures import ThreadPoolExecutor, wait
+from typing import Any, Optional
+
+from flask import Flask, current_app
+
+from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import (
+    NodeRunMetadataKey,
+    NodeType,
+    UserFrom,
+)
+from core.workflow.entities.variable_pool import VariablePool, VariableValue
+from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
+from core.workflow.graph_engine.entities.event import (
+    BaseIterationEvent,
+    GraphEngineEvent,
+    GraphRunFailedEvent,
+    GraphRunStartedEvent,
+    GraphRunSucceededEvent,
+    NodeRunFailedEvent,
+    NodeRunRetrieverResourceEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+    ParallelBranchRunFailedEvent,
+    ParallelBranchRunStartedEvent,
+    ParallelBranchRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
+from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
+from core.workflow.nodes.node_mapping import node_classes
+from extensions.ext_database import db
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
+
+logger = logging.getLogger(__name__)
+
+
+class GraphEngineThreadPool(ThreadPoolExecutor):
+    def __init__(self, max_workers=None, thread_name_prefix='',
+                 initializer=None, initargs=(), max_submit_count=100) -> None:
+        super().__init__(max_workers, thread_name_prefix, initializer, initargs)
+        self.max_submit_count = max_submit_count
+        self.submit_count = 0
+
+    def submit(self, fn, *args, **kwargs):
+        self.submit_count += 1
+        self.check_is_full()
+        
+        return super().submit(fn, *args, **kwargs)
+    
+    def check_is_full(self) -> None:
+        print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
+        if self.submit_count > self.max_submit_count:
+            raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
+
+
+class GraphEngine:
+    workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
+
+    def __init__(
+            self,
+            tenant_id: str,
+            app_id: str,
+            workflow_type: WorkflowType,
+            workflow_id: str,
+            user_id: str,
+            user_from: UserFrom,
+            invoke_from: InvokeFrom,
+            call_depth: int,
+            graph: Graph,
+            graph_config: Mapping[str, Any],
+            variable_pool: VariablePool,
+            max_execution_steps: int,
+            max_execution_time: int,
+            thread_pool_id: Optional[str] = None
+    ) -> None:
+        thread_pool_max_submit_count = 100
+        thread_pool_max_workers = 10
+
+        ## init thread pool
+        if thread_pool_id:
+            if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
+                raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
+            
+            self.thread_pool_id = thread_pool_id
+            self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
+            self.is_main_thread_pool = False
+        else:
+            self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
+            self.thread_pool_id = str(uuid.uuid4())
+            self.is_main_thread_pool = True
+            GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
+
+        self.graph = graph
+        self.init_params = GraphInitParams(
+            tenant_id=tenant_id,
+            app_id=app_id,
+            workflow_type=workflow_type,
+            workflow_id=workflow_id,
+            graph_config=graph_config,
+            user_id=user_id,
+            user_from=user_from,
+            invoke_from=invoke_from,
+            call_depth=call_depth
+        )
+
+        self.graph_runtime_state = GraphRuntimeState(
+            variable_pool=variable_pool,
+            start_at=time.perf_counter()
+        )
+
+        self.max_execution_steps = max_execution_steps
+        self.max_execution_time = max_execution_time
+
+    def run(self) -> Generator[GraphEngineEvent, None, None]:
+        # trigger graph run start event
+        yield GraphRunStartedEvent()
+
+        try:
+            stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
+            if self.init_params.workflow_type == WorkflowType.CHAT:
+                stream_processor_cls = AnswerStreamProcessor
+            else:
+                stream_processor_cls = EndStreamProcessor
+
+            stream_processor = stream_processor_cls(
+                graph=self.graph,
+                variable_pool=self.graph_runtime_state.variable_pool
+            )
+
+            # run graph
+            generator = stream_processor.process(
+                self._run(start_node_id=self.graph.root_node_id)
+            )
+
+            for item in generator:
+                try:
+                    yield item
+                    if isinstance(item, NodeRunFailedEvent):
+                        yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
+                        return
+                    elif isinstance(item, NodeRunSucceededEvent):
+                        if item.node_type == NodeType.END:
+                            self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
+                                             if item.route_node_state.node_run_result
+                                                and item.route_node_state.node_run_result.outputs
+                                             else {})
+                        elif item.node_type == NodeType.ANSWER:
+                            if "answer" not in self.graph_runtime_state.outputs:
+                                self.graph_runtime_state.outputs["answer"] = ""
+
+                            self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
+                                                               if item.route_node_state.node_run_result
+                                                                  and item.route_node_state.node_run_result.outputs
+                                                               else "")
+                        
+                            self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
+                except Exception as e:
+                    logger.exception(f"Graph run failed: {str(e)}")
+                    yield GraphRunFailedEvent(error=str(e))
+                    return
+
+            # trigger graph run success event
+            yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
+        except GraphRunFailedError as e:
+            yield GraphRunFailedEvent(error=e.error)
+            return
+        except Exception as e:
+            logger.exception("Unknown Error when graph running")
+            yield GraphRunFailedEvent(error=str(e))
+            raise e
+        finally:
+            if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
+                del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
+
+    def _run(
+            self, 
+            start_node_id: str, 
+            in_parallel_id: Optional[str] = None,
+            parent_parallel_id: Optional[str] = None,
+            parent_parallel_start_node_id: Optional[str] = None
+        ) -> Generator[GraphEngineEvent, None, None]:
+        parallel_start_node_id = None
+        if in_parallel_id:
+            parallel_start_node_id = start_node_id
+
+        next_node_id = start_node_id
+        previous_route_node_state: Optional[RouteNodeState] = None
+        while True:
+            # max steps reached
+            if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
+                raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
+
+            # or max execution time reached
+            if self._is_timed_out(
+                    start_at=self.graph_runtime_state.start_at,
+                    max_execution_time=self.max_execution_time
+            ):
+                raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
+
+            # init route node state
+            route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
+                node_id=next_node_id
+            )
+
+            # get node config
+            node_id = route_node_state.node_id
+            node_config = self.graph.node_id_config_mapping.get(node_id)
+            if not node_config:
+                raise GraphRunFailedError(f'Node {node_id} config not found.')
+
+            # convert to specific node
+            node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
+            node_cls = node_classes.get(node_type)
+            if not node_cls:
+                raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
+
+            previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
+
+            # init workflow run state
+            node_instance = node_cls(  # type: ignore
+                id=route_node_state.id,
+                config=node_config,
+                graph_init_params=self.init_params,
+                graph=self.graph,
+                graph_runtime_state=self.graph_runtime_state,
+                previous_node_id=previous_node_id,
+                thread_pool_id=self.thread_pool_id
+            )
+
+            try:
+                # run node
+                generator = self._run_node(
+                    node_instance=node_instance,
+                    route_node_state=route_node_state,
+                    parallel_id=in_parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id
+                )
+
+                for item in generator:
+                    if isinstance(item, NodeRunStartedEvent):
+                        self.graph_runtime_state.node_run_steps += 1
+                        item.route_node_state.index = self.graph_runtime_state.node_run_steps
+
+                    yield item
+
+                self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
+
+                # append route
+                if previous_route_node_state:
+                    self.graph_runtime_state.node_run_state.add_route(
+                        source_node_state_id=previous_route_node_state.id,
+                        target_node_state_id=route_node_state.id
+                    )
+            except Exception as e:
+                route_node_state.status = RouteNodeState.Status.FAILED
+                route_node_state.failed_reason = str(e)
+                yield NodeRunFailedEvent(
+                    error=str(e),
+                    id=node_instance.id,
+                    node_id=next_node_id,
+                    node_type=node_type,
+                    node_data=node_instance.node_data,
+                    route_node_state=route_node_state,
+                    parallel_id=in_parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id
+                )
+                raise e
+
+            # It may not be necessary, but it is necessary. :)
+            if (self.graph.node_id_config_mapping[next_node_id]
+                    .get("data", {}).get("type", "").lower() == NodeType.END.value):
+                break
+
+            previous_route_node_state = route_node_state
+
+            # get next node ids
+            edge_mappings = self.graph.edge_mapping.get(next_node_id)
+            if not edge_mappings:
+                break
+
+            if len(edge_mappings) == 1:
+                edge = edge_mappings[0]
+
+                if edge.run_condition:
+                    result = ConditionManager.get_condition_handler(
+                        init_params=self.init_params,
+                        graph=self.graph,
+                        run_condition=edge.run_condition,
+                    ).check(
+                        graph_runtime_state=self.graph_runtime_state,
+                        previous_route_node_state=previous_route_node_state
+                    )
+
+                    if not result:
+                        break
+
+                next_node_id = edge.target_node_id
+            else:
+                final_node_id = None
+
+                if any(edge.run_condition for edge in edge_mappings):
+                    # if nodes has run conditions, get node id which branch to take based on the run condition results
+                    condition_edge_mappings = {}
+                    for edge in edge_mappings:
+                        if edge.run_condition:
+                            run_condition_hash = edge.run_condition.hash
+                            if run_condition_hash not in condition_edge_mappings:
+                                condition_edge_mappings[run_condition_hash] = []
+
+                            condition_edge_mappings[run_condition_hash].append(edge)
+
+                    for _, sub_edge_mappings in condition_edge_mappings.items():
+                        if len(sub_edge_mappings) == 0:
+                            continue
+
+                        edge = sub_edge_mappings[0]
+
+                        result = ConditionManager.get_condition_handler(
+                            init_params=self.init_params,
+                            graph=self.graph,
+                            run_condition=edge.run_condition,
+                        ).check(
+                            graph_runtime_state=self.graph_runtime_state,
+                            previous_route_node_state=previous_route_node_state,
+                        )
+
+                        if not result:
+                            continue
+                        
+                        if len(sub_edge_mappings) == 1:
+                            final_node_id = edge.target_node_id
+                        else:
+                            parallel_generator = self._run_parallel_branches(
+                                edge_mappings=sub_edge_mappings,
+                                in_parallel_id=in_parallel_id,
+                                parallel_start_node_id=parallel_start_node_id
+                            )
+
+                            for item in parallel_generator:
+                                if isinstance(item, str):
+                                    final_node_id = item
+                                else:
+                                    yield item
+
+                        break
+
+                    if not final_node_id:
+                        break
+
+                    next_node_id = final_node_id
+                else:
+                    parallel_generator = self._run_parallel_branches(
+                        edge_mappings=edge_mappings,
+                        in_parallel_id=in_parallel_id,
+                        parallel_start_node_id=parallel_start_node_id
+                    )
+
+                    for item in parallel_generator:
+                        if isinstance(item, str):
+                            final_node_id = item
+                        else:
+                            yield item
+
+                    if not final_node_id:
+                        break
+
+                    next_node_id = final_node_id
+
+            if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
+                break
+
+    def _run_parallel_branches(
+            self,
+            edge_mappings: list[GraphEdge],
+            in_parallel_id: Optional[str] = None,
+            parallel_start_node_id: Optional[str] = None,
+    ) -> Generator[GraphEngineEvent | str, None, None]:
+        # if nodes has no run conditions, parallel run all nodes
+        parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
+        if not parallel_id:
+            node_id = edge_mappings[0].target_node_id
+            node_config = self.graph.node_id_config_mapping.get(node_id)
+            if not node_config:
+                raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
+
+            node_title = node_config.get('data', {}).get('title')
+            raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
+
+        parallel = self.graph.parallel_mapping.get(parallel_id)
+        if not parallel:
+            raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
+
+        # run parallel nodes, run in new thread and use queue to get results
+        q: queue.Queue = queue.Queue()
+
+        # Create a list to store the threads
+        futures = []
+
+        # new thread
+        for edge in edge_mappings:
+            if (
+                edge.target_node_id not in self.graph.node_parallel_mapping
+                or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
+            ):
+                continue
+
+            futures.append(
+                self.thread_pool.submit(self._run_parallel_node, **{
+                    'flask_app': current_app._get_current_object(),  # type: ignore[attr-defined]
+                    'q': q,
+                    'parallel_id': parallel_id,
+                    'parallel_start_node_id': edge.target_node_id,
+                    'parent_parallel_id': in_parallel_id,
+                    'parent_parallel_start_node_id': parallel_start_node_id,
+                })
+            )
+
+        succeeded_count = 0
+        while True:
+            try:
+                event = q.get(timeout=1)
+                if event is None:
+                    break
+
+                yield event
+                if event.parallel_id == parallel_id:
+                    if isinstance(event, ParallelBranchRunSucceededEvent):
+                        succeeded_count += 1
+                        if succeeded_count == len(futures):
+                            q.put(None)
+
+                        continue
+                    elif isinstance(event, ParallelBranchRunFailedEvent):
+                        raise GraphRunFailedError(event.error)
+            except queue.Empty:
+                continue
+        
+        # wait all threads
+        wait(futures)
+
+        # get final node id
+        final_node_id = parallel.end_to_node_id
+        if final_node_id:
+            yield final_node_id
+
+    def _run_parallel_node(
+            self,
+            flask_app: Flask,
+            q: queue.Queue,
+            parallel_id: str,
+            parallel_start_node_id: str,
+            parent_parallel_id: Optional[str] = None,
+            parent_parallel_start_node_id: Optional[str] = None,
+    ) -> None:
+        """
+        Run parallel nodes
+        """
+        with flask_app.app_context():
+            try:
+                q.put(ParallelBranchRunStartedEvent(
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id
+                ))
+
+                # run node
+                generator = self._run(
+                    start_node_id=parallel_start_node_id,
+                    in_parallel_id=parallel_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id
+                )
+
+                for item in generator:
+                    q.put(item)
+
+                # trigger graph run success event
+                q.put(ParallelBranchRunSucceededEvent(
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id
+                ))
+            except GraphRunFailedError as e:
+                q.put(ParallelBranchRunFailedEvent(
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id,
+                    error=e.error
+                ))
+            except Exception as e:
+                logger.exception("Unknown Error when generating in parallel")
+                q.put(ParallelBranchRunFailedEvent(
+                    parallel_id=parallel_id,
+                    parallel_start_node_id=parallel_start_node_id,
+                    parent_parallel_id=parent_parallel_id,
+                    parent_parallel_start_node_id=parent_parallel_start_node_id,
+                    error=str(e)
+                ))
+            finally:
+                db.session.remove()
+
+    def _run_node(
+            self,
+            node_instance: BaseNode,
+            route_node_state: RouteNodeState,
+            parallel_id: Optional[str] = None,
+            parallel_start_node_id: Optional[str] = None,
+            parent_parallel_id: Optional[str] = None,
+            parent_parallel_start_node_id: Optional[str] = None,
+    ) -> Generator[GraphEngineEvent, None, None]:
+        """
+        Run node
+        """
+        # trigger node run start event
+        yield NodeRunStartedEvent(
+            id=node_instance.id,
+            node_id=node_instance.node_id,
+            node_type=node_instance.node_type,
+            node_data=node_instance.node_data,
+            route_node_state=route_node_state,
+            predecessor_node_id=node_instance.previous_node_id,
+            parallel_id=parallel_id,
+            parallel_start_node_id=parallel_start_node_id,
+            parent_parallel_id=parent_parallel_id,
+            parent_parallel_start_node_id=parent_parallel_start_node_id
+        )
+
+        db.session.close()
+
+        try:
+            # run node
+            generator = node_instance.run()
+            for item in generator:
+                if isinstance(item, GraphEngineEvent):
+                    if isinstance(item, BaseIterationEvent):
+                        # add parallel info to iteration event
+                        item.parallel_id = parallel_id
+                        item.parallel_start_node_id = parallel_start_node_id
+                        item.parent_parallel_id = parent_parallel_id
+                        item.parent_parallel_start_node_id = parent_parallel_start_node_id
+
+                    yield item
+                else:
+                    if isinstance(item, RunCompletedEvent):
+                        run_result = item.run_result
+                        route_node_state.set_finished(run_result=run_result)
+
+                        if run_result.status == WorkflowNodeExecutionStatus.FAILED:
+                            yield NodeRunFailedEvent(
+                                error=route_node_state.failed_reason or 'Unknown error.',
+                                id=node_instance.id,
+                                node_id=node_instance.node_id,
+                                node_type=node_instance.node_type,
+                                node_data=node_instance.node_data,
+                                route_node_state=route_node_state,
+                                parallel_id=parallel_id,
+                                parallel_start_node_id=parallel_start_node_id,
+                                parent_parallel_id=parent_parallel_id,
+                                parent_parallel_start_node_id=parent_parallel_start_node_id
+                            )
+                        elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+                            if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
+                                # plus state total_tokens
+                                self.graph_runtime_state.total_tokens += int(
+                                    run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)  # type: ignore[arg-type]
+                                )
+
+                            if run_result.llm_usage:
+                                # use the latest usage
+                                self.graph_runtime_state.llm_usage += run_result.llm_usage
+
+                            # append node output variables to variable pool
+                            if run_result.outputs:
+                                for variable_key, variable_value in run_result.outputs.items():
+                                    # append variables to variable pool recursively
+                                    self._append_variables_recursively(
+                                        node_id=node_instance.node_id,
+                                        variable_key_list=[variable_key],
+                                        variable_value=variable_value
+                                    )
+
+                            # add parallel info to run result metadata
+                            if parallel_id and parallel_start_node_id:
+                                if not run_result.metadata:
+                                    run_result.metadata = {}
+
+                                run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
+                                run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
+                                if parent_parallel_id and parent_parallel_start_node_id:
+                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
+                                    run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
+
+                            yield NodeRunSucceededEvent(
+                                id=node_instance.id,
+                                node_id=node_instance.node_id,
+                                node_type=node_instance.node_type,
+                                node_data=node_instance.node_data,
+                                route_node_state=route_node_state,
+                                parallel_id=parallel_id,
+                                parallel_start_node_id=parallel_start_node_id,
+                                parent_parallel_id=parent_parallel_id,
+                                parent_parallel_start_node_id=parent_parallel_start_node_id
+                            )
+
+                        break
+                    elif isinstance(item, RunStreamChunkEvent):
+                        yield NodeRunStreamChunkEvent(
+                            id=node_instance.id,
+                            node_id=node_instance.node_id,
+                            node_type=node_instance.node_type,
+                            node_data=node_instance.node_data,
+                            chunk_content=item.chunk_content,
+                            from_variable_selector=item.from_variable_selector,
+                            route_node_state=route_node_state,
+                            parallel_id=parallel_id,
+                            parallel_start_node_id=parallel_start_node_id,
+                            parent_parallel_id=parent_parallel_id,
+                            parent_parallel_start_node_id=parent_parallel_start_node_id
+                        )
+                    elif isinstance(item, RunRetrieverResourceEvent):
+                        yield NodeRunRetrieverResourceEvent(
+                            id=node_instance.id,
+                            node_id=node_instance.node_id,
+                            node_type=node_instance.node_type,
+                            node_data=node_instance.node_data,
+                            retriever_resources=item.retriever_resources,
+                            context=item.context,
+                            route_node_state=route_node_state,
+                            parallel_id=parallel_id,
+                            parallel_start_node_id=parallel_start_node_id,
+                            parent_parallel_id=parent_parallel_id,
+                            parent_parallel_start_node_id=parent_parallel_start_node_id
+                        )
+        except GenerateTaskStoppedException:
+            # trigger node run failed event
+            route_node_state.status = RouteNodeState.Status.FAILED
+            route_node_state.failed_reason = "Workflow stopped."
+            yield NodeRunFailedEvent(
+                error="Workflow stopped.",
+                id=node_instance.id,
+                node_id=node_instance.node_id,
+                node_type=node_instance.node_type,
+                node_data=node_instance.node_data,
+                route_node_state=route_node_state,
+                parallel_id=parallel_id,
+                parallel_start_node_id=parallel_start_node_id,
+                parent_parallel_id=parent_parallel_id,
+                parent_parallel_start_node_id=parent_parallel_start_node_id
+            )
+            return
+        except Exception as e:
+            logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
+            raise e
+        finally:
+            db.session.close()
+
+    def _append_variables_recursively(self,
+                                      node_id: str,
+                                      variable_key_list: list[str],
+                                      variable_value: VariableValue):
+        """
+        Append variables recursively
+        :param node_id: node id
+        :param variable_key_list: variable key list
+        :param variable_value: variable value
+        :return:
+        """
+        self.graph_runtime_state.variable_pool.add(
+            [node_id] + variable_key_list,
+            variable_value
+        )
+
+        # if variable_value is a dict, then recursively append variables
+        if isinstance(variable_value, dict):
+            for key, value in variable_value.items():
+                # construct new key list
+                new_key_list = variable_key_list + [key]
+                self._append_variables_recursively(
+                    node_id=node_id,
+                    variable_key_list=new_key_list,
+                    variable_value=value
+                )
+
+    def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
+        """
+        Check timeout
+        :param start_at: start time
+        :param max_execution_time: max execution time
+        :return:
+        """
+        return time.perf_counter() - start_at > max_execution_time
+
+
+class GraphRunFailedError(Exception):
+    def __init__(self, error: str):
+        self.error = error

+ 19 - 72
api/core/workflow/nodes/answer/answer_node.py

@@ -1,9 +1,8 @@
-from typing import cast
+from collections.abc import Mapping, Sequence
+from typing import Any, cast
 
-from core.prompt.utils.prompt_template_parser import PromptTemplateParser
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
 from core.workflow.nodes.answer.entities import (
     AnswerNodeData,
     GenerateRouteChunk,
@@ -19,24 +18,26 @@ class AnswerNode(BaseNode):
     _node_data_cls = AnswerNodeData
     _node_type: NodeType = NodeType.ANSWER
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
         node_data = self.node_data
         node_data = cast(AnswerNodeData, node_data)
 
         # generate routes
-        generate_routes = self.extract_generate_route_from_node_data(node_data)
+        generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
 
         answer = ''
         for part in generate_routes:
-            if part.type == "var":
+            if part.type == GenerateRouteChunk.ChunkType.VAR:
                 part = cast(VarGenerateRouteChunk, part)
                 value_selector = part.value_selector
-                value = variable_pool.get(value_selector)
+                value = self.graph_runtime_state.variable_pool.get(
+                    value_selector
+                )
+
                 if value:
                     answer += value.markdown
             else:
@@ -51,70 +52,16 @@ class AnswerNode(BaseNode):
         )
 
     @classmethod
-    def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
-        """
-        Extract generate route selectors
-        :param config: node config
-        :return:
-        """
-        node_data = cls._node_data_cls(**config.get("data", {}))
-        node_data = cast(AnswerNodeData, node_data)
-
-        return cls.extract_generate_route_from_node_data(node_data)
-
-    @classmethod
-    def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
-        """
-        Extract generate route from node data
-        :param node_data: node data object
-        :return:
-        """
-        variable_template_parser = VariableTemplateParser(template=node_data.answer)
-        variable_selectors = variable_template_parser.extract_variable_selectors()
-
-        value_selector_mapping = {
-            variable_selector.variable: variable_selector.value_selector
-            for variable_selector in variable_selectors
-        }
-
-        variable_keys = list(value_selector_mapping.keys())
-
-        # format answer template
-        template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
-        template_variable_keys = template_parser.variable_keys
-
-        # Take the intersection of variable_keys and template_variable_keys
-        variable_keys = list(set(variable_keys) & set(template_variable_keys))
-
-        template = node_data.answer
-        for var in variable_keys:
-            template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
-
-        generate_routes = []
-        for part in template.split('Ω'):
-            if part:
-                if cls._is_variable(part, variable_keys):
-                    var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
-                    value_selector = value_selector_mapping[var_key]
-                    generate_routes.append(VarGenerateRouteChunk(
-                        value_selector=value_selector
-                    ))
-                else:
-                    generate_routes.append(TextGenerateRouteChunk(
-                        text=part
-                    ))
-
-        return generate_routes
-
-    @classmethod
-    def _is_variable(cls, part, variable_keys):
-        cleaned_part = part.replace('{{', '').replace('}}', '')
-        return part.startswith('{{') and cleaned_part in variable_keys
-
-    @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: AnswerNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
@@ -126,6 +73,6 @@ class AnswerNode(BaseNode):
 
         variable_mapping = {}
         for variable_selector in variable_selectors:
-            variable_mapping[variable_selector.variable] = variable_selector.value_selector
+            variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
 
         return variable_mapping

+ 169 - 0
api/core/workflow/nodes/answer/answer_stream_generate_router.py

@@ -0,0 +1,169 @@
+
+from core.prompt.utils.prompt_template_parser import PromptTemplateParser
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.nodes.answer.entities import (
+    AnswerNodeData,
+    AnswerStreamGenerateRoute,
+    GenerateRouteChunk,
+    TextGenerateRouteChunk,
+    VarGenerateRouteChunk,
+)
+from core.workflow.utils.variable_template_parser import VariableTemplateParser
+
+
+class AnswerStreamGeneratorRouter:
+
+    @classmethod
+    def init(cls,
+             node_id_config_mapping: dict[str, dict],
+             reverse_edge_mapping: dict[str, list["GraphEdge"]]  # type: ignore[name-defined]
+             ) -> AnswerStreamGenerateRoute:
+        """
+        Get stream generate routes.
+        :return:
+        """
+        # parse stream output node value selectors of answer nodes
+        answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
+        for answer_node_id, node_config in node_id_config_mapping.items():
+            if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
+                continue
+
+            # get generate route for stream output
+            generate_route = cls._extract_generate_route_selectors(node_config)
+            answer_generate_route[answer_node_id] = generate_route
+
+        # fetch answer dependencies
+        answer_node_ids = list(answer_generate_route.keys())
+        answer_dependencies = cls._fetch_answers_dependencies(
+            answer_node_ids=answer_node_ids,
+            reverse_edge_mapping=reverse_edge_mapping,
+            node_id_config_mapping=node_id_config_mapping
+        )
+
+        return AnswerStreamGenerateRoute(
+            answer_generate_route=answer_generate_route,
+            answer_dependencies=answer_dependencies
+        )
+
+    @classmethod
+    def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
+        """
+        Extract generate route from node data
+        :param node_data: node data object
+        :return:
+        """
+        variable_template_parser = VariableTemplateParser(template=node_data.answer)
+        variable_selectors = variable_template_parser.extract_variable_selectors()
+
+        value_selector_mapping = {
+            variable_selector.variable: variable_selector.value_selector
+            for variable_selector in variable_selectors
+        }
+
+        variable_keys = list(value_selector_mapping.keys())
+
+        # format answer template
+        template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
+        template_variable_keys = template_parser.variable_keys
+
+        # Take the intersection of variable_keys and template_variable_keys
+        variable_keys = list(set(variable_keys) & set(template_variable_keys))
+
+        template = node_data.answer
+        for var in variable_keys:
+            template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
+
+        generate_routes: list[GenerateRouteChunk] = []
+        for part in template.split('Ω'):
+            if part:
+                if cls._is_variable(part, variable_keys):
+                    var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
+                    value_selector = value_selector_mapping[var_key]
+                    generate_routes.append(VarGenerateRouteChunk(
+                        value_selector=value_selector
+                    ))
+                else:
+                    generate_routes.append(TextGenerateRouteChunk(
+                        text=part
+                    ))
+
+        return generate_routes
+
+    @classmethod
+    def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
+        """
+        Extract generate route selectors
+        :param config: node config
+        :return:
+        """
+        node_data = AnswerNodeData(**config.get("data", {}))
+        return cls.extract_generate_route_from_node_data(node_data)
+
+    @classmethod
+    def _is_variable(cls, part, variable_keys):
+        cleaned_part = part.replace('{{', '').replace('}}', '')
+        return part.startswith('{{') and cleaned_part in variable_keys
+
+    @classmethod
+    def _fetch_answers_dependencies(cls,
+                                    answer_node_ids: list[str],
+                                    reverse_edge_mapping: dict[str, list["GraphEdge"]],  # type: ignore[name-defined]
+                                    node_id_config_mapping: dict[str, dict]
+                                    ) -> dict[str, list[str]]:
+        """
+        Fetch answer dependencies
+        :param answer_node_ids: answer node ids
+        :param reverse_edge_mapping: reverse edge mapping
+        :param node_id_config_mapping: node id config mapping
+        :return:
+        """
+        answer_dependencies: dict[str, list[str]] = {}
+        for answer_node_id in answer_node_ids:
+            if answer_dependencies.get(answer_node_id) is None:
+                answer_dependencies[answer_node_id] = []
+
+            cls._recursive_fetch_answer_dependencies(
+                current_node_id=answer_node_id,
+                answer_node_id=answer_node_id,
+                node_id_config_mapping=node_id_config_mapping,
+                reverse_edge_mapping=reverse_edge_mapping,
+                answer_dependencies=answer_dependencies
+            )
+
+        return answer_dependencies
+
+    @classmethod
+    def _recursive_fetch_answer_dependencies(cls,
+                                             current_node_id: str,
+                                             answer_node_id: str,
+                                             node_id_config_mapping: dict[str, dict],
+                                             reverse_edge_mapping: dict[str, list["GraphEdge"]],  # type: ignore[name-defined]
+                                             answer_dependencies: dict[str, list[str]]
+                                             ) -> None:
+        """
+        Recursive fetch answer dependencies
+        :param current_node_id: current node id
+        :param answer_node_id: answer node id
+        :param node_id_config_mapping: node id config mapping
+        :param reverse_edge_mapping: reverse edge mapping
+        :param answer_dependencies: answer dependencies
+        :return:
+        """
+        reverse_edges = reverse_edge_mapping.get(current_node_id, [])
+        for edge in reverse_edges:
+            source_node_id = edge.source_node_id
+            source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
+            if source_node_type in (
+                    NodeType.ANSWER.value,
+                    NodeType.IF_ELSE.value,
+                    NodeType.QUESTION_CLASSIFIER,
+            ):
+                answer_dependencies[answer_node_id].append(source_node_id)
+            else:
+                cls._recursive_fetch_answer_dependencies(
+                    current_node_id=source_node_id,
+                    answer_node_id=answer_node_id,
+                    node_id_config_mapping=node_id_config_mapping,
+                    reverse_edge_mapping=reverse_edge_mapping,
+                    answer_dependencies=answer_dependencies
+                )

+ 221 - 0
api/core/workflow/nodes/answer/answer_stream_processor.py

@@ -0,0 +1,221 @@
+import logging
+from collections.abc import Generator
+from typing import Optional, cast
+
+from core.file.file_obj import FileVar
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.event import (
+    GraphEngineEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
+from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
+
+logger = logging.getLogger(__name__)
+
+
+class AnswerStreamProcessor(StreamProcessor):
+
+    def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
+        super().__init__(graph, variable_pool)
+        self.generate_routes = graph.answer_stream_generate_routes
+        self.route_position = {}
+        for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
+            self.route_position[answer_node_id] = 0
+        self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
+
+    def process(self,
+                generator: Generator[GraphEngineEvent, None, None]
+                ) -> Generator[GraphEngineEvent, None, None]:
+        for event in generator:
+            if isinstance(event, NodeRunStartedEvent):
+                if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
+                    self.reset()
+
+                yield event
+            elif isinstance(event, NodeRunStreamChunkEvent):
+                if event.in_iteration_id:
+                    yield event
+                    continue
+
+                if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
+                    stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
+                        event.route_node_state.node_id
+                    ]
+                else:
+                    stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
+                    self.current_stream_chunk_generating_node_ids[
+                        event.route_node_state.node_id
+                    ] = stream_out_answer_node_ids
+
+                for _ in stream_out_answer_node_ids:
+                    yield event
+            elif isinstance(event, NodeRunSucceededEvent):
+                yield event
+                if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
+                    # update self.route_position after all stream event finished
+                    for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
+                        self.route_position[answer_node_id] += 1
+
+                    del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
+
+                # remove unreachable nodes
+                self._remove_unreachable_nodes(event)
+
+                # generate stream outputs
+                yield from self._generate_stream_outputs_when_node_finished(event)
+            else:
+                yield event
+
+    def reset(self) -> None:
+        self.route_position = {}
+        for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
+            self.route_position[answer_node_id] = 0
+        self.rest_node_ids = self.graph.node_ids.copy()
+        self.current_stream_chunk_generating_node_ids = {}
+
+    def _generate_stream_outputs_when_node_finished(self,
+                                                    event: NodeRunSucceededEvent
+                                                    ) -> Generator[GraphEngineEvent, None, None]:
+        """
+        Generate stream outputs.
+        :param event: node run succeeded event
+        :return:
+        """
+        for answer_node_id, position in self.route_position.items():
+            # all depends on answer node id not in rest node ids
+            if (event.route_node_state.node_id != answer_node_id
+                    and (answer_node_id not in self.rest_node_ids
+                         or not all(dep_id not in self.rest_node_ids
+                                    for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
+                continue
+
+            route_position = self.route_position[answer_node_id]
+            route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
+
+            for route_chunk in route_chunks:
+                if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
+                    route_chunk = cast(TextGenerateRouteChunk, route_chunk)
+                    yield NodeRunStreamChunkEvent(
+                        id=event.id,
+                        node_id=event.node_id,
+                        node_type=event.node_type,
+                        node_data=event.node_data,
+                        chunk_content=route_chunk.text,
+                        route_node_state=event.route_node_state,
+                        parallel_id=event.parallel_id,
+                        parallel_start_node_id=event.parallel_start_node_id,
+                    )
+                else:
+                    route_chunk = cast(VarGenerateRouteChunk, route_chunk)
+                    value_selector = route_chunk.value_selector
+                    if not value_selector:
+                        break
+
+                    value = self.variable_pool.get(
+                        value_selector
+                    )
+
+                    if value is None:
+                        break
+
+                    text = value.markdown
+
+                    if text:
+                        yield NodeRunStreamChunkEvent(
+                            id=event.id,
+                            node_id=event.node_id,
+                            node_type=event.node_type,
+                            node_data=event.node_data,
+                            chunk_content=text,
+                            from_variable_selector=value_selector,
+                            route_node_state=event.route_node_state,
+                            parallel_id=event.parallel_id,
+                            parallel_start_node_id=event.parallel_start_node_id,
+                        )
+
+                self.route_position[answer_node_id] += 1
+
+    def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
+        """
+        Is stream out support
+        :param event: queue text chunk event
+        :return:
+        """
+        if not event.from_variable_selector:
+            return []
+
+        stream_output_value_selector = event.from_variable_selector
+        if not stream_output_value_selector:
+            return []
+
+        stream_out_answer_node_ids = []
+        for answer_node_id, route_position in self.route_position.items():
+            if answer_node_id not in self.rest_node_ids:
+                continue
+
+            # all depends on answer node id not in rest node ids
+            if all(dep_id not in self.rest_node_ids
+                   for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
+                if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
+                    continue
+
+                route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
+
+                if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
+                    continue
+
+                route_chunk = cast(VarGenerateRouteChunk, route_chunk)
+                value_selector = route_chunk.value_selector
+
+                # check chunk node id is before current node id or equal to current node id
+                if value_selector != stream_output_value_selector:
+                    continue
+
+                stream_out_answer_node_ids.append(answer_node_id)
+
+        return stream_out_answer_node_ids
+
+    @classmethod
+    def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
+        """
+        Fetch files from variable value
+        :param value: variable value
+        :return:
+        """
+        if not value:
+            return []
+
+        files = []
+        if isinstance(value, list):
+            for item in value:
+                file_var = cls._get_file_var_from_value(item)
+                if file_var:
+                    files.append(file_var)
+        elif isinstance(value, dict):
+            file_var = cls._get_file_var_from_value(value)
+            if file_var:
+                files.append(file_var)
+
+        return files
+
+    @classmethod
+    def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
+        """
+        Get file var from value
+        :param value: variable value
+        :return:
+        """
+        if not value:
+            return None
+
+        if isinstance(value, dict):
+            if '__variant' in value and value['__variant'] == FileVar.__name__:
+                return value
+        elif isinstance(value, FileVar):
+            return value.to_dict()
+
+        return None

+ 71 - 0
api/core/workflow/nodes/answer/base_stream_processor.py

@@ -0,0 +1,71 @@
+from abc import ABC, abstractmethod
+from collections.abc import Generator
+
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
+from core.workflow.graph_engine.entities.graph import Graph
+
+
+class StreamProcessor(ABC):
+
+    def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
+        self.graph = graph
+        self.variable_pool = variable_pool
+        self.rest_node_ids = graph.node_ids.copy()
+
+    @abstractmethod
+    def process(self,
+                generator: Generator[GraphEngineEvent, None, None]
+                ) -> Generator[GraphEngineEvent, None, None]:
+        raise NotImplementedError
+
+    def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
+        finished_node_id = event.route_node_state.node_id
+        if finished_node_id not in self.rest_node_ids:
+            return
+
+        # remove finished node id
+        self.rest_node_ids.remove(finished_node_id)
+
+        run_result = event.route_node_state.node_run_result
+        if not run_result:
+            return
+
+        if run_result.edge_source_handle:
+            reachable_node_ids = []
+            unreachable_first_node_ids = []
+            for edge in self.graph.edge_mapping[finished_node_id]:
+                if (edge.run_condition
+                        and edge.run_condition.branch_identify
+                        and run_result.edge_source_handle == edge.run_condition.branch_identify):
+                    reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
+                    continue
+                else:
+                    unreachable_first_node_ids.append(edge.target_node_id)
+
+            for node_id in unreachable_first_node_ids:
+                self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
+
+    def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
+        node_ids = []
+        for edge in self.graph.edge_mapping.get(node_id, []):
+            if edge.target_node_id == self.graph.root_node_id:
+                continue
+
+            node_ids.append(edge.target_node_id)
+            node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
+        return node_ids
+
+    def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
+        """
+        remove target node ids until merge
+        """
+        if node_id not in self.rest_node_ids:
+            return
+
+        self.rest_node_ids.remove(node_id)
+        for edge in self.graph.edge_mapping.get(node_id, []):
+            if edge.target_node_id in reachable_node_ids:
+                continue
+
+            self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

+ 35 - 7
api/core/workflow/nodes/answer/entities.py

@@ -1,5 +1,6 @@
+from enum import Enum
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 
@@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData):
     """
     Answer Node Data.
     """
-    answer: str
+    answer: str = Field(..., description="answer template string")
 
 
 class GenerateRouteChunk(BaseModel):
     """
     Generate Route Chunk.
     """
-    type: str
+
+    class ChunkType(Enum):
+        VAR = "var"
+        TEXT = "text"
+
+    type: ChunkType = Field(..., description="generate route chunk type")
 
 
 class VarGenerateRouteChunk(GenerateRouteChunk):
     """
     Var Generate Route Chunk.
     """
-    type: str = "var"
-    value_selector: list[str]
+    type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
+    """generate route chunk type"""
+    value_selector: list[str] = Field(..., description="value selector")
 
 
 class TextGenerateRouteChunk(GenerateRouteChunk):
     """
     Text Generate Route Chunk.
     """
-    type: str = "text"
-    text: str
+    type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
+    """generate route chunk type"""
+    text: str = Field(..., description="text")
+
+
+class AnswerNodeDoubleLink(BaseModel):
+    node_id: str = Field(..., description="node id")
+    source_node_ids: list[str] = Field(..., description="source node ids")
+    target_node_ids: list[str] = Field(..., description="target node ids")
+
+
+class AnswerStreamGenerateRoute(BaseModel):
+    """
+    AnswerStreamGenerateRoute entity
+    """
+    answer_dependencies: dict[str, list[str]] = Field(
+        ...,
+        description="answer dependencies (answer node id -> dependent answer node ids)"
+    )
+    answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
+        ...,
+        description="answer generate route (answer node id -> generate route chunks)"
+    )

+ 61 - 135
api/core/workflow/nodes/base_node.py

@@ -1,142 +1,103 @@
 from abc import ABC, abstractmethod
-from collections.abc import Mapping, Sequence
-from enum import Enum
+from collections.abc import Generator, Mapping, Sequence
 from typing import Any, Optional
 
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
+from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
-from models import WorkflowNodeExecutionStatus
-
-
-class UserFrom(Enum):
-    """
-    User from
-    """
-    ACCOUNT = "account"
-    END_USER = "end-user"
-
-    @classmethod
-    def value_of(cls, value: str) -> "UserFrom":
-        """
-        Value of
-        :param value: value
-        :return:
-        """
-        for item in cls:
-            if item.value == value:
-                return item
-        raise ValueError(f"Invalid value: {value}")
+from core.workflow.graph_engine.entities.event import InNodeEvent
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.nodes.event import RunCompletedEvent, RunEvent
 
 
 class BaseNode(ABC):
     _node_data_cls: type[BaseNodeData]
     _node_type: NodeType
 
-    tenant_id: str
-    app_id: str
-    workflow_id: str
-    user_id: str
-    user_from: UserFrom
-    invoke_from: InvokeFrom
-    
-    workflow_call_depth: int
-
-    node_id: str
-    node_data: BaseNodeData
-    node_run_result: Optional[NodeRunResult] = None
-
-    callbacks: Sequence[WorkflowCallback]
-
-    is_answer_previous_node: bool = False
-
-    def __init__(self, tenant_id: str,
-                 app_id: str,
-                 workflow_id: str,
-                 user_id: str,
-                 user_from: UserFrom,
-                 invoke_from: InvokeFrom,
+    def __init__(self,
+                 id: str,
                  config: Mapping[str, Any],
-                 callbacks: Sequence[WorkflowCallback] | None = None,
-                 workflow_call_depth: int = 0) -> None:
-        self.tenant_id = tenant_id
-        self.app_id = app_id
-        self.workflow_id = workflow_id
-        self.user_id = user_id
-        self.user_from = user_from
-        self.invoke_from = invoke_from
-        self.workflow_call_depth = workflow_call_depth
-
-        # TODO: May need to check if key exists.
-        self.node_id = config["id"]
-        if not self.node_id:
+                 graph_init_params: GraphInitParams,
+                 graph: Graph,
+                 graph_runtime_state: GraphRuntimeState,
+                 previous_node_id: Optional[str] = None,
+                 thread_pool_id: Optional[str] = None) -> None:
+        self.id = id
+        self.tenant_id = graph_init_params.tenant_id
+        self.app_id = graph_init_params.app_id
+        self.workflow_type = graph_init_params.workflow_type
+        self.workflow_id = graph_init_params.workflow_id
+        self.graph_config = graph_init_params.graph_config
+        self.user_id = graph_init_params.user_id
+        self.user_from = graph_init_params.user_from
+        self.invoke_from = graph_init_params.invoke_from
+        self.workflow_call_depth = graph_init_params.call_depth
+        self.graph = graph
+        self.graph_runtime_state = graph_runtime_state
+        self.previous_node_id = previous_node_id
+        self.thread_pool_id = thread_pool_id
+
+        node_id = config.get("id")
+        if not node_id:
             raise ValueError("Node ID is required.")
 
+        self.node_id = node_id
         self.node_data = self._node_data_cls(**config.get("data", {}))
-        self.callbacks = callbacks or []
 
     @abstractmethod
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) \
+            -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
         raise NotImplementedError
 
-    def run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
         """
         Run node entry
-        :param variable_pool: variable pool
         :return:
         """
-        try:
-            result = self._run(
-                variable_pool=variable_pool
-            )
-            self.node_run_result = result
-            return result
-        except Exception as e:
-            return NodeRunResult(
-                status=WorkflowNodeExecutionStatus.FAILED,
-                error=str(e),
-            )
+        result = self._run()
 
-    def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
-        """
-        Publish text chunk
-        :param text: chunk text
-        :param value_selector: value selector
-        :return:
-        """
-        if self.callbacks:
-            for callback in self.callbacks:
-                callback.on_node_text_chunk(
-                    node_id=self.node_id,
-                    text=text,
-                    metadata={
-                        "node_type": self.node_type,
-                        "is_answer_previous_node": self.is_answer_previous_node,
-                        "value_selector": value_selector
-                    }
-                )
+        if isinstance(result, NodeRunResult):
+            yield RunCompletedEvent(
+                run_result=result
+            )
+        else:
+            yield from result
 
     @classmethod
-    def extract_variable_selector_to_variable_mapping(cls, config: dict):
+    def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
         :param config: node config
         :return:
         """
+        node_id = config.get("id")
+        if not node_id:
+            raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
+
         node_data = cls._node_data_cls(**config.get("data", {}))
-        return cls._extract_variable_selector_to_variable_mapping(node_data)
+        return cls._extract_variable_selector_to_variable_mapping(
+            graph_config=graph_config,
+            node_id=node_id,
+            node_data=node_data
+        )
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: BaseNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
@@ -158,38 +119,3 @@ class BaseNode(ABC):
         :return:
         """
         return self._node_type
-
-class BaseIterationNode(BaseNode):
-    @abstractmethod
-    def _run(self, variable_pool: VariablePool) -> BaseIterationState:
-        """
-        Run node
-        :param variable_pool: variable pool
-        :return:
-        """
-        raise NotImplementedError
-
-    def run(self, variable_pool: VariablePool) -> BaseIterationState:
-        """
-        Run node entry
-        :param variable_pool: variable pool
-        :return:
-        """
-        return self._run(variable_pool=variable_pool)
-
-    def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
-        """
-        Get next iteration start node id based on the graph.
-        :param graph: graph
-        :return: next node id
-        """
-        return self._get_next_iteration(variable_pool, state)
-    
-    @abstractmethod
-    def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
-        """
-        Get next iteration start node id based on the graph.
-        :param graph: graph
-        :return: next node id
-        """
-        raise NotImplementedError

+ 15 - 9
api/core/workflow/nodes/code/code_node.py

@@ -1,4 +1,5 @@
-from typing import Optional, Union, cast
+from collections.abc import Mapping, Sequence
+from typing import Any, Optional, Union, cast
 
 from configs import dify_config
 from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
@@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
 from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
 from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.code.entities import CodeNodeData
 from models.workflow import WorkflowNodeExecutionStatus
@@ -33,13 +33,13 @@ class CodeNode(BaseNode):
 
         return code_provider.get_default_config()
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run code
-        :param variable_pool: variable pool
         :return:
         """
-        node_data = cast(CodeNodeData, self.node_data)
+        node_data = self.node_data
+        node_data = cast(CodeNodeData, node_data)
 
         # Get code language
         code_language = node_data.code_language
@@ -49,7 +49,7 @@ class CodeNode(BaseNode):
         variables = {}
         for variable_selector in node_data.variables:
             variable = variable_selector.variable
-            value = variable_pool.get_any(variable_selector.value_selector)
+            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
 
             variables[variable] = value
         # Run code
@@ -311,13 +311,19 @@ class CodeNode(BaseNode):
         return transformed_result
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls,
+        graph_config: Mapping[str, Any],
+        node_id: str,
+        node_data: CodeNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
-
         return {
-            variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
+            node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
         }

+ 12 - 50
api/core/workflow/nodes/end/end_node.py

@@ -1,8 +1,7 @@
-from typing import cast
+from collections.abc import Mapping, Sequence
+from typing import Any, cast
 
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.end.entities import EndNodeData
 from models.workflow import WorkflowNodeExecutionStatus
@@ -12,10 +11,9 @@ class EndNode(BaseNode):
     _node_data_cls = EndNodeData
     _node_type = NodeType.END
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
         node_data = self.node_data
@@ -24,7 +22,7 @@ class EndNode(BaseNode):
 
         outputs = {}
         for variable_selector in output_variables:
-            value = variable_pool.get_any(variable_selector.value_selector)
+            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
             outputs[variable_selector.variable] = value
 
         return NodeRunResult(
@@ -34,52 +32,16 @@ class EndNode(BaseNode):
         )
 
     @classmethod
-    def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
-        """
-        Extract generate nodes
-        :param graph: graph
-        :param config: node config
-        :return:
-        """
-        node_data = cls._node_data_cls(**config.get("data", {}))
-        node_data = cast(EndNodeData, node_data)
-
-        return cls.extract_generate_nodes_from_node_data(graph, node_data)
-
-    @classmethod
-    def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
-        """
-        Extract generate nodes from node data
-        :param graph: graph
-        :param node_data: node data object
-        :return:
-        """
-        nodes = graph.get('nodes', [])
-        node_mapping = {node.get('id'): node for node in nodes}
-
-        variable_selectors = node_data.outputs
-
-        generate_nodes = []
-        for variable_selector in variable_selectors:
-            if not variable_selector.value_selector:
-                continue
-
-            node_id = variable_selector.value_selector[0]
-            if node_id != 'sys' and node_id in node_mapping:
-                node = node_mapping[node_id]
-                node_type = node.get('data', {}).get('type')
-                if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
-                    generate_nodes.append(node_id)
-
-        # remove duplicates
-        generate_nodes = list(set(generate_nodes))
-
-        return generate_nodes
-
-    @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: EndNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """

+ 148 - 0
api/core/workflow/nodes/end/end_stream_generate_router.py

@@ -0,0 +1,148 @@
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
+
+
+class EndStreamGeneratorRouter:
+
+    @classmethod
+    def init(cls,
+             node_id_config_mapping: dict[str, dict],
+             reverse_edge_mapping: dict[str, list["GraphEdge"]],  # type: ignore[name-defined]
+             node_parallel_mapping: dict[str, str]
+             ) -> EndStreamParam:
+        """
+        Get stream generate routes.
+        :return:
+        """
+        # parse stream output node value selector of end nodes
+        end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
+        for end_node_id, node_config in node_id_config_mapping.items():
+            if not node_config.get('data', {}).get('type') == NodeType.END.value:
+                continue
+
+            # skip end node in parallel
+            if end_node_id in node_parallel_mapping:
+                continue
+
+            # get generate route for stream output
+            stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
+            end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
+
+        # fetch end dependencies
+        end_node_ids = list(end_stream_variable_selectors_mapping.keys())
+        end_dependencies = cls._fetch_ends_dependencies(
+            end_node_ids=end_node_ids,
+            reverse_edge_mapping=reverse_edge_mapping,
+            node_id_config_mapping=node_id_config_mapping
+        )
+
+        return EndStreamParam(
+            end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
+            end_dependencies=end_dependencies
+        )
+
+    @classmethod
+    def extract_stream_variable_selector_from_node_data(cls,
+                                                        node_id_config_mapping: dict[str, dict],
+                                                        node_data: EndNodeData) -> list[list[str]]:
+        """
+        Extract stream variable selector from node data
+        :param node_id_config_mapping: node id config mapping
+        :param node_data: node data object
+        :return:
+        """
+        variable_selectors = node_data.outputs
+
+        value_selectors = []
+        for variable_selector in variable_selectors:
+            if not variable_selector.value_selector:
+                continue
+
+            node_id = variable_selector.value_selector[0]
+            if node_id != 'sys' and node_id in node_id_config_mapping:
+                node = node_id_config_mapping[node_id]
+                node_type = node.get('data', {}).get('type')
+                if (
+                    variable_selector.value_selector not in value_selectors
+                    and node_type == NodeType.LLM.value 
+                    and variable_selector.value_selector[1] == 'text'
+                ):
+                    value_selectors.append(variable_selector.value_selector)
+
+        return value_selectors
+
+    @classmethod
+    def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
+            -> list[list[str]]:
+        """
+        Extract stream variable selector from node config
+        :param node_id_config_mapping: node id config mapping
+        :param config: node config
+        :return:
+        """
+        node_data = EndNodeData(**config.get("data", {}))
+        return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
+
+    @classmethod
+    def _fetch_ends_dependencies(cls,
+                                 end_node_ids: list[str],
+                                 reverse_edge_mapping: dict[str, list["GraphEdge"]],  # type: ignore[name-defined]
+                                 node_id_config_mapping: dict[str, dict]
+                                 ) -> dict[str, list[str]]:
+        """
+        Fetch end dependencies
+        :param end_node_ids: end node ids
+        :param reverse_edge_mapping: reverse edge mapping
+        :param node_id_config_mapping: node id config mapping
+        :return:
+        """
+        end_dependencies: dict[str, list[str]] = {}
+        for end_node_id in end_node_ids:
+            if end_dependencies.get(end_node_id) is None:
+                end_dependencies[end_node_id] = []
+
+            cls._recursive_fetch_end_dependencies(
+                current_node_id=end_node_id,
+                end_node_id=end_node_id,
+                node_id_config_mapping=node_id_config_mapping,
+                reverse_edge_mapping=reverse_edge_mapping,
+                end_dependencies=end_dependencies
+            )
+
+        return end_dependencies
+
+    @classmethod
+    def _recursive_fetch_end_dependencies(cls,
+                                          current_node_id: str,
+                                          end_node_id: str,
+                                          node_id_config_mapping: dict[str, dict],
+                                          reverse_edge_mapping: dict[str, list["GraphEdge"]],
+                                          # type: ignore[name-defined]
+                                          end_dependencies: dict[str, list[str]]
+                                          ) -> None:
+        """
+        Recursive fetch end dependencies
+        :param current_node_id: current node id
+        :param end_node_id: end node id
+        :param node_id_config_mapping: node id config mapping
+        :param reverse_edge_mapping: reverse edge mapping
+        :param end_dependencies: end dependencies
+        :return:
+        """
+        reverse_edges = reverse_edge_mapping.get(current_node_id, [])
+        for edge in reverse_edges:
+            source_node_id = edge.source_node_id
+            source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
+            if source_node_type in (
+                    NodeType.IF_ELSE.value,
+                    NodeType.QUESTION_CLASSIFIER,
+            ):
+                end_dependencies[end_node_id].append(source_node_id)
+            else:
+                cls._recursive_fetch_end_dependencies(
+                    current_node_id=source_node_id,
+                    end_node_id=end_node_id,
+                    node_id_config_mapping=node_id_config_mapping,
+                    reverse_edge_mapping=reverse_edge_mapping,
+                    end_dependencies=end_dependencies
+                )

+ 191 - 0
api/core/workflow/nodes/end/end_stream_processor.py

@@ -0,0 +1,191 @@
+import logging
+from collections.abc import Generator
+
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.event import (
+    GraphEngineEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
+
+logger = logging.getLogger(__name__)
+
+
+class EndStreamProcessor(StreamProcessor):
+
+    def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
+        super().__init__(graph, variable_pool)
+        self.end_stream_param = graph.end_stream_param
+        self.route_position = {}
+        for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
+            self.route_position[end_node_id] = 0
+        self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
+        self.has_outputed = False
+        self.outputed_node_ids = set()
+
+    def process(self,
+                generator: Generator[GraphEngineEvent, None, None]
+                ) -> Generator[GraphEngineEvent, None, None]:
+        for event in generator:
+            if isinstance(event, NodeRunStartedEvent):
+                if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
+                    self.reset()
+
+                yield event
+            elif isinstance(event, NodeRunStreamChunkEvent):
+                if event.in_iteration_id:
+                    if self.has_outputed and event.node_id not in self.outputed_node_ids:
+                        event.chunk_content = '\n' + event.chunk_content
+
+                    self.outputed_node_ids.add(event.node_id)
+                    self.has_outputed = True
+                    yield event
+                    continue
+
+                if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
+                    stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
+                        event.route_node_state.node_id
+                    ]
+                else:
+                    stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
+                    self.current_stream_chunk_generating_node_ids[
+                        event.route_node_state.node_id
+                    ] = stream_out_end_node_ids
+
+                if stream_out_end_node_ids:
+                    if self.has_outputed and event.node_id not in self.outputed_node_ids:
+                        event.chunk_content = '\n' + event.chunk_content
+
+                    self.outputed_node_ids.add(event.node_id)
+                    self.has_outputed = True
+                    yield event
+            elif isinstance(event, NodeRunSucceededEvent):
+                yield event
+                if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
+                    # update self.route_position after all stream event finished
+                    for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
+                        self.route_position[end_node_id] += 1
+
+                    del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
+
+                # remove unreachable nodes
+                self._remove_unreachable_nodes(event)
+
+                # generate stream outputs
+                yield from self._generate_stream_outputs_when_node_finished(event)
+            else:
+                yield event
+
+    def reset(self) -> None:
+        self.route_position = {}
+        for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
+            self.route_position[end_node_id] = 0
+        self.rest_node_ids = self.graph.node_ids.copy()
+        self.current_stream_chunk_generating_node_ids = {}
+
+    def _generate_stream_outputs_when_node_finished(self,
+                                                    event: NodeRunSucceededEvent
+                                                    ) -> Generator[GraphEngineEvent, None, None]:
+        """
+        Generate stream outputs.
+        :param event: node run succeeded event
+        :return:
+        """
+        for end_node_id, position in self.route_position.items():
+            # all depends on end node id not in rest node ids
+            if (event.route_node_state.node_id != end_node_id
+                    and (end_node_id not in self.rest_node_ids
+                         or not all(dep_id not in self.rest_node_ids
+                                    for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
+                continue
+
+            route_position = self.route_position[end_node_id]
+
+            position = 0
+            value_selectors = []
+            for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
+                if position >= route_position:
+                    value_selectors.append(current_value_selectors)
+
+                position += 1
+
+            for value_selector in value_selectors:
+                if not value_selector:
+                    continue
+
+                value = self.variable_pool.get(
+                    value_selector
+                )
+
+                if value is None:
+                    break
+
+                text = value.markdown
+
+                if text:
+                    current_node_id = value_selector[0]
+                    if self.has_outputed and current_node_id not in self.outputed_node_ids:
+                        text = '\n' + text
+
+                    self.outputed_node_ids.add(current_node_id)
+                    self.has_outputed = True
+                    yield NodeRunStreamChunkEvent(
+                        id=event.id,
+                        node_id=event.node_id,
+                        node_type=event.node_type,
+                        node_data=event.node_data,
+                        chunk_content=text,
+                        from_variable_selector=value_selector,
+                        route_node_state=event.route_node_state,
+                        parallel_id=event.parallel_id,
+                        parallel_start_node_id=event.parallel_start_node_id,
+                    )
+
+                self.route_position[end_node_id] += 1
+
+    def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
+        """
+        Is stream out support
+        :param event: queue text chunk event
+        :return:
+        """
+        if not event.from_variable_selector:
+            return []
+
+        stream_output_value_selector = event.from_variable_selector
+        if not stream_output_value_selector:
+            return []
+
+        stream_out_end_node_ids = []
+        for end_node_id, route_position in self.route_position.items():
+            if end_node_id not in self.rest_node_ids:
+                continue
+
+            # all depends on end node id not in rest node ids
+            if all(dep_id not in self.rest_node_ids
+                   for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
+                if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
+                    continue
+
+                position = 0
+                value_selector = None
+                for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
+                    if position == route_position:
+                        value_selector = current_value_selectors
+                        break
+
+                    position += 1
+                    
+                if not value_selector:
+                    continue
+
+                # check chunk node id is before current node id or equal to current node id
+                if value_selector != stream_output_value_selector:
+                    continue
+
+                stream_out_end_node_ids.append(end_node_id)
+
+        return stream_out_end_node_ids

+ 16 - 0
api/core/workflow/nodes/end/entities.py

@@ -1,3 +1,5 @@
+from pydantic import BaseModel, Field
+
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.variable_entities import VariableSelector
 
@@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
     END Node Data.
     """
     outputs: list[VariableSelector]
+
+
+class EndStreamParam(BaseModel):
+    """
+    EndStreamParam entity
+    """
+    end_dependencies: dict[str, list[str]] = Field(
+        ...,
+        description="end dependencies (end node id -> dependent node ids)"
+    )
+    end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
+        ...,
+        description="end stream variable selector mapping (end node id -> stream variable selectors)"
+    )

+ 20 - 0
api/core/workflow/nodes/event.py

@@ -0,0 +1,20 @@
+from pydantic import BaseModel, Field
+
+from core.workflow.entities.node_entities import NodeRunResult
+
+
+class RunCompletedEvent(BaseModel):
+    run_result: NodeRunResult = Field(..., description="run result")
+
+
+class RunStreamChunkEvent(BaseModel):
+    chunk_content: str = Field(..., description="chunk content")
+    from_variable_selector: list[str] = Field(..., description="from variable selector")
+
+
+class RunRetrieverResourceEvent(BaseModel):
+    retriever_resources: list[dict] = Field(..., description="retriever resources")
+    context: str = Field(..., description="context")
+
+
+RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

+ 19 - 9
api/core/workflow/nodes/http_request/http_request_node.py

@@ -1,15 +1,14 @@
 import logging
+from collections.abc import Mapping, Sequence
 from mimetypes import guess_extension
 from os import path
-from typing import cast
+from typing import Any, cast
 
 from configs import dify_config
 from core.app.segments import parser
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.tools.tool_file_manager import ToolFileManager
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.http_request.entities import (
     HttpRequestNodeData,
@@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode):
             },
         }
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
         # TODO: Switch to use segment directly
         if node_data.authorization.config and node_data.authorization.config.api_key:
-            node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
+            node_data.authorization.config.api_key = parser.convert_template(
+                template=node_data.authorization.config.api_key, 
+                variable_pool=self.graph_runtime_state.variable_pool
+                ).text
 
         # init http executor
         http_executor = None
         try:
             http_executor = HttpExecutor(
-                node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
+                node_data=node_data,
+                timeout=self._get_request_timeout(node_data),
+                variable_pool=self.graph_runtime_state.variable_pool
             )
 
             # invoke http executor
@@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode):
         return timeout
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: HttpRequestNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
-        node_data = cast(HttpRequestNodeData, node_data)
         try:
             http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
 
@@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode):
 
             variable_mapping = {}
             for variable_selector in variable_selectors:
-                variable_mapping[variable_selector.variable] = variable_selector.value_selector
+                variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
 
             return variable_mapping
         except Exception as e:

+ 1 - 14
api/core/workflow/nodes/if_else/entities.py

@@ -3,20 +3,7 @@ from typing import Literal, Optional
 from pydantic import BaseModel
 
 from core.workflow.entities.base_node_data_entities import BaseNodeData
-
-
-class Condition(BaseModel):
-    """
-    Condition entity
-    """
-    variable_selector: list[str]
-    comparison_operator: Literal[
-        # for string or array
-        "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match",
-            # for number
-        "=", "≠", ">", "<", "≥", "≤", "null", "not null"
-    ]
-    value: Optional[str] = None
+from core.workflow.utils.condition.entities import Condition
 
 
 class IfElseNodeData(BaseNodeData):

+ 25 - 380
api/core/workflow/nodes/if_else/if_else_node.py

@@ -1,13 +1,10 @@
-import re
-from collections.abc import Sequence
-from typing import Optional, cast
+from collections.abc import Mapping, Sequence
+from typing import Any, cast
 
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
-from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
-from core.workflow.utils.variable_template_parser import VariableTemplateParser
+from core.workflow.nodes.if_else.entities import IfElseNodeData
+from core.workflow.utils.condition.processor import ConditionProcessor
 from models.workflow import WorkflowNodeExecutionStatus
 
 
@@ -15,31 +12,35 @@ class IfElseNode(BaseNode):
     _node_data_cls = IfElseNodeData
     _node_type = NodeType.IF_ELSE
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
         node_data = self.node_data
         node_data = cast(IfElseNodeData, node_data)
 
-        node_inputs = {
+        node_inputs: dict[str, list] = {
             "conditions": []
         }
 
-        process_datas = {
+        process_datas: dict[str, list] = {
             "condition_results": []
         }
 
         input_conditions = []
         final_result = False
         selected_case_id = None
+        condition_processor = ConditionProcessor()
         try:
             # Check if the new cases structure is used
             if node_data.cases:
                 for case in node_data.cases:
-                    input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
+                    input_conditions, group_result = condition_processor.process_conditions(
+                        variable_pool=self.graph_runtime_state.variable_pool,
+                        conditions=case.conditions
+                    )
+
                     # Apply the logical operator for the current case
                     final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
 
@@ -58,7 +59,10 @@ class IfElseNode(BaseNode):
 
             else:
                 # Fallback to old structure if cases are not defined
-                input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
+                input_conditions, group_result = condition_processor.process_conditions(
+                    variable_pool=self.graph_runtime_state.variable_pool,
+                    conditions=node_data.conditions
+                )
 
                 final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
 
@@ -94,376 +98,17 @@ class IfElseNode(BaseNode):
 
         return data
 
-    def evaluate_condition(
-        self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
-    ) -> bool:
-        """
-        Evaluate condition
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :param comparison_operator: comparison operator
-
-        :return: bool
-        """
-        if comparison_operator == "contains":
-            return self._assert_contains(actual_value, expected_value)
-        elif comparison_operator == "not contains":
-            return self._assert_not_contains(actual_value, expected_value)
-        elif comparison_operator == "start with":
-            return self._assert_start_with(actual_value, expected_value)
-        elif comparison_operator == "end with":
-            return self._assert_end_with(actual_value, expected_value)
-        elif comparison_operator == "is":
-            return self._assert_is(actual_value, expected_value)
-        elif comparison_operator == "is not":
-            return self._assert_is_not(actual_value, expected_value)
-        elif comparison_operator == "empty":
-            return self._assert_empty(actual_value)
-        elif comparison_operator == "not empty":
-            return self._assert_not_empty(actual_value)
-        elif comparison_operator == "=":
-            return self._assert_equal(actual_value, expected_value)
-        elif comparison_operator == "≠":
-            return self._assert_not_equal(actual_value, expected_value)
-        elif comparison_operator == ">":
-            return self._assert_greater_than(actual_value, expected_value)
-        elif comparison_operator == "<":
-            return self._assert_less_than(actual_value, expected_value)
-        elif comparison_operator == "≥":
-            return self._assert_greater_than_or_equal(actual_value, expected_value)
-        elif comparison_operator == "≤":
-            return self._assert_less_than_or_equal(actual_value, expected_value)
-        elif comparison_operator == "null":
-            return self._assert_null(actual_value)
-        elif comparison_operator == "not null":
-            return self._assert_not_null(actual_value)
-        elif comparison_operator == "regex match":
-            return self._assert_regex_match(actual_value, expected_value)
-        else:
-            raise ValueError(f"Invalid comparison operator: {comparison_operator}")
-
-    def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
-        input_conditions = []
-        group_result = []
-
-        for condition in conditions:
-            actual_variable = variable_pool.get_any(condition.variable_selector)
-
-            if condition.value is not None:
-                variable_template_parser = VariableTemplateParser(template=condition.value)
-                expected_value = variable_template_parser.extract_variable_selectors()
-                variable_selectors = variable_template_parser.extract_variable_selectors()
-                if variable_selectors:
-                    for variable_selector in variable_selectors:
-                        value = variable_pool.get_any(variable_selector.value_selector)
-                        expected_value = variable_template_parser.format({variable_selector.variable: value})
-                else:
-                    expected_value = condition.value
-            else:
-                expected_value = None
-
-            comparison_operator = condition.comparison_operator
-            input_conditions.append(
-                {
-                    "actual_value": actual_variable,
-                    "expected_value": expected_value,
-                    "comparison_operator": comparison_operator
-                }
-            )
-
-            result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
-            group_result.append(result)
-
-        return input_conditions, group_result
-
-    def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
-        """
-        Assert contains
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if not actual_value:
-            return False
-
-        if not isinstance(actual_value, str | list):
-            raise ValueError('Invalid actual value type: string or array')
-
-        if expected_value not in actual_value:
-            return False
-        return True
-
-    def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
-        """
-        Assert not contains
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if not actual_value:
-            return True
-
-        if not isinstance(actual_value, str | list):
-            raise ValueError('Invalid actual value type: string or array')
-
-        if expected_value in actual_value:
-            return False
-        return True
-
-    def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
-        """
-        Assert start with
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if not actual_value:
-            return False
-
-        if not isinstance(actual_value, str):
-            raise ValueError('Invalid actual value type: string')
-
-        if not actual_value.startswith(expected_value):
-            return False
-        return True
-
-    def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
-        """
-        Assert end with
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if not actual_value:
-            return False
-
-        if not isinstance(actual_value, str):
-            raise ValueError('Invalid actual value type: string')
-
-        if not actual_value.endswith(expected_value):
-            return False
-        return True
-
-    def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
-        """
-        Assert is
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, str):
-            raise ValueError('Invalid actual value type: string')
-
-        if actual_value != expected_value:
-            return False
-        return True
-
-    def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
-        """
-        Assert is not
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, str):
-            raise ValueError('Invalid actual value type: string')
-
-        if actual_value == expected_value:
-            return False
-        return True
-
-    def _assert_empty(self, actual_value: Optional[str]) -> bool:
-        """
-        Assert empty
-        :param actual_value: actual value
-        :return:
-        """
-        if not actual_value:
-            return True
-        return False
-
-    def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool:
-        """
-        Assert empty
-        :param actual_value: actual value
-        :return:
-        """
-        if actual_value is None:
-            return False
-        return re.search(expected_value, actual_value) is not None
-
-    def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
-        """
-        Assert not empty
-        :param actual_value: actual value
-        :return:
-        """
-        if actual_value:
-            return True
-        return False
-
-    def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert equal
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value != expected_value:
-            return False
-        return True
-
-    def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert not equal
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value == expected_value:
-            return False
-        return True
-
-    def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert greater than
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value <= expected_value:
-            return False
-        return True
-
-    def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert less than
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value >= expected_value:
-            return False
-        return True
-
-    def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert greater than or equal
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value < expected_value:
-            return False
-        return True
-
-    def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
-        """
-        Assert less than or equal
-        :param actual_value: actual value
-        :param expected_value: expected value
-        :return:
-        """
-        if actual_value is None:
-            return False
-
-        if not isinstance(actual_value, int | float):
-            raise ValueError('Invalid actual value type: number')
-
-        if isinstance(actual_value, int):
-            expected_value = int(expected_value)
-        else:
-            expected_value = float(expected_value)
-
-        if actual_value > expected_value:
-            return False
-        return True
-
-    def _assert_null(self, actual_value: Optional[int | float]) -> bool:
-        """
-        Assert null
-        :param actual_value: actual value
-        :return:
-        """
-        if actual_value is None:
-            return True
-        return False
-
-    def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
-        """
-        Assert not null
-        :param actual_value: actual value
-        :return:
-        """
-        if actual_value is not None:
-            return True
-        return False
-
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: IfElseNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """

+ 8 - 1
api/core/workflow/nodes/iteration/entities.py

@@ -1,6 +1,6 @@
 from typing import Any, Optional
 
-from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
+from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
 
 
 class IterationNodeData(BaseIterationNodeData):
@@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
     iterator_selector: list[str] # variable selector
     output_selector: list[str] # output selector
 
+
+class IterationStartNodeData(BaseNodeData):
+    """
+    Iteration Start Node Data.
+    """
+    pass
+
 class IterationState(BaseIterationState):
     """
     Iteration State.

+ 338 - 91
api/core/workflow/nodes/iteration/iteration_node.py

@@ -1,124 +1,371 @@
-from typing import cast
+import logging
+from collections.abc import Generator, Mapping, Sequence
+from datetime import datetime, timezone
+from typing import Any, cast
 
+from configs import dify_config
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.workflow.entities.base_node_data_entities import BaseIterationState
-from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import BaseIterationNode
-from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
+from core.workflow.graph_engine.entities.event import (
+    BaseGraphEvent,
+    BaseNodeEvent,
+    BaseParallelBranchEvent,
+    GraphRunFailedEvent,
+    InNodeEvent,
+    IterationRunFailedEvent,
+    IterationRunNextEvent,
+    IterationRunStartedEvent,
+    IterationRunSucceededEvent,
+    NodeRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.run_condition import RunCondition
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.event import RunCompletedEvent, RunEvent
+from core.workflow.nodes.iteration.entities import IterationNodeData
+from core.workflow.utils.condition.entities import Condition
 from models.workflow import WorkflowNodeExecutionStatus
 
+logger = logging.getLogger(__name__)
 
-class IterationNode(BaseIterationNode):
+
+class IterationNode(BaseNode):
     """
     Iteration Node.
     """
     _node_data_cls = IterationNodeData
     _node_type = NodeType.ITERATION
 
-    def _run(self, variable_pool: VariablePool) -> BaseIterationState:
+    def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
         """
         Run the node.
         """
         self.node_data = cast(IterationNodeData, self.node_data)
-        iterator = variable_pool.get_any(self.node_data.iterator_selector)
-
-        if not isinstance(iterator, list):
-            raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
+        iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
 
-        state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
-            'iterator_selector': iterator
-        }, outputs=[], metadata=IterationState.MetaData(
-            iterator_length=len(iterator) if iterator is not None else 0
-        ))
+        if not iterator_list_segment:
+            raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
         
-        self._set_current_iteration_variable(variable_pool, state)
-        return state
+        iterator_list_value = iterator_list_segment.to_object()
 
-    def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
-        """
-        Get next iteration start node id based on the graph.
-        :param graph: graph
-        :return: next node id
-        """
-        # resolve current output
-        self._resolve_current_output(variable_pool, state)
-        # move to next iteration
-        self._next_iteration(variable_pool, state)
-
-        node_data = cast(IterationNodeData, self.node_data)
-        if self._reached_iteration_limit(variable_pool, state):
-            return NodeRunResult(
-                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+        if not isinstance(iterator_list_value, list):
+            raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
+
+        inputs = {
+            "iterator_selector": iterator_list_value
+        }
+
+        graph_config = self.graph_config
+    
+        if not self.node_data.start_node_id:
+            raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
+
+        root_node_id = self.node_data.start_node_id
+
+        # init graph
+        iteration_graph = Graph.init(
+            graph_config=graph_config,
+            root_node_id=root_node_id
+        )
+
+        if not iteration_graph:
+            raise ValueError('iteration graph not found')
+
+        leaf_node_ids = iteration_graph.get_leaf_node_ids()
+        iteration_leaf_node_ids = []
+        for leaf_node_id in leaf_node_ids:
+            node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
+            if not node_config:
+                continue
+
+            leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
+            if not leaf_node_iteration_id:
+                continue
+
+            if leaf_node_iteration_id != self.node_id:
+                continue
+
+            iteration_leaf_node_ids.append(leaf_node_id)
+
+            # add condition of end nodes to root node
+            iteration_graph.add_extra_edge(
+                source_node_id=leaf_node_id,
+                target_node_id=root_node_id,
+                run_condition=RunCondition(
+                    type="condition",
+                    conditions=[
+                        Condition(
+                            variable_selector=[self.node_id, "index"],
+                            comparison_operator="<",
+                            value=str(len(iterator_list_value))
+                        )
+                    ]
+                )
+            )
+
+        variable_pool = self.graph_runtime_state.variable_pool
+
+        # append iteration variable (item, index) to variable pool
+        variable_pool.add(
+            [self.node_id, 'index'],
+            0
+        )
+        variable_pool.add(
+            [self.node_id, 'item'],
+            iterator_list_value[0]
+        )
+
+        # init graph engine
+        from core.workflow.graph_engine.graph_engine import GraphEngine
+        graph_engine = GraphEngine(
+            tenant_id=self.tenant_id,
+            app_id=self.app_id,
+            workflow_type=self.workflow_type,
+            workflow_id=self.workflow_id,
+            user_id=self.user_id,
+            user_from=self.user_from,
+            invoke_from=self.invoke_from,
+            call_depth=self.workflow_call_depth,
+            graph=iteration_graph,
+            graph_config=graph_config,
+            variable_pool=variable_pool,
+            max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
+            max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
+        )
+
+        start_at = datetime.now(timezone.utc).replace(tzinfo=None)
+
+        yield IterationRunStartedEvent(
+            iteration_id=self.id,
+            iteration_node_id=self.node_id,
+            iteration_node_type=self.node_type,
+            iteration_node_data=self.node_data,
+            start_at=start_at,
+            inputs=inputs,
+            metadata={
+                "iterator_length": len(iterator_list_value)
+            },
+            predecessor_node_id=self.previous_node_id
+        )
+
+        yield IterationRunNextEvent(
+            iteration_id=self.id,
+            iteration_node_id=self.node_id,
+            iteration_node_type=self.node_type,
+            iteration_node_data=self.node_data,
+            index=0,
+            pre_iteration_output=None
+        )
+
+        outputs: list[Any] = []
+        try:
+            # run workflow
+            rst = graph_engine.run()
+            for event in rst:
+                if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
+                    event.in_iteration_id = self.node_id
+
+                if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
+                    continue
+
+                if isinstance(event, NodeRunSucceededEvent):
+                    if event.route_node_state.node_run_result:
+                        metadata = event.route_node_state.node_run_result.metadata
+                        if not metadata:
+                            metadata = {}
+
+                        if NodeRunMetadataKey.ITERATION_ID not in metadata:
+                            metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
+                            metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
+                            event.route_node_state.node_run_result.metadata = metadata
+
+                    yield event
+
+                    # handle iteration run result
+                    if event.route_node_state.node_id in iteration_leaf_node_ids:
+                        # append to iteration output variable list
+                        current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
+                        outputs.append(current_iteration_output)
+
+                        # remove all nodes outputs from variable pool
+                        for node_id in iteration_graph.node_ids:
+                            variable_pool.remove_node(node_id)
+
+                        # move to next iteration
+                        current_index = variable_pool.get([self.node_id, 'index'])
+                        if current_index is None:
+                            raise ValueError(f'iteration {self.node_id} current index not found')
+
+                        next_index = int(current_index.to_object()) + 1
+                        variable_pool.add(
+                            [self.node_id, 'index'],
+                            next_index
+                        )
+
+                        if next_index < len(iterator_list_value):
+                            variable_pool.add(
+                                [self.node_id, 'item'],
+                                iterator_list_value[next_index]
+                            )
+
+                        yield IterationRunNextEvent(
+                            iteration_id=self.id,
+                            iteration_node_id=self.node_id,
+                            iteration_node_type=self.node_type,
+                            iteration_node_data=self.node_data,
+                            index=next_index,
+                            pre_iteration_output=jsonable_encoder(
+                                current_iteration_output) if current_iteration_output else None
+                        )
+                elif isinstance(event, BaseGraphEvent):
+                    if isinstance(event, GraphRunFailedEvent):
+                        # iteration run failed
+                        yield IterationRunFailedEvent(
+                            iteration_id=self.id,
+                            iteration_node_id=self.node_id,
+                            iteration_node_type=self.node_type,
+                            iteration_node_data=self.node_data,
+                            start_at=start_at,
+                            inputs=inputs,
+                            outputs={
+                                "output": jsonable_encoder(outputs)
+                            },
+                            steps=len(iterator_list_value),
+                            metadata={
+                                "total_tokens": graph_engine.graph_runtime_state.total_tokens
+                            },
+                            error=event.error,
+                        )
+
+                        yield RunCompletedEvent(
+                            run_result=NodeRunResult(
+                                status=WorkflowNodeExecutionStatus.FAILED,
+                                error=event.error,
+                            )
+                        )
+                        break
+                else:
+                    event = cast(InNodeEvent, event)
+                    yield event
+
+            yield IterationRunSucceededEvent(
+                iteration_id=self.id,
+                iteration_node_id=self.node_id,
+                iteration_node_type=self.node_type,
+                iteration_node_data=self.node_data,
+                start_at=start_at,
+                inputs=inputs,
                 outputs={
-                    'output': jsonable_encoder(state.outputs)
+                    "output": jsonable_encoder(outputs)
+                },
+                steps=len(iterator_list_value),
+                metadata={
+                    "total_tokens": graph_engine.graph_runtime_state.total_tokens
                 }
             )
+
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                    outputs={
+                        'output': jsonable_encoder(outputs)
+                    }
+                )
+            )
+        except Exception as e:
+            # iteration run failed
+            logger.exception("Iteration run failed")
+            yield IterationRunFailedEvent(
+                iteration_id=self.id,
+                iteration_node_id=self.node_id,
+                iteration_node_type=self.node_type,
+                iteration_node_data=self.node_data,
+                start_at=start_at,
+                inputs=inputs,
+                outputs={
+                    "output": jsonable_encoder(outputs)
+                },
+                steps=len(iterator_list_value),
+                metadata={
+                    "total_tokens": graph_engine.graph_runtime_state.total_tokens
+                },
+                error=str(e),
+            )
         
-        return node_data.start_node_id
+
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    error=str(e),
+                )
+            )
+        finally:
+            # remove iteration variable (item, index) from variable pool after iteration run completed
+            variable_pool.remove([self.node_id, 'index'])
+            variable_pool.remove([self.node_id, 'item'])
     
-    def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
+    @classmethod
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: IterationNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
-        Set current iteration variable.
-        :variable_pool: variable pool
+        Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
+        :param node_data: node data
+        :return:
         """
-        node_data = cast(IterationNodeData, self.node_data)
+        variable_mapping = {
+            f'{node_id}.input_selector': node_data.iterator_selector,
+        }
 
-        variable_pool.add((self.node_id, 'index'), state.index)
-        # get the iterator value
-        iterator = variable_pool.get_any(node_data.iterator_selector)
+        # init graph
+        iteration_graph = Graph.init(
+            graph_config=graph_config,
+            root_node_id=node_data.start_node_id
+        )
 
-        if iterator is None or not isinstance(iterator, list):
-            return
+        if not iteration_graph:
+            raise ValueError('iteration graph not found')
         
-        if state.index < len(iterator):
-            variable_pool.add((self.node_id, 'item'), iterator[state.index])
+        for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
+            if sub_node_config.get('data', {}).get('iteration_id') != node_id:
+                continue
 
-    def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
-        """
-        Move to next iteration.
-        :param variable_pool: variable pool
-        """
-        state.index += 1
-        self._set_current_iteration_variable(variable_pool, state)
+            # variable selector to variable mapping
+            try:
+                # Get node class
+                from core.workflow.nodes.node_mapping import node_classes
+                node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
+                node_cls = node_classes.get(node_type)
+                if not node_cls:
+                    continue
 
-    def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
-        """
-        Check if iteration limit is reached.
-        :return: True if iteration limit is reached, False otherwise
-        """
-        node_data = cast(IterationNodeData, self.node_data)
-        iterator =  variable_pool.get_any(node_data.iterator_selector)
+                node_cls = cast(BaseNode, node_cls)
+                
+                sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+                    graph_config=graph_config, 
+                    config=sub_node_config
+                )
+                sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
+            except NotImplementedError:
+                sub_node_variable_mapping = {}
 
-        if iterator is None or not isinstance(iterator, list):
-            return True
+            # remove iteration variables
+            sub_node_variable_mapping = {
+                sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
+                if value[0] != node_id
+            }
 
-        return state.index >= len(iterator)
-    
-    def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
-        """
-        Resolve current output.
-        :param variable_pool: variable pool
-        """
-        output_selector = cast(IterationNodeData, self.node_data).output_selector
-        output = variable_pool.get_any(output_selector)
-        # clear the output for this iteration
-        variable_pool.remove([self.node_id] + output_selector[1:])
-        state.current_output = output
-        if output is not None:
-            # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
-            if isinstance(output, list):
-                state.outputs.extend(output)
-            else:
-                state.outputs.append(output)
+            variable_mapping.update(sub_node_variable_mapping)
 
-    @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
-        """
-        Extract variable selector to variable mapping
-        :param node_data: node data
-        :return:
-        """
-        return {
-            'input_selector': node_data.iterator_selector,
-        }
+        # remove variable out from iteration
+        variable_mapping = {
+            key: value for key, value in variable_mapping.items()
+            if value[0] not in iteration_graph.node_ids
+        }
+        
+        return variable_mapping

+ 39 - 0
api/core/workflow/nodes/iteration/iteration_start_node.py

@@ -0,0 +1,39 @@
+from collections.abc import Mapping, Sequence
+from typing import Any
+
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class IterationStartNode(BaseNode):
+    """
+    Iteration Start Node.
+    """
+    _node_data_cls = IterationStartNodeData
+    _node_type = NodeType.ITERATION_START
+
+    def _run(self) -> NodeRunResult:
+        """
+        Run the node.
+        """
+        return NodeRunResult(
+            status=WorkflowNodeExecutionStatus.SUCCEEDED
+        )
+    
+    @classmethod
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: IterationNodeData
+    ) -> Mapping[str, Sequence[str]]:
+        """
+        Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
+        :param node_data: node data
+        :return:
+        """
+        return {}

+ 22 - 10
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,3 +1,5 @@
+import logging
+from collections.abc import Mapping, Sequence
 from typing import Any, cast
 
 from sqlalchemy import func
@@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.workflow import WorkflowNodeExecutionStatus
 
+logger = logging.getLogger(__name__)
+
 default_retrieval_model = {
     'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
     'reranking_enable': False,
@@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode):
     _node_data_cls = KnowledgeRetrievalNodeData
     node_type = NodeType.KNOWLEDGE_RETRIEVAL
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
-        node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
+    def _run(self) -> NodeRunResult:
+        node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
 
         # extract variables
-        variable = variable_pool.get_any(node_data.query_variable_selector)
+        variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
         query = variable
         variables = {
             'query': query
@@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
             )
 
         except Exception as e:
-
+            logger.exception("Error when running knowledge retrieval node")
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 inputs=variables,
@@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode):
         return context_list
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
-        node_data = node_data
-        node_data = cast(cls._node_data_cls, node_data)
+    def _extract_variable_selector_to_variable_mapping(
+        cls,
+        graph_config: Mapping[str, Any],
+        node_id: str,
+        node_data: KnowledgeRetrievalNodeData
+    ) -> Mapping[str, Sequence[str]]:
+        """
+        Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
+        :param node_data: node data
+        :return:
+        """
         variable_mapping = {}
-        variable_mapping['query'] = node_data.query_variable_selector
+        variable_mapping[node_id + '.query'] = node_data.query_variable_selector
         return variable_mapping
 
     def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[

+ 112 - 54
api/core/workflow/nodes/llm/llm_node.py

@@ -1,16 +1,17 @@
 import json
-from collections.abc import Generator
+from collections.abc import Generator, Mapping, Sequence
 from copy import deepcopy
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional, cast
+
+from pydantic import BaseModel
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
 from core.entities.model_entities import ModelStatus
 from core.entities.provider_entities import QuotaUnit
 from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.llm_entities import LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
 from core.model_runtime.entities.message_entities import (
     ImagePromptMessageContent,
     PromptMessage,
@@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import InNodeEvent
 from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
 from core.workflow.nodes.llm.entities import (
     LLMNodeChatModelMessage,
     LLMNodeCompletionModelPromptTemplate,
@@ -43,17 +46,26 @@ if TYPE_CHECKING:
 
 
 
+class ModelInvokeCompleted(BaseModel):
+    """
+    Model invoke completed
+    """
+    text: str
+    usage: LLMUsage
+    finish_reason: Optional[str] = None
+
+
 class LLMNode(BaseNode):
     _node_data_cls = LLMNodeData
     _node_type = NodeType.LLM
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
         node_data = cast(LLMNodeData, deepcopy(self.node_data))
+        variable_pool = self.graph_runtime_state.variable_pool
 
         node_inputs = None
         process_data = None
@@ -80,10 +92,15 @@ class LLMNode(BaseNode):
                 node_inputs['#files#'] = [file.to_dict() for file in files]
 
             # fetch context value
-            context = self._fetch_context(node_data, variable_pool)
+            generator = self._fetch_context(node_data, variable_pool)
+            context = None
+            for event in generator:
+                if isinstance(event, RunRetrieverResourceEvent):
+                    context = event.context
+                    yield event
 
             if context:
-                node_inputs['#context#'] = context
+                node_inputs['#context#'] = context  # type: ignore
 
             # fetch model config
             model_instance, model_config = self._fetch_model_config(node_data.model)
@@ -115,19 +132,34 @@ class LLMNode(BaseNode):
             }
 
             # handle invoke result
-            result_text, usage, finish_reason = self._invoke_llm(
+            generator = self._invoke_llm(
                 node_data_model=node_data.model,
                 model_instance=model_instance,
                 prompt_messages=prompt_messages,
                 stop=stop
             )
+
+            result_text = ''
+            usage = LLMUsage.empty_usage()
+            finish_reason = None
+            for event in generator:
+                if isinstance(event, RunStreamChunkEvent):
+                    yield event
+                elif isinstance(event, ModelInvokeCompleted):
+                    result_text = event.text
+                    usage = event.usage
+                    finish_reason = event.finish_reason
+                    break
         except Exception as e:
-            return NodeRunResult(
-                status=WorkflowNodeExecutionStatus.FAILED,
-                error=str(e),
-                inputs=node_inputs,
-                process_data=process_data
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    error=str(e),
+                    inputs=node_inputs,
+                    process_data=process_data
+                )
             )
+            return
 
         outputs = {
             'text': result_text,
@@ -135,22 +167,26 @@ class LLMNode(BaseNode):
             'finish_reason': finish_reason
         }
 
-        return NodeRunResult(
-            status=WorkflowNodeExecutionStatus.SUCCEEDED,
-            inputs=node_inputs,
-            process_data=process_data,
-            outputs=outputs,
-            metadata={
-                NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
-                NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
-                NodeRunMetadataKey.CURRENCY: usage.currency
-            }
+        yield RunCompletedEvent(
+            run_result=NodeRunResult(
+                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                inputs=node_inputs,
+                process_data=process_data,
+                outputs=outputs,
+                metadata={
+                    NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
+                    NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
+                    NodeRunMetadataKey.CURRENCY: usage.currency
+                },
+                llm_usage=usage
+            )
         )
 
     def _invoke_llm(self, node_data_model: ModelConfig,
                     model_instance: ModelInstance,
                     prompt_messages: list[PromptMessage],
-                    stop: list[str]) -> tuple[str, LLMUsage]:
+                    stop: Optional[list[str]] = None) \
+            -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
         """
         Invoke large language model
         :param node_data_model: node data model
@@ -170,23 +206,31 @@ class LLMNode(BaseNode):
         )
 
         # handle invoke result
-        text, usage, finish_reason = self._handle_invoke_result(
+        generator = self._handle_invoke_result(
             invoke_result=invoke_result
         )
 
+        usage = LLMUsage.empty_usage()
+        for event in generator:
+            yield event
+            if isinstance(event, ModelInvokeCompleted):
+                usage = event.usage
+
         # deduct quota
         self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
 
-        return text, usage, finish_reason
-
-    def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
+    def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
+            -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
         """
         Handle invoke result
         :param invoke_result: invoke result
         :return:
         """
+        if isinstance(invoke_result, LLMResult):
+            return
+
         model = None
-        prompt_messages = []
+        prompt_messages: list[PromptMessage] = []
         full_text = ''
         usage = None
         finish_reason = None
@@ -194,7 +238,10 @@ class LLMNode(BaseNode):
             text = result.delta.message.content
             full_text += text
 
-            self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
+            yield RunStreamChunkEvent(
+                chunk_content=text,
+                from_variable_selector=[self.node_id, 'text']
+            )
 
             if not model:
                 model = result.model
@@ -211,11 +258,15 @@ class LLMNode(BaseNode):
         if not usage:
             usage = LLMUsage.empty_usage()
 
-        return full_text, usage, finish_reason
+        yield ModelInvokeCompleted(
+            text=full_text,
+            usage=usage,
+            finish_reason=finish_reason
+        )
 
     def _transform_chat_messages(self,
-        messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
-    ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
+                                 messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
+                                 ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
         """
         Transform chat messages
 
@@ -224,13 +275,13 @@ class LLMNode(BaseNode):
         """
 
         if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
-            if messages.edition_type == 'jinja2':
+            if messages.edition_type == 'jinja2' and messages.jinja2_text:
                 messages.text = messages.jinja2_text
 
             return messages
 
         for message in messages:
-            if message.edition_type == 'jinja2':
+            if message.edition_type == 'jinja2' and message.jinja2_text:
                 message.text = message.jinja2_text
 
         return messages
@@ -348,7 +399,7 @@ class LLMNode(BaseNode):
 
         return files
 
-    def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
+    def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
         """
         Fetch context
         :param node_data: node data
@@ -356,15 +407,18 @@ class LLMNode(BaseNode):
         :return:
         """
         if not node_data.context.enabled:
-            return None
+            return
 
         if not node_data.context.variable_selector:
-            return None
+            return
 
         context_value = variable_pool.get_any(node_data.context.variable_selector)
         if context_value:
             if isinstance(context_value, str):
-                return context_value
+                yield RunRetrieverResourceEvent(
+                    retriever_resources=[],
+                    context=context_value
+                )
             elif isinstance(context_value, list):
                 context_str = ''
                 original_retriever_resource = []
@@ -381,17 +435,10 @@ class LLMNode(BaseNode):
                         if retriever_resource:
                             original_retriever_resource.append(retriever_resource)
 
-                if self.callbacks and original_retriever_resource:
-                    for callback in self.callbacks:
-                        callback.on_event(
-                            event=QueueRetrieverResourcesEvent(
-                                retriever_resources=original_retriever_resource
-                            )
-                        )
-
-                return context_str.strip()
-
-        return None
+                yield RunRetrieverResourceEvent(
+                    retriever_resources=original_retriever_resource,
+                    context=context_str.strip()
+                )
 
     def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
         """
@@ -574,7 +621,8 @@ class LLMNode(BaseNode):
             if not isinstance(prompt_message.content, str):
                 prompt_message_content = []
                 for content_item in prompt_message.content:
-                    if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
+                    if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
+                            content_item, ImagePromptMessageContent):
                         # Override vision config if LLM node has vision config
                         if vision_detail:
                             content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@@ -646,13 +694,19 @@ class LLMNode(BaseNode):
             db.session.commit()
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls,
+        graph_config: Mapping[str, Any],
+        node_id: str,
+        node_data: LLMNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
-
         prompt_template = node_data.prompt_template
 
         variable_selectors = []
@@ -702,6 +756,10 @@ class LLMNode(BaseNode):
                 for variable_selector in node_data.prompt_config.jinja2_variables or []:
                     variable_mapping[variable_selector.variable] = variable_selector.value_selector
 
+        variable_mapping = {
+            node_id + '.' + key: value for key, value in variable_mapping.items()
+        }
+
         return variable_mapping
 
     @classmethod

+ 22 - 8
api/core/workflow/nodes/loop/loop_node.py

@@ -1,20 +1,34 @@
-from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import BaseIterationNode
+from typing import Any
+
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
+from core.workflow.utils.condition.entities import Condition
 
 
-class LoopNode(BaseIterationNode):
+class LoopNode(BaseNode):
     """
     Loop Node.
     """
     _node_data_cls = LoopNodeData
     _node_type = NodeType.LOOP
 
-    def _run(self, variable_pool: VariablePool) -> LoopState:
-        return super()._run(variable_pool)
+    def _run(self) -> LoopState:
+        return super()._run()
 
-    def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
+    @classmethod
+    def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
         """
-        Get next iteration start node id based on the graph.
+        Get conditions.
         """
+        node_id = node_config.get('id')
+        if not node_id:
+            return []
+
+        # TODO waiting for implementation
+        return [Condition(
+            variable_selector=[node_id, 'index'],
+            comparison_operator="≤",
+            value_type="value_selector",
+            value_selector=[]
+        )]

+ 37 - 0
api/core/workflow/nodes/node_mapping.py

@@ -0,0 +1,37 @@
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.nodes.answer.answer_node import AnswerNode
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.end.end_node import EndNode
+from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
+from core.workflow.nodes.if_else.if_else_node import IfElseNode
+from core.workflow.nodes.iteration.iteration_node import IterationNode
+from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
+from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.llm.llm_node import LLMNode
+from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
+from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
+from core.workflow.nodes.start.start_node import StartNode
+from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from core.workflow.nodes.tool.tool_node import ToolNode
+from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
+from core.workflow.nodes.variable_assigner import VariableAssignerNode
+
+node_classes = {
+    NodeType.START: StartNode,
+    NodeType.END: EndNode,
+    NodeType.ANSWER: AnswerNode,
+    NodeType.LLM: LLMNode,
+    NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
+    NodeType.IF_ELSE: IfElseNode,
+    NodeType.CODE: CodeNode,
+    NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
+    NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
+    NodeType.HTTP_REQUEST: HttpRequestNode,
+    NodeType.TOOL: ToolNode,
+    NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
+    NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,  # original name of VARIABLE_AGGREGATOR
+    NodeType.ITERATION: IterationNode,
+    NodeType.ITERATION_START: IterationStartNode,
+    NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
+    NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
+}

+ 24 - 11
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -1,6 +1,7 @@
 import json
 import uuid
-from typing import Optional, cast
+from collections.abc import Mapping, Sequence
+from typing import Any, Optional, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.memory.token_buffer_memory import TokenBufferMemory
@@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode):
             }
         }
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run the node.
         """
         node_data = cast(ParameterExtractorNodeData, self.node_data)
-        variable = variable_pool.get_any(node_data.query)
+        variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
         if not variable:
             raise ValueError("Input variable content not found or is empty")
         query = variable
@@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode):
             raise ValueError("Model schema not found")
 
         # fetch memory
-        memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
+        memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
 
         if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
             and node_data.reasoning_mode == 'function_call':
             # use function call 
             prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
-                node_data, query, variable_pool, model_config, memory
+                node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
             )
         else:
             # use prompt engineering
-            prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
+            prompt_messages = self._generate_prompt_engineering_prompt(node_data,
+                                                                       query,
+                                                                       self.graph_runtime_state.variable_pool,
+                                                                       model_config,
                                                                        memory)
             prompt_message_tools = []
 
@@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode):
                 NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
                 NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
                 NodeRunMetadataKey.CURRENCY: usage.currency
-            }
+            },
+            llm_usage=usage
         )
 
     def _invoke_llm(self, node_data_model: ModelConfig,
@@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode):
         return self._model_instance, self._model_config
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
-        str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: ParameterExtractorNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
-        node_data = node_data
-
         variable_mapping = {
             'query': node_data.query
         }
@@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode):
             for selector in variable_template_parser.extract_variable_selectors():
                 variable_mapping[selector.variable] = selector.value_selector
 
+        variable_mapping = {
+            node_id + '.' + key: value for key, value in variable_mapping.items()
+        }
+
         return variable_mapping

+ 40 - 10
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -1,10 +1,12 @@
 import json
 import logging
-from typing import Optional, Union, cast
+from collections.abc import Mapping, Sequence
+from typing import Any, Optional, Union, cast
 
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
+from core.model_runtime.entities.llm_entities import LLMUsage
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
 from core.model_runtime.entities.model_entities import ModelPropertyKey
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
 from core.prompt.simple_prompt_transform import ModelMode
 from core.prompt.utils.prompt_message_util import PromptMessageUtil
 from core.prompt.utils.prompt_template_parser import PromptTemplateParser
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
 from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.llm.llm_node import LLMNode
+from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
 from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
 from core.workflow.nodes.question_classifier.template_prompts import (
     QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
@@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
     _node_data_cls = QuestionClassifierNodeData
     node_type = NodeType.QUESTION_CLASSIFIER
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
         node_data = cast(QuestionClassifierNodeData, node_data)
+        variable_pool = self.graph_runtime_state.variable_pool
 
         # extract variables
         variable = variable_pool.get(node_data.query_variable_selector)
@@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode):
         )
 
         # handle invoke result
-        result_text, usage, finish_reason = self._invoke_llm(
+        generator = self._invoke_llm(
             node_data_model=node_data.model,
             model_instance=model_instance,
             prompt_messages=prompt_messages,
             stop=stop
         )
+
+        result_text = ''
+        usage = LLMUsage.empty_usage()
+        finish_reason = None
+        for event in generator:
+            if isinstance(event, ModelInvokeCompleted):
+                result_text = event.text
+                usage = event.usage
+                finish_reason = event.finish_reason
+                break
+
         category_name = node_data.classes[0].name
         category_id = node_data.classes[0].id
         try:
@@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode):
                     NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
                     NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
                     NodeRunMetadataKey.CURRENCY: usage.currency
-                }
+                },
+                llm_usage=usage
             )
 
         except ValueError as e:
@@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode):
                     NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
                     NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
                     NodeRunMetadataKey.CURRENCY: usage.currency
-                }
+                },
+                llm_usage=usage
             )
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
-        node_data = node_data
-        node_data = cast(cls._node_data_cls, node_data)
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: QuestionClassifierNodeData
+    ) -> Mapping[str, Sequence[str]]:
+        """
+        Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
+        :param node_data: node data
+        :return:
+        """
         variable_mapping = {'query': node_data.query_variable_selector}
         variable_selectors = []
         if node_data.instruction:
@@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode):
             variable_selectors.extend(variable_template_parser.extract_variable_selectors())
         for variable_selector in variable_selectors:
             variable_mapping[variable_selector.variable] = variable_selector.value_selector
+
+        variable_mapping = {
+            node_id + '.' + key: value for key, value in variable_mapping.items()
+        }
+        
         return variable_mapping
 
     @classmethod

+ 15 - 7
api/core/workflow/nodes/start/start_node.py

@@ -1,7 +1,9 @@
 
-from core.workflow.entities.base_node_data_entities import BaseNodeData
+from collections.abc import Mapping, Sequence
+from typing import Any
+
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
+from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.start.entities import StartNodeData
 from models.workflow import WorkflowNodeExecutionStatus
@@ -11,14 +13,13 @@ class StartNode(BaseNode):
     _node_data_cls = StartNodeData
     _node_type = NodeType.START
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run node
-        :param variable_pool: variable pool
         :return:
         """
-        node_inputs = dict(variable_pool.user_inputs)
-        system_inputs = variable_pool.system_variables
+        node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
+        system_inputs = self.graph_runtime_state.variable_pool.system_variables
 
         for var in system_inputs:
             node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
@@ -30,9 +31,16 @@ class StartNode(BaseNode):
         )
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls,
+        graph_config: Mapping[str, Any],
+        node_id: str,
+        node_data: StartNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """

+ 17 - 9
api/core/workflow/nodes/template_transform/template_transform_node.py

@@ -1,15 +1,16 @@
 import os
-from typing import Optional, cast
+from collections.abc import Mapping, Sequence
+from typing import Any, Optional, cast
 
 from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
 from models.workflow import WorkflowNodeExecutionStatus
 
 MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
 
+
 class TemplateTransformNode(BaseNode):
     _node_data_cls = TemplateTransformNodeData
     _node_type = NodeType.TEMPLATE_TRANSFORM
@@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode):
             }
         }
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run node
         """
@@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode):
         variables = {}
         for variable_selector in node_data.variables:
             variable_name = variable_selector.variable
-            value = variable_pool.get_any(variable_selector.value_selector)
+            value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
             variables[variable_name] = value
         # Run code
         try:
@@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode):
                 status=WorkflowNodeExecutionStatus.FAILED,
                 error=str(e)
             )
-        
+
         if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
             return NodeRunResult(
                 inputs=variables,
@@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode):
                 'output': result['result']
             }
         )
-    
+
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: TemplateTransformNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
         return {
-            variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
-        }
+            node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
+        }

+ 18 - 5
api/core/workflow/nodes/tool/tool_node.py

@@ -26,7 +26,7 @@ class ToolNode(BaseNode):
     _node_data_cls = ToolNodeData
     _node_type = NodeType.TOOL
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         """
         Run the tool node
         """
@@ -56,8 +56,8 @@ class ToolNode(BaseNode):
 
         # get parameters
         tool_parameters = tool_runtime.get_runtime_parameters() or []
-        parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
-        parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
+        parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
+        parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
 
         try:
             messages = ToolEngine.workflow_invoke(
@@ -66,6 +66,7 @@ class ToolNode(BaseNode):
                 user_id=self.user_id,
                 workflow_tool_callback=DifyWorkflowCallbackHandler(),
                 workflow_call_depth=self.workflow_call_depth,
+                thread_pool_id=self.thread_pool_id,
             )
         except Exception as e:
             return NodeRunResult(
@@ -145,7 +146,8 @@ class ToolNode(BaseNode):
         assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
         return list(variable.value) if variable else []
 
-    def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
+    def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
+            -> tuple[str, list[FileVar], list[dict]]:
         """
         Convert ToolInvokeMessages into tuple[plain_text, files]
         """
@@ -221,9 +223,16 @@ class ToolNode(BaseNode):
         return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: ToolNodeData
+    ) -> Mapping[str, Sequence[str]]:
         """
         Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
         :param node_data: node data
         :return:
         """
@@ -239,4 +248,8 @@ class ToolNode(BaseNode):
             elif input.type == 'constant':
                 pass
 
+        result = {
+            node_id + '.' + key: value for key, value in result.items()
+        }
+
         return result

+ 18 - 7
api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py

@@ -1,8 +1,7 @@
-from typing import cast
+from collections.abc import Mapping, Sequence
+from typing import Any, cast
 
-from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
 from models.workflow import WorkflowNodeExecutionStatus
@@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
     _node_data_cls = VariableAssignerNodeData
     _node_type = NodeType.VARIABLE_AGGREGATOR
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         node_data = cast(VariableAssignerNodeData, self.node_data)
         # Get variables
         outputs = {}
@@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
 
         if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
             for selector in node_data.variables:
-                variable = variable_pool.get_any(selector)
+                variable = self.graph_runtime_state.variable_pool.get_any(selector)
                 if variable is not None:
                     outputs = {
                         "output": variable
@@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode):
         else:
             for group in node_data.advanced_settings.groups:
                 for selector in group.variables:
-                    variable = variable_pool.get_any(selector)
+                    variable = self.graph_runtime_state.variable_pool.get_any(selector)
 
                     if variable is not None:
                         outputs[group.group_name] = {
@@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode):
         )
 
     @classmethod
-    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
+    def _extract_variable_selector_to_variable_mapping(
+        cls, 
+        graph_config: Mapping[str, Any], 
+        node_id: str,
+        node_data: VariableAssignerNodeData
+    ) -> Mapping[str, Sequence[str]]:
+        """
+        Extract variable selector to variable mapping
+        :param graph_config: graph config
+        :param node_id: node id
+        :param node_data: node data
+        :return:
+        """
         return {}

+ 6 - 7
api/core/workflow/nodes/variable_assigner/node.py

@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
 from core.app.segments import SegmentType, Variable, factory
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base_node import BaseNode
 from extensions.ext_database import db
 from models import ConversationVariable, WorkflowNodeExecutionStatus
@@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode):
     _node_data_cls: type[BaseNodeData] = VariableAssignerData
     _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
 
-    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
+    def _run(self) -> NodeRunResult:
         data = cast(VariableAssignerData, self.node_data)
 
         # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
-        original_variable = variable_pool.get(data.assigned_variable_selector)
+        original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
         if not isinstance(original_variable, Variable):
             raise VariableAssignerNodeError('assigned variable not found')
 
         match data.write_mode:
             case WriteMode.OVER_WRITE:
-                income_value = variable_pool.get(data.input_variable_selector)
+                income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
                 if not income_value:
                     raise VariableAssignerNodeError('input value not found')
                 updated_variable = original_variable.model_copy(update={'value': income_value.value})
 
             case WriteMode.APPEND:
-                income_value = variable_pool.get(data.input_variable_selector)
+                income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
                 if not income_value:
                     raise VariableAssignerNodeError('input value not found')
                 updated_value = original_variable.value + [income_value.value]
@@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode):
                 raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
 
         # Over write the variable.
-        variable_pool.add(data.assigned_variable_selector, updated_variable)
+        self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
 
         # TODO: Move database operation to the pipeline.
         # Update conversation variable.
-        conversation_id = variable_pool.get(['sys', 'conversation_id'])
+        conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
         if not conversation_id:
             raise VariableAssignerNodeError('conversation_id not found')
         update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

+ 0 - 0
api/core/workflow/utils/condition/__init__.py


+ 17 - 0
api/core/workflow/utils/condition/entities.py

@@ -0,0 +1,17 @@
+from typing import Literal, Optional
+
+from pydantic import BaseModel
+
+
+class Condition(BaseModel):
+    """
+    Condition entity
+    """
+    variable_selector: list[str]
+    comparison_operator: Literal[
+        # for string or array
+        "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
+        # for number
+        "=", "≠", ">", "<", "≥", "≤", "null", "not null"
+    ]
+    value: Optional[str] = None

+ 383 - 0
api/core/workflow/utils/condition/processor.py

@@ -0,0 +1,383 @@
+from collections.abc import Sequence
+from typing import Any, Optional
+
+from core.file.file_obj import FileVar
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.utils.condition.entities import Condition
+from core.workflow.utils.variable_template_parser import VariableTemplateParser
+
+
+class ConditionProcessor:
+    def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
+        input_conditions = []
+        group_result = []
+
+        index = 0
+        for condition in conditions:
+            index += 1
+            actual_value = variable_pool.get_any(
+                condition.variable_selector
+            )
+
+            expected_value = None
+            if condition.value is not None:
+                variable_template_parser = VariableTemplateParser(template=condition.value)
+                variable_selectors = variable_template_parser.extract_variable_selectors()
+                if variable_selectors:
+                    for variable_selector in variable_selectors:
+                        value = variable_pool.get_any(
+                            variable_selector.value_selector
+                        )
+                        expected_value = variable_template_parser.format({variable_selector.variable: value})
+
+                    if expected_value is None:
+                        expected_value = condition.value
+                else:
+                    expected_value = condition.value
+
+            comparison_operator = condition.comparison_operator
+            input_conditions.append(
+                {
+                    "actual_value": actual_value,
+                    "expected_value": expected_value,
+                    "comparison_operator": comparison_operator
+                }
+            )
+
+            result = self.evaluate_condition(actual_value, comparison_operator, expected_value)
+            group_result.append(result)
+
+        return input_conditions, group_result
+
+    def evaluate_condition(
+            self,
+            actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None],
+            comparison_operator: str,
+            expected_value: Optional[str] = None
+    ) -> bool:
+        """
+        Evaluate condition
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :param comparison_operator: comparison operator
+
+        :return: bool
+        """
+        if comparison_operator == "contains":
+            return self._assert_contains(actual_value, expected_value)
+        elif comparison_operator == "not contains":
+            return self._assert_not_contains(actual_value, expected_value)
+        elif comparison_operator == "start with":
+            return self._assert_start_with(actual_value, expected_value)
+        elif comparison_operator == "end with":
+            return self._assert_end_with(actual_value, expected_value)
+        elif comparison_operator == "is":
+            return self._assert_is(actual_value, expected_value)
+        elif comparison_operator == "is not":
+            return self._assert_is_not(actual_value, expected_value)
+        elif comparison_operator == "empty":
+            return self._assert_empty(actual_value)
+        elif comparison_operator == "not empty":
+            return self._assert_not_empty(actual_value)
+        elif comparison_operator == "=":
+            return self._assert_equal(actual_value, expected_value)
+        elif comparison_operator == "≠":
+            return self._assert_not_equal(actual_value, expected_value)
+        elif comparison_operator == ">":
+            return self._assert_greater_than(actual_value, expected_value)
+        elif comparison_operator == "<":
+            return self._assert_less_than(actual_value, expected_value)
+        elif comparison_operator == "≥":
+            return self._assert_greater_than_or_equal(actual_value, expected_value)
+        elif comparison_operator == "≤":
+            return self._assert_less_than_or_equal(actual_value, expected_value)
+        elif comparison_operator == "null":
+            return self._assert_null(actual_value)
+        elif comparison_operator == "not null":
+            return self._assert_not_null(actual_value)
+        else:
+            raise ValueError(f"Invalid comparison operator: {comparison_operator}")
+
+    def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
+        """
+        Assert contains
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if not actual_value:
+            return False
+
+        if not isinstance(actual_value, str | list):
+            raise ValueError('Invalid actual value type: string or array')
+
+        if expected_value not in actual_value:
+            return False
+        return True
+
+    def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
+        """
+        Assert not contains
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if not actual_value:
+            return True
+
+        if not isinstance(actual_value, str | list):
+            raise ValueError('Invalid actual value type: string or array')
+
+        if expected_value in actual_value:
+            return False
+        return True
+
+    def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
+        """
+        Assert start with
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if not actual_value:
+            return False
+
+        if not isinstance(actual_value, str):
+            raise ValueError('Invalid actual value type: string')
+
+        if not actual_value.startswith(expected_value):
+            return False
+        return True
+
+    def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
+        """
+        Assert end with
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if not actual_value:
+            return False
+
+        if not isinstance(actual_value, str):
+            raise ValueError('Invalid actual value type: string')
+
+        if not actual_value.endswith(expected_value):
+            return False
+        return True
+
+    def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
+        """
+        Assert is
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, str):
+            raise ValueError('Invalid actual value type: string')
+
+        if actual_value != expected_value:
+            return False
+        return True
+
+    def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
+        """
+        Assert is not
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, str):
+            raise ValueError('Invalid actual value type: string')
+
+        if actual_value == expected_value:
+            return False
+        return True
+
+    def _assert_empty(self, actual_value: Optional[str]) -> bool:
+        """
+        Assert empty
+        :param actual_value: actual value
+        :return:
+        """
+        if not actual_value:
+            return True
+        return False
+
+    def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
+        """
+        Assert not empty
+        :param actual_value: actual value
+        :return:
+        """
+        if actual_value:
+            return True
+        return False
+
+    def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
+        """
+        Assert equal
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value != expected_value:
+            return False
+        return True
+
+    def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
+        """
+        Assert not equal
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value == expected_value:
+            return False
+        return True
+
+    def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
+        """
+        Assert greater than
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value <= expected_value:
+            return False
+        return True
+
+    def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
+        """
+        Assert less than
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value >= expected_value:
+            return False
+        return True
+
+    def _assert_greater_than_or_equal(self, actual_value: Optional[int | float],
+                                      expected_value: str | int | float) -> bool:
+        """
+        Assert greater than or equal
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value < expected_value:
+            return False
+        return True
+
+    def _assert_less_than_or_equal(self, actual_value: Optional[int | float],
+                                   expected_value: str | int | float) -> bool:
+        """
+        Assert less than or equal
+        :param actual_value: actual value
+        :param expected_value: expected value
+        :return:
+        """
+        if actual_value is None:
+            return False
+
+        if not isinstance(actual_value, int | float):
+            raise ValueError('Invalid actual value type: number')
+
+        if isinstance(actual_value, int):
+            expected_value = int(expected_value)
+        else:
+            expected_value = float(expected_value)
+
+        if actual_value > expected_value:
+            return False
+        return True
+
+    def _assert_null(self, actual_value: Optional[int | float]) -> bool:
+        """
+        Assert null
+        :param actual_value: actual value
+        :return:
+        """
+        if actual_value is None:
+            return True
+        return False
+
+    def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
+        """
+        Assert not null
+        :param actual_value: actual value
+        :return:
+        """
+        if actual_value is not None:
+            return True
+        return False
+
+
+class ConditionAssertionError(Exception):
+    def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
+        self.message = message
+        self.conditions = conditions
+        self.sub_condition_compare_results = sub_condition_compare_results
+        super().__init__(self.message)

+ 0 - 1005
api/core/workflow/workflow_engine_manager.py

@@ -1,1005 +0,0 @@
-import logging
-import time
-from collections.abc import Mapping, Sequence
-from typing import Any, Optional, cast
-
-import contexts
-from configs import dify_config
-from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
-from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
-from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
-from core.workflow.entities.variable_pool import VariablePool, VariableValue
-from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
-from core.workflow.errors import WorkflowNodeRunFailedError
-from core.workflow.nodes.answer.answer_node import AnswerNode
-from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
-from core.workflow.nodes.code.code_node import CodeNode
-from core.workflow.nodes.end.end_node import EndNode
-from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
-from core.workflow.nodes.if_else.if_else_node import IfElseNode
-from core.workflow.nodes.iteration.entities import IterationState
-from core.workflow.nodes.iteration.iteration_node import IterationNode
-from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
-from core.workflow.nodes.llm.entities import LLMNodeData
-from core.workflow.nodes.llm.llm_node import LLMNode
-from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
-from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
-from core.workflow.nodes.start.start_node import StartNode
-from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
-from core.workflow.nodes.tool.tool_node import ToolNode
-from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
-from core.workflow.nodes.variable_assigner import VariableAssignerNode
-from extensions.ext_database import db
-from models.workflow import (
-    Workflow,
-    WorkflowNodeExecutionStatus,
-)
-
-node_classes: Mapping[NodeType, type[BaseNode]] = {
-    NodeType.START: StartNode,
-    NodeType.END: EndNode,
-    NodeType.ANSWER: AnswerNode,
-    NodeType.LLM: LLMNode,
-    NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
-    NodeType.IF_ELSE: IfElseNode,
-    NodeType.CODE: CodeNode,
-    NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
-    NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
-    NodeType.HTTP_REQUEST: HttpRequestNode,
-    NodeType.TOOL: ToolNode,
-    NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
-    NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
-    NodeType.ITERATION: IterationNode,
-    NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
-    NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
-}
-
-logger = logging.getLogger(__name__)
-
-
-class WorkflowEngineManager:
-    def get_default_configs(self) -> list[dict]:
-        """
-        Get default block configs
-        """
-        default_block_configs = []
-        for node_type, node_class in node_classes.items():
-            default_config = node_class.get_default_config()
-            if default_config:
-                default_block_configs.append(default_config)
-
-        return default_block_configs
-
-    def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]:
-        """
-        Get default config of node.
-        :param node_type: node type
-        :param filters: filter by node config parameters.
-        :return:
-        """
-        node_class = node_classes.get(node_type)
-        if not node_class:
-            return None
-
-        default_config = node_class.get_default_config(filters=filters)
-        if not default_config:
-            return None
-
-        return default_config
-
-    def run_workflow(
-        self,
-        *,
-        workflow: Workflow,
-        user_id: str,
-        user_from: UserFrom,
-        invoke_from: InvokeFrom,
-        callbacks: Sequence[WorkflowCallback],
-        call_depth: int = 0,
-        variable_pool: VariablePool | None = None,
-    ) -> None:
-        """
-        :param workflow: Workflow instance
-        :param user_id: user id
-        :param user_from: user from
-        :param invoke_from: invoke from
-        :param callbacks: workflow callbacks
-        :param call_depth: call depth
-        :param variable_pool: variable pool
-        """
-        # fetch workflow graph
-        graph = workflow.graph_dict
-        if not graph:
-            raise ValueError('workflow graph not found')
-
-        if 'nodes' not in graph or 'edges' not in graph:
-            raise ValueError('nodes or edges not found in workflow graph')
-
-        if not isinstance(graph.get('nodes'), list):
-            raise ValueError('nodes in workflow graph must be a list')
-
-        if not isinstance(graph.get('edges'), list):
-            raise ValueError('edges in workflow graph must be a list')
-
-
-        workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
-        if call_depth > workflow_call_max_depth:
-            raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
-
-        # init workflow run state
-        if not variable_pool:
-            variable_pool = contexts.workflow_variable_pool.get()
-        workflow_run_state = WorkflowRunState(
-            workflow=workflow,
-            start_at=time.perf_counter(),
-            variable_pool=variable_pool,
-            user_id=user_id,
-            user_from=user_from,
-            invoke_from=invoke_from,
-            workflow_call_depth=call_depth
-        )
-
-        # init workflow run
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_run_started()
-
-        # run workflow
-        self._run_workflow(
-            workflow=workflow,
-            workflow_run_state=workflow_run_state,
-            callbacks=callbacks,
-        )
-
-    def _run_workflow(self, workflow: Workflow,
-                     workflow_run_state: WorkflowRunState,
-                     callbacks: Sequence[WorkflowCallback],
-                     start_at: Optional[str] = None,
-                     end_at: Optional[str] = None) -> None:
-        """
-        Run workflow
-        :param workflow: Workflow instance
-        :param user_id: user id
-        :param user_from: user from
-        :param user_inputs: user variables inputs
-        :param system_inputs: system inputs, like: query, files
-        :param callbacks: workflow callbacks
-        :param call_depth: call depth
-        :param start_at: force specific start node
-        :param end_at: force specific end node
-        :return:
-        """
-        graph = workflow.graph_dict
-
-        try:
-            answer_prov_node_ids = []
-            for node in graph.get('nodes', []):
-                if node.get('id', '') == 'answer':
-                    try:
-                        answer_prov_node_ids.append(node.get('data', {})
-                                                    .get('answer', '')
-                                                    .replace('#', '')
-                                                    .replace('.text', '')
-                                                    .replace('{{', '')
-                                                    .replace('}}', '').split('.')[0])
-                    except Exception as e:
-                        logger.error(e)
-
-            predecessor_node: BaseNode | None = None
-            current_iteration_node: BaseIterationNode | None = None
-            has_entry_node = False
-            max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS
-            max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME
-            while True:
-                # get next node, multiple target nodes in the future
-                next_node = self._get_next_overall_node(
-                    workflow_run_state=workflow_run_state,
-                    graph=graph,
-                    predecessor_node=predecessor_node,
-                    callbacks=callbacks,
-                    start_at=start_at,
-                    end_at=end_at
-                )
-
-                if not next_node:
-                    # reached loop/iteration end or overall end
-                    if current_iteration_node and workflow_run_state.current_iteration_state:
-                        # reached loop/iteration end
-                        # get next iteration
-                        next_iteration = current_iteration_node.get_next_iteration(
-                            variable_pool=workflow_run_state.variable_pool,
-                            state=workflow_run_state.current_iteration_state
-                        )
-                        self._workflow_iteration_next(
-                            graph=graph,
-                            current_iteration_node=current_iteration_node,
-                            workflow_run_state=workflow_run_state,
-                            callbacks=callbacks
-                        )
-                        if isinstance(next_iteration, NodeRunResult):
-                            if next_iteration.outputs:
-                                for variable_key, variable_value in next_iteration.outputs.items():
-                                    # append variables to variable pool recursively
-                                    self._append_variables_recursively(
-                                        variable_pool=workflow_run_state.variable_pool,
-                                        node_id=current_iteration_node.node_id,
-                                        variable_key_list=[variable_key],
-                                        variable_value=variable_value
-                                    )
-                            self._workflow_iteration_completed(
-                                current_iteration_node=current_iteration_node,
-                                workflow_run_state=workflow_run_state,
-                                callbacks=callbacks
-                            )
-                            # iteration has ended
-                            next_node = self._get_next_overall_node(
-                                workflow_run_state=workflow_run_state,
-                                graph=graph,
-                                predecessor_node=current_iteration_node,
-                                callbacks=callbacks,
-                                start_at=start_at,
-                                end_at=end_at
-                            )
-                            current_iteration_node = None
-                            workflow_run_state.current_iteration_state = None
-                            # continue overall process
-                        elif isinstance(next_iteration, str):
-                            # move to next iteration
-                            next_node_id = next_iteration
-                            # get next id
-                            next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
-
-                if not next_node:
-                    break
-
-                # check is already ran
-                if self._check_node_has_ran(workflow_run_state, next_node.node_id):
-                    predecessor_node = next_node
-                    continue
-
-                has_entry_node = True
-
-                # max steps reached
-                if workflow_run_state.workflow_node_steps > max_execution_steps:
-                    raise ValueError('Max steps {} reached.'.format(max_execution_steps))
-
-                # or max execution time reached
-                if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
-                    raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
-
-                # handle iteration nodes
-                if isinstance(next_node, BaseIterationNode):
-                    current_iteration_node = next_node
-                    workflow_run_state.current_iteration_state = next_node.run(
-                        variable_pool=workflow_run_state.variable_pool
-                    )
-                    self._workflow_iteration_started(
-                        graph=graph,
-                        current_iteration_node=current_iteration_node,
-                        workflow_run_state=workflow_run_state,
-                        predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
-                        callbacks=callbacks
-                    )
-                    predecessor_node = next_node
-                    # move to start node of iteration
-                    next_node_id = next_node.get_next_iteration(
-                        variable_pool=workflow_run_state.variable_pool,
-                        state=workflow_run_state.current_iteration_state
-                    )
-                    self._workflow_iteration_next(
-                        graph=graph,
-                        current_iteration_node=current_iteration_node,
-                        workflow_run_state=workflow_run_state,
-                        callbacks=callbacks
-                    )
-                    if isinstance(next_node_id, NodeRunResult):
-                        # iteration has ended
-                        current_iteration_node.set_output(
-                            variable_pool=workflow_run_state.variable_pool,
-                            state=workflow_run_state.current_iteration_state
-                        )
-                        self._workflow_iteration_completed(
-                            current_iteration_node=current_iteration_node,
-                            workflow_run_state=workflow_run_state,
-                            callbacks=callbacks
-                        )
-                        current_iteration_node = None
-                        workflow_run_state.current_iteration_state = None
-                        continue
-                    else:
-                        next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)
-
-                if next_node and next_node.node_id in answer_prov_node_ids:
-                    next_node.is_answer_previous_node = True
-
-                # run workflow, run multiple target nodes in the future
-                self._run_workflow_node(
-                    workflow_run_state=workflow_run_state,
-                    node=next_node,
-                    predecessor_node=predecessor_node,
-                    callbacks=callbacks
-                )
-
-                if next_node.node_type in [NodeType.END]:
-                    break
-
-                predecessor_node = next_node
-
-            if not has_entry_node:
-                self._workflow_run_failed(
-                    error='Start node not found in workflow graph.',
-                    callbacks=callbacks
-                )
-                return
-        except GenerateTaskStoppedException as e:
-            return
-        except Exception as e:
-            self._workflow_run_failed(
-                error=str(e),
-                callbacks=callbacks
-            )
-            return
-
-        # workflow run success
-        self._workflow_run_success(
-            callbacks=callbacks
-        )
-
-    def single_step_run_workflow_node(self, workflow: Workflow,
-                                      node_id: str,
-                                      user_id: str,
-                                      user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
-        """
-        Single step run workflow node
-        :param workflow: Workflow instance
-        :param node_id: node id
-        :param user_id: user id
-        :param user_inputs: user inputs
-        :return:
-        """
-        # fetch node info from workflow graph
-        graph = workflow.graph_dict
-        if not graph:
-            raise ValueError('workflow graph not found')
-
-        nodes = graph.get('nodes')
-        if not nodes:
-            raise ValueError('nodes not found in workflow graph')
-
-        # fetch node config from node id
-        node_config = None
-        for node in nodes:
-            if node.get('id') == node_id:
-                node_config = node
-                break
-
-        if not node_config:
-            raise ValueError('node id not found in workflow graph')
-
-        # Get node class
-        node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
-        node_cls = node_classes.get(node_type)
-
-        # init workflow run state
-        node_instance = node_cls(
-            tenant_id=workflow.tenant_id,
-            app_id=workflow.app_id,
-            workflow_id=workflow.id,
-            user_id=user_id,
-            user_from=UserFrom.ACCOUNT,
-            invoke_from=InvokeFrom.DEBUGGER,
-            config=node_config,
-            workflow_call_depth=0
-        )
-
-        try:
-            # init variable pool
-            variable_pool = VariablePool(
-                system_variables={},
-                user_inputs={},
-                environment_variables=workflow.environment_variables,
-                conversation_variables=workflow.conversation_variables,
-            )
-
-            if node_cls is None:
-                raise ValueError('Node class not found')
-            # variable selector to variable mapping
-            variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
-
-            self._mapping_user_inputs_to_variable_pool(
-                variable_mapping=variable_mapping,
-                user_inputs=user_inputs,
-                variable_pool=variable_pool,
-                tenant_id=workflow.tenant_id,
-                node_instance=node_instance
-            )
-
-            # run node
-            node_run_result = node_instance.run(
-                variable_pool=variable_pool
-            )
-
-            # sign output files
-            node_run_result.outputs = self.handle_special_values(node_run_result.outputs)
-        except Exception as e:
-            raise WorkflowNodeRunFailedError(
-                node_id=node_instance.node_id,
-                node_type=node_instance.node_type,
-                node_title=node_instance.node_data.title,
-                error=str(e)
-            )
-
-        return node_instance, node_run_result
-
-    def single_step_run_iteration_workflow_node(self, workflow: Workflow,
-                                            node_id: str,
-                                            user_id: str,
-                                            user_inputs: dict,
-                                            callbacks: Sequence[WorkflowCallback],
-    ) -> None:
-        """
-        Single iteration run workflow node
-        """
-        # fetch node info from workflow graph
-        graph = workflow.graph_dict
-        if not graph:
-            raise ValueError('workflow graph not found')
-
-        nodes = graph.get('nodes')
-        if not nodes:
-            raise ValueError('nodes not found in workflow graph')
-
-        for node in nodes:
-            if node.get('id') == node_id:
-                if node.get('data', {}).get('type') in [
-                    NodeType.ITERATION.value,
-                    NodeType.LOOP.value,
-                ]:
-                    node_config = node
-                else:
-                    raise ValueError('node id is not an iteration node')
-
-        # init variable pool
-        variable_pool = VariablePool(
-            system_variables={},
-            user_inputs={},
-            environment_variables=workflow.environment_variables,
-            conversation_variables=workflow.conversation_variables,
-        )
-
-        # variable selector to variable mapping
-        iteration_nested_nodes = [
-            node for node in nodes
-            if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
-        ]
-        iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
-
-        if not iteration_nested_nodes:
-            raise ValueError('iteration has no nested nodes')
-
-        # init workflow run
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_run_started()
-
-        for node_config in iteration_nested_nodes:
-            # mapping user inputs to variable pool
-            node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
-            if node_cls is None:
-                raise ValueError('Node class not found')
-            variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
-
-            # remove iteration variables
-            variable_mapping = {
-                f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
-                if value[0] != node_id
-            }
-
-            # remove variable out from iteration
-            variable_mapping = {
-                key: value for key, value in variable_mapping.items()
-                if value[0] not in iteration_nested_node_ids
-            }
-
-            # append variables to variable pool
-            node_instance = node_cls(
-                tenant_id=workflow.tenant_id,
-                app_id=workflow.app_id,
-                workflow_id=workflow.id,
-                user_id=user_id,
-                user_from=UserFrom.ACCOUNT,
-                invoke_from=InvokeFrom.DEBUGGER,
-                config=node_config,
-                callbacks=callbacks,
-                workflow_call_depth=0
-            )
-
-            self._mapping_user_inputs_to_variable_pool(
-                variable_mapping=variable_mapping,
-                user_inputs=user_inputs,
-                variable_pool=variable_pool,
-                tenant_id=workflow.tenant_id,
-                node_instance=node_instance
-            )
-
-        # fetch end node of iteration
-        end_node_id = None
-        for edge in graph.get('edges'):
-            if edge.get('source') == node_id:
-                end_node_id = edge.get('target')
-                break
-
-        if not end_node_id:
-            raise ValueError('end node of iteration not found')
-
-        # init workflow run state
-        workflow_run_state = WorkflowRunState(
-            workflow=workflow,
-            start_at=time.perf_counter(),
-            variable_pool=variable_pool,
-            user_id=user_id,
-            user_from=UserFrom.ACCOUNT,
-            invoke_from=InvokeFrom.DEBUGGER,
-            workflow_call_depth=0
-        )
-
-        # run workflow
-        self._run_workflow(
-            workflow=workflow,
-            workflow_run_state=workflow_run_state,
-            callbacks=callbacks,
-            start_at=node_id,
-            end_at=end_node_id
-        )
-
-    def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None:
-        """
-        Workflow run success
-        :param callbacks: workflow callbacks
-        :return:
-        """
-
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_run_succeeded()
-
-    def _workflow_run_failed(self, error: str,
-                             callbacks: Sequence[WorkflowCallback]) -> None:
-        """
-        Workflow run failed
-        :param error: error message
-        :param callbacks: workflow callbacks
-        :return:
-        """
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_run_failed(
-                    error=error
-                )
-
-    def _workflow_iteration_started(self, *, graph: Mapping[str, Any],
-                                    current_iteration_node: BaseIterationNode,
-                                    workflow_run_state: WorkflowRunState,
-                                    predecessor_node_id: Optional[str] = None,
-                                    callbacks: Sequence[WorkflowCallback]) -> None:
-        """
-        Workflow iteration started
-        :param current_iteration_node: current iteration node
-        :param workflow_run_state: workflow run state
-        :param callbacks: workflow callbacks
-        :return:
-        """
-        # get nested nodes
-        iteration_nested_nodes = [
-            node for node in graph.get('nodes')
-            if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id
-        ]
-
-        if not iteration_nested_nodes:
-            raise ValueError('iteration has no nested nodes')
-
-        if callbacks:
-            if isinstance(workflow_run_state.current_iteration_state, IterationState):
-                for callback in callbacks:
-                    callback.on_workflow_iteration_started(
-                        node_id=current_iteration_node.node_id,
-                        node_type=NodeType.ITERATION,
-                        node_run_index=workflow_run_state.workflow_node_steps,
-                        node_data=current_iteration_node.node_data,
-                        inputs=workflow_run_state.current_iteration_state.inputs,
-                        predecessor_node_id=predecessor_node_id,
-                        metadata=workflow_run_state.current_iteration_state.metadata.model_dump()
-                    )
-
-        # add steps
-        workflow_run_state.workflow_node_steps += 1
-
-    def _workflow_iteration_next(self, *, graph: Mapping[str, Any],
-                                 current_iteration_node: BaseIterationNode,
-                                 workflow_run_state: WorkflowRunState,
-                                 callbacks: Sequence[WorkflowCallback]) -> None:
-        """
-        Workflow iteration next
-        :param workflow_run_state: workflow run state
-        :return:
-        """
-        if callbacks:
-            if isinstance(workflow_run_state.current_iteration_state, IterationState):
-                for callback in callbacks:
-                    callback.on_workflow_iteration_next(
-                        node_id=current_iteration_node.node_id,
-                        node_type=NodeType.ITERATION,
-                        index=workflow_run_state.current_iteration_state.index,
-                        node_run_index=workflow_run_state.workflow_node_steps,
-                        output=workflow_run_state.current_iteration_state.get_current_output()
-                    )
-        # clear ran nodes
-        workflow_run_state.workflow_node_runs = [
-            node_run for node_run in workflow_run_state.workflow_node_runs
-            if node_run.iteration_node_id != current_iteration_node.node_id
-        ]
-
-        # clear variables in current iteration
-        nodes = graph.get('nodes')
-        nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
-
-        for node in nodes:
-            workflow_run_state.variable_pool.remove((node.get('id'),))
-
-    def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode,
-                                        workflow_run_state: WorkflowRunState,
-                                        callbacks: Sequence[WorkflowCallback]) -> None:
-        if callbacks:
-            if isinstance(workflow_run_state.current_iteration_state, IterationState):
-                for callback in callbacks:
-                    callback.on_workflow_iteration_completed(
-                        node_id=current_iteration_node.node_id,
-                        node_type=NodeType.ITERATION,
-                        node_run_index=workflow_run_state.workflow_node_steps,
-                        outputs={
-                            'output': workflow_run_state.current_iteration_state.outputs
-                        }
-                    )
-
-    def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState,
-                       graph: Mapping[str, Any],
-                       predecessor_node: Optional[BaseNode] = None,
-                       callbacks: Sequence[WorkflowCallback],
-                       start_at: Optional[str] = None,
-                       end_at: Optional[str] = None) -> Optional[BaseNode]:
-        """
-        Get next node
-        multiple target nodes in the future.
-        :param graph: workflow graph
-        :param predecessor_node: predecessor node
-        :param callbacks: workflow callbacks
-        :return:
-        """
-        nodes = graph.get('nodes')
-        if not nodes:
-            return None
-
-        if not predecessor_node:
-            for node_config in nodes:
-                node_cls = None
-                if start_at:
-                    if node_config.get('id') == start_at:
-                        node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
-                else:
-                    if node_config.get('data', {}).get('type', '') == NodeType.START.value:
-                        node_cls = StartNode
-                if node_cls:
-                    return node_cls(
-                        tenant_id=workflow_run_state.tenant_id,
-                        app_id=workflow_run_state.app_id,
-                        workflow_id=workflow_run_state.workflow_id,
-                        user_id=workflow_run_state.user_id,
-                        user_from=workflow_run_state.user_from,
-                        invoke_from=workflow_run_state.invoke_from,
-                        config=node_config,
-                        callbacks=callbacks,
-                        workflow_call_depth=workflow_run_state.workflow_call_depth
-                    )
-
-        else:
-            edges = graph.get('edges')
-            source_node_id = predecessor_node.node_id
-
-            # fetch all outgoing edges from source node
-            outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
-            if not outgoing_edges:
-                return None
-
-            # fetch target node id from outgoing edges
-            outgoing_edge = None
-            source_handle = predecessor_node.node_run_result.edge_source_handle \
-                if predecessor_node.node_run_result else None
-            if source_handle:
-                for edge in outgoing_edges:
-                    if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle:
-                        outgoing_edge = edge
-                        break
-            else:
-                outgoing_edge = outgoing_edges[0]
-
-            if not outgoing_edge:
-                return None
-
-            target_node_id = outgoing_edge.get('target')
-
-            if end_at and target_node_id == end_at:
-                return None
-
-            # fetch target node from target node id
-            target_node_config = None
-            for node in nodes:
-                if node.get('id') == target_node_id:
-                    target_node_config = node
-                    break
-
-            if not target_node_config:
-                return None
-
-            # get next node
-            target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
-
-            return target_node(
-                tenant_id=workflow_run_state.tenant_id,
-                app_id=workflow_run_state.app_id,
-                workflow_id=workflow_run_state.workflow_id,
-                user_id=workflow_run_state.user_id,
-                user_from=workflow_run_state.user_from,
-                invoke_from=workflow_run_state.invoke_from,
-                config=target_node_config,
-                callbacks=callbacks,
-                workflow_call_depth=workflow_run_state.workflow_call_depth
-            )
-
-    def _get_node(self, workflow_run_state: WorkflowRunState,
-                  graph: Mapping[str, Any],
-                  node_id: str,
-                  callbacks: Sequence[WorkflowCallback]):
-        """
-        Get node from graph by node id
-        """
-        nodes = graph.get('nodes')
-        if not nodes:
-            return None
-
-        for node_config in nodes:
-            if node_config.get('id') == node_id:
-                node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
-                node_cls = node_classes[node_type]
-                return node_cls(
-                    tenant_id=workflow_run_state.tenant_id,
-                    app_id=workflow_run_state.app_id,
-                    workflow_id=workflow_run_state.workflow_id,
-                    user_id=workflow_run_state.user_id,
-                    user_from=workflow_run_state.user_from,
-                    invoke_from=workflow_run_state.invoke_from,
-                    config=node_config,
-                    callbacks=callbacks,
-                    workflow_call_depth=workflow_run_state.workflow_call_depth
-                )
-
-    def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
-        """
-        Check timeout
-        :param start_at: start time
-        :param max_execution_time: max execution time
-        :return:
-        """
-        return time.perf_counter() - start_at > max_execution_time
-
-    def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool:
-        """
-        Check node has ran
-        """
-        return bool([
-            node_and_result for node_and_result in workflow_run_state.workflow_node_runs
-            if node_and_result.node_id == node_id
-        ])
-
-    def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState,
-                           node: BaseNode,
-                           predecessor_node: Optional[BaseNode] = None,
-                           callbacks: Sequence[WorkflowCallback]) -> None:
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_node_execute_started(
-                    node_id=node.node_id,
-                    node_type=node.node_type,
-                    node_data=node.node_data,
-                    node_run_index=workflow_run_state.workflow_node_steps,
-                    predecessor_node_id=predecessor_node.node_id if predecessor_node else None
-                )
-
-        db.session.close()
-
-        workflow_nodes_and_result = WorkflowNodeAndResult(
-            node=node,
-            result=None
-        )
-
-        # add to workflow_nodes_and_results
-        workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
-
-        # add steps
-        workflow_run_state.workflow_node_steps += 1
-
-        # mark node as running
-        if workflow_run_state.current_iteration_state:
-            workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun(
-                node_id=node.node_id,
-                iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id
-            ))
-
-        try:
-            # run node, result must have inputs, process_data, outputs, execution_metadata
-            node_run_result = node.run(
-                variable_pool=workflow_run_state.variable_pool
-            )
-        except GenerateTaskStoppedException as e:
-            node_run_result = NodeRunResult(
-                status=WorkflowNodeExecutionStatus.FAILED,
-                error='Workflow stopped.'
-            )
-        except Exception as e:
-            logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
-            node_run_result = NodeRunResult(
-                status=WorkflowNodeExecutionStatus.FAILED,
-                error=str(e)
-            )
-
-        if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
-            # node run failed
-            if callbacks:
-                for callback in callbacks:
-                    callback.on_workflow_node_execute_failed(
-                        node_id=node.node_id,
-                        node_type=node.node_type,
-                        node_data=node.node_data,
-                        error=node_run_result.error,
-                        inputs=node_run_result.inputs,
-                        outputs=node_run_result.outputs,
-                        process_data=node_run_result.process_data,
-                    )
-
-            raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
-
-        if node.is_answer_previous_node and not isinstance(node, LLMNode):
-            if not node_run_result.metadata:
-                node_run_result.metadata = {}
-            node_run_result.metadata["is_answer_previous_node"]=True
-        workflow_nodes_and_result.result = node_run_result
-
-        # node run success
-        if callbacks:
-            for callback in callbacks:
-                callback.on_workflow_node_execute_succeeded(
-                    node_id=node.node_id,
-                    node_type=node.node_type,
-                    node_data=node.node_data,
-                    inputs=node_run_result.inputs,
-                    process_data=node_run_result.process_data,
-                    outputs=node_run_result.outputs,
-                    execution_metadata=node_run_result.metadata
-                )
-
-        if node_run_result.outputs:
-            for variable_key, variable_value in node_run_result.outputs.items():
-                # append variables to variable pool recursively
-                self._append_variables_recursively(
-                    variable_pool=workflow_run_state.variable_pool,
-                    node_id=node.node_id,
-                    variable_key_list=[variable_key],
-                    variable_value=variable_value
-                )
-
-        if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
-            workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
-
-        db.session.close()
-
-    def _append_variables_recursively(self, variable_pool: VariablePool,
-                                      node_id: str,
-                                      variable_key_list: list[str],
-                                      variable_value: VariableValue):
-        """
-        Append variables recursively
-        :param variable_pool: variable pool
-        :param node_id: node id
-        :param variable_key_list: variable key list
-        :param variable_value: variable value
-        :return:
-        """
-        variable_pool.add(
-            [node_id] + variable_key_list, variable_value
-        )
-
-        # if variable_value is a dict, then recursively append variables
-        if isinstance(variable_value, dict):
-            for key, value in variable_value.items():
-                # construct new key list
-                new_key_list = variable_key_list + [key]
-                self._append_variables_recursively(
-                    variable_pool=variable_pool,
-                    node_id=node_id,
-                    variable_key_list=new_key_list,
-                    variable_value=value
-                )
-
-    @classmethod
-    def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]:
-        """
-        Handle special values
-        :param value: value
-        :return:
-        """
-        if not value:
-            return None
-
-        new_value = value.copy()
-        if isinstance(new_value, dict):
-            for key, val in new_value.items():
-                if isinstance(val, FileVar):
-                    new_value[key] = val.to_dict()
-                elif isinstance(val, list):
-                    new_val = []
-                    for v in val:
-                        if isinstance(v, FileVar):
-                            new_val.append(v.to_dict())
-                        else:
-                            new_val.append(v)
-
-                    new_value[key] = new_val
-
-        return new_value
-
-    def _mapping_user_inputs_to_variable_pool(self,
-                                              variable_mapping: Mapping[str, Sequence[str]],
-                                              user_inputs: dict,
-                                              variable_pool: VariablePool,
-                                              tenant_id: str,
-                                              node_instance: BaseNode):
-        for variable_key, variable_selector in variable_mapping.items():
-            if variable_key not in user_inputs and not variable_pool.get(variable_selector):
-                raise ValueError(f'Variable key {variable_key} not found in user inputs.')
-
-            # fetch variable node id from variable selector
-            variable_node_id = variable_selector[0]
-            variable_key_list = variable_selector[1:]
-
-            # get value
-            value = user_inputs.get(variable_key)
-
-            # FIXME: temp fix for image type
-            if node_instance.node_type == NodeType.LLM:
-                new_value = []
-                if isinstance(value, list):
-                    node_data = node_instance.node_data
-                    node_data = cast(LLMNodeData, node_data)
-
-                    detail = node_data.vision.configs.detail if node_data.vision.configs else None
-
-                    for item in value:
-                        if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
-                            transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
-                            file = FileVar(
-                                tenant_id=tenant_id,
-                                type=FileType.IMAGE,
-                                transfer_method=transfer_method,
-                                url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
-                                related_id=item.get(
-                                    'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
-                                extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
-                            )
-                            new_value.append(file)
-
-                if new_value:
-                    value = new_value
-
-            # append variable and value to variable pool
-            variable_pool.add([variable_node_id]+variable_key_list, value)

+ 314 - 0
api/core/workflow/workflow_entry.py

@@ -0,0 +1,314 @@
+import logging
+import time
+import uuid
+from collections.abc import Generator, Mapping, Sequence
+from typing import Any, Optional, cast
+
+from configs import dify_config
+from core.app.app_config.entities import FileExtraConfig
+from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file.file_obj import FileTransferMethod, FileType, FileVar
+from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
+from core.workflow.entities.base_node_data_entities import BaseNodeData
+from core.workflow.entities.node_entities import NodeType, UserFrom
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.errors import WorkflowNodeRunFailedError
+from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.graph_engine import GraphEngine
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.event import RunEvent
+from core.workflow.nodes.llm.entities import LLMNodeData
+from core.workflow.nodes.node_mapping import node_classes
+from models.workflow import (
+    Workflow,
+    WorkflowType,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowEntry:
+    def __init__(
+            self,
+            tenant_id: str,
+            app_id: str,
+            workflow_id: str,
+            workflow_type: WorkflowType,
+            graph_config: Mapping[str, Any],
+            graph: Graph,
+            user_id: str,
+            user_from: UserFrom,
+            invoke_from: InvokeFrom,
+            call_depth: int,
+            variable_pool: VariablePool,
+            thread_pool_id: Optional[str] = None
+    ) -> None:
+        """
+        Init workflow entry
+        :param tenant_id: tenant id
+        :param app_id: app id
+        :param workflow_id: workflow id
+        :param workflow_type: workflow type
+        :param graph_config: workflow graph config
+        :param graph: workflow graph
+        :param user_id: user id
+        :param user_from: user from
+        :param invoke_from: invoke from
+        :param call_depth: call depth
+        :param variable_pool: variable pool
+        :param thread_pool_id: thread pool id
+        """
+        # check call depth
+        workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
+        if call_depth > workflow_call_max_depth:
+            raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
+
+        # init workflow run state
+        self.graph_engine = GraphEngine(
+            tenant_id=tenant_id,
+            app_id=app_id,
+            workflow_type=workflow_type,
+            workflow_id=workflow_id,
+            user_id=user_id,
+            user_from=user_from,
+            invoke_from=invoke_from,
+            call_depth=call_depth,
+            graph=graph,
+            graph_config=graph_config,
+            variable_pool=variable_pool,
+            max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
+            max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
+            thread_pool_id=thread_pool_id
+        )
+
+    def run(
+            self,
+            *,
+            callbacks: Sequence[WorkflowCallback],
+    ) -> Generator[GraphEngineEvent, None, None]:
+        """
+        :param callbacks: workflow callbacks
+        """
+        graph_engine = self.graph_engine
+
+        try:
+            # run workflow
+            generator = graph_engine.run()
+            for event in generator:
+                if callbacks:
+                    for callback in callbacks:
+                        callback.on_event(
+                            event=event
+                        )
+                yield event
+        except GenerateTaskStoppedException:
+            pass
+        except Exception as e:
+            logger.exception("Unknown Error when workflow entry running")
+            if callbacks:
+                for callback in callbacks:
+                    callback.on_event(
+                        event=GraphRunFailedEvent(
+                            error=str(e)
+                        )
+                    )
+            return
+
+    @classmethod
+    def single_step_run(
+        cls, 
+        workflow: Workflow,
+        node_id: str,
+        user_id: str,
+        user_inputs: dict
+    ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
+        """
+        Single step run workflow node
+        :param workflow: Workflow instance
+        :param node_id: node id
+        :param user_id: user id
+        :param user_inputs: user inputs
+        :return:
+        """
+        # fetch node info from workflow graph
+        graph = workflow.graph_dict
+        if not graph:
+            raise ValueError('workflow graph not found')
+
+        nodes = graph.get('nodes')
+        if not nodes:
+            raise ValueError('nodes not found in workflow graph')
+
+        # fetch node config from node id
+        node_config = None
+        for node in nodes:
+            if node.get('id') == node_id:
+                node_config = node
+                break
+
+        if not node_config:
+            raise ValueError('node id not found in workflow graph')
+
+        # Get node class
+        node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
+        node_cls = node_classes.get(node_type)
+        node_cls = cast(type[BaseNode], node_cls)
+
+        if not node_cls:
+            raise ValueError(f'Node class not found for node type {node_type}')
+        
+        # init variable pool
+        variable_pool = VariablePool(
+            system_variables={},
+            user_inputs={},
+            environment_variables=workflow.environment_variables,
+        )
+
+        # init graph
+        graph = Graph.init(
+            graph_config=workflow.graph_dict
+        )
+
+        # init workflow run state
+        node_instance: BaseNode = node_cls(
+            id=str(uuid.uuid4()),
+            config=node_config,
+            graph_init_params=GraphInitParams(
+                tenant_id=workflow.tenant_id,
+                app_id=workflow.app_id,
+                workflow_type=WorkflowType.value_of(workflow.type),
+                workflow_id=workflow.id,
+                graph_config=workflow.graph_dict,
+                user_id=user_id,
+                user_from=UserFrom.ACCOUNT,
+                invoke_from=InvokeFrom.DEBUGGER,
+                call_depth=0
+            ),
+            graph=graph,
+            graph_runtime_state=GraphRuntimeState(
+                variable_pool=variable_pool,
+                start_at=time.perf_counter()
+            )
+        )
+
+        try:
+            # variable selector to variable mapping
+            try:
+                variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+                    graph_config=workflow.graph_dict, 
+                    config=node_config
+                )
+            except NotImplementedError:
+                variable_mapping = {}
+
+            cls.mapping_user_inputs_to_variable_pool(
+                variable_mapping=variable_mapping,
+                user_inputs=user_inputs,
+                variable_pool=variable_pool,
+                tenant_id=workflow.tenant_id,
+                node_type=node_type,
+                node_data=node_instance.node_data
+            )
+
+            # run node
+            generator = node_instance.run()
+
+            return node_instance, generator
+        except Exception as e:
+            raise WorkflowNodeRunFailedError(
+                node_instance=node_instance,
+                error=str(e)
+            )
+
+    @classmethod
+    def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
+        """
+        Handle special values
+        :param value: value
+        :return:
+        """
+        if not value:
+            return None
+
+        new_value = dict(value) if value else {}
+        if isinstance(new_value, dict):
+            for key, val in new_value.items():
+                if isinstance(val, FileVar):
+                    new_value[key] = val.to_dict()
+                elif isinstance(val, list):
+                    new_val = []
+                    for v in val:
+                        if isinstance(v, FileVar):
+                            new_val.append(v.to_dict())
+                        else:
+                            new_val.append(v)
+
+                    new_value[key] = new_val
+
+        return new_value
+
+    @classmethod
+    def mapping_user_inputs_to_variable_pool(
+        cls,
+        variable_mapping: Mapping[str, Sequence[str]],
+        user_inputs: dict,
+        variable_pool: VariablePool,
+        tenant_id: str,
+        node_type: NodeType,
+        node_data: BaseNodeData
+    ) -> None:
+        for node_variable, variable_selector in variable_mapping.items():
+            # fetch node id and variable key from node_variable
+            node_variable_list = node_variable.split('.')
+            if len(node_variable_list) < 1:
+                raise ValueError(f'Invalid node variable {node_variable}')
+            
+            node_variable_key = '.'.join(node_variable_list[1:])
+
+            if (
+                node_variable_key not in user_inputs
+                and node_variable not in user_inputs
+            ) and not variable_pool.get(variable_selector):
+                raise ValueError(f'Variable key {node_variable} not found in user inputs.')
+
+            # fetch variable node id from variable selector
+            variable_node_id = variable_selector[0]
+            variable_key_list = variable_selector[1:]
+            variable_key_list = cast(list[str], variable_key_list)
+
+            # get input value
+            input_value = user_inputs.get(node_variable)
+            if not input_value:
+                input_value = user_inputs.get(node_variable_key)
+
+            # FIXME: temp fix for image type
+            if node_type == NodeType.LLM:
+                new_value = []
+                if isinstance(input_value, list):
+                    node_data = cast(LLMNodeData, node_data)
+
+                    detail = node_data.vision.configs.detail if node_data.vision.configs else None
+
+                    for item in input_value:
+                        if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
+                            transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
+                            file = FileVar(
+                                tenant_id=tenant_id,
+                                type=FileType.IMAGE,
+                                transfer_method=transfer_method,
+                                url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+                                related_id=item.get(
+                                    'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+                                extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
+                            )
+                            new_value.append(file)
+
+                if new_value:
+                    value = new_value
+
+            # append variable and value to variable pool
+            variable_pool.add([variable_node_id] + variable_key_list, input_value)

+ 35 - 0
api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py

@@ -0,0 +1,35 @@
+"""add node_execution_id into node_executions
+
+Revision ID: 675b5321501b
+Revises: 030f4915f36a
+Create Date: 2024-08-12 10:54:02.259331
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+import models as models
+
+# revision identifiers, used by Alembic.
+revision = '675b5321501b'
+down_revision = '030f4915f36a'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True))
+        batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
+        batch_op.drop_index('workflow_node_execution_id_idx')
+        batch_op.drop_column('node_execution_id')
+
+    # ### end Alembic commands ###

+ 3 - 0
api/models/workflow.py

@@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model):
                  'triggered_from', 'workflow_run_id'),
         db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id',
                  'triggered_from', 'node_id'),
+        db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id',
+                 'triggered_from', 'node_execution_id'),
     )
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
@@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model):
     workflow_run_id = db.Column(StringUUID)
     index = db.Column(db.Integer, nullable=False)
     predecessor_node_id = db.Column(db.String(255))
+    node_execution_id = db.Column(db.String(255), nullable=True)
     node_id = db.Column(db.String(255), nullable=False)
     node_type = db.Column(db.String(255), nullable=False)
     title = db.Column(db.String(255), nullable=False)

+ 2 - 1
api/services/app_dsl_service.py

@@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService
 
 logger = logging.getLogger(__name__)
 
-current_dsl_version = "0.1.1"
+current_dsl_version = "0.1.2"
 dsl_to_dify_version_mapping: dict[str, str] = {
+    "0.1.2": "0.8.0",
     "0.1.1": "0.6.0",  # dsl version -> from dify version
 }
 

+ 3 - 4
api/services/app_generate_service.py

@@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.features.rate_limiting import RateLimit
 from models.model import Account, App, AppMode, EndUser
+from models.workflow import Workflow
 from services.errors.llm import InvokeRateLimitError
 from services.workflow_service import WorkflowService
 
@@ -103,9 +104,7 @@ class AppGenerateService:
         return max_active_requests
 
     @classmethod
-    def generate_single_iteration(
-        cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
-    ):
+    def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
         if app_model.mode == AppMode.ADVANCED_CHAT.value:
             workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
             return AdvancedChatAppGenerator().single_iteration_generate(
@@ -142,7 +141,7 @@ class AppGenerateService:
         )
 
     @classmethod
-    def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any:
+    def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
         """
         Get workflow
         :param app_model: app model

+ 69 - 91
api/services/workflow_service.py

@@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.app.segments import Variable
 from core.model_runtime.utils.encoders import jsonable_encoder
-from core.workflow.entities.node_entities import NodeType
+from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.errors import WorkflowNodeRunFailedError
-from core.workflow.workflow_engine_manager import WorkflowEngineManager
+from core.workflow.nodes.event import RunCompletedEvent
+from core.workflow.nodes.node_mapping import node_classes
+from core.workflow.workflow_entry import WorkflowEntry
 from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
 from extensions.ext_database import db
 from models.account import Account
@@ -172,8 +174,13 @@ class WorkflowService:
         Get default block configs
         """
         # return default block config
-        workflow_engine_manager = WorkflowEngineManager()
-        return workflow_engine_manager.get_default_configs()
+        default_block_configs = []
+        for node_type, node_class in node_classes.items():
+            default_config = node_class.get_default_config()
+            if default_config:
+                default_block_configs.append(default_config)
+
+        return default_block_configs
 
     def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
         """
@@ -182,11 +189,18 @@ class WorkflowService:
         :param filters: filter by node config parameters.
         :return:
         """
-        node_type = NodeType.value_of(node_type)
+        node_type_enum: NodeType = NodeType.value_of(node_type)
 
         # return default block config
-        workflow_engine_manager = WorkflowEngineManager()
-        return workflow_engine_manager.get_default_config(node_type, filters)
+        node_class = node_classes.get(node_type_enum)
+        if not node_class:
+            return None
+
+        default_config = node_class.get_default_config(filters=filters)
+        if not default_config:
+            return None
+
+        return default_config
 
     def run_draft_workflow_node(
         self, app_model: App, node_id: str, user_inputs: dict, account: Account
@@ -200,82 +214,68 @@ class WorkflowService:
             raise ValueError("Workflow not initialized")
 
         # run draft workflow node
-        workflow_engine_manager = WorkflowEngineManager()
         start_at = time.perf_counter()
 
         try:
-            node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
+            node_instance, generator = WorkflowEntry.single_step_run(
                 workflow=draft_workflow,
                 node_id=node_id,
                 user_inputs=user_inputs,
                 user_id=account.id,
             )
-        except WorkflowNodeRunFailedError as e:
-            workflow_node_execution = WorkflowNodeExecution(
-                tenant_id=app_model.tenant_id,
-                app_id=app_model.id,
-                workflow_id=draft_workflow.id,
-                triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
-                index=1,
-                node_id=e.node_id,
-                node_type=e.node_type.value,
-                title=e.node_title,
-                status=WorkflowNodeExecutionStatus.FAILED.value,
-                error=e.error,
-                elapsed_time=time.perf_counter() - start_at,
-                created_by_role=CreatedByRole.ACCOUNT.value,
-                created_by=account.id,
-                created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
-            )
-            db.session.add(workflow_node_execution)
-            db.session.commit()
 
-            return workflow_node_execution
+            node_run_result: NodeRunResult | None = None
+            for event in generator:
+                if isinstance(event, RunCompletedEvent):
+                    node_run_result = event.run_result
+
+                    # sign output files
+                    node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
+                    break
+
+            if not node_run_result:
+                raise ValueError("Node run failed with no run result")
 
-        if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+            run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
+            error = node_run_result.error if not run_succeeded else None
+        except WorkflowNodeRunFailedError as e:
+            node_instance = e.node_instance
+            run_succeeded = False
+            node_run_result = None
+            error = e.error
+
+        workflow_node_execution = WorkflowNodeExecution()
+        workflow_node_execution.tenant_id = app_model.tenant_id
+        workflow_node_execution.app_id = app_model.id
+        workflow_node_execution.workflow_id = draft_workflow.id
+        workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
+        workflow_node_execution.index = 1
+        workflow_node_execution.node_id = node_id
+        workflow_node_execution.node_type = node_instance.node_type.value
+        workflow_node_execution.title = node_instance.node_data.title
+        workflow_node_execution.elapsed_time = time.perf_counter() - start_at
+        workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
+        workflow_node_execution.created_by = account.id
+        workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
+        workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+
+        if run_succeeded and node_run_result:
             # create workflow node execution
-            workflow_node_execution = WorkflowNodeExecution(
-                tenant_id=app_model.tenant_id,
-                app_id=app_model.id,
-                workflow_id=draft_workflow.id,
-                triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
-                index=1,
-                node_id=node_id,
-                node_type=node_instance.node_type.value,
-                title=node_instance.node_data.title,
-                inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
-                process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
-                outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
-                execution_metadata=(
-                    json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
-                ),
-                status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
-                elapsed_time=time.perf_counter() - start_at,
-                created_by_role=CreatedByRole.ACCOUNT.value,
-                created_by=account.id,
-                created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
+            workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
+            workflow_node_execution.process_data = (
+                json.dumps(node_run_result.process_data) if node_run_result.process_data else None
+            )
+            workflow_node_execution.outputs = (
+                json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
+            )
+            workflow_node_execution.execution_metadata = (
+                json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
             )
+            workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
         else:
             # create workflow node execution
-            workflow_node_execution = WorkflowNodeExecution(
-                tenant_id=app_model.tenant_id,
-                app_id=app_model.id,
-                workflow_id=draft_workflow.id,
-                triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
-                index=1,
-                node_id=node_id,
-                node_type=node_instance.node_type.value,
-                title=node_instance.node_data.title,
-                status=node_run_result.status.value,
-                error=node_run_result.error,
-                elapsed_time=time.perf_counter() - start_at,
-                created_by_role=CreatedByRole.ACCOUNT.value,
-                created_by=account.id,
-                created_at=datetime.now(timezone.utc).replace(tzinfo=None),
-                finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
-            )
+            workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
+            workflow_node_execution.error = error
 
         db.session.add(workflow_node_execution)
         db.session.commit()
@@ -321,25 +321,3 @@ class WorkflowService:
             )
         else:
             raise ValueError(f"Invalid app mode: {app_model.mode}")
-
-    @classmethod
-    def get_elapsed_time(cls, workflow_run_id: str) -> float:
-        """
-        Get elapsed time
-        """
-        elapsed_time = 0.0
-
-        # fetch workflow node execution by workflow_run_id
-        workflow_nodes = (
-            db.session.query(WorkflowNodeExecution)
-            .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id)
-            .order_by(WorkflowNodeExecution.created_at.asc())
-            .all()
-        )
-        if not workflow_nodes:
-            return elapsed_time
-
-        for node in workflow_nodes:
-            elapsed_time += node.elapsed_time
-
-        return elapsed_time

+ 175 - 144
api/tests/integration_tests/workflow/nodes/test_code.py

@@ -1,17 +1,72 @@
+import time
+import uuid
 from os import getenv
+from typing import cast
 
 import pytest
 
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import NodeRunResult, UserFrom
 from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from core.workflow.nodes.code.code_node import CodeNode
-from models.workflow import WorkflowNodeExecutionStatus
+from core.workflow.nodes.code.entities import CodeNodeData
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
 from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
 
 CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
 
 
+def init_code_node(code_config: dict):
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-code-target",
+                "source": "start",
+                "target": "code",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
+        tenant_id="1",
+        app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
+        workflow_id="1",
+        graph_config=graph_config,
+        user_id="1",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    variable_pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+    variable_pool.add(["code", "123", "args1"], 1)
+    variable_pool.add(["code", "123", "args2"], 2)
+
+    node = CodeNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=code_config,
+    )
+
+    return node
+
+
 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
 def test_execute_code(setup_code_executor_mock):
     code = """
@@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock):
     """
     # trim first 4 spaces at the beginning of each line
     code = "\n".join([line[4:] for line in code.split("\n")])
-    node = CodeNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        user_from=UserFrom.ACCOUNT,
-        invoke_from=InvokeFrom.WEB_APP,
-        config={
-            "id": "1",
-            "data": {
-                "outputs": {
-                    "result": {
-                        "type": "number",
-                    },
+
+    code_config = {
+        "id": "code",
+        "data": {
+            "outputs": {
+                "result": {
+                    "type": "number",
                 },
-                "title": "123",
-                "variables": [
-                    {
-                        "variable": "args1",
-                        "value_selector": ["1", "123", "args1"],
-                    },
-                    {"variable": "args2", "value_selector": ["1", "123", "args2"]},
-                ],
-                "answer": "123",
-                "code_language": "python3",
-                "code": code,
             },
+            "title": "123",
+            "variables": [
+                {
+                    "variable": "args1",
+                    "value_selector": ["1", "123", "args1"],
+                },
+                {"variable": "args2", "value_selector": ["1", "123", "args2"]},
+            ],
+            "answer": "123",
+            "code_language": "python3",
+            "code": code,
         },
-    )
+    }
 
-    # construct variable pool
-    pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-    pool.add(["1", "123", "args1"], 1)
-    pool.add(["1", "123", "args2"], 2)
+    node = init_code_node(code_config)
 
     # execute node
-    result = node.run(pool)
+    result = node._run()
+    assert isinstance(result, NodeRunResult)
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs["result"] == 3
     assert result.error is None
 
@@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock):
     """
     # trim first 4 spaces at the beginning of each line
     code = "\n".join([line[4:] for line in code.split("\n")])
-    node = CodeNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        user_from=UserFrom.ACCOUNT,
-        invoke_from=InvokeFrom.WEB_APP,
-        config={
-            "id": "1",
-            "data": {
-                "outputs": {
-                    "result": {
-                        "type": "string",
-                    },
+
+    code_config = {
+        "id": "code",
+        "data": {
+            "outputs": {
+                "result": {
+                    "type": "string",
                 },
-                "title": "123",
-                "variables": [
-                    {
-                        "variable": "args1",
-                        "value_selector": ["1", "123", "args1"],
-                    },
-                    {"variable": "args2", "value_selector": ["1", "123", "args2"]},
-                ],
-                "answer": "123",
-                "code_language": "python3",
-                "code": code,
             },
+            "title": "123",
+            "variables": [
+                {
+                    "variable": "args1",
+                    "value_selector": ["1", "123", "args1"],
+                },
+                {"variable": "args2", "value_selector": ["1", "123", "args2"]},
+            ],
+            "answer": "123",
+            "code_language": "python3",
+            "code": code,
         },
-    )
+    }
 
-    # construct variable pool
-    pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-    pool.add(["1", "123", "args1"], 1)
-    pool.add(["1", "123", "args2"], 2)
+    node = init_code_node(code_config)
 
     # execute node
-    result = node.run(pool)
-
+    result = node._run()
+    assert isinstance(result, NodeRunResult)
     assert result.status == WorkflowNodeExecutionStatus.FAILED
     assert result.error == "Output variable `result` must be a string"
 
@@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth():
     """
     # trim first 4 spaces at the beginning of each line
     code = "\n".join([line[4:] for line in code.split("\n")])
-    node = CodeNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        user_from=UserFrom.ACCOUNT,
-        invoke_from=InvokeFrom.WEB_APP,
-        config={
-            "id": "1",
-            "data": {
-                "outputs": {
-                    "string_validator": {
-                        "type": "string",
-                    },
-                    "number_validator": {
-                        "type": "number",
-                    },
-                    "number_array_validator": {
-                        "type": "array[number]",
-                    },
-                    "string_array_validator": {
-                        "type": "array[string]",
-                    },
-                    "object_validator": {
-                        "type": "object",
-                        "children": {
-                            "result": {
-                                "type": "number",
-                            },
-                            "depth": {
-                                "type": "object",
-                                "children": {
-                                    "depth": {
-                                        "type": "object",
-                                        "children": {
-                                            "depth": {
-                                                "type": "number",
-                                            }
-                                        },
-                                    }
-                                },
+
+    code_config = {
+        "id": "code",
+        "data": {
+            "outputs": {
+                "string_validator": {
+                    "type": "string",
+                },
+                "number_validator": {
+                    "type": "number",
+                },
+                "number_array_validator": {
+                    "type": "array[number]",
+                },
+                "string_array_validator": {
+                    "type": "array[string]",
+                },
+                "object_validator": {
+                    "type": "object",
+                    "children": {
+                        "result": {
+                            "type": "number",
+                        },
+                        "depth": {
+                            "type": "object",
+                            "children": {
+                                "depth": {
+                                    "type": "object",
+                                    "children": {
+                                        "depth": {
+                                            "type": "number",
+                                        }
+                                    },
+                                }
                             },
                         },
                     },
                 },
-                "title": "123",
-                "variables": [
-                    {
-                        "variable": "args1",
-                        "value_selector": ["1", "123", "args1"],
-                    },
-                    {"variable": "args2", "value_selector": ["1", "123", "args2"]},
-                ],
-                "answer": "123",
-                "code_language": "python3",
-                "code": code,
             },
+            "title": "123",
+            "variables": [
+                {
+                    "variable": "args1",
+                    "value_selector": ["1", "123", "args1"],
+                },
+                {"variable": "args2", "value_selector": ["1", "123", "args2"]},
+            ],
+            "answer": "123",
+            "code_language": "python3",
+            "code": code,
         },
-    )
+    }
+
+    node = init_code_node(code_config)
 
     # construct result
     result = {
@@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth():
         "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
     }
 
+    node.node_data = cast(CodeNodeData, node.node_data)
+
     # validate
     node._transform_result(result, node.node_data.outputs)
 
@@ -250,35 +284,30 @@ def test_execute_code_output_object_list():
     """
     # trim first 4 spaces at the beginning of each line
     code = "\n".join([line[4:] for line in code.split("\n")])
-    node = CodeNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
-        config={
-            "id": "1",
-            "data": {
-                "outputs": {
-                    "object_list": {
-                        "type": "array[object]",
-                    },
+
+    code_config = {
+        "id": "code",
+        "data": {
+            "outputs": {
+                "object_list": {
+                    "type": "array[object]",
                 },
-                "title": "123",
-                "variables": [
-                    {
-                        "variable": "args1",
-                        "value_selector": ["1", "123", "args1"],
-                    },
-                    {"variable": "args2", "value_selector": ["1", "123", "args2"]},
-                ],
-                "answer": "123",
-                "code_language": "python3",
-                "code": code,
             },
+            "title": "123",
+            "variables": [
+                {
+                    "variable": "args1",
+                    "value_selector": ["1", "123", "args1"],
+                },
+                {"variable": "args2", "value_selector": ["1", "123", "args2"]},
+            ],
+            "answer": "123",
+            "code_language": "python3",
+            "code": code,
         },
-    )
+    }
+
+    node = init_code_node(code_config)
 
     # construct result
     result = {
@@ -295,6 +324,8 @@ def test_execute_code_output_object_list():
         ]
     }
 
+    node.node_data = cast(CodeNodeData, node.node_data)
+
     # validate
     node._transform_result(result, node.node_data.outputs)
 

+ 96 - 60
api/tests/integration_tests/workflow/nodes/test_http.py

@@ -1,31 +1,69 @@
+import time
+import uuid
 from urllib.parse import urlencode
 
 import pytest
 
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import UserFrom
 from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
+from models.workflow import WorkflowType
 from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
 
-BASIC_NODE_DATA = {
-    "tenant_id": "1",
-    "app_id": "1",
-    "workflow_id": "1",
-    "user_id": "1",
-    "user_from": UserFrom.ACCOUNT,
-    "invoke_from": InvokeFrom.WEB_APP,
-}
 
-# construct variable pool
-pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-pool.add(["a", "b123", "args1"], 1)
-pool.add(["a", "b123", "args2"], 2)
+def init_http_node(config: dict):
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-next-target",
+                "source": "start",
+                "target": "1",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
+        tenant_id="1",
+        app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
+        workflow_id="1",
+        graph_config=graph_config,
+        user_id="1",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    variable_pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+    variable_pool.add(["a", "b123", "args1"], 1)
+    variable_pool.add(["a", "b123", "args2"], 2)
+
+    return HttpRequestNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=config,
+    )
 
 
 @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
 def test_get(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -45,12 +83,11 @@ def test_get(setup_http_mock):
                 "params": "A:b",
                 "body": None,
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
-
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "?A=b" in data
@@ -59,7 +96,7 @@ def test_get(setup_http_mock):
 
 @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
 def test_no_auth(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock):
                 "params": "A:b",
                 "body": None,
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
-
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "?A=b" in data
@@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock):
 
 @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
 def test_custom_authorization_header(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock):
                 "params": "A:b",
                 "body": None,
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
-
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "?A=b" in data
@@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock):
 
 @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
 def test_template(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -143,11 +178,11 @@ def test_template(setup_http_mock):
                 "params": "A:b\nTemplate:{{#a.b123.args2#}}",
                 "body": None,
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "?A=b" in data
@@ -158,7 +193,7 @@ def test_template(setup_http_mock):
 
 @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
 def test_json(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -178,11 +213,11 @@ def test_json(setup_http_mock):
                 "params": "A:b",
                 "body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'},
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert '{"a": "1"}' in data
@@ -190,7 +225,7 @@ def test_json(setup_http_mock):
 
 
 def test_x_www_form_urlencoded(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock):
                 "params": "A:b",
                 "body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "a=1&b=2" in data
@@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
 
 
 def test_form_data(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -242,11 +277,11 @@ def test_form_data(setup_http_mock):
                 "params": "A:b",
                 "body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert 'form-data; name="a"' in data
@@ -257,7 +292,7 @@ def test_form_data(setup_http_mock):
 
 
 def test_none_data(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -277,11 +312,11 @@ def test_none_data(setup_http_mock):
                 "params": "A:b",
                 "body": {"type": "none", "data": "123123123"},
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
     data = result.process_data.get("request", "")
 
     assert "X-Header: 123" in data
@@ -289,7 +324,7 @@ def test_none_data(setup_http_mock):
 
 
 def test_mock_404(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock):
                 "params": "",
                 "headers": "X-Header:123",
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.outputs is not None
     resp = result.outputs
 
     assert 404 == resp.get("status_code")
-    assert "Not Found" in resp.get("body")
+    assert "Not Found" in resp.get("body", "")
 
 
 def test_multi_colons_parse(setup_http_mock):
-    node = HttpRequestNode(
+    node = init_http_node(
         config={
             "id": "1",
             "data": {
@@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock):
                 "headers": "Referer:http://example3.com\nRedirect:http://example4.com",
                 "body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"},
             },
-        },
-        **BASIC_NODE_DATA,
+        }
     )
 
-    result = node.run(pool)
+    result = node._run()
+    assert result.process_data is not None
+    assert result.outputs is not None
     resp = result.outputs
 
-    assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request")
-    assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request")
-    assert "http://example3.com" == resp.get("headers").get("referer")
+    assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
+    assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "")
+    assert "http://example3.com" == resp.get("headers", {}).get("referer")

+ 82 - 51
api/tests/integration_tests/workflow/nodes/test_llm.py

@@ -1,5 +1,8 @@
 import json
 import os
+import time
+import uuid
+from collections.abc import Generator
 from unittest.mock import MagicMock
 
 import pytest
@@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers import ModelProviderFactory
+from core.workflow.entities.node_entities import UserFrom
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.nodes.event import RunCompletedEvent
 from core.workflow.nodes.llm.llm_node import LLMNode
 from extensions.ext_database import db
 from models.provider import ProviderType
-from models.workflow import WorkflowNodeExecutionStatus
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
 
 """FOR MOCK FIXTURES, DO NOT REMOVE"""
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
 
 
-@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
-def test_execute_llm(setup_openai_mock):
-    node = LLMNode(
+def init_llm_node(config: dict) -> LLMNode:
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-next-target",
+                "source": "start",
+                "target": "llm",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
         tenant_id="1",
         app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
         workflow_id="1",
+        graph_config=graph_config,
         user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
         user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    variable_pool = VariablePool(
+        system_variables={
+            SystemVariableKey.QUERY: "what's the weather today?",
+            SystemVariableKey.FILES: [],
+            SystemVariableKey.CONVERSATION_ID: "abababa",
+            SystemVariableKey.USER_ID: "aaa",
+        },
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+    variable_pool.add(["abc", "output"], "sunny")
+
+    node = LLMNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=config,
+    )
+
+    return node
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_execute_llm(setup_openai_mock):
+    node = init_llm_node(
         config={
             "id": "llm",
             "data": {
@@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock):
         },
     )
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather today?",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-    pool.add(["abc", "output"], "sunny")
-
     credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
 
     provider_instance = ModelProviderFactory().get_provider_instance("openai")
@@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock):
         model_type_instance=model_type_instance,
     )
     model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
+    model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
+    assert model_schema is not None
     model_config = ModelConfigWithCredentialsEntity(
         model="gpt-3.5-turbo",
         provider="openai",
         mode="chat",
         credentials=credentials,
         parameters={},
-        model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
+        model_schema=model_schema,
         provider_model_bundle=provider_model_bundle,
     )
 
@@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock):
     node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
 
     # execute node
-    result = node.run(pool)
+    result = node._run()
+    assert isinstance(result, Generator)
 
-    assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
-    assert result.outputs["text"] is not None
-    assert result.outputs["usage"]["total_tokens"] > 0
+    for item in result:
+        if isinstance(item, RunCompletedEvent):
+            assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+            assert item.run_result.process_data is not None
+            assert item.run_result.outputs is not None
+            assert item.run_result.outputs.get("text") is not None
+            assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
 
 
 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
@@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
     """
     Test execute LLM node with jinja2
     """
-    node = LLMNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_llm_node(
         config={
             "id": "llm",
             "data": {
@@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
         },
     )
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather today?",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-    pool.add(["abc", "output"], "sunny")
-
     credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
 
     provider_instance = ModelProviderFactory().get_provider_instance("openai")
@@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
     )
 
     model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
-
+    model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
+    assert model_schema is not None
     model_config = ModelConfigWithCredentialsEntity(
         model="gpt-3.5-turbo",
         provider="openai",
         mode="chat",
         credentials=credentials,
         parameters={},
-        model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
+        model_schema=model_schema,
         provider_model_bundle=provider_model_bundle,
     )
 
@@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
     node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
 
     # execute node
-    result = node.run(pool)
-
-    assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
-    assert "sunny" in json.dumps(result.process_data)
-    assert "what's the weather today?" in json.dumps(result.process_data)
+    result = node._run()
+
+    for item in result:
+        if isinstance(item, RunCompletedEvent):
+            assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+            assert item.run_result.process_data is not None
+            assert "sunny" in json.dumps(item.run_result.process_data)
+            assert "what's the weather today?" in json.dumps(item.run_result.process_data)

+ 87 - 105
api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py

@@ -1,5 +1,7 @@
 import json
 import os
+import time
+import uuid
 from typing import Optional
 from unittest.mock import MagicMock
 
@@ -8,19 +10,21 @@ import pytest
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
-from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
+from core.workflow.entities.node_entities import UserFrom
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
 from extensions.ext_database import db
 from models.provider import ProviderType
 
 """FOR MOCK FIXTURES, DO NOT REMOVE"""
-from models.workflow import WorkflowNodeExecutionStatus
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
 from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
 from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
 
@@ -47,13 +51,15 @@ def get_mocked_fetch_model_config(
         model_type_instance=model_type_instance,
     )
     model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
+    model_schema = model_type_instance.get_model_schema(model)
+    assert model_schema is not None
     model_config = ModelConfigWithCredentialsEntity(
         model=model,
         provider=provider,
         mode=mode,
         credentials=credentials,
         parameters={},
-        model_schema=model_type_instance.get_model_schema(model),
+        model_schema=model_schema,
         provider_model_bundle=provider_model_bundle,
     )
 
@@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str):
     return MagicMock(return_value=MemoryMock())
 
 
-@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
-def test_function_calling_parameter_extractor(setup_openai_mock):
-    """
-    Test function calling for parameter extractor.
-    """
-    node = ParameterExtractorNode(
+def init_parameter_extractor_node(config: dict):
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-next-target",
+                "source": "start",
+                "target": "llm",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
         tenant_id="1",
         app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
         workflow_id="1",
+        graph_config=graph_config,
         user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
         user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    variable_pool = VariablePool(
+        system_variables={
+            SystemVariableKey.QUERY: "what's the weather in SF",
+            SystemVariableKey.FILES: [],
+            SystemVariableKey.CONVERSATION_ID: "abababa",
+            SystemVariableKey.USER_ID: "aaa",
+        },
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+    variable_pool.add(["a", "b123", "args1"], 1)
+    variable_pool.add(["a", "b123", "args2"], 2)
+
+    return ParameterExtractorNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=config,
+    )
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
+def test_function_calling_parameter_extractor(setup_openai_mock):
+    """
+    Test function calling for parameter extractor.
+    """
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
                 "reasoning_mode": "function_call",
                 "memory": None,
             },
-        },
+        }
     )
 
     node._fetch_model_config = get_mocked_fetch_model_config(
@@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
         environment_variables=[],
     )
 
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs.get("location") == "kawaii"
     assert result.outputs.get("__reason") == None
 
@@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock):
     """
     Test chat parameter extractor.
     """
-    node = ParameterExtractorNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock):
     )
     db.session.close = MagicMock()
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather in SF",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs.get("location") == "kawaii"
     assert result.outputs.get("__reason") == None
 
     process_data = result.process_data
 
+    assert process_data is not None
     process_data.get("prompts")
 
-    for prompt in process_data.get("prompts"):
+    for prompt in process_data.get("prompts", []):
         if prompt.get("role") == "system":
             assert "what's the weather in SF" in prompt.get("text")
 
@@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
     """
     Test chat parameter extractor.
     """
-    node = ParameterExtractorNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
     )
     db.session.close = MagicMock()
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather in SF",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs.get("location") == ""
     assert (
         result.outputs.get("__reason")
         == "Failed to extract result from function call or text response, using empty result."
     )
-    prompts = result.process_data.get("prompts")
+    assert result.process_data is not None
+    prompts = result.process_data.get("prompts", [])
 
     for prompt in prompts:
         if prompt.get("role") == "user":
@@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock):
     """
     Test completion parameter extractor.
     """
-    node = ParameterExtractorNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock):
     )
     db.session.close = MagicMock()
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather in SF",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs.get("location") == ""
     assert (
         result.outputs.get("__reason")
         == "Failed to extract result from function call or text response, using empty result."
     )
-    assert len(result.process_data.get("prompts")) == 1
-    assert "SF" in result.process_data.get("prompts")[0].get("text")
+    assert result.process_data is not None
+    assert len(result.process_data.get("prompts", [])) == 1
+    assert "SF" in result.process_data.get("prompts", [])[0].get("text")
 
 
 def test_extract_json_response():
@@ -322,13 +325,7 @@ def test_extract_json_response():
     Test extract json response.
     """
 
-    node = ParameterExtractorNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -357,6 +354,7 @@ def test_extract_json_response():
         hello world.
     """)
 
+    assert result is not None
     assert result["location"] == "kawaii"
 
 
@@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
     """
     Test chat parameter extractor with memory.
     """
-    node = ParameterExtractorNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_parameter_extractor_node(
         config={
             "id": "llm",
             "data": {
@@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
     node._fetch_memory = get_mocked_fetch_memory("customized memory")
     db.session.close = MagicMock()
 
-    # construct variable pool
-    pool = VariablePool(
-        system_variables={
-            SystemVariableKey.QUERY: "what's the weather in SF",
-            SystemVariableKey.FILES: [],
-            SystemVariableKey.CONVERSATION_ID: "abababa",
-            SystemVariableKey.USER_ID: "aaa",
-        },
-        user_inputs={},
-        environment_variables=[],
-    )
-
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs.get("location") == ""
     assert (
         result.outputs.get("__reason")
         == "Failed to extract result from function call or text response, using empty result."
     )
-    prompts = result.process_data.get("prompts")
+    assert result.process_data is not None
+    prompts = result.process_data.get("prompts", [])
 
     latest_role = None
     for prompt in prompts:

+ 61 - 23
api/tests/integration_tests/workflow/nodes/test_template_transform.py

@@ -1,46 +1,84 @@
+import time
+import uuid
+
 import pytest
 
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import UserFrom
 from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
-from models.workflow import WorkflowNodeExecutionStatus
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
 from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
 
 
 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
 def test_execute_code(setup_code_executor_mock):
     code = """{{args2}}"""
-    node = TemplateTransformNode(
+    config = {
+        "id": "1",
+        "data": {
+            "title": "123",
+            "variables": [
+                {
+                    "variable": "args1",
+                    "value_selector": ["1", "123", "args1"],
+                },
+                {"variable": "args2", "value_selector": ["1", "123", "args2"]},
+            ],
+            "template": code,
+        },
+    }
+
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-next-target",
+                "source": "start",
+                "target": "1",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
         tenant_id="1",
         app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
         workflow_id="1",
+        graph_config=graph_config,
         user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.END_USER,
-        config={
-            "id": "1",
-            "data": {
-                "title": "123",
-                "variables": [
-                    {
-                        "variable": "args1",
-                        "value_selector": ["1", "123", "args1"],
-                    },
-                    {"variable": "args2", "value_selector": ["1", "123", "args2"]},
-                ],
-                "template": code,
-            },
-        },
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
     )
 
     # construct variable pool
-    pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-    pool.add(["1", "123", "args1"], 1)
-    pool.add(["1", "123", "args2"], 3)
+    variable_pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+    variable_pool.add(["1", "123", "args1"], 1)
+    variable_pool.add(["1", "123", "args2"], 3)
+
+    node = TemplateTransformNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=config,
+    )
 
     # execute node
-    result = node.run(pool)
+    result = node._run()
 
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert result.outputs["output"] == "3"

+ 61 - 23
api/tests/integration_tests/workflow/nodes/test_tool.py

@@ -1,21 +1,62 @@
+import time
+import uuid
+
 from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import NodeRunResult, UserFrom
 from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.nodes.base_node import UserFrom
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
 from core.workflow.nodes.tool.tool_node import ToolNode
-from models.workflow import WorkflowNodeExecutionStatus
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
 
 
-def test_tool_variable_invoke():
-    pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-    pool.add(["1", "123", "args1"], "1+1")
+def init_tool_node(config: dict):
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-next-target",
+                "source": "start",
+                "target": "1",
+            },
+        ],
+        "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
+    }
 
-    node = ToolNode(
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
         tenant_id="1",
         app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
         workflow_id="1",
+        graph_config=graph_config,
         user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
         user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    variable_pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+        user_inputs={},
+        environment_variables=[],
+        conversation_variables=[],
+    )
+
+    return ToolNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        config=config,
+    )
+
+
+def test_tool_variable_invoke():
+    node = init_tool_node(
         config={
             "id": "1",
             "data": {
@@ -34,28 +75,22 @@ def test_tool_variable_invoke():
                     }
                 },
             },
-        },
+        }
     )
 
-    # execute node
-    result = node.run(pool)
+    node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
 
+    # execute node
+    result = node._run()
+    assert isinstance(result, NodeRunResult)
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert "2" in result.outputs["text"]
     assert result.outputs["files"] == []
 
 
 def test_tool_mixed_invoke():
-    pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
-    pool.add(["1", "args1"], "1+1")
-
-    node = ToolNode(
-        tenant_id="1",
-        app_id="1",
-        workflow_id="1",
-        user_id="1",
-        invoke_from=InvokeFrom.WEB_APP,
-        user_from=UserFrom.ACCOUNT,
+    node = init_tool_node(
         config={
             "id": "1",
             "data": {
@@ -74,12 +109,15 @@ def test_tool_mixed_invoke():
                     }
                 },
             },
-        },
+        }
     )
 
-    # execute node
-    result = node.run(pool)
+    node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
 
+    # execute node
+    result = node._run()
+    assert isinstance(result, NodeRunResult)
     assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs is not None
     assert "2" in result.outputs["text"]
     assert result.outputs["files"] == []

+ 17 - 0
api/tests/unit_tests/conftest.py

@@ -1,7 +1,24 @@
 import os
 
+import pytest
+from flask import Flask
+
 # Getting the absolute path of the current file's directory
 ABS_PATH = os.path.dirname(os.path.abspath(__file__))
 
 # Getting the absolute path of the project's root directory
 PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
+
+CACHED_APP = Flask(__name__)
+CACHED_APP.config.update({"TESTING": True})
+
+
+@pytest.fixture()
+def app() -> Flask:
+    return CACHED_APP
+
+
+@pytest.fixture(autouse=True)
+def _provide_app_context(app: Flask):
+    with app.app_context():
+        yield

+ 0 - 0
api/tests/unit_tests/core/workflow/graph_engine/__init__.py


+ 791 - 0
api/tests/unit_tests/core/workflow/graph_engine/test_graph.py

@@ -0,0 +1,791 @@
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.run_condition import RunCondition
+from core.workflow.utils.condition.entities import Condition
+
+
+def test_init():
+    graph_config = {
+        "edges": [
+            {
+                "id": "llm-source-answer-target",
+                "source": "llm",
+                "target": "answer",
+            },
+            {
+                "id": "start-source-qc-target",
+                "source": "start",
+                "target": "qc",
+            },
+            {
+                "id": "qc-1-llm-target",
+                "source": "qc",
+                "sourceHandle": "1",
+                "target": "llm",
+            },
+            {
+                "id": "qc-2-http-target",
+                "source": "qc",
+                "sourceHandle": "2",
+                "target": "http",
+            },
+            {
+                "id": "http-source-answer2-target",
+                "source": "http",
+                "target": "answer2",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+            {
+                "data": {"type": "question-classifier"},
+                "id": "qc",
+            },
+            {
+                "data": {
+                    "type": "http-request",
+                },
+                "id": "http",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer2",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    start_node_id = "start"
+
+    assert graph.root_node_id == start_node_id
+    assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
+    assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
+
+
+def test__init_iteration_graph():
+    graph_config = {
+        "edges": [
+            {
+                "id": "llm-answer",
+                "source": "llm",
+                "sourceHandle": "source",
+                "target": "answer",
+            },
+            {
+                "id": "iteration-source-llm-target",
+                "source": "iteration",
+                "sourceHandle": "source",
+                "target": "llm",
+            },
+            {
+                "id": "template-transform-in-iteration-source-llm-in-iteration-target",
+                "source": "template-transform-in-iteration",
+                "sourceHandle": "source",
+                "target": "llm-in-iteration",
+            },
+            {
+                "id": "llm-in-iteration-source-answer-in-iteration-target",
+                "source": "llm-in-iteration",
+                "sourceHandle": "source",
+                "target": "answer-in-iteration",
+            },
+            {
+                "id": "start-source-code-target",
+                "source": "start",
+                "sourceHandle": "source",
+                "target": "code",
+            },
+            {
+                "id": "code-source-iteration-target",
+                "source": "code",
+                "sourceHandle": "source",
+                "target": "iteration",
+            },
+        ],
+        "nodes": [
+            {
+                "data": {
+                    "type": "start",
+                },
+                "id": "start",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+            {
+                "data": {"type": "iteration"},
+                "id": "iteration",
+            },
+            {
+                "data": {
+                    "type": "template-transform",
+                },
+                "id": "template-transform-in-iteration",
+                "parentId": "iteration",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm-in-iteration",
+                "parentId": "iteration",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer-in-iteration",
+                "parentId": "iteration",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
+    graph.add_extra_edge(
+        source_node_id="answer-in-iteration",
+        target_node_id="template-transform-in-iteration",
+        run_condition=RunCondition(
+            type="condition",
+            conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")],
+        ),
+    )
+
+    # iteration:
+    #   [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
+
+    assert graph.root_node_id == "template-transform-in-iteration"
+    assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
+    assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
+    assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
+
+
+def test_parallels_graph():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm1-source-answer-target",
+                "source": "llm1",
+                "target": "answer",
+            },
+            {
+                "id": "llm2-source-answer-target",
+                "source": "llm2",
+                "target": "answer",
+            },
+            {
+                "id": "llm3-source-answer-target",
+                "source": "llm3",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(3):
+        start_edges = graph.edge_mapping.get("start")
+        assert start_edges is not None
+        assert start_edges[i].target_node_id == f"llm{i+1}"
+
+        llm_edges = graph.edge_mapping.get(f"llm{i+1}")
+        assert llm_edges is not None
+        assert llm_edges[0].target_node_id == "answer"
+
+    assert len(graph.parallel_mapping) == 1
+    assert len(graph.node_parallel_mapping) == 3
+
+    for node_id in ["llm1", "llm2", "llm3"]:
+        assert node_id in graph.node_parallel_mapping
+
+
+def test_parallels_graph2():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm1-source-answer-target",
+                "source": "llm1",
+                "target": "answer",
+            },
+            {
+                "id": "llm2-source-answer-target",
+                "source": "llm2",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(3):
+        assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
+
+        if i < 2:
+            assert graph.edge_mapping.get(f"llm{i + 1}") is not None
+            assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
+
+    assert len(graph.parallel_mapping) == 1
+    assert len(graph.node_parallel_mapping) == 3
+
+    for node_id in ["llm1", "llm2", "llm3"]:
+        assert node_id in graph.node_parallel_mapping
+
+
+def test_parallels_graph3():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(3):
+        assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
+
+    assert len(graph.parallel_mapping) == 1
+    assert len(graph.node_parallel_mapping) == 3
+
+    for node_id in ["llm1", "llm2", "llm3"]:
+        assert node_id in graph.node_parallel_mapping
+
+
+def test_parallels_graph4():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm1-source-answer-target",
+                "source": "llm1",
+                "target": "code1",
+            },
+            {
+                "id": "llm2-source-answer-target",
+                "source": "llm2",
+                "target": "code2",
+            },
+            {
+                "id": "llm3-source-code3-target",
+                "source": "llm3",
+                "target": "code3",
+            },
+            {
+                "id": "code1-source-answer-target",
+                "source": "code1",
+                "target": "answer",
+            },
+            {
+                "id": "code2-source-answer-target",
+                "source": "code2",
+                "target": "answer",
+            },
+            {
+                "id": "code3-source-answer-target",
+                "source": "code3",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(3):
+        assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
+        assert graph.edge_mapping.get(f"llm{i + 1}") is not None
+        assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
+        assert graph.edge_mapping.get(f"code{i + 1}") is not None
+        assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
+
+    assert len(graph.parallel_mapping) == 1
+    assert len(graph.node_parallel_mapping) == 6
+
+    for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
+        assert node_id in graph.node_parallel_mapping
+
+
+def test_parallels_graph5():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm4",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm5",
+            },
+            {
+                "id": "llm1-source-code1-target",
+                "source": "llm1",
+                "target": "code1",
+            },
+            {
+                "id": "llm2-source-code1-target",
+                "source": "llm2",
+                "target": "code1",
+            },
+            {
+                "id": "llm3-source-code2-target",
+                "source": "llm3",
+                "target": "code2",
+            },
+            {
+                "id": "llm4-source-code2-target",
+                "source": "llm4",
+                "target": "code2",
+            },
+            {
+                "id": "llm5-source-code3-target",
+                "source": "llm5",
+                "target": "code3",
+            },
+            {
+                "id": "code1-source-answer-target",
+                "source": "code1",
+                "target": "answer",
+            },
+            {
+                "id": "code2-source-answer-target",
+                "source": "code2",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm4",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm5",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(5):
+        assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
+
+    assert graph.edge_mapping.get("llm1") is not None
+    assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
+    assert graph.edge_mapping.get("llm2") is not None
+    assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
+    assert graph.edge_mapping.get("llm3") is not None
+    assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
+    assert graph.edge_mapping.get("llm4") is not None
+    assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
+    assert graph.edge_mapping.get("llm5") is not None
+    assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
+    assert graph.edge_mapping.get("code1") is not None
+    assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
+    assert graph.edge_mapping.get("code2") is not None
+    assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
+
+    assert len(graph.parallel_mapping) == 1
+    assert len(graph.node_parallel_mapping) == 8
+
+    for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
+        assert node_id in graph.node_parallel_mapping
+
+
+def test_parallels_graph6():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm1-source-code1-target",
+                "source": "llm1",
+                "target": "code1",
+            },
+            {
+                "id": "llm1-source-code2-target",
+                "source": "llm1",
+                "target": "code2",
+            },
+            {
+                "id": "llm2-source-code3-target",
+                "source": "llm2",
+                "target": "code3",
+            },
+            {
+                "id": "code1-source-answer-target",
+                "source": "code1",
+                "target": "answer",
+            },
+            {
+                "id": "code2-source-answer-target",
+                "source": "code2",
+                "target": "answer",
+            },
+            {
+                "id": "code3-source-answer-target",
+                "source": "code3",
+                "target": "answer",
+            },
+            {
+                "id": "llm3-source-answer-target",
+                "source": "llm3",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "code",
+                },
+                "id": "code3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1"},
+                "id": "answer",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    assert graph.root_node_id == "start"
+    for i in range(3):
+        assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
+
+    assert graph.edge_mapping.get("llm1") is not None
+    assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
+    assert graph.edge_mapping.get("llm1") is not None
+    assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
+    assert graph.edge_mapping.get("llm2") is not None
+    assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
+    assert graph.edge_mapping.get("code1") is not None
+    assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
+    assert graph.edge_mapping.get("code2") is not None
+    assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
+    assert graph.edge_mapping.get("code3") is not None
+    assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
+
+    assert len(graph.parallel_mapping) == 2
+    assert len(graph.node_parallel_mapping) == 6
+
+    for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
+        assert node_id in graph.node_parallel_mapping
+
+    parent_parallel = None
+    child_parallel = None
+    for p_id, parallel in graph.parallel_mapping.items():
+        if parallel.parent_parallel_id is None:
+            parent_parallel = parallel
+        else:
+            child_parallel = parallel
+
+    for node_id in ["llm1", "llm2", "llm3", "code3"]:
+        assert graph.node_parallel_mapping[node_id] == parent_parallel.id
+
+    for node_id in ["code1", "code2"]:
+        assert graph.node_parallel_mapping[node_id] == child_parallel.id

+ 505 - 0
api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py

@@ -0,0 +1,505 @@
+from unittest.mock import patch
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import (
+    BaseNodeEvent,
+    GraphRunFailedEvent,
+    GraphRunStartedEvent,
+    GraphRunSucceededEvent,
+    NodeRunFailedEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+from core.workflow.graph_engine.graph_engine import GraphEngine
+from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
+from core.workflow.nodes.llm.llm_node import LLMNode
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
+
+
+@patch("extensions.ext_database.db.session.remove")
+@patch("extensions.ext_database.db.session.close")
+def test_run_parallel_in_workflow(mock_close, mock_remove):
+    graph_config = {
+        "edges": [
+            {
+                "id": "1",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "2",
+                "source": "llm1",
+                "target": "llm2",
+            },
+            {
+                "id": "3",
+                "source": "llm1",
+                "target": "llm3",
+            },
+            {
+                "id": "4",
+                "source": "llm2",
+                "target": "end1",
+            },
+            {
+                "id": "5",
+                "source": "llm3",
+                "target": "end2",
+            },
+        ],
+        "nodes": [
+            {
+                "data": {
+                    "type": "start",
+                    "title": "start",
+                    "variables": [
+                        {
+                            "label": "query",
+                            "max_length": 48,
+                            "options": [],
+                            "required": True,
+                            "type": "text-input",
+                            "variable": "query",
+                        }
+                    ],
+                },
+                "id": "start",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                    "title": "llm1",
+                    "context": {"enabled": False, "variable_selector": []},
+                    "model": {
+                        "completion_params": {"temperature": 0.7},
+                        "mode": "chat",
+                        "name": "gpt-4o",
+                        "provider": "openai",
+                    },
+                    "prompt_template": [
+                        {"role": "system", "text": "say hi"},
+                        {"role": "user", "text": "{{#start.query#}}"},
+                    ],
+                    "vision": {"configs": {"detail": "high"}, "enabled": False},
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                    "title": "llm2",
+                    "context": {"enabled": False, "variable_selector": []},
+                    "model": {
+                        "completion_params": {"temperature": 0.7},
+                        "mode": "chat",
+                        "name": "gpt-4o",
+                        "provider": "openai",
+                    },
+                    "prompt_template": [
+                        {"role": "system", "text": "say bye"},
+                        {"role": "user", "text": "{{#start.query#}}"},
+                    ],
+                    "vision": {"configs": {"detail": "high"}, "enabled": False},
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                    "title": "llm3",
+                    "context": {"enabled": False, "variable_selector": []},
+                    "model": {
+                        "completion_params": {"temperature": 0.7},
+                        "mode": "chat",
+                        "name": "gpt-4o",
+                        "provider": "openai",
+                    },
+                    "prompt_template": [
+                        {"role": "system", "text": "say good morning"},
+                        {"role": "user", "text": "{{#start.query#}}"},
+                    ],
+                    "vision": {"configs": {"detail": "high"}, "enabled": False},
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "end",
+                    "title": "end1",
+                    "outputs": [
+                        {"value_selector": ["llm2", "text"], "variable": "result2"},
+                        {"value_selector": ["start", "query"], "variable": "query"},
+                    ],
+                },
+                "id": "end1",
+            },
+            {
+                "data": {
+                    "type": "end",
+                    "title": "end2",
+                    "outputs": [
+                        {"value_selector": ["llm1", "text"], "variable": "result1"},
+                        {"value_selector": ["llm3", "text"], "variable": "result3"},
+                    ],
+                },
+                "id": "end2",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    variable_pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
+    )
+
+    graph_engine = GraphEngine(
+        tenant_id="111",
+        app_id="222",
+        workflow_type=WorkflowType.WORKFLOW,
+        workflow_id="333",
+        graph_config=graph_config,
+        user_id="444",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.WEB_APP,
+        call_depth=0,
+        graph=graph,
+        variable_pool=variable_pool,
+        max_execution_steps=500,
+        max_execution_time=1200,
+    )
+
+    def llm_generator(self):
+        contents = ["hi", "bye", "good morning"]
+
+        yield RunStreamChunkEvent(
+            chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"]
+        )
+
+        yield RunCompletedEvent(
+            run_result=NodeRunResult(
+                status=WorkflowNodeExecutionStatus.SUCCEEDED,
+                inputs={},
+                process_data={},
+                outputs={},
+                metadata={
+                    NodeRunMetadataKey.TOTAL_TOKENS: 1,
+                    NodeRunMetadataKey.TOTAL_PRICE: 1,
+                    NodeRunMetadataKey.CURRENCY: "USD",
+                },
+            )
+        )
+
+    # print("")
+
+    with patch.object(LLMNode, "_run", new=llm_generator):
+        items = []
+        generator = graph_engine.run()
+        for item in generator:
+            # print(type(item), item)
+            items.append(item)
+            if isinstance(item, NodeRunSucceededEvent):
+                assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
+
+            assert not isinstance(item, NodeRunFailedEvent)
+            assert not isinstance(item, GraphRunFailedEvent)
+
+            if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]:
+                assert item.parallel_id is not None
+
+        assert len(items) == 18
+        assert isinstance(items[0], GraphRunStartedEvent)
+        assert isinstance(items[1], NodeRunStartedEvent)
+        assert items[1].route_node_state.node_id == "start"
+        assert isinstance(items[2], NodeRunSucceededEvent)
+        assert items[2].route_node_state.node_id == "start"
+
+
+@patch("extensions.ext_database.db.session.remove")
+@patch("extensions.ext_database.db.session.close")
+def test_run_parallel_in_chatflow(mock_close, mock_remove):
+    graph_config = {
+        "edges": [
+            {
+                "id": "1",
+                "source": "start",
+                "target": "answer1",
+            },
+            {
+                "id": "2",
+                "source": "answer1",
+                "target": "answer2",
+            },
+            {
+                "id": "3",
+                "source": "answer1",
+                "target": "answer3",
+            },
+            {
+                "id": "4",
+                "source": "answer2",
+                "target": "answer4",
+            },
+            {
+                "id": "5",
+                "source": "answer3",
+                "target": "answer5",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start", "title": "start"}, "id": "start"},
+            {"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"},
+            {
+                "data": {"type": "answer", "title": "answer2", "answer": "2"},
+                "id": "answer2",
+            },
+            {
+                "data": {"type": "answer", "title": "answer3", "answer": "3"},
+                "id": "answer3",
+            },
+            {
+                "data": {"type": "answer", "title": "answer4", "answer": "4"},
+                "id": "answer4",
+            },
+            {
+                "data": {"type": "answer", "title": "answer5", "answer": "5"},
+                "id": "answer5",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    variable_pool = VariablePool(
+        system_variables={
+            SystemVariableKey.QUERY: "what's the weather in SF",
+            SystemVariableKey.FILES: [],
+            SystemVariableKey.CONVERSATION_ID: "abababa",
+            SystemVariableKey.USER_ID: "aaa",
+        },
+        user_inputs={},
+    )
+
+    graph_engine = GraphEngine(
+        tenant_id="111",
+        app_id="222",
+        workflow_type=WorkflowType.CHAT,
+        workflow_id="333",
+        graph_config=graph_config,
+        user_id="444",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.WEB_APP,
+        call_depth=0,
+        graph=graph,
+        variable_pool=variable_pool,
+        max_execution_steps=500,
+        max_execution_time=1200,
+    )
+
+    # print("")
+
+    items = []
+    generator = graph_engine.run()
+    for item in generator:
+        # print(type(item), item)
+        items.append(item)
+        if isinstance(item, NodeRunSucceededEvent):
+            assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
+
+        assert not isinstance(item, NodeRunFailedEvent)
+        assert not isinstance(item, GraphRunFailedEvent)
+
+        if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
+            "answer2",
+            "answer3",
+            "answer4",
+            "answer5",
+        ]:
+            assert item.parallel_id is not None
+
+    assert len(items) == 23
+    assert isinstance(items[0], GraphRunStartedEvent)
+    assert isinstance(items[1], NodeRunStartedEvent)
+    assert items[1].route_node_state.node_id == "start"
+    assert isinstance(items[2], NodeRunSucceededEvent)
+    assert items[2].route_node_state.node_id == "start"
+
+
+@patch("extensions.ext_database.db.session.remove")
+@patch("extensions.ext_database.db.session.close")
+def test_run_branch(mock_close, mock_remove):
+    graph_config = {
+        "edges": [
+            {
+                "id": "1",
+                "source": "start",
+                "target": "if-else-1",
+            },
+            {
+                "id": "2",
+                "source": "if-else-1",
+                "sourceHandle": "true",
+                "target": "answer-1",
+            },
+            {
+                "id": "3",
+                "source": "if-else-1",
+                "sourceHandle": "false",
+                "target": "if-else-2",
+            },
+            {
+                "id": "4",
+                "source": "if-else-2",
+                "sourceHandle": "true",
+                "target": "answer-2",
+            },
+            {
+                "id": "5",
+                "source": "if-else-2",
+                "sourceHandle": "false",
+                "target": "answer-3",
+            },
+        ],
+        "nodes": [
+            {
+                "data": {
+                    "title": "Start",
+                    "type": "start",
+                    "variables": [
+                        {
+                            "label": "uid",
+                            "max_length": 48,
+                            "options": [],
+                            "required": True,
+                            "type": "text-input",
+                            "variable": "uid",
+                        }
+                    ],
+                },
+                "id": "start",
+            },
+            {
+                "data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []},
+                "id": "answer-1",
+            },
+            {
+                "data": {
+                    "cases": [
+                        {
+                            "case_id": "true",
+                            "conditions": [
+                                {
+                                    "comparison_operator": "contains",
+                                    "id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
+                                    "value": "hi",
+                                    "varType": "string",
+                                    "variable_selector": ["sys", "query"],
+                                }
+                            ],
+                            "id": "true",
+                            "logical_operator": "and",
+                        }
+                    ],
+                    "desc": "",
+                    "title": "IF/ELSE",
+                    "type": "if-else",
+                },
+                "id": "if-else-1",
+            },
+            {
+                "data": {
+                    "cases": [
+                        {
+                            "case_id": "true",
+                            "conditions": [
+                                {
+                                    "comparison_operator": "contains",
+                                    "id": "ae895199-5608-433b-b5f0-0997ae1431e4",
+                                    "value": "takatost",
+                                    "varType": "string",
+                                    "variable_selector": ["sys", "query"],
+                                }
+                            ],
+                            "id": "true",
+                            "logical_operator": "and",
+                        }
+                    ],
+                    "title": "IF/ELSE 2",
+                    "type": "if-else",
+                },
+                "id": "if-else-2",
+            },
+            {
+                "data": {
+                    "answer": "2",
+                    "title": "Answer 2",
+                    "type": "answer",
+                },
+                "id": "answer-2",
+            },
+            {
+                "data": {
+                    "answer": "3",
+                    "title": "Answer 3",
+                    "type": "answer",
+                },
+                "id": "answer-3",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    variable_pool = VariablePool(
+        system_variables={
+            SystemVariableKey.QUERY: "hi",
+            SystemVariableKey.FILES: [],
+            SystemVariableKey.CONVERSATION_ID: "abababa",
+            SystemVariableKey.USER_ID: "aaa",
+        },
+        user_inputs={"uid": "takato"},
+    )
+
+    graph_engine = GraphEngine(
+        tenant_id="111",
+        app_id="222",
+        workflow_type=WorkflowType.CHAT,
+        workflow_id="333",
+        graph_config=graph_config,
+        user_id="444",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.WEB_APP,
+        call_depth=0,
+        graph=graph,
+        variable_pool=variable_pool,
+        max_execution_steps=500,
+        max_execution_time=1200,
+    )
+
+    # print("")
+
+    items = []
+    generator = graph_engine.run()
+    for item in generator:
+        # print(type(item), item)
+        items.append(item)
+
+    assert len(items) == 10
+    assert items[3].route_node_state.node_id == "if-else-1"
+    assert items[4].route_node_state.node_id == "if-else-1"
+    assert isinstance(items[5], NodeRunStreamChunkEvent)
+    assert items[5].chunk_content == "1 "
+    assert isinstance(items[6], NodeRunStreamChunkEvent)
+    assert items[6].chunk_content == "takato"
+    assert items[7].route_node_state.node_id == "answer-1"
+    assert items[8].route_node_state.node_id == "answer-1"
+    assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato"
+    assert isinstance(items[9], GraphRunSucceededEvent)
+
+    # print(graph_engine.graph_runtime_state.model_dump_json(indent=2))

+ 0 - 0
api/tests/unit_tests/core/workflow/nodes/answer/__init__.py


+ 82 - 0
api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py

@@ -0,0 +1,82 @@
+import time
+import uuid
+from unittest.mock import MagicMock
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities.node_entities import UserFrom
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.nodes.answer.answer_node import AnswerNode
+from extensions.ext_database import db
+from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
+
+
+def test_execute_answer():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm-target",
+                "source": "start",
+                "target": "llm",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    init_params = GraphInitParams(
+        tenant_id="1",
+        app_id="1",
+        workflow_type=WorkflowType.WORKFLOW,
+        workflow_id="1",
+        graph_config=graph_config,
+        user_id="1",
+        user_from=UserFrom.ACCOUNT,
+        invoke_from=InvokeFrom.DEBUGGER,
+        call_depth=0,
+    )
+
+    # construct variable pool
+    pool = VariablePool(
+        system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+        user_inputs={},
+        environment_variables=[],
+    )
+    pool.add(["start", "weather"], "sunny")
+    pool.add(["llm", "text"], "You are a helpful AI.")
+
+    node = AnswerNode(
+        id=str(uuid.uuid4()),
+        graph_init_params=init_params,
+        graph=graph,
+        graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
+        config={
+            "id": "answer",
+            "data": {
+                "title": "123",
+                "type": "answer",
+                "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+            },
+        },
+    )
+
+    # Mock db.session.close()
+    db.session.close = MagicMock()
+
+    # execute node
+    result = node._run()
+
+    assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+    assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

+ 109 - 0
api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py

@@ -0,0 +1,109 @@
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
+
+
+def test_init():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm3-source-llm4-target",
+                "source": "llm3",
+                "target": "llm4",
+            },
+            {
+                "id": "llm3-source-llm5-target",
+                "source": "llm3",
+                "target": "llm5",
+            },
+            {
+                "id": "llm4-source-answer2-target",
+                "source": "llm4",
+                "target": "answer2",
+            },
+            {
+                "id": "llm5-source-answer-target",
+                "source": "llm5",
+                "target": "answer",
+            },
+            {
+                "id": "answer2-source-answer-target",
+                "source": "answer2",
+                "target": "answer",
+            },
+            {
+                "id": "llm2-source-answer-target",
+                "source": "llm2",
+                "target": "answer",
+            },
+            {
+                "id": "llm1-source-answer-target",
+                "source": "llm1",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm4",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm5",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"},
+                "id": "answer",
+            },
+            {
+                "data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"},
+                "id": "answer2",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
+        node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping
+    )
+
+    assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"]
+    assert answer_stream_generate_route.answer_dependencies["answer2"] == []

+ 216 - 0
api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py

@@ -0,0 +1,216 @@
+import uuid
+from collections.abc import Generator
+from datetime import datetime, timezone
+
+from core.workflow.entities.node_entities import NodeType
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import (
+    GraphEngineEvent,
+    NodeRunStartedEvent,
+    NodeRunStreamChunkEvent,
+    NodeRunSucceededEvent,
+)
+from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
+from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
+from core.workflow.nodes.start.entities import StartNodeData
+
+
+def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
+    if next_node_id == "start":
+        yield from _publish_events(graph, next_node_id)
+
+    for edge in graph.edge_mapping.get(next_node_id, []):
+        yield from _publish_events(graph, edge.target_node_id)
+
+    for edge in graph.edge_mapping.get(next_node_id, []):
+        yield from _recursive_process(graph, edge.target_node_id)
+
+
+def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
+    route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
+
+    parallel_id = graph.node_parallel_mapping.get(next_node_id)
+    parallel_start_node_id = None
+    if parallel_id:
+        parallel = graph.parallel_mapping.get(parallel_id)
+        parallel_start_node_id = parallel.start_from_node_id if parallel else None
+
+    node_execution_id = str(uuid.uuid4())
+    node_config = graph.node_id_config_mapping[next_node_id]
+    node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
+    mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
+
+    yield NodeRunStartedEvent(
+        id=node_execution_id,
+        node_id=next_node_id,
+        node_type=node_type,
+        node_data=mock_node_data,
+        route_node_state=route_node_state,
+        parallel_id=graph.node_parallel_mapping.get(next_node_id),
+        parallel_start_node_id=parallel_start_node_id,
+    )
+
+    if "llm" in next_node_id:
+        length = int(next_node_id[-1])
+        for i in range(0, length):
+            yield NodeRunStreamChunkEvent(
+                id=node_execution_id,
+                node_id=next_node_id,
+                node_type=node_type,
+                node_data=mock_node_data,
+                chunk_content=str(i),
+                route_node_state=route_node_state,
+                from_variable_selector=[next_node_id, "text"],
+                parallel_id=parallel_id,
+                parallel_start_node_id=parallel_start_node_id,
+            )
+
+    route_node_state.status = RouteNodeState.Status.SUCCESS
+    route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
+    yield NodeRunSucceededEvent(
+        id=node_execution_id,
+        node_id=next_node_id,
+        node_type=node_type,
+        node_data=mock_node_data,
+        route_node_state=route_node_state,
+        parallel_id=parallel_id,
+        parallel_start_node_id=parallel_start_node_id,
+    )
+
+
+def test_process():
+    graph_config = {
+        "edges": [
+            {
+                "id": "start-source-llm1-target",
+                "source": "start",
+                "target": "llm1",
+            },
+            {
+                "id": "start-source-llm2-target",
+                "source": "start",
+                "target": "llm2",
+            },
+            {
+                "id": "start-source-llm3-target",
+                "source": "start",
+                "target": "llm3",
+            },
+            {
+                "id": "llm3-source-llm4-target",
+                "source": "llm3",
+                "target": "llm4",
+            },
+            {
+                "id": "llm3-source-llm5-target",
+                "source": "llm3",
+                "target": "llm5",
+            },
+            {
+                "id": "llm4-source-answer2-target",
+                "source": "llm4",
+                "target": "answer2",
+            },
+            {
+                "id": "llm5-source-answer-target",
+                "source": "llm5",
+                "target": "answer",
+            },
+            {
+                "id": "answer2-source-answer-target",
+                "source": "answer2",
+                "target": "answer",
+            },
+            {
+                "id": "llm2-source-answer-target",
+                "source": "llm2",
+                "target": "answer",
+            },
+            {
+                "id": "llm1-source-answer-target",
+                "source": "llm1",
+                "target": "answer",
+            },
+        ],
+        "nodes": [
+            {"data": {"type": "start"}, "id": "start"},
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm1",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm2",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm3",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm4",
+            },
+            {
+                "data": {
+                    "type": "llm",
+                },
+                "id": "llm5",
+            },
+            {
+                "data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"},
+                "id": "answer",
+            },
+            {
+                "data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"},
+                "id": "answer2",
+            },
+        ],
+    }
+
+    graph = Graph.init(graph_config=graph_config)
+
+    variable_pool = VariablePool(
+        system_variables={
+            SystemVariableKey.QUERY: "what's the weather in SF",
+            SystemVariableKey.FILES: [],
+            SystemVariableKey.CONVERSATION_ID: "abababa",
+            SystemVariableKey.USER_ID: "aaa",
+        },
+        user_inputs={},
+    )
+
+    answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool)
+
+    def graph_generator() -> Generator[GraphEngineEvent, None, None]:
+        # print("")
+        for event in _recursive_process(graph, "start"):
+            # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
+            #       " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
+            if isinstance(event, NodeRunSucceededEvent):
+                if "llm" in event.route_node_state.node_id:
+                    variable_pool.add(
+                        [event.route_node_state.node_id, "text"],
+                        "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))),
+                    )
+            yield event
+
+    result_generator = answer_stream_processor.process(graph_generator())
+    stream_contents = ""
+    for event in result_generator:
+        # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
+        #       " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
+        if isinstance(event, NodeRunStreamChunkEvent):
+            stream_contents += event.chunk_content
+        pass
+
+    assert stream_contents == "c012da01b"

Vissa filer visades inte eftersom för många filer har ändrats