dgx-spark-playbooks/nvidia/multi-sparks-through-switch/assets/spark_cluster_setup/spark_cluster_setup.py
2026-03-20 02:59:35 +00:00

902 lines
37 KiB
Python

#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import paramiko
import argparse
import time
import json
from scp import SCPClient
import threading
import sys
from pathlib import Path
import os
import subprocess
import re
from ipaddress import ip_address as ip_addr_obj, ip_network
logging.getLogger("paramiko").setLevel(logging.CRITICAL)
logging.getLogger("paramiko.transport").setLevel(logging.CRITICAL)
# Default paths
SCRIPT_DIR = Path(__file__).resolve().parent
CONFIG_PATH = SCRIPT_DIR / "spark_config.json"
SHARED_KEY = Path.home() / ".ssh" / "id_ed25519_shared"
SSH_DIR = Path.home() / ".ssh"
AUTHORIZED_KEYS = SSH_DIR / "authorized_keys"
SSH_CONFIG = SSH_DIR / "config"
IDENTITY_LINE = "IdentityFile ~/.ssh/id_ed25519_shared"
NETWORK_SETUP_SCRIPT_NAME = "detect_and_configure_cluster_networking.py"
NETWORK_SETUP_SCRIPT = SCRIPT_DIR / "node_scripts" / NETWORK_SETUP_SCRIPT_NAME
IP_PREFIX = "192.168.100."
LAST_OCTET_START = 10
SUBNET_SIZE = 24
MIN_NCCL_TEST_BW = 21.875 # 175 Gbps
MIN_NCCL_TEST_BW_RING = 10 # 80 Gbps
NCCL_ENV = """export CUDA_HOME="/usr/local/cuda" && export MPI_HOME="/usr/lib/aarch64-linux-gnu/openmpi" && export NCCL_HOME="$HOME/nccl_spark_cluster/build/" && export LD_LIBRARY_PATH="$NCCL_HOME/lib:$CUDA_HOME/lib64/:$MPI_HOME/lib:$LD_LIBRARY_PATH" """
class ExceptionThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exc = None
def run(self):
try:
super().run()
except Exception as e:
self.exc = e
def join(self, timeout=None):
super().join(timeout)
if self.exc:
raise self.exc
def create_ssh_client(server, port, user, password, timeout=10):
"""Creates a Paramiko SSH client and connects."""
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
time.sleep(1)
client.connect(server, port, user, password, timeout=timeout)
return client
def paramiko_run_command_with_output(ssh_client, cmd):
_, stdout, stderr = ssh_client.exec_command(cmd)
output = ""
error = ""
while not stdout.channel.exit_status_ready():
if stdout.channel.recv_ready():
out = stdout.channel.recv(1024).decode('utf-8')
output += out
if stderr.channel.recv_ready():
error_out = stderr.channel.recv(1024).decode('utf-8')
error += error_out
time.sleep(0.1)
# After the loop finishes, there might be remaining output
# Read the rest of the output
remaining_output = stdout.read().decode('utf-8')
output += remaining_output
remaining_error = stderr.read().decode('utf-8')
if remaining_error:
error += remaining_error
exit_code = stdout.channel.recv_exit_status()
return exit_code, output, error
def paramiko_run_command(ssh_client, cmd):
_, stdout, stderr = ssh_client.exec_command(cmd)
while not stdout.channel.exit_status_ready():
if stdout.channel.recv_ready():
stdout.channel.recv(1024).decode('utf-8')
if stderr.channel.recv_ready():
stderr.channel.recv(1024).decode('utf-8')
time.sleep(0.1)
# After the loop finishes, there might be remaining output
# Read the rest of the output
stdout.read().decode('utf-8')
stderr.read().decode('utf-8')
# Get the exit status
exit_code = stdout.channel.recv_exit_status()
return exit_code
def _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output):
"""Run a command with sudo, feeding password via stdin when needed.
Password is never put on the command line, so it won't appear in ps or /proc.
"""
if password:
full_cmd = "sudo -S " + cmd
stdin, stdout, stderr = ssh_client.exec_command(full_cmd)
stdin.write(password + "\n")
stdin.channel.shutdown_write()
else:
full_cmd = "sudo -n " + cmd
stdin, stdout, stderr = ssh_client.exec_command(full_cmd)
output = ""
error = ""
while not stdout.channel.exit_status_ready():
if stdout.channel.recv_ready():
out = stdout.channel.recv(1024).decode('utf-8')
if capture_output:
output += out
if stderr.channel.recv_ready():
err = stderr.channel.recv(1024).decode('utf-8')
if capture_output:
error += err
time.sleep(0.1)
output += stdout.read().decode('utf-8')
err_rest = stderr.read().decode('utf-8')
if err_rest:
error += err_rest
exit_code = stdout.channel.recv_exit_status()
return exit_code, output, error
def paramiko_run_sudo_command(ssh_client, password, cmd):
"""Run a command with sudo (password via stdin if provided). Returns exit code only."""
exit_code, _, _ = _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output=False)
return exit_code
def paramiko_run_sudo_command_with_output(ssh_client, password, cmd):
"""Run a command with sudo (password via stdin if provided). Returns (exit_code, output, error)."""
return _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output=True)
def ssh_client_active(ssh):
return bool(ssh and ssh.get_transport() and ssh.get_transport().is_active())
def close_ssh_session(ssh):
if ssh_client_active(ssh):
ssh.close()
def setup_nccl_deps(node, ring_topology):
"""Setup NCCL dependencies on the node."""
ssh = None
try:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
print(f"Updating apt on node {node["ip_address"]}...")
paramiko_run_sudo_command(ssh, node["password"], "apt update")
print(f"Installing libopenmpi-dev on node {node["ip_address"]}...")
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], "apt install -y libopenmpi-dev")
if exit_code:
raise Exception(f"Failed to install libopenmpi-dev on node {node["ip_address"]}: output:{output} error:{error}")
print(f"Cloning NCCL repo on node {node["ip_address"]}...")
if ring_topology:
cmd = """rm -rf ~/nccl_spark_cluster/ && git clone -b dgxspark-3node-ring https://github.com/zyang-dev/nccl.git ~/nccl_spark_cluster/"""
else:
cmd = """rm -rf ~/nccl_spark_cluster/ && git clone -b v2.28.9-1 https://github.com/NVIDIA/nccl.git ~/nccl_spark_cluster/"""
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
if exit_code:
raise Exception(f"Failed to clone NCCL repo on node {node['ip_address']}: output:{output} error:{error}")
print(f"Building NCCL on node {node["ip_address"]}...")
cmd = """cd ~/nccl_spark_cluster/ && make -j src.build NVCC_GENCODE="-gencode=arch=compute_121,code=sm_121" """
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
if exit_code:
raise Exception(f"Failed to build NCCL on node {node['ip_address']}: output:{output} error:{error}")
print(f"Cloning NCCL tests repo on node {node["ip_address"]}...")
cmd = """rm -rf ~/nccl-tests_spark_cluster/ && git clone https://github.com/NVIDIA/nccl-tests.git ~/nccl-tests_spark_cluster/"""
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
if exit_code:
raise Exception(f"Failed to clone NCCL tests repo on node {node['ip_address']}: output:{output} error:{error}")
print(f"Building NCCL tests on node {node["ip_address"]}...")
cmd = """cd ~/nccl-tests_spark_cluster/ && %s && make MPI=1 -j8 """ % NCCL_ENV
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
if exit_code:
raise Exception(f"Failed to build NCCL tests on node {node['ip_address']}: {error}")
print(f"Successfully setup NCCL dependencies on node {node['ip_address']}")
close_ssh_session(ssh)
except Exception as e:
close_ssh_session(ssh)
raise Exception(f"Failed to setup NCCL dependencies on node {node["ip_address"]}:\n{e}")
def run_nccl_test(nodes_info, ring_topology):
"""Runs the NCCL test."""
threads = []
for i, node in enumerate(nodes_info):
t = ExceptionThread(target=setup_nccl_deps, args=(node, ring_topology,))
threads.append(t)
t.start()
for t in threads:
try:
t.join()
except Exception as e:
print(f"An error occurred when running NCCL setup on nodes:\n{e}")
return False
print(f"Successfully setup NCCL dependencies on all nodes...")
print(f"Running NCCL test...")
# Generate the mpirun command
host_list = ",".join(f"{node['ip_address']}:1" for node in nodes_info)
ring_topology_specific_env = "-x NCCL_IB_MERGE_NICS=0 -x NCCL_NET_PLUGIN=none " if ring_topology else ""
mpirun_cmd = (
f"{NCCL_ENV} && mpirun -np {len(nodes_info)} -H {host_list} "
'--mca plm_rsh_agent "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" '
"-x LD_LIBRARY_PATH=$LD_LIBRARY_PATH "
"-x UCX_NET_DEVICES=enP7s7 "
"-x NCCL_SOCKET_IFNAME=enP7s7 "
"-x OMPI_MCA_btl_tcp_if_include=enP7s7 "
"-x NCCL_IB_HCA=rocep1s0f0,rocep1s0f1,roceP2p1s0f0,roceP2p1s0f1 "
"-x NCCL_IB_SUBNET_AWARE_ROUTING=1 "
f"{ring_topology_specific_env}"
"$HOME/nccl-tests_spark_cluster/build/all_gather_perf -b 16G -e 16G -f 2"
)
# Run command on the primary node (first node in the list)
node0 = nodes_info[0]
ssh = create_ssh_client(node0["ip_address"], node0["port"], node0["user"], node0["password"])
if not ssh.get_transport().is_active():
print(f"Could not establish a session to node {node0}. Check the credentials and try again.")
return False
print(f"NCCL test command: {mpirun_cmd}")
exit_code, output, error = paramiko_run_command_with_output(ssh, mpirun_cmd)
if exit_code:
print(f"Failed to run NCCL test on node {node0["ip_address"]}: output:{output} error:{error}")
close_ssh_session(ssh)
return False
# Extract the "Avg bus bandwidth" value from the NCCL test output
avg_bus_bw = None
# The output could potentially be multiline (as it is command output)
# We need to search for a line matching "# Avg bus bandwidth : value"
for line in output.splitlines():
m = re.match(r"# Avg bus bandwidth\s*:\s*([0-9.]+)", line.strip())
if m:
avg_bus_bw = float(m.group(1))
print(f"Avg bus bandwidth from NCCL test: {avg_bus_bw} GB/s")
break
if avg_bus_bw is None:
print("WARNING: Failed to extract Avg bus bandwidth from NCCL test output.")
else:
# If the average bus bandwidth is less then throw a warning
if (ring_topology and avg_bus_bw < MIN_NCCL_TEST_BW_RING) or (not ring_topology and avg_bus_bw < MIN_NCCL_TEST_BW):
print("WARNING: NCCL Test bandwidth is less than expected. Stop any GPU workloads on the nodes and try NCCL test again using the NCCL test command above.")
else:
print(f"NCCL test BW is as expected")
close_ssh_session(ssh)
return True
def ensure_ssh_dir():
"""Ensure ~/.ssh exists with mode 0700."""
SSH_DIR.mkdir(parents=True, exist_ok=True)
os.chmod(SSH_DIR, 0o700)
def generate_shared_key():
"""Generate shared ed25519 key if it does not exist."""
if SHARED_KEY.exists():
return
ensure_ssh_dir()
print("Generating shared SSH key for all nodes...")
os.system(f"ssh-keygen -t ed25519 -N '' -f {SHARED_KEY} -q -C 'shared-cluster-key' > /dev/null 2>&1")
if not SHARED_KEY.exists():
raise Exception("Failed to generate shared SSH key.")
def add_pubkey_to_authorized_keys():
"""Add shared public key to local authorized_keys if not present."""
ensure_ssh_dir()
pub_content = (SHARED_KEY.with_suffix(".pub")).read_text()
if AUTHORIZED_KEYS.exists():
current = AUTHORIZED_KEYS.read_text()
if pub_content.strip() in current:
return
with open(AUTHORIZED_KEYS, "a") as f:
f.write(pub_content)
os.chmod(AUTHORIZED_KEYS, 0o600)
print("Added shared public key to local authorized_keys")
def update_local_ssh_config():
"""Add IdentityFile for shared key to local SSH config if missing."""
ensure_ssh_dir()
if SSH_CONFIG.exists():
content = SSH_CONFIG.read_text()
if "id_ed25519_shared" in content:
return
with open(SSH_CONFIG, "a") as f:
f.write("Host *\n")
f.write(f" {IDENTITY_LINE}\n")
os.chmod(SSH_CONFIG, 0o600)
print("Updated local SSH config to use shared key")
def configure_node_ssh_keys(node) -> bool:
"""Copy shared key to node and set up authorized_keys and SSH config."""
ssh = None
try:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
# Resolve remote home (e.g. /home/nvidia or /root)
_, stdout, _ = ssh.exec_command("echo $HOME")
home = stdout.read().decode().strip() or f"/home/{node["user"]}"
remote_ssh = f"{home}/.ssh"
exit_code, output, error = paramiko_run_command_with_output(ssh, f"mkdir -p {remote_ssh} && chmod 700 {remote_ssh}")
if exit_code:
raise Exception(f"Failed to create remote SSH directory on node {node["ip_address"]}: output:{output} error:{error}")
with SCPClient(ssh.get_transport()) as scp:
scp.put(str(SHARED_KEY), f"{remote_ssh}/id_ed25519_shared")
scp.put(str(SHARED_KEY.with_suffix(".pub")), f"{remote_ssh}/id_ed25519_shared.pub")
# Set key permissions and add to authorized_keys
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/id_ed25519_shared")
if exit_code:
raise Exception(f"Failed to set permissions on {remote_ssh}/id_ed25519_shared: output:{output} error:{error}")
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 644 {remote_ssh}/id_ed25519_shared.pub")
if exit_code:
raise Exception(f"Failed to set permissions on {remote_ssh}/id_ed25519_shared.pub: output:{output} error:{error}")
pub_line = (SHARED_KEY.with_suffix(".pub")).read_text().strip()
pub_escaped = pub_line.replace("'", "'\"'\"'")
exit_code, output, error = paramiko_run_command_with_output(
ssh,
f"grep -qF '{pub_escaped}' {remote_ssh}/authorized_keys 2>/dev/null || "
f"echo '{pub_escaped}' >> {remote_ssh}/authorized_keys",
)
if exit_code:
raise Exception(f"Failed to add {SHARED_KEY.with_suffix(".pub")} to authorized_keys: output:{output} error:{error}")
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/authorized_keys")
if exit_code:
raise Exception(f"Failed to set permissions on {remote_ssh}/authorized_keys: output:{output} error:{error}")
exit_code, output, error = paramiko_run_command_with_output(
ssh,
f"grep -q 'IdentityFile.*id_ed25519_shared' {remote_ssh}/config 2>/dev/null || "
f"(echo 'Host *' >> {remote_ssh}/config && echo ' {IDENTITY_LINE}' >> {remote_ssh}/config)",
)
if exit_code:
raise Exception(f"Failed to add {IDENTITY_LINE} to config: output:{output} error:{error}")
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/config")
if exit_code:
raise Exception(f"Failed to set permissions on {remote_ssh}/config: output:{output} error:{error}")
print(f"Successfully configured {node["ip_address"]} with shared key")
return True
except Exception as e:
print(f" ✗ Failed to configure {node["ip_address"]}:\n{e}")
return False
finally:
close_ssh_session(ssh)
def check_and_get_up_cx7_interfaces(node_info):
ssh = None
up_ifaces = []
try:
node = node_info[0]
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
if_groups = [["enp1s0f0np0", "enP2p1s0f0np0"], ["enp1s0f1np1", "enP2p1s0f1np1"]]
for if_group in if_groups:
up_count = 0
for if_name in if_group:
cmd = r"""ip link show %s | grep -c "state UP" """ % if_name
exit_code = paramiko_run_command(ssh, cmd)
if not exit_code:
up_count+=1
if up_count == len(if_group):
# found the if_group which has UP interfaces
up_ifaces.extend(if_group)
close_ssh_session(ssh)
if not len(up_ifaces):
print(f"ERROR: CX7 interfaces on {node["ip_address"]} are not UP")
return []
print(f"Found UP CX7 interfaces {up_ifaces} on {node["ip_address"]}. Checking other nodes...")
for node in node_info[1:]:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
for if_group in if_groups:
up_count = 0
for if_name in if_group:
cmd = r"""ip link show %s | grep -c "state UP" """ % if_name
exit_code = paramiko_run_command(ssh, cmd)
if not exit_code:
if if_name not in up_ifaces:
raise Exception(f"ERROR: CX7 interface {if_name} on {node["ip_address"]} is UP which is not in {up_ifaces}. Make sure the same CX7 port(s) on each node are connected and try again.")
else:
if if_name in up_ifaces:
raise Exception(f"ERROR: CX7 interface {if_name} on {node["ip_address"]} is DOWN. {up_ifaces} are expected to be UP. Make sure the same CX7 port(s) on each node are connected and try again.")
close_ssh_session(ssh)
except Exception as e:
close_ssh_session(ssh)
raise Exception(f"ERROR: An error occurred when checking UP CX7 interfaces:\n{e}")
return up_ifaces
def check_interface_link_speed(nodes_info, interfaces):
"""Checks the link speed of the interfaces."""
ssh = None
try:
for node in nodes_info:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
for iface in interfaces:
cmd = """ethtool %s | grep -i speed | awk '-F: ' '{print $2}' """ % iface
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], cmd)
if exit_code:
print(f"ERROR: Failed to check link speed on {iface} on node {node["ip_address"]}: {error}")
close_ssh_session(ssh)
return False
speed = output.strip()
if "200000" not in speed:
print(f"ERROR: Link speed on {iface} on node {node["ip_address"]} is not 200Gbps.")
print("Check the following:\n"
"- QSFP cable should be compatible and rated at least for 200Gbps.\n"
"- If running with a switch then check the switch port speed.\n"
"- With a switch, sometimes auto-negotiation may not negotiate 200Gbps, in which case set the link speed manually on the switch ports.\n")
close_ssh_session(ssh)
return False
close_ssh_session(ssh)
except Exception as e:
close_ssh_session(ssh)
raise Exception(f"Failed to check link speed:\n{e}")
return True
def scp_put_file_with_ssh(client, local_file, remote_file) -> bool:
if not local_file or not remote_file:
print("ERROR: Local file or remote file not specified")
return False
try:
with SCPClient(client.get_transport()) as scp:
scp.put(local_file, remote_file)
except Exception as e:
print(f"scp_put_file: An error occurred:\n{e}")
return False
return True
def copy_network_setup_script_to_nodes(nodes_info) -> bool:
"""Copies the detect_and_configure_cluster_networking.py script to the nodes and runs it in threads."""
ssh = None
try:
for node in nodes_info:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
print(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
return False
if not scp_put_file_with_ssh(ssh, NETWORK_SETUP_SCRIPT, f"~/{NETWORK_SETUP_SCRIPT_NAME}"):
raise Exception(f"Failed to copy {NETWORK_SETUP_SCRIPT_NAME} to node {node["ip_address"]}")
close_ssh_session(ssh)
except Exception as e:
close_ssh_session(ssh)
print(f"Failed to copy {NETWORK_SETUP_SCRIPT_NAME}:\n{e}")
return False
return True
def run_network_setup_script(node, cmd):
"""Runs the network setup script on the node. cmd is the command to run under sudo (no 'sudo' prefix)."""
ssh = None
try:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}.")
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], cmd)
if exit_code:
raise Exception(f"Failed to run network setup script on node {node["ip_address"]}: output:{output} error:{error}\n")
close_ssh_session(ssh)
except Exception as e:
close_ssh_session(ssh)
raise Exception(f"Failed to run network setup script on node {node["ip_address"]}:\n{e}")
def run_network_setup_scripts_on_nodes(nodes_info):
"""Runs the network setup scripts on the nodes in threads."""
threads = []
ret = True
for i, node in enumerate(nodes_info):
cmd = f"python3 ~/{NETWORK_SETUP_SCRIPT_NAME} --apply-netplan-yaml"
if i == 0:
cmd = cmd + " --primary"
t = ExceptionThread(target=run_network_setup_script, args=(node, cmd))
threads.append(t)
t.start()
for t in threads:
try:
t.join()
except Exception as e:
print(f"An error occurred when running network setup on nodes:\n{e}")
ret = False
return ret
def verify_ip_addresses(nodes_info, up_interfaces) -> bool:
"""Verifies that the IP addresses are assigned to the interfaces."""
ssh = None
try:
nodes_to_ip_cidrs = {}
all_nodes_ip_addresses = []
for node in nodes_info:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
nodes_to_ip_cidrs[node["ip_address"]] = []
for iface in up_interfaces:
cmd = """ip addr show %s | grep -w inet | awk '{print $2}'""" % iface
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
if exit_code:
raise Exception(f"ERROR: Failed to verify IP address on {iface} on node {node["ip_address"]}")
ip_addresses = output.strip().split("\n")
if len(ip_addresses) != 1:
raise Exception(f"ERROR: Zero or multiple IP addresses found on node {node["ip_address"]}, {iface}: {ip_addresses}")
if len(ip_addresses[0]) == 0:
raise Exception(f"ERROR: No IP address found on node {node["ip_address"]}, {iface}")
# Parse CIDR (e.g. 192.168.1.1/24) for uniqueness check by IP only
ip_parts = [a.split("/")[0] for a in ip_addresses]
if set(all_nodes_ip_addresses).intersection(ip_parts):
raise Exception(f"ERROR: IP address {ip_addresses} on node {node["ip_address"]}, {iface} is already assigned to another node.")
print(f"IP address on node {node["ip_address"]}, {iface}: {ip_addresses}")
all_nodes_ip_addresses.extend(ip_parts)
nodes_to_ip_cidrs[node["ip_address"]].extend(ip_addresses)
close_ssh_session(ssh)
print(f"Running cluster connectivity test...")
for node in nodes_info:
# Run mesh ping test between all nodes in the cluster
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
node_subnets = list(set([ip_network(cidr, strict=False) for cidr in nodes_to_ip_cidrs[node["ip_address"]]]))
for ip_address in all_nodes_ip_addresses:
# Check if the ip_address is in one of the node's subnets
if not any([ip_addr_obj(ip_address) in s for s in node_subnets]):
continue
cmd = f"ping -c 1 {ip_address} > /dev/null 2>&1"
exit_code = paramiko_run_command(ssh, cmd)
if exit_code:
raise Exception(f"Failed to run ping test from node {node["ip_address"]} to node {ip_address}")
close_ssh_session(ssh)
print(f"Cluster connectivity test completed successfully.")
except Exception as e:
close_ssh_session(ssh)
print(f"Failed to verify IP addresses:\n{e}")
return False
return True
def configure_ssh_keys_on_nodes(nodes_info) -> bool:
"""Configures the ssh keys on the nodes."""
print("Generating shared SSH key for all nodes...")
generate_shared_key()
print("Setting up shared SSH access across all nodes...")
add_pubkey_to_authorized_keys()
for node in nodes_info:
print(f"Configuring shared SSH key on node {node["ip_address"]}...")
if not configure_node_ssh_keys(node):
return False
update_local_ssh_config()
print("Shared SSH keys configured successfully.")
return True
def pre_validate_cluster(config) -> tuple[bool, bool, list[str]]:
"""Pre-validates the cluster."""
try:
nodes_info = config.get("nodes_info", None)
if not nodes_info:
print("ERROR: Nodes information not found.")
return False, False, []
print(f"Checking UP CX7 interfaces...")
up_interfaces = check_and_get_up_cx7_interfaces(nodes_info)
if not up_interfaces:
print("ERROR: Failed to check UP CX7 interfaces. Check the QSFP cable connection and try again.")
return False, False, []
print(f"Checking CX7 interface link speed...")
if not check_interface_link_speed(nodes_info, up_interfaces):
return False, False, []
ring_topology = (len(nodes_info) == 3 and len(up_interfaces) == 4)
except Exception as e:
print(f"ERROR: An error occurred when pre-validating the cluster:\n{e}")
return False, False, []
return True, ring_topology, up_interfaces
def handle_cluster_setup(config, up_interfaces) -> bool:
"""Handles the cluster network setup."""
try:
nodes_info = config.get("nodes_info", None)
if not nodes_info:
print("ERROR: Nodes information not found.")
return False
print(f"Copying network setup scripts on nodes...")
# Copy the detect_and_configure_cluster_networking.py script to the nodes and run it in threads
if not copy_network_setup_script_to_nodes(nodes_info):
return False
print(f"Running network setup scripts on nodes...")
if not run_network_setup_scripts_on_nodes(nodes_info):
print("ERROR: Failed to run network setup scripts on nodes. Check the QSFP cable connections and the nodes config in the json file and try again.")
return False
# Verify that the IP addresses are assigned to the interfaces
max_retries = 5
retries = max_retries
while retries > 0:
wait_secs = (max_retries - retries + 1) * 10
print(f"Waiting for {wait_secs} seconds before checking IP addresses")
time.sleep(wait_secs)
if not verify_ip_addresses(nodes_info, up_interfaces):
print(f"ERROR: Failed to verify IP addresses on nodes. ({retries - 1} retries left)...")
retries -= 1
continue
break
if retries == 0:
print("ERROR: Failed to verify IP addresses on nodes. Check the QSFP cable connections and the nodes config in the json file and try again.")
return False
# Configure ssh keys across nodes
if not configure_ssh_keys_on_nodes(nodes_info):
print("ERROR: Failed to configure ssh keys on nodes. Please check the configuration and try again.")
return False
except Exception as e:
print(f"ERROR: An error occurred when handling cluster setup:\n{e}")
return False
return True
def validate_config(config):
"""Validates the configuration."""
if not config.get("nodes_info", None):
print("ERROR: Nodes information not found.")
return False
if len(config.get("nodes_info")) < 2 or len(config.get("nodes_info")) > 4:
print("ERROR: Cluster can not contain less than 2 or more than 4 nodes. Please check the configuration and try again.")
return False
cmd = """ip a | grep -w inet | awk -F"inet |/" '{print $2}' """
result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
if result.returncode:
print(f"ERROR: Failed to check IP addresses on current machine: {result.stderr}")
return False
else:
ip_addresses = result.stdout.strip().split("\n")
print(f"Checking connectivity and permissions...")
nodes_valid = True
current_node_in_cluster = False
for node in config.get("nodes_info", []):
if not node.get("ip_address", None):
print("ERROR: IP address not found for node.")
return False
if not node.get("user", None):
print("ERROR: User not found for node.")
return False
if not node.get("port", None):
# Default port is 22
node["port"] = 22
if not node.get("password", None):
# Default password is empty
node["password"] = ""
ssh = None
try:
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
if not ssh.get_transport().is_active():
print(f"ERROR: Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
return False
except Exception as e:
print(f"ERROR: Could not establish a session to node {node["ip_address"]}, check the credentials in the config file and try again: {e}")
return False
if node["password"] == "":
# No password is provided, so we need to validate the ssh key
exit_code = paramiko_run_sudo_command(ssh, "", "true")
else:
exit_code = paramiko_run_sudo_command(ssh, node["password"], "true")
if exit_code:
print(f"ERROR: Failed to check sudo access on node {node["ip_address"]}. If password is not specified then make sure that user has sudo access without password.")
nodes_valid = False
close_ssh_session(ssh)
break
if node["ip_address"] in ip_addresses:
current_node_in_cluster = True
close_ssh_session(ssh)
if not nodes_valid:
return False
if not current_node_in_cluster:
print("ERROR: This script must be run on a node in the cluster.")
return False
return True
def validate_environment():
"""Validates the environment."""
# Check if the script is being run directly instead of via the spark_cluster_setup.sh shell script
# We expect an environment variable ONLY set by the shell wrapper (e.g. SPARK_CLUSTER_SETUP_WRAPPER=1)
if os.environ.get("SPARK_CLUSTER_SETUP_WRAPPER") != "1":
print("ERROR: Please run this script via the spark_cluster_setup.sh shell script, not directly.")
return False
# Check if we are running inside a virtual environment
if sys.prefix == sys.base_prefix:
print("ERROR: Please run this script inside a Python virtual environment (venv) with requirements installed.")
return False
# Check if /etc/dgx-release exists and contains the expected DGX Spark markers
try:
with open("/etc/dgx-release", "r") as f:
content = f.read()
# Look for DGX_NAME="DGX Spark" and DGX_PRETTY_NAME="NVIDIA DGX Spark"
if 'DGX_NAME="DGX Spark"' not in content or 'DGX_PRETTY_NAME="NVIDIA DGX Spark"' not in content:
print("ERROR: This script must be run on a DGX Spark.")
return False
except FileNotFoundError:
print("ERROR: /etc/dgx-release not found. This is not a DGX Spark environment.")
return False
return True
class _HelpHintParser(argparse.ArgumentParser):
"""ArgumentParser that appends a --help hint to every error."""
def error(self, message):
self.exit(2, f"{self.prog}: error: {message}\nRun with --help for usage.\n")
def main():
"""Main function to setup the Spark cluster."""
parser = _HelpHintParser(
description="Setup the Spark cluster.",
epilog="One of --pre-validate-only, --run-setup, or --run-nccl-test is required.",
)
parser.add_argument("-c", "--config", type=str, required=True, help="Path to the configuration file.")
parser.add_argument("-v", "--pre-validate-only", action="store_true", help="Only run pre-setup validations.")
parser.add_argument("-s", "--run-setup", action="store_true", help="Run the cluster setup and run NCCL bandwidth test.")
parser.add_argument("-n", "--run-nccl-test", action="store_true", help="Run the NCCL bandwidth test.")
args = parser.parse_args()
if not (args.pre_validate_only or args.run_setup or args.run_nccl_test):
parser.error("One of -v/--pre-validate-only, -s/--run-setup, or -n/--run-nccl-test is required.")
config_path = args.config
if not os.path.exists(config_path):
print(f"ERROR: Configuration file not found: {config_path}")
return
with open(config_path, "r") as f:
config = json.load(f)
if not config:
print("ERROR: Configuration file not found.")
return
try:
# Validate env
print(f"Validating environment...")
if not validate_environment():
return
# Validate the config
print(f"Validating configuration...")
if not validate_config(config):
return
print(f"Pre-validating cluster setup...")
ret, ring_topology, up_interfaces = pre_validate_cluster(config)
if not ret:
return
if args.pre_validate_only:
print("Pre-setup validations completed successfully.")
return
if args.run_setup:
print("Setting up Spark cluster...")
if not handle_cluster_setup(config, up_interfaces):
return
print("Spark cluster setup completed successfully.")
if args.run_nccl_test or args.run_setup:
print("Running NCCL test...")
if ring_topology:
print("Detected ring topology...")
if not run_nccl_test(config.get("nodes_info", []), ring_topology):
return
print("NCCL test completed.")
except Exception as e:
print(f"ERROR: An error occurred when running Spark cluster setup:\n{e}")
if __name__ == "__main__":
main()