diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 206e91ed14..58c15d8d8c 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -252,6 +252,8 @@ async def send_transaction( path = _create_v1_path("/send/%s", transaction.transaction_id) + headers: Dict[bytes, List[bytes]] = {b"Accept": [b"application/json"]} + return await self.client.put_json( transaction.destination, path=path, @@ -262,6 +264,7 @@ async def send_transaction( # Sending a transaction should always succeed, if it doesn't # then something is wrong and we should backoff. backoff_on_all_error_codes=True, + headers=headers, ) async def make_query( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index f6d2536957..2136f2915a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -532,6 +532,7 @@ async def _send_request_with_optional_trailing_slash( async def _send_request( self, request: MatrixFederationRequest, + *, retry_on_dns_fail: bool = True, timeout: Optional[int] = None, long_retries: bool = False, @@ -539,6 +540,7 @@ async def _send_request( backoff_on_404: bool = False, backoff_on_all_error_codes: bool = False, follow_redirects: bool = False, + headers: Optional[Dict[bytes, List[bytes]]] = None, ) -> IResponse: """ Sends a request to the given server. @@ -584,6 +586,9 @@ async def _send_request( follow_redirects: True to follow the Location header of 307/308 redirect responses. This does not recurse. + headers: Additional Headers to pass to the request. Authorization, + Content-type, User-Agent and Host will be overridden below + Returns: Resolves with the HTTP response object on success. @@ -646,6 +651,8 @@ async def _send_request( # Inject the span into the headers headers_dict: Dict[bytes, List[bytes]] = {} + if headers: + headers_dict.update(headers) opentracing.inject_header_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -754,14 +761,15 @@ async def _send_request( return await self._send_request( attr.evolve(request, uri=new_uri, generate_uri=False), - retry_on_dns_fail, - timeout, - long_retries, - ignore_backoff, - backoff_on_404, - backoff_on_all_error_codes, + retry_on_dns_fail=retry_on_dns_fail, + timeout=timeout, + long_retries=long_retries, + ignore_backoff=ignore_backoff, + backoff_on_404=backoff_on_404, + backoff_on_all_error_codes=backoff_on_all_error_codes, # Do not continue following redirects. follow_redirects=False, + headers=headers, ) else: logger.info( @@ -942,6 +950,7 @@ async def put_json( try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, backoff_on_all_error_codes: bool = False, + headers: Optional[Dict[bytes, List[bytes]]] = None, ) -> JsonDict: ... @overload @@ -959,6 +968,7 @@ async def put_json( try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, backoff_on_all_error_codes: bool = False, + headers: Optional[Dict[bytes, List[bytes]]] = None, ) -> T: ... async def put_json( @@ -975,6 +985,7 @@ async def put_json( try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, backoff_on_all_error_codes: bool = False, + headers: Optional[Dict[bytes, List[bytes]]] = None, ) -> Union[JsonDict, T]: """Sends the specified json data using PUT @@ -1012,6 +1023,9 @@ async def put_json( parsing as JSON. backoff_on_all_error_codes: Back off if we get any error response + headers: Additional Headers to pass to the request. Authorization, + Content-type, User-Agent and Host will be overridden below + Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1045,6 +1059,7 @@ async def put_json( long_retries=long_retries, timeout=timeout, backoff_on_all_error_codes=backoff_on_all_error_codes, + headers=headers, ) if timeout is not None: