/ cli / commands / init_cmd / orchestrator_step.py
orchestrator_step.py
  1  import re
  2  from pathlib import Path
  3  
  4  import click
  5  import yaml
  6  
  7  from config_portal.backend.common import DEFAULT_COMMUNICATION_TIMEOUT
  8  
  9  from ...utils import ask_if_not_provided, get_formatted_names, load_template
 10  
 11  ORCHESTRATOR_DEFAULTS = {
 12      "agent_name": "OrchestratorAgent",
 13      "supports_streaming": True,
 14      "artifact_handling_mode": "reference",
 15      "enable_embed_resolution": True,
 16      "enable_artifact_content_instruction": True,
 17      "enable_builtin_artifact_tools": {"enabled": True},
 18      "enable_builtin_data_tools": {"enabled": True},
 19      "artifact_service": {
 20          "type": "filesystem",
 21          "base_path": "/tmp/samv2",
 22          "artifact_scope": "namespace",
 23          "bucket_name": "",
 24          "endpoint_url": "",
 25          "region": "us-east-1",
 26      },
 27      "agent_card": {
 28          "description": "The Orchestrator component. It manages tasks and coordinates multi-agent workflows.",
 29          "defaultInputModes": ["text"],
 30          "defaultOutputModes": ["text", "file"],
 31          "skills": [],
 32      },
 33      "agent_card_publishing": {"interval_seconds": 10},
 34      "agent_discovery": {"enabled": True},
 35      "inter_agent_communication": {
 36          "allow_list": ["*"],
 37          "request_timeout_seconds": DEFAULT_COMMUNICATION_TIMEOUT,
 38      },
 39      "use_orchestrator_db": True,
 40  }
 41  
 42  
 43  def create_orchestrator_config(
 44      project_root: Path, options: dict, skip_interactive: bool
 45  ) -> bool:
 46      """
 47      Creates the main_orchestrator.yaml file with configured values.
 48      Returns True on success, False on failure.
 49      """
 50      click.echo("Configuring main orchestrator...")
 51  
 52      raise_if_not_valid_agent_name = options.get("agent_name") or skip_interactive
 53      ask_if_not_provided(
 54          options,
 55          "agent_name",
 56          "Enter agent name",
 57          ORCHESTRATOR_DEFAULTS["agent_name"],
 58          skip_interactive,
 59      )
 60  
 61      agent_name = options.get("agent_name")
 62      while not re.match(r"^[a-zA-Z0-9_]+$", agent_name):
 63          if raise_if_not_valid_agent_name:
 64              raise click.UsageError(
 65                  "Invalid agent name. Only letters, numbers, and underscores are allowed."
 66              )
 67          else:
 68              click.echo(
 69                  click.style(
 70                      "Invalid agent name. Only letters, numbers, and underscores are allowed.",
 71                      fg="red",
 72                  ),
 73                  err=True,
 74              )
 75              agent_name = click.prompt("Please enter a valid agent name")
 76      options["agent_name"] = agent_name
 77  
 78      ask_if_not_provided(
 79          options,
 80          "supports_streaming",
 81          "Enable streaming support? (true/false)",
 82          ORCHESTRATOR_DEFAULTS["supports_streaming"],
 83          skip_interactive,
 84          is_bool=True,
 85      )
 86  
 87      options["use_orchestrator_db"] = ORCHESTRATOR_DEFAULTS["use_orchestrator_db"]
 88  
 89      artifact_type = ask_if_not_provided(
 90          options,
 91          "artifact_service_type",
 92          "Enter artifact service type",
 93          ORCHESTRATOR_DEFAULTS["artifact_service"]["type"],
 94          skip_interactive,
 95          choices=["memory", "filesystem", "gcs", "s3"],
 96      )
 97  
 98      artifact_base_path = None
 99      s3_bucket_name = None
