diff --git a/livekit-rtc/livekit/rtc/_utils.py b/livekit-rtc/livekit/rtc/_utils.py index 2342b21a..59eaa0c5 100644 --- a/livekit-rtc/livekit/rtc/_utils.py +++ b/livekit-rtc/livekit/rtc/_utils.py @@ -40,8 +40,18 @@ def task_done_logger(task: asyncio.Task) -> None: return -def get_address(mv: memoryview) -> int: - return ctypes.addressof(ctypes.c_char.from_buffer(mv)) +def get_address(data) -> int: + if isinstance(data, memoryview): + if not data.readonly: + return ctypes.addressof(ctypes.c_char.from_buffer(data)) + data = data.obj + if isinstance(data, bytearray): + return ctypes.addressof(ctypes.c_char.from_buffer(data)) + if isinstance(data, bytes): + addr = ctypes.cast(ctypes.c_char_p(data), ctypes.c_void_p).value + assert addr is not None + return addr + raise TypeError(f"expected bytes, bytearray, or memoryview, got {type(data)}") T = TypeVar("T") diff --git a/livekit-rtc/livekit/rtc/audio_frame.py b/livekit-rtc/livekit/rtc/audio_frame.py index c4f1943c..1461d978 100644 --- a/livekit-rtc/livekit/rtc/audio_frame.py +++ b/livekit-rtc/livekit/rtc/audio_frame.py @@ -49,19 +49,22 @@ def __init__( Raises: ValueError: If the length of `data` is smaller than the required size. """ - data = memoryview(data).cast("B") + if isinstance(data, memoryview): + data = data.obj # type: ignore[assignment] - if len(data) < num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16): + min_size = num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16) + data_len = len(data) + + if data_len < min_size: raise ValueError( "data length must be >= num_channels * samples_per_channel * sizeof(int16)" ) - if len(data) % ctypes.sizeof(ctypes.c_int16) != 0: + if data_len % ctypes.sizeof(ctypes.c_int16) != 0: # can happen if data is bigger than needed raise ValueError("data length must be a multiple of sizeof(int16)") - n = len(data) // ctypes.sizeof(ctypes.c_int16) - self._data = (ctypes.c_int16 * n).from_buffer_copy(data) + self._data = data self._sample_rate = sample_rate self._num_channels = num_channels @@ -97,7 +100,7 @@ def _from_owned_info(owned_info: proto_audio.OwnedAudioFrameBuffer) -> "AudioFra def _proto_info(self) -> proto_audio.AudioFrameBufferInfo: audio_info = proto_audio.AudioFrameBufferInfo() - audio_info.data_ptr = get_address(memoryview(self._data)) + audio_info.data_ptr = get_address(self._data) audio_info.sample_rate = self.sample_rate audio_info.num_channels = self.num_channels audio_info.samples_per_channel = self.samples_per_channel diff --git a/livekit-rtc/livekit/rtc/video_frame.py b/livekit-rtc/livekit/rtc/video_frame.py index 76b66625..3d9f64ba 100644 --- a/livekit-rtc/livekit/rtc/video_frame.py +++ b/livekit-rtc/livekit/rtc/video_frame.py @@ -51,7 +51,7 @@ def __init__( self._width = width self._height = height self._type = type - self._data = bytearray(data) + self._data = data @property def width(self) -> int: