Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Snapshot,
SnapshotSummaryCollector,
Summary,
ancestors_of,
update_snapshot_summaries,
)
from pyiceberg.table.update import (
Expand Down Expand Up @@ -985,6 +986,40 @@ def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: str | N
self._transaction._stage(update, requirement)
return self

def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
"""Rollback the table to the given snapshot id.

The snapshot needs to be an ancestor of the current table state.

Args:
snapshot_id (int): rollback to this snapshot_id that used to be current.

Returns:
This for method chaining

Raises:
ValueError: If the snapshot does not exist or is not an ancestor of the current table state.
"""
if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
raise ValueError(f"Cannot roll back to unknown snapshot id: {snapshot_id}")

if not self._is_current_ancestor(snapshot_id):
raise ValueError(f"Cannot roll back to snapshot, not an ancestor of the current state: {snapshot_id}")

return self.set_current_snapshot(snapshot_id=snapshot_id)

def _is_current_ancestor(self, snapshot_id: int) -> bool:
return snapshot_id in self._current_ancestors()

def _current_ancestors(self) -> set[int]:
return {
a.snapshot_id
for a in ancestors_of(
self._transaction.table_metadata.current_snapshot(),
self._transaction.table_metadata,
)
}


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
108 changes: 108 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,44 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import uuid
from collections.abc import Generator

import pyarrow as pa
import pytest

from pyiceberg.catalog import Catalog
from pyiceberg.table import Table
from pyiceberg.table.refs import SnapshotRef


@pytest.fixture
def table_with_snapshots(session_catalog: Catalog) -> Generator[Table, None, None]:
session_catalog.create_namespace_if_not_exists("default")
identifier = f"default.test_table_snapshot_ops_{uuid.uuid4().hex[:8]}"

arrow_schema = pa.schema(
[
pa.field("id", pa.int64(), nullable=False),
pa.field("data", pa.string(), nullable=True),
]
)

tbl = session_catalog.create_table(identifier=identifier, schema=arrow_schema)

data1 = pa.Table.from_pylist([{"id": 1, "data": "a"}, {"id": 2, "data": "b"}], schema=arrow_schema)
tbl.append(data1)

data2 = pa.Table.from_pylist([{"id": 3, "data": "c"}, {"id": 4, "data": "d"}], schema=arrow_schema)
tbl.append(data2)

tbl = session_catalog.load_table(identifier)

yield tbl

session_catalog.drop_table(identifier)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_create_tag(catalog: Catalog) -> None:
Expand Down Expand Up @@ -160,3 +192,79 @@ def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
tbl = catalog.load_table(identifier)
tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
assert tbl.metadata.refs.get(tag_name, None) is None


@pytest.mark.integration
def test_rollback_to_snapshot(table_with_snapshots: Table) -> None:
history = table_with_snapshots.history()
assert len(history) >= 2

ancestor_snapshot_id = history[-2].snapshot_id

table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()

updated = table_with_snapshots.current_snapshot()
assert updated is not None
assert updated.snapshot_id == ancestor_snapshot_id


@pytest.mark.integration
def test_rollback_to_current_snapshot(table_with_snapshots: Table) -> None:
current = table_with_snapshots.current_snapshot()
assert current is not None

table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=current.snapshot_id).commit()

updated = table_with_snapshots.current_snapshot()
assert updated is not None
assert updated.snapshot_id == current.snapshot_id


@pytest.mark.integration
def test_rollback_to_snapshot_chained_with_tag(table_with_snapshots: Table) -> None:
history = table_with_snapshots.history()
assert len(history) >= 2

ancestor_snapshot_id = history[-2].snapshot_id
tag_name = "my-tag"

(
table_with_snapshots.manage_snapshots()
.create_tag(snapshot_id=ancestor_snapshot_id, tag_name=tag_name)
.rollback_to_snapshot(snapshot_id=ancestor_snapshot_id)
.commit()
)

updated = table_with_snapshots.current_snapshot()
assert updated is not None
assert updated.snapshot_id == ancestor_snapshot_id
assert table_with_snapshots.metadata.refs[tag_name] == SnapshotRef(snapshot_id=ancestor_snapshot_id, snapshot_ref_type="tag")


@pytest.mark.integration
def test_rollback_to_snapshot_not_ancestor(table_with_snapshots: Table) -> None:
history = table_with_snapshots.history()
assert len(history) >= 2

snapshot_a = history[-2].snapshot_id

branch_name = "my-branch"
table_with_snapshots.manage_snapshots().create_branch(snapshot_id=snapshot_a, branch_name=branch_name).commit()

data = pa.Table.from_pylist([{"id": 5, "data": "e"}], schema=table_with_snapshots.schema().as_arrow())
table_with_snapshots.append(data, branch=branch_name)

snapshot_c = table_with_snapshots.metadata.snapshot_by_name(branch_name)
assert snapshot_c is not None
assert snapshot_c.snapshot_id != snapshot_a

with pytest.raises(ValueError, match="not an ancestor"):
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c.snapshot_id).commit()


@pytest.mark.integration
def test_rollback_to_snapshot_unknown_id(table_with_snapshots: Table) -> None:
invalid_snapshot_id = 1234567890000

with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot id"):
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()