100      s3_endpoint_url = None
101      s3_region = None
102  
103      if artifact_type == "filesystem":
104          artifact_base_path = ask_if_not_provided(
105              options,
106              "artifact_service_base_path",
107              "Enter artifact service base path",
108              ORCHESTRATOR_DEFAULTS["artifact_service"]["base_path"],
109              skip_interactive,
110          )
111      elif artifact_type == "s3":
112          # Map CLI artifact-service-* parameters to s3_* keys
113          if options.get("artifact_service_bucket_name"):
114              options["s3_bucket_name"] = options["artifact_service_bucket_name"]
115          if options.get("artifact_service_endpoint_url"):
116              options["s3_endpoint_url"] = options["artifact_service_endpoint_url"]
117          if options.get("artifact_service_region"):
118              options["s3_region"] = options["artifact_service_region"]
119  
120          s3_bucket_name = ask_if_not_provided(
121              options,
122              "s3_bucket_name",
123              "Enter S3 bucket name",
124              ORCHESTRATOR_DEFAULTS["artifact_service"]["bucket_name"],
125              skip_interactive,
126          )
127          s3_endpoint_url = ask_if_not_provided(
128              options,
129              "s3_endpoint_url",
130              "Enter S3 endpoint URL (leave empty for AWS S3)",
131              ORCHESTRATOR_DEFAULTS["artifact_service"]["endpoint_url"],
132              skip_interactive,
133          )
134          s3_region = ask_if_not_provided(
135              options,
136              "s3_region",
137              "Enter S3 region",
138              ORCHESTRATOR_DEFAULTS["artifact_service"]["region"],
139              skip_interactive,
140          )
141  
142      artifact_scope = ask_if_not_provided(
143          options,
144          "artifact_service_scope",
145          "Enter artifact service scope",
146          ORCHESTRATOR_DEFAULTS["artifact_service"]["artifact_scope"],
147          skip_interactive,
148          choices=["namespace", "app", "custom"],
149      )
150  
151      artifact_handling_mode = ask_if_not_provided(
152          options,
153          "artifact_handling_mode",
154          "Enter artifact handling mode",
155          ORCHESTRATOR_DEFAULTS["artifact_handling_mode"],
156          skip_interactive,
157          choices=["ignore", "embed", "reference"],
158      )
159  
160      enable_embed_resolution = ask_if_not_provided(
161          options,
162          "enable_embed_resolution",
163          "Enable embed resolution? (true/false)",
164          ORCHESTRATOR_DEFAULTS["enable_embed_resolution"],
165          skip_interactive,
166          is_bool=True,
167      )
168  
169      enable_artifact_content_instruction = ask_if_not_provided(
170          options,
171          "enable_artifact_content_instruction",
172          "Enable artifact content instruction? (true/false)",
173          ORCHESTRATOR_DEFAULTS["enable_artifact_content_instruction"],
174          skip_interactive,
175          is_bool=True,
176      )
177  
178      agent_card_description = ask_if_not_provided(
179          options,
180          "agent_card_description",
181          "Enter agent card description",
182          ORCHESTRATOR_DEFAULTS["agent_card"]["description"],
183          skip_interactive,
184      )
185  
186      if "agent_card_default_input_modes" in options and isinstance(
187          options["agent_card_default_input_modes"], list
188      ):
189          default_input_modes = options["agent_card_default_input_modes"]
190      else:
191          default_input_modes_str = ask_if_not_provided(
192              options,
193              "agent_card_default_input_modes",
194              "Enter agent card default input modes (comma-separated)",
195              ",".join(ORCHESTRATOR_DEFAULTS["agent_card"]["defaultInputModes"]),
196              skip_interactive,
197          )
198          if isinstance(default_input_modes_str, list):
199              default_input_modes = default_input_modes_str
200          else:
201              default_input_modes = [
202                  mode.strip() for mode in default_input_modes_str.split(",")
203              ]
204  
205      if "agent_card_default_output_modes" in options and isinstance(
206          options["agent_card_default_output_modes"], list
207      ):
208          default_output_modes = options["agent_card_default_output_modes"]
209      else:
210          default_output_modes_str = ask_if_not_provided(
211              options,
212              "agent_card_default_output_modes",
213              "Enter agent card default output modes (comma-separated)",
214              ",".join(ORCHESTRATOR_DEFAULTS["agent_card"]["defaultOutputModes"]),
215              skip_interactive,
216          )
217          if isinstance(default_output_modes_str, list):
218              default_output_modes = default_output_modes_str
219          else:
220              default_output_modes = [
221                  mode.strip() for mode in default_output_modes_str.split(",")
222              ]
223  
224      agent_discovery_enabled = ask_if_not_provided(
225          options,
226          "agent_discovery_enabled",
227          "Enable agent discovery? (true/false)",
228          ORCHESTRATOR_DEFAULTS["agent_discovery"]["enabled"],
229          skip_interactive,
230          is_bool=True,
231      )
232  
233      agent_card_publishing_interval = ask_if_not_provided(
234          options,
235          "agent_card_publishing_interval",
236          "Enter agent card publishing interval (seconds)",
237          ORCHESTRATOR_DEFAULTS["agent_card_publishing"]["interval_seconds"],
238          skip_interactive,
239      )
240  
241      if "inter_agent_communication_allow_list" in options and isinstance(
242          options["inter_agent_communication_allow_list"], list
243      ):
244          allow_list = options["inter_agent_communication_allow_list"]
245      else:
246          allow_list_str = ask_if_not_provided(
247              options,
248              "inter_agent_communication_allow_list",
249              "Enter inter-agent communication allow list (comma-separated, use * for all)",
250              ",".join(ORCHESTRATOR_DEFAULTS["inter_agent_communication"]["allow_list"]),
251              skip_interactive,
252          )
253          if isinstance(allow_list_str, list):
254              allow_list = allow_list_str
255          else:
256              allow_list = [item.strip() for item in allow_list_str.split(",")]
257  
258      if "inter_agent_communication_deny_list" in options and isinstance(
259          options["inter_agent_communication_deny_list"], list
260      ):
261          deny_list = options["inter_agent_communication_deny_list"]
262      else:
263          deny_list_str = ask_if_not_provided(
264              options,
265              "inter_agent_communication_deny_list",
266              "Enter inter-agent communication deny list (comma-separated, leave empty for none)",
267              "",
268              skip_interactive,
269          )
270          if isinstance(deny_list_str, list):
271              deny_list = deny_list_str
272          else:
273              deny_list = (
274                  [item.strip() for item in deny_list_str.split(",")]
275                  if deny_list_str.strip()
276                  else []
277              )
278  
279      inter_agent_communication_timeout = ask_if_not_provided(
280          options,
281          "inter_agent_communication_timeout",
282          "Enter inter-agent communication timeout (seconds)",
283          ORCHESTRATOR_DEFAULTS["inter_agent_communication"]["request_timeout_seconds"],
284          skip_interactive,
285      )
286  
287      shared_config_dest_path = project_root / "configs" / "shared_config.yaml"
288  
289      try:
290          # Load the single shared_config.yaml template
291          shared_template_content = load_template("shared_config.yaml")
292  
293          # Check which LLM provider is being used (if any)
294          llm_provider = options.get("llm_provider", "")
295          has_llm_provider = bool(llm_provider)
296  
297          # Configure model sections based on provider
298          if llm_provider == "aws_bedrock":
299              planning_model_config = """# Note: If you want a different model for planning, change it here
300        model: ${BEDROCK_MODEL_NAME}
301        model_id: ${BEDROCK_MODEL_ID}
302        aws_access_key_id: ${AWS_ACCESS_KEY_ID}
303        aws_secret_access_key: ${AWS_SECRET_ACCESS_KEY}
304        aws_session_token: ${AWS_SESSION_TOKEN}
305        temperature: 0.1  # Lower temperature for more focused responses
306        # max_tokens: 2048  # Limit response length"""
307  
308              general_model_config = """# Note: If you want a different model for general, change it here
309        model: ${BEDROCK_MODEL_NAME}
310        model_id: ${BEDROCK_MODEL_ID}
311        aws_access_key_id: ${AWS_ACCESS_KEY_ID}
312        aws_secret_access_key: ${AWS_SECRET_ACCESS_KEY}
313        aws_session_token: ${AWS_SESSION_TOKEN}
314        temperature: 0.1  # Lower temperature for more focused responses
315        # max_tokens: 1536  # Limit response length for general queries"""
316          elif llm_provider:
317              planning_model_config = """# This dictionary structure tells ADK to use the LiteLlm wrapper.
318        # 'model' uses the specific model identifier your endpoint expects.
319        model: ${LLM_SERVICE_PLANNING_MODEL_NAME} # Use env var for model name
320        # 'api_base' tells LiteLLM where to send the request.
321        api_base: ${LLM_SERVICE_ENDPOINT} # Use env var for endpoint URL
322        # 'api_key' provides authentication.
323        api_key: ${LLM_SERVICE_API_KEY} # Use env var for API key
324        # Enable parallel tool calls for planning model
325        parallel_tool_calls: true
326        # Prompt Caching Strategy
327        cache_strategy: "5m" # none, 5m, 1h
328  
329        # max_tokens: ${MAX_TOKENS, 16000} # Set a reasonable max token limit for planning
330        # temperature: 0.1 # Lower temperature for more deterministic planning"""
331  
332              general_model_config = """# This dictionary structure tells ADK to use the LiteLlm wrapper.
333        # 'model' uses the specific model identifier your endpoint expects.
334        model: ${LLM_SERVICE_GENERAL_MODEL_NAME} # Use env var for model name
335        # 'api_base' tells LiteLLM where to send the request.
336        api_base: ${LLM_SERVICE_ENDPOINT} # Use env var for endpoint URL
337        # 'api_key' provides authentication.
338        api_key: ${LLM_SERVICE_API_KEY} # Use env var for API key
339        # Prompt Caching Strategy
340        cache_strategy: "5m" # none, 5m, 1h"""
341  
342          # Configure artifact service section based on provider
343          artifact_base_path_line = ""
344          if artifact_type == "filesystem":
345              artifact_base_path_line = f'base_path: "{artifact_base_path}"'
346          elif artifact_type == "s3":
347              s3_config_lines = ["bucket_name: ${S3_BUCKET_NAME}"]
348              s3_config_lines.append("endpoint_url: ${S3_ENDPOINT_URL}")
349              s3_config_lines.append("region: ${S3_REGION}")
350              artifact_base_path_line = "\n      ".join(s3_config_lines)
351  
352          # Replace all placeholders
353          shared_replacements = {
354              "__DEFAULT_ARTIFACT_SERVICE_TYPE__": artifact_type,
355              "__DEFAULT_ARTIFACT_SERVICE_SCOPE__": artifact_scope,
356          }
357  
358          modified_shared_content = shared_template_content
359          for placeholder, value in shared_replacements.items():
360              modified_shared_content = modified_shared_content.replace(
361                  placeholder, str(value)
362              )
363  
364          if has_llm_provider:
365              # Replace model configuration placeholders
366              modified_shared_content = modified_shared_content.replace(
367                  "      # __PLANNING_MODEL_CONFIG__",
368                  planning_model_config,
369              )
370              modified_shared_content = modified_shared_content.replace(
371                  "      # __GENERAL_MODEL_CONFIG__",
372                  general_model_config,
373              )
374          else:
375              # No provider — strip the entire models section from shared_config
376              modified_shared_content = re.sub(
377                  r"\n  - models:.*?(?=\n  - )",
378                  "",
379                  modified_shared_content,
380                  flags=re.DOTALL,
381              )
382  
383          # Replace artifact base path line
384          if not artifact_base_path_line:
385              modified_shared_content = re.sub(
386                  r"\s*# __DEFAULT_ARTIFACT_SERVICE_BASE_PATH_LINE__.*",
387                  "",
388                  modified_shared_content,
389              )
390          else:
391              modified_shared_content = modified_shared_content.replace(
392                  "      # __DEFAULT_ARTIFACT_SERVICE_BASE_PATH_LINE__",
393                  f"      {artifact_base_path_line}",
394              )
395  
396          shared_config_dest_path.parent.mkdir(parents=True, exist_ok=True)
397          with open(shared_config_dest_path, "w", encoding="utf-8") as f:
398              f.write(modified_shared_content)
399          click.echo(f"  Configured: {shared_config_dest_path.relative_to(project_root)}")
400  
401      except Exception as e:
402          click.echo(
403              click.style(
404                  f"Error configuring file {shared_config_dest_path}: {e}", fg="red"
405              ),
406              err=True,
407          )
408          return False
409  
410      try:
411          logging_config_dest_path = project_root / "configs" / "logging_config.yaml"
412          logging_template_content = load_template("logging_config_template.yaml")
413          with open(logging_config_dest_path, "w", encoding="utf-8") as f:
414              f.write(logging_template_content)
415          click.echo(
416              f"  Configured: {logging_config_dest_path.relative_to(project_root)}"
417          )
418      except Exception as e:
419          error_message = (
420              f"Error configuring file {logging_config_dest_path}: {e}"
421              if logging_config_dest_path
422              else f"Error configuring logging configuration: {e}"
423          )
424          click.echo(
425              click.style(error_message, fg="red"),
426              err=True,
427          )
428          return False
429  
430      main_orchestrator_path = (
431          project_root / "configs" / "agents" / "main_orchestrator.yaml"
432      )
433  
434      try:
435          orchestrator_template_content = load_template("main_orchestrator.yaml")
436  
437          formatted_name = get_formatted_names(options["agent_name"])
438          kebab_case_name = formatted_name.get("KEBAB_CASE_NAME")
439  
440          deny_list_line = ""
441          if deny_list:
442              deny_list_yaml = (
443                  yaml.dump(deny_list, Dumper=yaml.SafeDumper, default_flow_style=True)
444                  .strip()
445                  .replace("'", '"')
446              )
447              deny_list_line = f"deny_list: {deny_list_yaml}"
448  
449          default_instruction = """You are the Orchestrator Agent within an AI agentic system. Your primary responsibilities are to:
450          1. Process tasks received from external sources via the system Gateway.
451          2. Analyze each task to determine the optimal execution strategy:
452             a. Single Agent Delegation: If the task can be fully addressed by a single peer agent (based on their declared capabilities/description), delegate the task to that agent.
453             b. Multi-Agent Coordination: If task completion requires a coordinated effort from multiple peer agents: first, devise a logical execution plan (detailing the sequence of agent invocations and any necessary data handoffs). Then, manage the execution of this plan, invoking each agent in the defined order.
454             c. Direct Execution: If the task is not suitable for delegation (neither to a single agent nor a multi-agent sequence) and falls within your own capabilities, execute the task yourself.
455  
456          Artifact Management Guidelines:
457          - You must review your artifacts and return the ones that are important for the user by using artifact_return embed. You can use list_artifacts to see all available artifacts.
458          - Provide regular progress updates using `status_update` embed directives, especially before initiating any tool call."""
459  
460          session_service_lines = [
461              f'type: "sql"',
462              'database_url: "${ORCHESTRATOR_DATABASE_URL, sqlite:///orchestrator.db}"',
463              f'default_behavior: "PERSISTENT"',
464          ]
465          session_service_block = "\n" + "\n".join(
466              [f"        {line}" for line in session_service_lines]
467          )
468  
469          orchestrator_replacements = {
470              "__NAMESPACE__": "${NAMESPACE}",
471              "__APP_NAME__": f"{kebab_case_name}_app",
472              "__SUPPORTS_STREAMING__": str(options["supports_streaming"]).lower(),
473              "__AGENT_NAME__": options["agent_name"],
474              "__LOG_FILE_NAME__": f"{kebab_case_name}.log",
475              "__INSTRUCTION__": default_instruction,
476              "__SESSION_SERVICE__": session_service_block,
477              "__ARTIFACT_SERVICE__": "*default_artifact_service",
478              "__ARTIFACT_HANDLING_MODE__": artifact_handling_mode,
479              "__ENABLE_EMBED_RESOLUTION__": str(enable_embed_resolution).lower(),
480              "__ENABLE_ARTIFACT_CONTENT_INSTRUCTION__": str(
481                  enable_artifact_content_instruction
482              ).lower(),
483              "__AGENT_CARD_DESCRIPTION__": agent_card_description,
484              "__DEFAULT_INPUT_MODES__": yaml.dump(
485                  default_input_modes, Dumper=yaml.SafeDumper, default_flow_style=True
486              )
487              .strip()
488              .replace("'", '"'),
489              "__DEFAULT_OUTPUT_MODES__": yaml.dump(
490                  default_output_modes, Dumper=yaml.SafeDumper, default_flow_style=True
491              )
492              .strip()
493              .replace("'", '"'),
494              "__AGENT_CARD_PUBLISHING_INTERVAL__": str(agent_card_publishing_interval),
495              "__AGENT_DISCOVERY_ENABLED__": str(agent_discovery_enabled).lower(),
496              "__INTER_AGENT_COMMUNICATION_ALLOW_LIST__": yaml.dump(
497                  allow_list, Dumper=yaml.SafeDumper, default_flow_style=True
498              )
499              .strip()
500              .replace("'", '"'),
501              "__INTER_AGENT_COMMUNICATION_TIMEOUT__": str(
502                  inter_agent_communication_timeout
503              ),
504          }
505  
506          modified_orchestrator_content = orchestrator_template_content
507          for placeholder, value in orchestrator_replacements.items():
508              modified_orchestrator_content = modified_orchestrator_content.replace(
509                  placeholder, str(value)
510              )
511  
512          if deny_list:
513              modified_orchestrator_content = modified_orchestrator_content.replace(
514                  "__INTER_AGENT_COMMUNICATION_DENY_LIST_LINE__",
515                  deny_list_line,
516              )
517          else:
518              modified_orchestrator_content = re.sub(
519                  r"^\s*__INTER_AGENT_COMMUNICATION_DENY_LIST_LINE__\n?$",
520                  "",
521                  modified_orchestrator_content,
522                  flags=re.MULTILINE,
523              )
524  
525          # If no LLM provider, strip the model anchor line (keep model_provider)
526          if not has_llm_provider:
527              modified_orchestrator_content = re.sub(
528                  r"^\s*model: \*planning_model\n",
529                  "",
530                  modified_orchestrator_content,
531                  flags=re.MULTILINE,
532              )
533  
534          main_orchestrator_path.parent.mkdir(parents=True, exist_ok=True)
535          with open(main_orchestrator_path, "w", encoding="utf-8") as f:
536              f.write(modified_orchestrator_content)
537  
538          click.echo(f"  Created: {main_orchestrator_path.relative_to(project_root)}")
539          return True
540      except Exception as e:
541          click.echo(
542              click.style(f"Error creating file {main_orchestrator_path}: {e}", fg="red"),
543              err=True,
544          )
545          return False