mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-23 02:23:53 +00:00
902 lines
37 KiB
Python
902 lines
37 KiB
Python
#!/usr/bin/env python3
|
|
|
|
#
|
|
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import logging
|
|
import paramiko
|
|
import argparse
|
|
import time
|
|
import json
|
|
from scp import SCPClient
|
|
import threading
|
|
import sys
|
|
from pathlib import Path
|
|
import os
|
|
import subprocess
|
|
import re
|
|
from ipaddress import ip_address as ip_addr_obj, ip_network
|
|
|
|
logging.getLogger("paramiko").setLevel(logging.CRITICAL)
|
|
logging.getLogger("paramiko.transport").setLevel(logging.CRITICAL)
|
|
|
|
# Default paths
|
|
SCRIPT_DIR = Path(__file__).resolve().parent
|
|
CONFIG_PATH = SCRIPT_DIR / "spark_config.json"
|
|
SHARED_KEY = Path.home() / ".ssh" / "id_ed25519_shared"
|
|
SSH_DIR = Path.home() / ".ssh"
|
|
AUTHORIZED_KEYS = SSH_DIR / "authorized_keys"
|
|
SSH_CONFIG = SSH_DIR / "config"
|
|
IDENTITY_LINE = "IdentityFile ~/.ssh/id_ed25519_shared"
|
|
|
|
NETWORK_SETUP_SCRIPT_NAME = "detect_and_configure_cluster_networking.py"
|
|
NETWORK_SETUP_SCRIPT = SCRIPT_DIR / "node_scripts" / NETWORK_SETUP_SCRIPT_NAME
|
|
|
|
IP_PREFIX = "192.168.100."
|
|
LAST_OCTET_START = 10
|
|
SUBNET_SIZE = 24
|
|
|
|
MIN_NCCL_TEST_BW = 21.875 # 175 Gbps
|
|
MIN_NCCL_TEST_BW_RING = 10 # 80 Gbps
|
|
|
|
NCCL_ENV = """export CUDA_HOME="/usr/local/cuda" && export MPI_HOME="/usr/lib/aarch64-linux-gnu/openmpi" && export NCCL_HOME="$HOME/nccl_spark_cluster/build/" && export LD_LIBRARY_PATH="$NCCL_HOME/lib:$CUDA_HOME/lib64/:$MPI_HOME/lib:$LD_LIBRARY_PATH" """
|
|
|
|
class ExceptionThread(threading.Thread):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.exc = None
|
|
|
|
def run(self):
|
|
try:
|
|
super().run()
|
|
except Exception as e:
|
|
self.exc = e
|
|
|
|
def join(self, timeout=None):
|
|
super().join(timeout)
|
|
if self.exc:
|
|
raise self.exc
|
|
|
|
def create_ssh_client(server, port, user, password, timeout=10):
|
|
"""Creates a Paramiko SSH client and connects."""
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
time.sleep(1)
|
|
client.connect(server, port, user, password, timeout=timeout)
|
|
return client
|
|
|
|
def paramiko_run_command_with_output(ssh_client, cmd):
|
|
_, stdout, stderr = ssh_client.exec_command(cmd)
|
|
|
|
output = ""
|
|
error = ""
|
|
while not stdout.channel.exit_status_ready():
|
|
if stdout.channel.recv_ready():
|
|
out = stdout.channel.recv(1024).decode('utf-8')
|
|
output += out
|
|
if stderr.channel.recv_ready():
|
|
error_out = stderr.channel.recv(1024).decode('utf-8')
|
|
error += error_out
|
|
time.sleep(0.1)
|
|
|
|
# After the loop finishes, there might be remaining output
|
|
# Read the rest of the output
|
|
remaining_output = stdout.read().decode('utf-8')
|
|
output += remaining_output
|
|
remaining_error = stderr.read().decode('utf-8')
|
|
if remaining_error:
|
|
error += remaining_error
|
|
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
return exit_code, output, error
|
|
|
|
def paramiko_run_command(ssh_client, cmd):
|
|
_, stdout, stderr = ssh_client.exec_command(cmd)
|
|
|
|
while not stdout.channel.exit_status_ready():
|
|
if stdout.channel.recv_ready():
|
|
stdout.channel.recv(1024).decode('utf-8')
|
|
if stderr.channel.recv_ready():
|
|
stderr.channel.recv(1024).decode('utf-8')
|
|
time.sleep(0.1)
|
|
|
|
# After the loop finishes, there might be remaining output
|
|
# Read the rest of the output
|
|
stdout.read().decode('utf-8')
|
|
stderr.read().decode('utf-8')
|
|
|
|
# Get the exit status
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
return exit_code
|
|
|
|
|
|
def _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output):
|
|
"""Run a command with sudo, feeding password via stdin when needed.
|
|
Password is never put on the command line, so it won't appear in ps or /proc.
|
|
"""
|
|
if password:
|
|
full_cmd = "sudo -S " + cmd
|
|
stdin, stdout, stderr = ssh_client.exec_command(full_cmd)
|
|
stdin.write(password + "\n")
|
|
stdin.channel.shutdown_write()
|
|
else:
|
|
full_cmd = "sudo -n " + cmd
|
|
stdin, stdout, stderr = ssh_client.exec_command(full_cmd)
|
|
|
|
output = ""
|
|
error = ""
|
|
while not stdout.channel.exit_status_ready():
|
|
if stdout.channel.recv_ready():
|
|
out = stdout.channel.recv(1024).decode('utf-8')
|
|
if capture_output:
|
|
output += out
|
|
if stderr.channel.recv_ready():
|
|
err = stderr.channel.recv(1024).decode('utf-8')
|
|
if capture_output:
|
|
error += err
|
|
time.sleep(0.1)
|
|
|
|
output += stdout.read().decode('utf-8')
|
|
err_rest = stderr.read().decode('utf-8')
|
|
if err_rest:
|
|
error += err_rest
|
|
exit_code = stdout.channel.recv_exit_status()
|
|
return exit_code, output, error
|
|
|
|
|
|
def paramiko_run_sudo_command(ssh_client, password, cmd):
|
|
"""Run a command with sudo (password via stdin if provided). Returns exit code only."""
|
|
exit_code, _, _ = _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output=False)
|
|
return exit_code
|
|
|
|
|
|
def paramiko_run_sudo_command_with_output(ssh_client, password, cmd):
|
|
"""Run a command with sudo (password via stdin if provided). Returns (exit_code, output, error)."""
|
|
return _paramiko_run_sudo_impl(ssh_client, password, cmd, capture_output=True)
|
|
|
|
|
|
def ssh_client_active(ssh):
|
|
return bool(ssh and ssh.get_transport() and ssh.get_transport().is_active())
|
|
|
|
def close_ssh_session(ssh):
|
|
if ssh_client_active(ssh):
|
|
ssh.close()
|
|
|
|
def setup_nccl_deps(node, ring_topology):
|
|
"""Setup NCCL dependencies on the node."""
|
|
ssh = None
|
|
try:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
print(f"Updating apt on node {node["ip_address"]}...")
|
|
paramiko_run_sudo_command(ssh, node["password"], "apt update")
|
|
|
|
print(f"Installing libopenmpi-dev on node {node["ip_address"]}...")
|
|
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], "apt install -y libopenmpi-dev")
|
|
if exit_code:
|
|
raise Exception(f"Failed to install libopenmpi-dev on node {node["ip_address"]}: output:{output} error:{error}")
|
|
|
|
print(f"Cloning NCCL repo on node {node["ip_address"]}...")
|
|
if ring_topology:
|
|
cmd = """rm -rf ~/nccl_spark_cluster/ && git clone -b dgxspark-3node-ring https://github.com/zyang-dev/nccl.git ~/nccl_spark_cluster/"""
|
|
else:
|
|
cmd = """rm -rf ~/nccl_spark_cluster/ && git clone -b v2.28.9-1 https://github.com/NVIDIA/nccl.git ~/nccl_spark_cluster/"""
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to clone NCCL repo on node {node['ip_address']}: output:{output} error:{error}")
|
|
|
|
print(f"Building NCCL on node {node["ip_address"]}...")
|
|
cmd = """cd ~/nccl_spark_cluster/ && make -j src.build NVCC_GENCODE="-gencode=arch=compute_121,code=sm_121" """
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to build NCCL on node {node['ip_address']}: output:{output} error:{error}")
|
|
|
|
print(f"Cloning NCCL tests repo on node {node["ip_address"]}...")
|
|
cmd = """rm -rf ~/nccl-tests_spark_cluster/ && git clone https://github.com/NVIDIA/nccl-tests.git ~/nccl-tests_spark_cluster/"""
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to clone NCCL tests repo on node {node['ip_address']}: output:{output} error:{error}")
|
|
|
|
print(f"Building NCCL tests on node {node["ip_address"]}...")
|
|
cmd = """cd ~/nccl-tests_spark_cluster/ && %s && make MPI=1 -j8 """ % NCCL_ENV
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to build NCCL tests on node {node['ip_address']}: {error}")
|
|
|
|
print(f"Successfully setup NCCL dependencies on node {node['ip_address']}")
|
|
close_ssh_session(ssh)
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
raise Exception(f"Failed to setup NCCL dependencies on node {node["ip_address"]}:\n{e}")
|
|
|
|
def run_nccl_test(nodes_info, ring_topology):
|
|
"""Runs the NCCL test."""
|
|
|
|
threads = []
|
|
for i, node in enumerate(nodes_info):
|
|
t = ExceptionThread(target=setup_nccl_deps, args=(node, ring_topology,))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
try:
|
|
t.join()
|
|
except Exception as e:
|
|
print(f"An error occurred when running NCCL setup on nodes:\n{e}")
|
|
return False
|
|
|
|
print(f"Successfully setup NCCL dependencies on all nodes...")
|
|
|
|
print(f"Running NCCL test...")
|
|
|
|
# Generate the mpirun command
|
|
host_list = ",".join(f"{node['ip_address']}:1" for node in nodes_info)
|
|
ring_topology_specific_env = "-x NCCL_IB_MERGE_NICS=0 -x NCCL_NET_PLUGIN=none " if ring_topology else ""
|
|
mpirun_cmd = (
|
|
f"{NCCL_ENV} && mpirun -np {len(nodes_info)} -H {host_list} "
|
|
'--mca plm_rsh_agent "ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" '
|
|
"-x LD_LIBRARY_PATH=$LD_LIBRARY_PATH "
|
|
"-x UCX_NET_DEVICES=enP7s7 "
|
|
"-x NCCL_SOCKET_IFNAME=enP7s7 "
|
|
"-x OMPI_MCA_btl_tcp_if_include=enP7s7 "
|
|
"-x NCCL_IB_HCA=rocep1s0f0,rocep1s0f1,roceP2p1s0f0,roceP2p1s0f1 "
|
|
"-x NCCL_IB_SUBNET_AWARE_ROUTING=1 "
|
|
f"{ring_topology_specific_env}"
|
|
"$HOME/nccl-tests_spark_cluster/build/all_gather_perf -b 16G -e 16G -f 2"
|
|
)
|
|
|
|
# Run command on the primary node (first node in the list)
|
|
node0 = nodes_info[0]
|
|
ssh = create_ssh_client(node0["ip_address"], node0["port"], node0["user"], node0["password"])
|
|
if not ssh.get_transport().is_active():
|
|
print(f"Could not establish a session to node {node0}. Check the credentials and try again.")
|
|
return False
|
|
|
|
print(f"NCCL test command: {mpirun_cmd}")
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, mpirun_cmd)
|
|
if exit_code:
|
|
print(f"Failed to run NCCL test on node {node0["ip_address"]}: output:{output} error:{error}")
|
|
close_ssh_session(ssh)
|
|
return False
|
|
|
|
# Extract the "Avg bus bandwidth" value from the NCCL test output
|
|
avg_bus_bw = None
|
|
# The output could potentially be multiline (as it is command output)
|
|
# We need to search for a line matching "# Avg bus bandwidth : value"
|
|
for line in output.splitlines():
|
|
m = re.match(r"# Avg bus bandwidth\s*:\s*([0-9.]+)", line.strip())
|
|
if m:
|
|
avg_bus_bw = float(m.group(1))
|
|
print(f"Avg bus bandwidth from NCCL test: {avg_bus_bw} GB/s")
|
|
break
|
|
|
|
if avg_bus_bw is None:
|
|
print("WARNING: Failed to extract Avg bus bandwidth from NCCL test output.")
|
|
else:
|
|
# If the average bus bandwidth is less then throw a warning
|
|
if (ring_topology and avg_bus_bw < MIN_NCCL_TEST_BW_RING) or (not ring_topology and avg_bus_bw < MIN_NCCL_TEST_BW):
|
|
print("WARNING: NCCL Test bandwidth is less than expected. Stop any GPU workloads on the nodes and try NCCL test again using the NCCL test command above.")
|
|
else:
|
|
print(f"NCCL test BW is as expected")
|
|
|
|
close_ssh_session(ssh)
|
|
return True
|
|
|
|
def ensure_ssh_dir():
|
|
"""Ensure ~/.ssh exists with mode 0700."""
|
|
SSH_DIR.mkdir(parents=True, exist_ok=True)
|
|
os.chmod(SSH_DIR, 0o700)
|
|
|
|
def generate_shared_key():
|
|
"""Generate shared ed25519 key if it does not exist."""
|
|
if SHARED_KEY.exists():
|
|
return
|
|
ensure_ssh_dir()
|
|
print("Generating shared SSH key for all nodes...")
|
|
os.system(f"ssh-keygen -t ed25519 -N '' -f {SHARED_KEY} -q -C 'shared-cluster-key' > /dev/null 2>&1")
|
|
if not SHARED_KEY.exists():
|
|
raise Exception("Failed to generate shared SSH key.")
|
|
|
|
def add_pubkey_to_authorized_keys():
|
|
"""Add shared public key to local authorized_keys if not present."""
|
|
ensure_ssh_dir()
|
|
pub_content = (SHARED_KEY.with_suffix(".pub")).read_text()
|
|
if AUTHORIZED_KEYS.exists():
|
|
current = AUTHORIZED_KEYS.read_text()
|
|
if pub_content.strip() in current:
|
|
return
|
|
with open(AUTHORIZED_KEYS, "a") as f:
|
|
f.write(pub_content)
|
|
os.chmod(AUTHORIZED_KEYS, 0o600)
|
|
print("Added shared public key to local authorized_keys")
|
|
|
|
def update_local_ssh_config():
|
|
"""Add IdentityFile for shared key to local SSH config if missing."""
|
|
ensure_ssh_dir()
|
|
if SSH_CONFIG.exists():
|
|
content = SSH_CONFIG.read_text()
|
|
if "id_ed25519_shared" in content:
|
|
return
|
|
with open(SSH_CONFIG, "a") as f:
|
|
f.write("Host *\n")
|
|
f.write(f" {IDENTITY_LINE}\n")
|
|
os.chmod(SSH_CONFIG, 0o600)
|
|
print("Updated local SSH config to use shared key")
|
|
|
|
def configure_node_ssh_keys(node) -> bool:
|
|
"""Copy shared key to node and set up authorized_keys and SSH config."""
|
|
ssh = None
|
|
try:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
# Resolve remote home (e.g. /home/nvidia or /root)
|
|
_, stdout, _ = ssh.exec_command("echo $HOME")
|
|
home = stdout.read().decode().strip() or f"/home/{node["user"]}"
|
|
remote_ssh = f"{home}/.ssh"
|
|
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, f"mkdir -p {remote_ssh} && chmod 700 {remote_ssh}")
|
|
if exit_code:
|
|
raise Exception(f"Failed to create remote SSH directory on node {node["ip_address"]}: output:{output} error:{error}")
|
|
|
|
with SCPClient(ssh.get_transport()) as scp:
|
|
scp.put(str(SHARED_KEY), f"{remote_ssh}/id_ed25519_shared")
|
|
scp.put(str(SHARED_KEY.with_suffix(".pub")), f"{remote_ssh}/id_ed25519_shared.pub")
|
|
|
|
# Set key permissions and add to authorized_keys
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/id_ed25519_shared")
|
|
if exit_code:
|
|
raise Exception(f"Failed to set permissions on {remote_ssh}/id_ed25519_shared: output:{output} error:{error}")
|
|
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 644 {remote_ssh}/id_ed25519_shared.pub")
|
|
if exit_code:
|
|
raise Exception(f"Failed to set permissions on {remote_ssh}/id_ed25519_shared.pub: output:{output} error:{error}")
|
|
|
|
pub_line = (SHARED_KEY.with_suffix(".pub")).read_text().strip()
|
|
pub_escaped = pub_line.replace("'", "'\"'\"'")
|
|
exit_code, output, error = paramiko_run_command_with_output(
|
|
ssh,
|
|
f"grep -qF '{pub_escaped}' {remote_ssh}/authorized_keys 2>/dev/null || "
|
|
f"echo '{pub_escaped}' >> {remote_ssh}/authorized_keys",
|
|
)
|
|
if exit_code:
|
|
raise Exception(f"Failed to add {SHARED_KEY.with_suffix(".pub")} to authorized_keys: output:{output} error:{error}")
|
|
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/authorized_keys")
|
|
if exit_code:
|
|
raise Exception(f"Failed to set permissions on {remote_ssh}/authorized_keys: output:{output} error:{error}")
|
|
|
|
exit_code, output, error = paramiko_run_command_with_output(
|
|
ssh,
|
|
f"grep -q 'IdentityFile.*id_ed25519_shared' {remote_ssh}/config 2>/dev/null || "
|
|
f"(echo 'Host *' >> {remote_ssh}/config && echo ' {IDENTITY_LINE}' >> {remote_ssh}/config)",
|
|
)
|
|
if exit_code:
|
|
raise Exception(f"Failed to add {IDENTITY_LINE} to config: output:{output} error:{error}")
|
|
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, f"chmod 600 {remote_ssh}/config")
|
|
if exit_code:
|
|
raise Exception(f"Failed to set permissions on {remote_ssh}/config: output:{output} error:{error}")
|
|
|
|
print(f"Successfully configured {node["ip_address"]} with shared key")
|
|
return True
|
|
except Exception as e:
|
|
print(f" ✗ Failed to configure {node["ip_address"]}:\n{e}")
|
|
return False
|
|
finally:
|
|
close_ssh_session(ssh)
|
|
|
|
def check_and_get_up_cx7_interfaces(node_info):
|
|
ssh = None
|
|
up_ifaces = []
|
|
try:
|
|
node = node_info[0]
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
if_groups = [["enp1s0f0np0", "enP2p1s0f0np0"], ["enp1s0f1np1", "enP2p1s0f1np1"]]
|
|
for if_group in if_groups:
|
|
up_count = 0
|
|
for if_name in if_group:
|
|
cmd = r"""ip link show %s | grep -c "state UP" """ % if_name
|
|
exit_code = paramiko_run_command(ssh, cmd)
|
|
if not exit_code:
|
|
up_count+=1
|
|
|
|
if up_count == len(if_group):
|
|
# found the if_group which has UP interfaces
|
|
up_ifaces.extend(if_group)
|
|
close_ssh_session(ssh)
|
|
if not len(up_ifaces):
|
|
print(f"ERROR: CX7 interfaces on {node["ip_address"]} are not UP")
|
|
return []
|
|
|
|
print(f"Found UP CX7 interfaces {up_ifaces} on {node["ip_address"]}. Checking other nodes...")
|
|
for node in node_info[1:]:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
for if_group in if_groups:
|
|
up_count = 0
|
|
for if_name in if_group:
|
|
cmd = r"""ip link show %s | grep -c "state UP" """ % if_name
|
|
exit_code = paramiko_run_command(ssh, cmd)
|
|
if not exit_code:
|
|
if if_name not in up_ifaces:
|
|
raise Exception(f"ERROR: CX7 interface {if_name} on {node["ip_address"]} is UP which is not in {up_ifaces}. Make sure the same CX7 port(s) on each node are connected and try again.")
|
|
else:
|
|
if if_name in up_ifaces:
|
|
raise Exception(f"ERROR: CX7 interface {if_name} on {node["ip_address"]} is DOWN. {up_ifaces} are expected to be UP. Make sure the same CX7 port(s) on each node are connected and try again.")
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
raise Exception(f"ERROR: An error occurred when checking UP CX7 interfaces:\n{e}")
|
|
|
|
return up_ifaces
|
|
|
|
def check_interface_link_speed(nodes_info, interfaces):
|
|
"""Checks the link speed of the interfaces."""
|
|
ssh = None
|
|
try:
|
|
for node in nodes_info:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
for iface in interfaces:
|
|
cmd = """ethtool %s | grep -i speed | awk '-F: ' '{print $2}' """ % iface
|
|
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], cmd)
|
|
if exit_code:
|
|
print(f"ERROR: Failed to check link speed on {iface} on node {node["ip_address"]}: {error}")
|
|
close_ssh_session(ssh)
|
|
return False
|
|
|
|
speed = output.strip()
|
|
if "200000" not in speed:
|
|
print(f"ERROR: Link speed on {iface} on node {node["ip_address"]} is not 200Gbps.")
|
|
print("Check the following:\n"
|
|
"- QSFP cable should be compatible and rated at least for 200Gbps.\n"
|
|
"- If running with a switch then check the switch port speed.\n"
|
|
"- With a switch, sometimes auto-negotiation may not negotiate 200Gbps, in which case set the link speed manually on the switch ports.\n")
|
|
close_ssh_session(ssh)
|
|
return False
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
raise Exception(f"Failed to check link speed:\n{e}")
|
|
|
|
return True
|
|
|
|
def scp_put_file_with_ssh(client, local_file, remote_file) -> bool:
|
|
if not local_file or not remote_file:
|
|
print("ERROR: Local file or remote file not specified")
|
|
return False
|
|
|
|
try:
|
|
with SCPClient(client.get_transport()) as scp:
|
|
scp.put(local_file, remote_file)
|
|
except Exception as e:
|
|
print(f"scp_put_file: An error occurred:\n{e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def copy_network_setup_script_to_nodes(nodes_info) -> bool:
|
|
"""Copies the detect_and_configure_cluster_networking.py script to the nodes and runs it in threads."""
|
|
ssh = None
|
|
try:
|
|
for node in nodes_info:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
print(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
return False
|
|
|
|
if not scp_put_file_with_ssh(ssh, NETWORK_SETUP_SCRIPT, f"~/{NETWORK_SETUP_SCRIPT_NAME}"):
|
|
raise Exception(f"Failed to copy {NETWORK_SETUP_SCRIPT_NAME} to node {node["ip_address"]}")
|
|
|
|
close_ssh_session(ssh)
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
print(f"Failed to copy {NETWORK_SETUP_SCRIPT_NAME}:\n{e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def run_network_setup_script(node, cmd):
|
|
"""Runs the network setup script on the node. cmd is the command to run under sudo (no 'sudo' prefix)."""
|
|
ssh = None
|
|
try:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}.")
|
|
|
|
exit_code, output, error = paramiko_run_sudo_command_with_output(ssh, node["password"], cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to run network setup script on node {node["ip_address"]}: output:{output} error:{error}\n")
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
raise Exception(f"Failed to run network setup script on node {node["ip_address"]}:\n{e}")
|
|
|
|
def run_network_setup_scripts_on_nodes(nodes_info):
|
|
"""Runs the network setup scripts on the nodes in threads."""
|
|
|
|
threads = []
|
|
ret = True
|
|
for i, node in enumerate(nodes_info):
|
|
cmd = f"python3 ~/{NETWORK_SETUP_SCRIPT_NAME} --apply-netplan-yaml"
|
|
if i == 0:
|
|
cmd = cmd + " --primary"
|
|
t = ExceptionThread(target=run_network_setup_script, args=(node, cmd))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
try:
|
|
t.join()
|
|
except Exception as e:
|
|
print(f"An error occurred when running network setup on nodes:\n{e}")
|
|
ret = False
|
|
|
|
return ret
|
|
|
|
def verify_ip_addresses(nodes_info, up_interfaces) -> bool:
|
|
"""Verifies that the IP addresses are assigned to the interfaces."""
|
|
ssh = None
|
|
try:
|
|
nodes_to_ip_cidrs = {}
|
|
all_nodes_ip_addresses = []
|
|
for node in nodes_info:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
nodes_to_ip_cidrs[node["ip_address"]] = []
|
|
for iface in up_interfaces:
|
|
cmd = """ip addr show %s | grep -w inet | awk '{print $2}'""" % iface
|
|
exit_code, output, error = paramiko_run_command_with_output(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"ERROR: Failed to verify IP address on {iface} on node {node["ip_address"]}")
|
|
|
|
ip_addresses = output.strip().split("\n")
|
|
if len(ip_addresses) != 1:
|
|
raise Exception(f"ERROR: Zero or multiple IP addresses found on node {node["ip_address"]}, {iface}: {ip_addresses}")
|
|
|
|
if len(ip_addresses[0]) == 0:
|
|
raise Exception(f"ERROR: No IP address found on node {node["ip_address"]}, {iface}")
|
|
|
|
# Parse CIDR (e.g. 192.168.1.1/24) for uniqueness check by IP only
|
|
ip_parts = [a.split("/")[0] for a in ip_addresses]
|
|
if set(all_nodes_ip_addresses).intersection(ip_parts):
|
|
raise Exception(f"ERROR: IP address {ip_addresses} on node {node["ip_address"]}, {iface} is already assigned to another node.")
|
|
|
|
print(f"IP address on node {node["ip_address"]}, {iface}: {ip_addresses}")
|
|
all_nodes_ip_addresses.extend(ip_parts)
|
|
nodes_to_ip_cidrs[node["ip_address"]].extend(ip_addresses)
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
print(f"Running cluster connectivity test...")
|
|
|
|
for node in nodes_info:
|
|
# Run mesh ping test between all nodes in the cluster
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
raise Exception(f"Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
|
|
node_subnets = list(set([ip_network(cidr, strict=False) for cidr in nodes_to_ip_cidrs[node["ip_address"]]]))
|
|
|
|
for ip_address in all_nodes_ip_addresses:
|
|
# Check if the ip_address is in one of the node's subnets
|
|
if not any([ip_addr_obj(ip_address) in s for s in node_subnets]):
|
|
continue
|
|
|
|
cmd = f"ping -c 1 {ip_address} > /dev/null 2>&1"
|
|
exit_code = paramiko_run_command(ssh, cmd)
|
|
if exit_code:
|
|
raise Exception(f"Failed to run ping test from node {node["ip_address"]} to node {ip_address}")
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
print(f"Cluster connectivity test completed successfully.")
|
|
|
|
except Exception as e:
|
|
close_ssh_session(ssh)
|
|
print(f"Failed to verify IP addresses:\n{e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def configure_ssh_keys_on_nodes(nodes_info) -> bool:
|
|
"""Configures the ssh keys on the nodes."""
|
|
|
|
print("Generating shared SSH key for all nodes...")
|
|
generate_shared_key()
|
|
|
|
print("Setting up shared SSH access across all nodes...")
|
|
add_pubkey_to_authorized_keys()
|
|
|
|
for node in nodes_info:
|
|
print(f"Configuring shared SSH key on node {node["ip_address"]}...")
|
|
if not configure_node_ssh_keys(node):
|
|
return False
|
|
|
|
update_local_ssh_config()
|
|
|
|
print("Shared SSH keys configured successfully.")
|
|
|
|
return True
|
|
|
|
def pre_validate_cluster(config) -> tuple[bool, bool, list[str]]:
|
|
"""Pre-validates the cluster."""
|
|
try:
|
|
nodes_info = config.get("nodes_info", None)
|
|
if not nodes_info:
|
|
print("ERROR: Nodes information not found.")
|
|
return False, False, []
|
|
|
|
print(f"Checking UP CX7 interfaces...")
|
|
up_interfaces = check_and_get_up_cx7_interfaces(nodes_info)
|
|
if not up_interfaces:
|
|
print("ERROR: Failed to check UP CX7 interfaces. Check the QSFP cable connection and try again.")
|
|
return False, False, []
|
|
|
|
print(f"Checking CX7 interface link speed...")
|
|
if not check_interface_link_speed(nodes_info, up_interfaces):
|
|
return False, False, []
|
|
|
|
ring_topology = (len(nodes_info) == 3 and len(up_interfaces) == 4)
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: An error occurred when pre-validating the cluster:\n{e}")
|
|
return False, False, []
|
|
|
|
return True, ring_topology, up_interfaces
|
|
|
|
def handle_cluster_setup(config, up_interfaces) -> bool:
|
|
"""Handles the cluster network setup."""
|
|
try:
|
|
nodes_info = config.get("nodes_info", None)
|
|
if not nodes_info:
|
|
print("ERROR: Nodes information not found.")
|
|
return False
|
|
|
|
print(f"Copying network setup scripts on nodes...")
|
|
# Copy the detect_and_configure_cluster_networking.py script to the nodes and run it in threads
|
|
if not copy_network_setup_script_to_nodes(nodes_info):
|
|
return False
|
|
|
|
print(f"Running network setup scripts on nodes...")
|
|
if not run_network_setup_scripts_on_nodes(nodes_info):
|
|
print("ERROR: Failed to run network setup scripts on nodes. Check the QSFP cable connections and the nodes config in the json file and try again.")
|
|
return False
|
|
|
|
# Verify that the IP addresses are assigned to the interfaces
|
|
max_retries = 5
|
|
retries = max_retries
|
|
while retries > 0:
|
|
wait_secs = (max_retries - retries + 1) * 10
|
|
print(f"Waiting for {wait_secs} seconds before checking IP addresses")
|
|
time.sleep(wait_secs)
|
|
if not verify_ip_addresses(nodes_info, up_interfaces):
|
|
print(f"ERROR: Failed to verify IP addresses on nodes. ({retries - 1} retries left)...")
|
|
retries -= 1
|
|
continue
|
|
break
|
|
|
|
if retries == 0:
|
|
print("ERROR: Failed to verify IP addresses on nodes. Check the QSFP cable connections and the nodes config in the json file and try again.")
|
|
return False
|
|
|
|
# Configure ssh keys across nodes
|
|
if not configure_ssh_keys_on_nodes(nodes_info):
|
|
print("ERROR: Failed to configure ssh keys on nodes. Please check the configuration and try again.")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: An error occurred when handling cluster setup:\n{e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def validate_config(config):
|
|
"""Validates the configuration."""
|
|
|
|
if not config.get("nodes_info", None):
|
|
print("ERROR: Nodes information not found.")
|
|
return False
|
|
|
|
if len(config.get("nodes_info")) < 2 or len(config.get("nodes_info")) > 4:
|
|
print("ERROR: Cluster can not contain less than 2 or more than 4 nodes. Please check the configuration and try again.")
|
|
return False
|
|
|
|
cmd = """ip a | grep -w inet | awk -F"inet |/" '{print $2}' """
|
|
result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
|
|
if result.returncode:
|
|
print(f"ERROR: Failed to check IP addresses on current machine: {result.stderr}")
|
|
return False
|
|
else:
|
|
ip_addresses = result.stdout.strip().split("\n")
|
|
|
|
print(f"Checking connectivity and permissions...")
|
|
nodes_valid = True
|
|
current_node_in_cluster = False
|
|
for node in config.get("nodes_info", []):
|
|
if not node.get("ip_address", None):
|
|
print("ERROR: IP address not found for node.")
|
|
return False
|
|
if not node.get("user", None):
|
|
print("ERROR: User not found for node.")
|
|
return False
|
|
if not node.get("port", None):
|
|
# Default port is 22
|
|
node["port"] = 22
|
|
if not node.get("password", None):
|
|
# Default password is empty
|
|
node["password"] = ""
|
|
|
|
ssh = None
|
|
try:
|
|
ssh = create_ssh_client(node["ip_address"], node["port"], node["user"], node["password"])
|
|
if not ssh.get_transport().is_active():
|
|
print(f"ERROR: Could not establish a session to node {node["ip_address"]}. Check the credentials and try again.")
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Could not establish a session to node {node["ip_address"]}, check the credentials in the config file and try again: {e}")
|
|
return False
|
|
|
|
if node["password"] == "":
|
|
# No password is provided, so we need to validate the ssh key
|
|
exit_code = paramiko_run_sudo_command(ssh, "", "true")
|
|
else:
|
|
exit_code = paramiko_run_sudo_command(ssh, node["password"], "true")
|
|
if exit_code:
|
|
print(f"ERROR: Failed to check sudo access on node {node["ip_address"]}. If password is not specified then make sure that user has sudo access without password.")
|
|
nodes_valid = False
|
|
close_ssh_session(ssh)
|
|
break
|
|
|
|
if node["ip_address"] in ip_addresses:
|
|
current_node_in_cluster = True
|
|
|
|
close_ssh_session(ssh)
|
|
|
|
if not nodes_valid:
|
|
return False
|
|
|
|
if not current_node_in_cluster:
|
|
print("ERROR: This script must be run on a node in the cluster.")
|
|
return False
|
|
|
|
return True
|
|
|
|
def validate_environment():
|
|
"""Validates the environment."""
|
|
|
|
# Check if the script is being run directly instead of via the spark_cluster_setup.sh shell script
|
|
# We expect an environment variable ONLY set by the shell wrapper (e.g. SPARK_CLUSTER_SETUP_WRAPPER=1)
|
|
if os.environ.get("SPARK_CLUSTER_SETUP_WRAPPER") != "1":
|
|
print("ERROR: Please run this script via the spark_cluster_setup.sh shell script, not directly.")
|
|
return False
|
|
|
|
# Check if we are running inside a virtual environment
|
|
if sys.prefix == sys.base_prefix:
|
|
print("ERROR: Please run this script inside a Python virtual environment (venv) with requirements installed.")
|
|
return False
|
|
|
|
# Check if /etc/dgx-release exists and contains the expected DGX Spark markers
|
|
try:
|
|
with open("/etc/dgx-release", "r") as f:
|
|
content = f.read()
|
|
# Look for DGX_NAME="DGX Spark" and DGX_PRETTY_NAME="NVIDIA DGX Spark"
|
|
if 'DGX_NAME="DGX Spark"' not in content or 'DGX_PRETTY_NAME="NVIDIA DGX Spark"' not in content:
|
|
print("ERROR: This script must be run on a DGX Spark.")
|
|
return False
|
|
except FileNotFoundError:
|
|
print("ERROR: /etc/dgx-release not found. This is not a DGX Spark environment.")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class _HelpHintParser(argparse.ArgumentParser):
|
|
"""ArgumentParser that appends a --help hint to every error."""
|
|
|
|
def error(self, message):
|
|
self.exit(2, f"{self.prog}: error: {message}\nRun with --help for usage.\n")
|
|
|
|
|
|
def main():
|
|
"""Main function to setup the Spark cluster."""
|
|
parser = _HelpHintParser(
|
|
description="Setup the Spark cluster.",
|
|
epilog="One of --pre-validate-only, --run-setup, or --run-nccl-test is required.",
|
|
)
|
|
parser.add_argument("-c", "--config", type=str, required=True, help="Path to the configuration file.")
|
|
parser.add_argument("-v", "--pre-validate-only", action="store_true", help="Only run pre-setup validations.")
|
|
parser.add_argument("-s", "--run-setup", action="store_true", help="Run the cluster setup and run NCCL bandwidth test.")
|
|
parser.add_argument("-n", "--run-nccl-test", action="store_true", help="Run the NCCL bandwidth test.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not (args.pre_validate_only or args.run_setup or args.run_nccl_test):
|
|
parser.error("One of -v/--pre-validate-only, -s/--run-setup, or -n/--run-nccl-test is required.")
|
|
|
|
config_path = args.config
|
|
if not os.path.exists(config_path):
|
|
print(f"ERROR: Configuration file not found: {config_path}")
|
|
return
|
|
|
|
with open(config_path, "r") as f:
|
|
config = json.load(f)
|
|
|
|
if not config:
|
|
print("ERROR: Configuration file not found.")
|
|
return
|
|
|
|
try:
|
|
# Validate env
|
|
print(f"Validating environment...")
|
|
if not validate_environment():
|
|
return
|
|
|
|
# Validate the config
|
|
print(f"Validating configuration...")
|
|
if not validate_config(config):
|
|
return
|
|
|
|
print(f"Pre-validating cluster setup...")
|
|
ret, ring_topology, up_interfaces = pre_validate_cluster(config)
|
|
if not ret:
|
|
return
|
|
|
|
if args.pre_validate_only:
|
|
print("Pre-setup validations completed successfully.")
|
|
return
|
|
|
|
if args.run_setup:
|
|
print("Setting up Spark cluster...")
|
|
if not handle_cluster_setup(config, up_interfaces):
|
|
return
|
|
|
|
print("Spark cluster setup completed successfully.")
|
|
|
|
if args.run_nccl_test or args.run_setup:
|
|
print("Running NCCL test...")
|
|
if ring_topology:
|
|
print("Detected ring topology...")
|
|
if not run_nccl_test(config.get("nodes_info", []), ring_topology):
|
|
return
|
|
print("NCCL test completed.")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: An error occurred when running Spark cluster setup:\n{e}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|