preprocessing.to_model
format_xarray_for_rnn
def format_xarray_for_rnn(
ds: xr.Dataset,
read_from_variable: str = "position_processed",
keypoints: list[str] | None = None,
extra_features: list[str] | None = None
) -> Tuple[np.ndarray, Dict[str, Any]]
Formats the xarray dataset for use VAME's RNN model:
- The x and y coordinates of the centered_reference_keypoint are excluded.
- The x coordinate of the orientation_reference_keypoint is excluded.
- The remaining data is flattened and transposed.
- Optionally, pre-computed scalar features (speed, head direction, ...) are appended as additional rows after the pose-derived features.
Parameters
- ds (
xr.Dataset): The xarray dataset to format. - read_from_variable (
str, default="position_processed"): The variable to read from the dataset. - keypoints (
list[str] | None, optional): A list of keypoints to include in the output. If None, all keypoints are included. If provided, only the specified keypoints will be included in the output. - extra_features (
list[str] | None, optional): Names of pre-computed scalar features to append after the pose-derived features. Each must exist as a data variable indswith dims("time",)or("time", "individuals"). VAME applies no preprocessing to these — the user is responsible for alignment, scaling, and NaN handling. Seevame.validate_extra_features.
Returns
Tuple[np.ndarray, Dict[str, Any]]: A tuple containing:
- The formatted array in the shape (n_features, n_samples)
- A dictionary with feature provenance and processing information Where n_features = 2 * n_keypoints * n_spaces - 3 + len(extra_features or []).