◐ Shell
clean mode source ↗

feat: add reconfirm_record per RFC 6762 §10.4 by bluetoothbot · Pull Request #1772 · python-zeroconf/python-zeroconf

Expand Up @@ -22,16 +22,26 @@
from __future__ import annotations
import asyncio from functools import partial from typing import TYPE_CHECKING, cast
from .._cache import _UniqueRecordsType from .._dns import DNSQuestion, DNSRecord from .._logger import log from .._protocol.incoming import DNSIncoming from .._protocol.outgoing import DNSOutgoing from .._record_update import RecordUpdate from .._updates import RecordUpdateListener from .._utils.time import current_time_millis from ..const import _ADDRESS_RECORD_TYPES, _DNS_PTR_MIN_TTL, _TYPE_PTR from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _ADDRESS_RECORD_TYPES, _DNS_PTR_MIN_TTL, _FLAGS_QR_QUERY, _RECONFIRM_QUERY_INTERVALS_MS, _RECONFIRM_TIMEOUT_MS, _TYPE_PTR, )
if TYPE_CHECKING: from .._core import Zeroconf Expand All @@ -42,13 +52,17 @@ class RecordManager: """Process records into the cache and notify listeners."""
__slots__ = ("cache", "listeners", "zc") __slots__ = ("_reconfirm_tasks", "cache", "listeners", "zc")
def __init__(self, zeroconf: Zeroconf) -> None: """Init the record manager.""" self.zc = zeroconf self.cache = zeroconf.cache self.listeners: set[RecordUpdateListener] = set() # Active per-record reconfirmations. Keyed by the cache entry # so that repeated calls for the same record while one is in # flight are no-ops (RFC 6762 §10.4). self._reconfirm_tasks: dict[DNSRecord, asyncio.Task] = {}
def async_updates(self, now: _float, records: list[RecordUpdate]) -> None: """Used to notify listeners of new information that has updated Expand Down Expand Up @@ -219,3 +233,83 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: self.zc.async_notify_all() except ValueError as e: log.exception("Failed to remove listener: %r", e)
def async_reconfirm_record(self, record: DNSRecord) -> bool: """Schedule RFC 6762 §10.4 reconfirmation for ``record``.""" cached = self.cache.get(record) if cached is None: return False if cached in self._reconfirm_tasks: return False loop = self.zc.loop if loop is None: return False task = loop.create_task(self._async_reconfirm(cached)) self._reconfirm_tasks[cached] = task task.add_done_callback(partial(self._reconfirm_done, cached)) return True
def _reconfirm_done(self, record: DNSRecord, _task: asyncio.Task) -> None: """Drop ``record`` from the active reconfirmation set.""" self._reconfirm_tasks.pop(record, None)
async def _async_reconfirm(self, record: DNSRecord) -> None: """Re-query ``record`` and flush from cache if not refreshed.
RFC 6762 §10.4: send two or more queries, then flush the record if no response arrives within ten seconds. """ start = current_time_millis() original_created = record.created zc = self.zc question = DNSQuestion(record.name, record.type, record.class_)
prev_delay_ms = 0 for delay_ms in _RECONFIRM_QUERY_INTERVALS_MS: wait_ms = delay_ms - prev_delay_ms if wait_ms > 0: await asyncio.sleep(millis_to_seconds(wait_ms)) prev_delay_ms = delay_ms if zc.done: return if self._record_refreshed_since(record, original_created): return out = DNSOutgoing(_FLAGS_QR_QUERY) out.add_question(question) zc.async_send(out)
remaining_ms = _RECONFIRM_TIMEOUT_MS - prev_delay_ms if remaining_ms > 0: await asyncio.sleep(millis_to_seconds(remaining_ms)) if zc.done: return if self._record_refreshed_since(record, original_created): return
now = current_time_millis() elapsed_secs = max(0, int((now - start) / 1000)) log.debug( "Reconfirmation of %s timed out after %ds; flushing from cache", record, elapsed_secs, ) cached = self.cache.get(record) if cached is None: return # Mark expired so listeners interpret this as a goodbye when # they re-check ``is_expired(now)`` from inside # ``async_update_records``. Mirrors the goodbye path in # ``async_updates_from_response``. cached._set_created_ttl(now - 1000, 0) update = RecordUpdate.__new__(RecordUpdate) update._fast_init(cached, cached) self.async_updates(now, [update]) self.cache.async_remove_records([cached]) self.async_updates_complete(True)
def _record_refreshed_since(self, record: DNSRecord, original_created: float) -> bool: """Return True if the cache holds a newer copy of ``record``.""" cached = self.cache.get(record) if cached is None: return True return cached.created > original_created