diff --git a/tools/find_rtkbase/find_rtkbase.py b/tools/find_rtkbase/find_rtkbase.py index a211d064..57fe308f 100644 --- a/tools/find_rtkbase/find_rtkbase.py +++ b/tools/find_rtkbase/find_rtkbase.py @@ -16,6 +16,14 @@ log.setLevel('ERROR') +def format_base_label(base): + display_host = base.get('server') or base.get('fqdn') or base.get('ip') or 'Unknown host' + ip_address = base.get('ip') + if ip_address and display_host != ip_address: + return f"{display_host} ({ip_address})" + return display_host + + class MyApp: def __init__(self, master, ports=[80, 443], allscan=False): self.master = master @@ -123,16 +131,23 @@ def _after_scan_thread(self): self.base_buttons_list = ["base" + str(i) + "Button" for i, j in enumerate(self.available_base)] if len(self.available_base)>0: for i, base in enumerate(self.available_base): - def browser_fqdn(event, ip = (base.get('server') or base.get('ip')), port = base.get('port')): - self.launch_browser(ip, port) - def browser_ip(event, ip = (base.get('ip')), port = base.get('port')): - self.launch_browser(ip, port) + primary_host = scan_network.preferred_access_address(base) + alternate_host = scan_network.alternate_access_address(base) + + def browser_primary(event, host = primary_host, port = base.get('port')): + if host: + self.launch_browser(host, port) + + def browser_alternate(event, host = alternate_host, port = base.get('port')): + if host: + self.launch_browser(host, port) - self.base_labels_list[i] = ttk.Label(self.top_frame, text=f"{base.get('server') or base.get('fqdn')} ({base.get('ip')})") + self.base_labels_list[i] = ttk.Label(self.top_frame, text=format_base_label(base)) self.base_labels_list[i].grid(column=0, row=i) self.base_buttons_list[i] = ttk.Button(self.top_frame, text='Open') - self.base_buttons_list[i].bind("", browser_fqdn) - self.base_buttons_list[i].bind("", browser_ip) + self.base_buttons_list[i].bind("", browser_primary) + if alternate_host and alternate_host != primary_host: + self.base_buttons_list[i].bind("", browser_alternate) self.base_buttons_list[i].grid(column=3, row=i) else: self.nobase_label.grid() diff --git a/tools/find_rtkbase/scan_network.py b/tools/find_rtkbase/scan_network.py index dd69c9d1..1a272ab9 100644 --- a/tools/find_rtkbase/scan_network.py +++ b/tools/find_rtkbase/scan_network.py @@ -11,6 +11,63 @@ log = logging.getLogger(__name__) log.setLevel('ERROR') + +def _normalize_address(address): + if address in (None, '', 'None'): + return None + return address + + +def iter_access_addresses(result): + ip_address = _normalize_address(result.get('ip') or result.get('IP')) + server_name = _normalize_address(result.get('server') or result.get('SERVER') or result.get('fqdn')) + seen = set() + + for address in (ip_address, server_name): + if address and address not in seen: + seen.add(address) + yield address + + +def preferred_access_address(result): + return next(iter_access_addresses(result), None) + + +def alternate_access_address(result): + addresses = list(iter_access_addresses(result)) + if not addresses: + return None + return addresses[1] if len(addresses) > 1 else addresses[0] + +def sort_hosts(hosts_list): + def first_port(host): + if host.get('port') is not None: + return host.get('port') + ports = host.get('PORTS') or [] + return ports[0] if ports else 0 + + def host_sort_key(host): + return ( + (host.get('server') or host.get('SERVER') or host.get('fqdn') or host.get('NAME') or '').casefold(), + host.get('ip') or host.get('IP') or '', + first_port(host), + ) + + return sorted(hosts_list, key=host_sort_key) + + +def iter_probe_addresses(result): + ip_address = _normalize_address(result.get('IP')) + server_name = _normalize_address(result.get('SERVER')) + + if ip_address: + # Direct IP probes are faster and more reliable than mDNS on some VPN setups. + yield ip_address + yield ip_address + if server_name and server_name != ip_address: + yield server_name + + class MyZeroConfListener: def __init__(self): self.services = [] @@ -27,15 +84,19 @@ def zeroconf_scan(name, prot_type, timeout=5): log.debug("Scanning with zeroconf") service_list = [] zeroconf = Zeroconf() - listener = MyZeroConfListener() - browser = ServiceBrowser(zeroconf, prot_type, listener) - time.sleep(timeout) - for service in listener.services: - if name.lower() in service.name.lower(): - service_list.append({'NAME' : service.name, - 'PORTS' : [service.port], - 'SERVER' : service.server.rstrip('.'), - 'IP' : '.'.join(str(byte) for byte in service.addresses[0])}) + try: + listener = MyZeroConfListener() + browser = ServiceBrowser(zeroconf, prot_type, listener) + time.sleep(timeout) + for service in listener.services: + if name.lower() in service.name.lower(): + service_list.append({'NAME' : service.name, + 'PORTS' : [service.port], + 'SERVER' : service.server.rstrip('.'), + 'IP' : '.'.join(str(byte) for byte in service.addresses[0])}) + finally: + zeroconf.close() + service_list = sort_hosts(service_list) log.debug(f"filtered list for {name}") log.debug(service_list) return service_list @@ -104,10 +165,10 @@ def get_rtkbase_infos(host_list): if result.get('PORTS') and len(result.get('PORTS')) > 0: try: for port in result.get('PORTS'): - #try with mDns server name at first, then with the ip address if it fails - for address in (result.get('SERVER'), result.get('IP')): + ans = None + # Prefer direct IP probes before falling back to the advertised mDNS name. + for address in iter_probe_addresses(result): try: - ans = None if address is None: continue log.debug(f"{address}:{port} Api request") @@ -202,6 +263,7 @@ def main(ports, allscan=False, iprange=None): available_rtkbase = get_rtkbase_infos(scan_results) #remove duplicate available_rtkbase = remove_duplicate_hosts(available_rtkbase) + available_rtkbase = sort_hosts(available_rtkbase) log.debug("RTKBase station found: ") log.debug(available_rtkbase) return available_rtkbase @@ -211,4 +273,4 @@ def main(ports, allscan=False, iprange=None): if args.debug: log.setLevel('DEBUG') log.debug(f"Arguments: {args}") - print(main(args.ports, args.allscan, args.iprange)) \ No newline at end of file + print(main(args.ports, args.allscan, args.iprange)) diff --git a/tools/find_rtkbase/test_scan_network.py b/tools/find_rtkbase/test_scan_network.py new file mode 100644 index 00000000..b5094f9f --- /dev/null +++ b/tools/find_rtkbase/test_scan_network.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import importlib.util +import unittest +from pathlib import Path +from unittest import mock + + +MODULE_PATH = Path(__file__).with_name("scan_network.py") +SPEC = importlib.util.spec_from_file_location("scan_network_under_test", MODULE_PATH) +scan_network = importlib.util.module_from_spec(SPEC) +assert SPEC.loader is not None +SPEC.loader.exec_module(scan_network) + + +class FakeService: + def __init__(self, name, port, server, address_bytes): + self.name = name + self.port = port + self.server = server + self.addresses = [address_bytes] + + +class FakeZeroconf: + def __init__(self, services): + self.services = services + self.closed = False + + def get_service_info(self, service_type, name): + return self.services[name] + + def close(self): + self.closed = True + + +class FakeResponse: + def __init__(self, status_code, payload): + self.status_code = status_code + self._payload = payload + + def json(self): + return self._payload + + +class ScanNetworkTests(unittest.TestCase): + def test_preferred_access_address_prefers_ip_over_server_name(self): + host = {"ip": "10.0.0.5", "server": "alpha.local", "fqdn": "alpha.local"} + + self.assertEqual("10.0.0.5", scan_network.preferred_access_address(host)) + self.assertEqual("alpha.local", scan_network.alternate_access_address(host)) + + def test_access_address_helpers_ignore_placeholder_server_values(self): + host = {"ip": "10.0.0.5", "server": "None"} + + self.assertEqual("10.0.0.5", scan_network.preferred_access_address(host)) + self.assertEqual("10.0.0.5", scan_network.alternate_access_address(host)) + + def test_zeroconf_scan_returns_results_sorted_by_server_name(self): + fake_services = { + "svc-b": FakeService( + "RTKBase Web Server Beta", + 80, + "beta.local.", + bytes([192, 168, 1, 20]), + ), + "svc-a": FakeService( + "RTKBase Web Server Alpha", + 80, + "alpha.local.", + bytes([192, 168, 1, 10]), + ), + } + fake_zeroconf = FakeZeroconf(fake_services) + + def fake_browser(zeroconf, service_type, listener): + listener.add_service(zeroconf, service_type, "svc-b") + listener.add_service(zeroconf, service_type, "svc-a") + return object() + + with ( + mock.patch.object(scan_network, "Zeroconf", return_value=fake_zeroconf), + mock.patch.object(scan_network, "ServiceBrowser", side_effect=fake_browser), + mock.patch.object(scan_network.time, "sleep", return_value=None), + ): + results = scan_network.zeroconf_scan("RTKBase Web Server", "_http._tcp.local.") + + self.assertEqual(["alpha.local", "beta.local"], [item["SERVER"] for item in results]) + self.assertTrue(fake_zeroconf.closed) + + def test_get_rtkbase_infos_retries_ip_before_server_name(self): + calls = [] + + def fake_get(url, timeout): + calls.append(url) + if len(calls) < 3: + raise scan_network.requests.exceptions.ConnectionError("network") + return FakeResponse( + 200, + {"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"}, + ) + + host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}] + with mock.patch.object(scan_network.requests, "get", side_effect=fake_get): + results = scan_network.get_rtkbase_infos(host_list) + + self.assertEqual( + [ + "http://10.0.0.5:80/api/v1/infos", + "http://10.0.0.5:80/api/v1/infos", + "http://alpha.local:80/api/v1/infos", + ], + calls, + ) + self.assertEqual("alpha.local", results[0]["server"]) + self.assertEqual("10.0.0.5", results[0]["ip"]) + + def test_get_rtkbase_infos_stops_after_successful_ip_retry(self): + calls = [] + + def fake_get(url, timeout): + calls.append(url) + if len(calls) == 1: + raise scan_network.requests.exceptions.ConnectionError("network") + return FakeResponse( + 200, + {"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"}, + ) + + host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}] + with mock.patch.object(scan_network.requests, "get", side_effect=fake_get): + results = scan_network.get_rtkbase_infos(host_list) + + self.assertEqual( + [ + "http://10.0.0.5:80/api/v1/infos", + "http://10.0.0.5:80/api/v1/infos", + ], + calls, + ) + self.assertEqual(1, len(results)) + + +if __name__ == "__main__": + unittest.main()