diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 88fbb11f88..d772814643 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -136,13 +136,6 @@ class ProtocolVersion(object): Defines native protocol versions supported by this driver. """ - V3 = 3 - """ - v3, supported in Cassandra 2.1-->3.x+; - added support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), - serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. - """ - V4 = 4 """ v4, supported in Cassandra 2.2-->3.x+; @@ -170,9 +163,9 @@ class ProtocolVersion(object): DSE private protocol v2, supported in DSE 6.0+ """ - SUPPORTED_VERSIONS = (V5, V4, V3) + SUPPORTED_VERSIONS = (V5, V4) """ - A tuple of all supported protocol versions for ScyllaDB, including future v5 version. + A tuple of all supported protocol versions for ScyllaDB. """ BETA_VERSIONS = (V6,) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 099043eae0..3e19979789 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -735,12 +735,7 @@ def auth_provider(self, value): try: self._auth_provider_callable = value.new_authenticator except AttributeError: - if self.protocol_version > 1: - raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " - "interface when protocol_version >= 2") - elif not callable(value): - raise TypeError("auth_provider must be callable when protocol_version == 1") - self._auth_provider_callable = value + raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider interface") self._auth_provider = value @@ -1557,7 +1552,7 @@ def register_user_type(self, keyspace, user_type, klass): Example:: - cluster = Cluster(protocol_version=3) + cluster = Cluster() session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") @@ -1582,11 +1577,6 @@ def __init__(self, street, zipcode): print(row.id, row.location.street, row.location.zipcode) """ - if self.protocol_version < 3: - log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " - "CQL encoding for simple statements will still work, but named tuples will " - "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) - self._user_types[keyspace][user_type] = klass for session in tuple(self.sessions): session.user_type_registered(keyspace, user_type, klass) @@ -2442,8 +2432,6 @@ def default_serial_consistency_level(self): The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through this session. This default may be overridden by setting the :attr:`~.Statement.serial_consistency_level` on individual statements. - - Only valid for ``protocol_version >= 2``. """ return self._default_serial_consistency_level @@ -2954,11 +2942,6 @@ def _create_response_future(self, query, parameters, trace, custom_payload, continuous_paging_options=continuous_paging_options, result_metadata_id=prepared_statement.result_metadata_id) elif isinstance(query, BatchStatement): - if self._protocol_version < 2: - raise UnsupportedOperation( - "BatchStatement execution is only supported with protocol version " - "2 or higher (supported in Cassandra 2.0 and higher). Consider " - "setting Cluster.protocol_version to 2 to support this operation.") statement_keyspace = query.keyspace if ProtocolVersion.uses_keyspace_flag(self._protocol_version) else None message = BatchMessage( query.batch_type, query._statements_and_parameters, cl, @@ -3097,7 +3080,7 @@ def prepare(self, query, custom_payload=None, keyspace=None): prepared_keyspace = keyspace if keyspace else None prepared_statement = PreparedStatement.from_message( response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace, - self._protocol_version, response.column_metadata, response.result_metadata_id, response.is_lwt, self.cluster.column_encryption_policy) + response.column_metadata, response.result_metadata_id, response.is_lwt, self.cluster.column_encryption_policy) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(response.query_id, prepared_statement) @@ -4637,10 +4620,9 @@ def _set_result(self, host, connection, pool, response): self._custom_payload = getattr(response, 'custom_payload', None) if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload: - protocol = self.session.cluster.protocol_version info = self._custom_payload.get('tablets-routing-v1') ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') - tablet_routing_info = ctype.from_binary(info, protocol) + tablet_routing_info = ctype.from_binary(info) first_token = tablet_routing_info[0] last_token = tablet_routing_info[1] tablet_replicas = tablet_routing_info[2] diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index bc00001666..7cdb5ddd3f 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -422,8 +422,8 @@ def __str__(self): ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) @classmethod - def _routing_key_from_values(cls, pk_values, protocol_version): - return cls._key_serializer(pk_values, protocol_version) + def _routing_key_from_values(cls, pk_values): + return cls._key_serializer(pk_values) @classmethod def _discover_polymorphic_submodels(cls): @@ -948,10 +948,10 @@ def _transform_column(col_name, col_obj): key_cols = [c for c in partition_keys.values()] partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) key_cql_types = [c.cql_type for c in key_cols] - key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) + key_serializer = staticmethod(lambda parts: [t.to_binary(p) for t, p in zip(key_cql_types, parts)]) else: partition_key_index = {} - key_serializer = staticmethod(lambda parts, proto_version: None) + key_serializer = staticmethod(lambda parts: None) # setup partition key shortcut if len(partition_keys) == 0: diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index afc7ceeef6..5eb4be9166 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -1523,7 +1523,7 @@ def _execute_statement(model, statement, consistency_level, timeout, connection= if model._partition_key_index: key_values = statement.partition_key_values(model._partition_key_index) if not any(v is None for v in key_values): - parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) + parts = model._routing_key_from_values(key_values) s.routing_key = parts s.keyspace = model._get_keyspace() connection = connection or model._get_connection() diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index e36c48563c..91d5ec19af 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -291,7 +291,7 @@ def __repr__(self): return '<%s>' % (self.cql_parameterized_type()) @classmethod - def from_binary(cls, byts, protocol_version): + def from_binary(cls, byts): """ Deserialize a bytestring into a value. See the deserialize() method for more information. This method differs in that if None or the empty @@ -301,19 +301,19 @@ def from_binary(cls, byts, protocol_version): return None elif len(byts) == 0 and not cls.empty_binary_ok: return EMPTY if cls.support_empty_values else None - return cls.deserialize(byts, protocol_version) + return cls.deserialize(byts) @classmethod - def to_binary(cls, val, protocol_version): + def to_binary(cls, val): """ Serialize a value into a bytestring. See the serialize() method for more information. This method differs in that if None is passed in, the result is the empty string. """ - return b'' if val is None else cls.serialize(val, protocol_version) + return b'' if val is None else cls.serialize(val) @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): """ Given a bytestring, deserialize into a value according to the protocol for this type. Note that this does not create a new instance of this @@ -323,7 +323,7 @@ def deserialize(byts, protocol_version): return byts @staticmethod - def serialize(val, protocol_version): + def serialize(val): """ Given a value appropriate for this class, serialize it according to the protocol for this type and return the corresponding bytestring. @@ -416,7 +416,7 @@ class BytesType(_CassandraType): empty_binary_ok = True @staticmethod - def serialize(val, protocol_version): + def serialize(val): return bytes(val) @@ -424,13 +424,13 @@ class DecimalType(_CassandraType): typename = 'decimal' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) return Decimal('%de%d' % (unscaled, -scale)) @staticmethod - def serialize(dec, protocol_version): + def serialize(dec): try: sign, digits, exponent = dec.as_tuple() except AttributeError: @@ -450,11 +450,11 @@ class UUIDType(_CassandraType): typename = 'uuid' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return UUID(bytes=byts) @staticmethod - def serialize(uuid, protocol_version): + def serialize(uuid): try: return uuid.bytes except AttributeError: @@ -468,11 +468,11 @@ class BooleanType(_CassandraType): typename = 'boolean' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return bool(int8_unpack(byts)) @staticmethod - def serialize(truth, protocol_version): + def serialize(truth): return int8_pack(truth) @classmethod @@ -483,11 +483,11 @@ class ByteType(_CassandraType): typename = 'tinyint' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return int8_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return int8_pack(byts) @@ -496,11 +496,11 @@ class AsciiType(_CassandraType): empty_binary_ok = True @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return byts.decode('ascii') @staticmethod - def serialize(var, protocol_version): + def serialize(var): try: return var.encode('ascii') except UnicodeDecodeError: @@ -511,11 +511,11 @@ class FloatType(_CassandraType): typename = 'float' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return float_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return float_pack(byts) @classmethod @@ -526,11 +526,11 @@ class DoubleType(_CassandraType): typename = 'double' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return double_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return double_pack(byts) @classmethod @@ -541,11 +541,11 @@ class LongType(_CassandraType): typename = 'bigint' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return int64_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return int64_pack(byts) @classmethod @@ -556,11 +556,11 @@ class Int32Type(_CassandraType): typename = 'int' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return int32_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return int32_pack(byts) @classmethod @@ -571,11 +571,11 @@ class IntegerType(_CassandraType): typename = 'varint' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return varint_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return varint_pack(byts) @@ -583,7 +583,7 @@ class InetAddressType(_CassandraType): typename = 'inet' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): if len(byts) == 16: return util.inet_ntop(socket.AF_INET6, byts) else: @@ -592,7 +592,7 @@ def deserialize(byts, protocol_version): return socket.inet_ntoa(byts) @staticmethod - def serialize(addr, protocol_version): + def serialize(addr): try: if ':' in addr: return util.inet_pton(socket.AF_INET6, addr) @@ -641,12 +641,12 @@ def interpret_datestring(val): raise ValueError("can't interpret %r as a date" % (val,)) @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): timestamp = int64_unpack(byts) / 1000.0 return util.datetime_from_timestamp(timestamp) @staticmethod - def serialize(v, protocol_version): + def serialize(v): try: # v is datetime timestamp_seconds = calendar.timegm(v.utctimetuple()) @@ -677,11 +677,11 @@ def my_timestamp(self): return util.unix_time_from_uuid1(self.val) @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return UUID(bytes=byts) @staticmethod - def serialize(timeuuid, protocol_version): + def serialize(timeuuid): try: return timeuuid.bytes except AttributeError: @@ -701,12 +701,12 @@ class SimpleDateType(_CassandraType): EPOCH_OFFSET_DAYS = 2 ** 31 @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS return util.Date(days) @staticmethod - def serialize(val, protocol_version): + def serialize(val): try: days = val.days_from_epoch except AttributeError: @@ -723,11 +723,11 @@ class ShortType(_CassandraType): typename = 'smallint' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return int16_unpack(byts) @staticmethod - def serialize(byts, protocol_version): + def serialize(byts): return int16_pack(byts) class TimeType(_CassandraType): @@ -740,11 +740,11 @@ class TimeType(_CassandraType): # return 8 @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return util.Time(int64_unpack(byts)) @staticmethod - def serialize(val, protocol_version): + def serialize(val): try: nano = val.nanosecond_time except AttributeError: @@ -756,12 +756,12 @@ class DurationType(_CassandraType): typename = 'duration' @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): months, days, nanoseconds = vints_unpack(byts) return util.Duration(months, days, nanoseconds) @staticmethod - def serialize(duration, protocol_version): + def serialize(duration): try: m, d, n = duration.months, duration.days, duration.nanoseconds except AttributeError: @@ -774,11 +774,11 @@ class UTF8Type(_CassandraType): empty_binary_ok = True @staticmethod - def deserialize(byts, protocol_version): + def deserialize(byts): return byts.decode('utf8') @staticmethod - def serialize(ustr, protocol_version): + def serialize(ustr): try: return ustr.encode('utf-8') except UnicodeDecodeError: @@ -794,29 +794,28 @@ class _ParameterizedType(_CassandraType): num_subtypes = 'UNKNOWN' @classmethod - def deserialize(cls, byts, protocol_version): + def deserialize(cls, byts): if not cls.subtypes: raise NotImplementedError("can't deserialize unparameterized %s" % cls.typename) - return cls.deserialize_safe(byts, protocol_version) + return cls.deserialize_safe(byts) @classmethod - def serialize(cls, val, protocol_version): + def serialize(cls, val): if not cls.subtypes: raise NotImplementedError("can't serialize unparameterized %s" % cls.typename) - return cls.serialize_safe(val, protocol_version) + return cls.serialize_safe(val) class _SimpleParameterizedType(_ParameterizedType): @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(cls, byts): subtype, = cls.subtypes length = 4 numelements = int32_unpack(byts[:length]) p = length result = [] - inner_proto = max(3, protocol_version) for _ in range(numelements): itemlen = int32_unpack(byts[p:p + length]) p += length @@ -825,23 +824,22 @@ def deserialize_safe(cls, byts, protocol_version): else: item = byts[p:p + itemlen] p += itemlen - result.append(subtype.from_binary(item, inner_proto)) + result.append(subtype.from_binary(item)) return cls.adapter(result) @classmethod - def serialize_safe(cls, items, protocol_version): + def serialize_safe(cls, items): if isinstance(items, str): raise TypeError("Received a string for a type that expects a sequence") subtype, = cls.subtypes buf = io.BytesIO() buf.write(int32_pack(len(items))) - inner_proto = max(3, protocol_version) for item in items: if item is None: buf.write(int32_pack(-1)) else: - itembytes = subtype.to_binary(item, inner_proto) + itembytes = subtype.to_binary(item) buf.write(int32_pack(len(itembytes))) buf.write(itembytes) return buf.getvalue() @@ -864,13 +862,12 @@ class MapType(_ParameterizedType): num_subtypes = 2 @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(cls, byts): key_type, value_type = cls.subtypes length = 4 numelements = int32_unpack(byts[:length]) p = length - themap = util.OrderedMapSerializedKey(key_type, protocol_version) - inner_proto = max(3, protocol_version) + themap = util.OrderedMapSerializedKey(key_type) for _ in range(numelements): key_len = int32_unpack(byts[p:p + length]) p += length @@ -880,7 +877,7 @@ def deserialize_safe(cls, byts, protocol_version): else: keybytes = byts[p:p + key_len] p += key_len - key = key_type.from_binary(keybytes, inner_proto) + key = key_type.from_binary(keybytes) val_len = int32_unpack(byts[p:p + length]) p += length @@ -889,13 +886,13 @@ def deserialize_safe(cls, byts, protocol_version): else: valbytes = byts[p:p + val_len] p += val_len - val = value_type.from_binary(valbytes, inner_proto) + val = value_type.from_binary(valbytes) themap._insert_unchecked(key, keybytes, val) return themap @classmethod - def serialize_safe(cls, themap, protocol_version): + def serialize_safe(cls, themap): key_type, value_type = cls.subtypes buf = io.BytesIO() buf.write(int32_pack(len(themap))) @@ -903,16 +900,15 @@ def serialize_safe(cls, themap, protocol_version): items = themap.items() except AttributeError: raise TypeError("Got a non-map object for a map value") - inner_proto = max(3, protocol_version) for key, val in items: if key is not None: - keybytes = key_type.to_binary(key, inner_proto) + keybytes = key_type.to_binary(key) buf.write(int32_pack(len(keybytes))) buf.write(keybytes) else: buf.write(int32_pack(-1)) if val is not None: - valbytes = value_type.to_binary(val, inner_proto) + valbytes = value_type.to_binary(val) buf.write(int32_pack(len(valbytes))) buf.write(valbytes) else: @@ -924,8 +920,7 @@ class TupleType(_ParameterizedType): typename = 'tuple' @classmethod - def deserialize_safe(cls, byts, protocol_version): - proto_version = max(3, protocol_version) + def deserialize_safe(cls, byts): p = 0 values = [] for col_type in cls.subtypes: @@ -938,9 +933,7 @@ def deserialize_safe(cls, byts, protocol_version): p += itemlen else: item = None - # collections inside UDTs are always encoded with at least the - # version 3 format - values.append(col_type.from_binary(item, proto_version)) + values.append(col_type.from_binary(item)) if len(values) < len(cls.subtypes): nones = [None] * (len(cls.subtypes) - len(values)) @@ -949,16 +942,15 @@ def deserialize_safe(cls, byts, protocol_version): return tuple(values) @classmethod - def serialize_safe(cls, val, protocol_version): + def serialize_safe(cls, val): if len(val) > len(cls.subtypes): raise ValueError("Expected %d items in a tuple, but got %d: %s" % (len(cls.subtypes), len(val), val)) - proto_version = max(3, protocol_version) buf = io.BytesIO() for item, subtype in zip(val, cls.subtypes): if item is not None: - packed_item = subtype.to_binary(item, proto_version) + packed_item = subtype.to_binary(item) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: @@ -1012,8 +1004,8 @@ def cql_parameterized_type(cls): return "frozen<%s>" % (cls.typename,) @classmethod - def deserialize_safe(cls, byts, protocol_version): - values = super(UserType, cls).deserialize_safe(byts, protocol_version) + def deserialize_safe(cls, byts): + values = super(UserType, cls).deserialize_safe(byts) if cls.mapped_class: return cls.mapped_class(**dict(zip(cls.fieldnames, values))) elif cls.tuple_type: @@ -1022,8 +1014,7 @@ def deserialize_safe(cls, byts, protocol_version): return tuple(values) @classmethod - def serialize_safe(cls, val, protocol_version): - proto_version = max(3, protocol_version) + def serialize_safe(cls, val): buf = io.BytesIO() for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)): # first treat as a tuple, else by custom type @@ -1035,7 +1026,7 @@ def serialize_safe(cls, val, protocol_version): log.warning(f"field {fieldname} is part of the UDT {cls.typename} but is not present in the value {val}") if item is not None: - packed_item = subtype.to_binary(item, proto_version) + packed_item = subtype.to_binary(item) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: @@ -1085,7 +1076,7 @@ def cql_parameterized_type(cls): return "'%s'" % (typestring,) @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(cls, byts): result = [] for subtype in cls.subtypes: if not byts: @@ -1097,7 +1088,7 @@ def deserialize_safe(cls, byts, protocol_version): # skip element length, element, and the EOC (one byte) byts = byts[2 + element_length + 1:] - result.append(subtype.from_binary(element, protocol_version)) + result.append(subtype.from_binary(element)) return tuple(result) @@ -1125,14 +1116,14 @@ class ReversedType(_ParameterizedType): num_subtypes = 1 @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(cls, byts): subtype, = cls.subtypes - return subtype.from_binary(byts, protocol_version) + return subtype.from_binary(byts) @classmethod - def serialize_safe(cls, val, protocol_version): + def serialize_safe(cls, val): subtype, = cls.subtypes - return subtype.to_binary(val, protocol_version) + return subtype.to_binary(val) class FrozenType(_ParameterizedType): @@ -1140,14 +1131,14 @@ class FrozenType(_ParameterizedType): num_subtypes = 1 @classmethod - def deserialize_safe(cls, byts, protocol_version): + def deserialize_safe(cls, byts): subtype, = cls.subtypes - return subtype.from_binary(byts, protocol_version) + return subtype.from_binary(byts) @classmethod - def serialize_safe(cls, val, protocol_version): + def serialize_safe(cls, val): subtype, = cls.subtypes - return subtype.to_binary(val, protocol_version) + return subtype.to_binary(val) def is_counter_type(t): @@ -1182,11 +1173,11 @@ class PointType(CassandraType): _type = struct.pack('[[]] type_ = int8_unpack(byts[0:1]) @@ -1376,7 +1367,7 @@ def deserialize(cls, byts, protocol_version): raise ValueError('Could not deserialize %r' % (byts,)) @classmethod - def serialize(cls, v, protocol_version): + def serialize(cls, v): buf = io.BytesIO() bound_kind, bounds = None, () @@ -1444,7 +1435,7 @@ def apply_parameters(cls, params, names): return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype}) @classmethod - def deserialize(cls, byts, protocol_version): + def deserialize(cls, byts): serialized_size = cls.subtype.serial_size() if serialized_size is not None: expected_byte_size = serialized_size * cls.vector_size @@ -1453,7 +1444,7 @@ def deserialize(cls, byts, protocol_version): "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\ .format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts))) indexes = (serialized_size * x for x in range(0, cls.vector_size)) - return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes] + return [cls.subtype.deserialize(byts[idx:idx + serialized_size]) for idx in indexes] idx = 0 rv = [] @@ -1461,7 +1452,7 @@ def deserialize(cls, byts, protocol_version): try: size, bytes_read = uvint_unpack(byts[idx:]) idx += bytes_read - rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version)) + rv.append(cls.subtype.deserialize(byts[idx:idx + size])) idx += size except: raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\ @@ -1473,7 +1464,7 @@ def deserialize(cls, byts, protocol_version): return rv @classmethod - def serialize(cls, v, protocol_version): + def serialize(cls, v): v_length = len(v) if cls.vector_size != v_length: raise ValueError( @@ -1483,7 +1474,7 @@ def serialize(cls, v, protocol_version): serialized_size = cls.subtype.serial_size() buf = io.BytesIO() for item in v: - item_bytes = cls.subtype.serialize(item, protocol_version) + item_bytes = cls.subtype.serialize(item) if serialized_size is None: buf.write(uvint_pack(len(item_bytes))) buf.write(item_bytes) diff --git a/cassandra/deserializers.pxd b/cassandra/deserializers.pxd index 7b307226ad..e7b04d26e5 100644 --- a/cassandra/deserializers.pxd +++ b/cassandra/deserializers.pxd @@ -26,18 +26,16 @@ cdef class Deserializer: # paragraph 6) cdef bint empty_binary_ok - cdef deserialize(self, Buffer *buf, int protocol_version) - # cdef deserialize(self, CString byts, protocol_version) + cdef deserialize(self, Buffer *buf) cdef inline object from_binary(Deserializer deserializer, - Buffer *buf, - int protocol_version): + Buffer *buf): if buf.size < 0: return None elif buf.size == 0 and not deserializer.empty_binary_ok: return _ret_empty(deserializer, buf.size) else: - return deserializer.deserialize(buf, protocol_version) + return deserializer.deserialize(buf) cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size) diff --git a/cassandra/deserializers.pyx b/cassandra/deserializers.pyx index 97d249d02f..e477e47a14 100644 --- a/cassandra/deserializers.pyx +++ b/cassandra/deserializers.pyx @@ -36,12 +36,12 @@ cdef class Deserializer: self.cqltype = cqltype self.empty_binary_ok = cqltype.empty_binary_ok - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): raise NotImplementedError cdef class DesBytesType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return b"" return to_bytes(buf) @@ -50,14 +50,14 @@ cdef class DesBytesType(Deserializer): # It is switched in by simply overwriting DesBytesType: # deserializers.DesBytesType = deserializers.DesBytesTypeByteArray cdef class DesBytesTypeByteArray(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return bytearray() return bytearray(buf.ptr[:buf.size]) # TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html cdef class DesDecimalType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Buffer varint_buf slice_buffer(buf, &varint_buf, 4, buf.size - 4) @@ -68,56 +68,56 @@ cdef class DesDecimalType(Deserializer): cdef class DesUUIDType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return UUID(bytes=to_bytes(buf)) cdef class DesBooleanType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if unpack_num[int8_t](buf): return True return False cdef class DesByteType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int8_t](buf) cdef class DesAsciiType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return "" return to_bytes(buf).decode('ascii') cdef class DesFloatType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[float](buf) cdef class DesDoubleType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[double](buf) cdef class DesLongType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int64_t](buf) cdef class DesInt32Type(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int32_t](buf) cdef class DesIntegerType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return varint_unpack(buf) cdef class DesInetAddressType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef bytes byts = to_bytes(buf) # TODO: optimize inet_ntop, inet_ntoa @@ -134,7 +134,7 @@ cdef class DesCounterColumnType(DesLongType): cdef class DesDateType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef double timestamp = unpack_num[int64_t](buf) / 1000.0 return datetime_from_timestamp(timestamp) @@ -144,7 +144,7 @@ cdef class TimestampType(DesDateType): cdef class TimeUUIDType(DesDateType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return UUID(bytes=to_bytes(buf)) @@ -154,23 +154,23 @@ cdef class TimeUUIDType(DesDateType): EPOCH_OFFSET_DAYS = 2 ** 31 cdef class DesSimpleDateType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): days = unpack_num[uint32_t](buf) - EPOCH_OFFSET_DAYS return util.Date(days) cdef class DesShortType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return unpack_num[int16_t](buf) cdef class DesTimeType(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): return util.Time(unpack_num[int64_t](buf)) cdef class DesUTF8Type(Deserializer): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): if buf.size == 0: return "" cdef val = to_bytes(buf) @@ -207,19 +207,19 @@ cdef class _DesSingleParamType(_DesParameterizedType): # List and set deserialization cdef class DesListType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): result = _deserialize_list_or_set( - buf, protocol_version, self.deserializer) + buf, self.deserializer) return result cdef class DesSetType(DesListType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) + cdef deserialize(self, Buffer *buf): + return util.sortedset(DesListType.deserialize(self, buf)) -cdef list _deserialize_list_or_set(Buffer *buf, int protocol_version, +cdef list _deserialize_list_or_set(Buffer *buf, Deserializer deserializer): """ Deserialize a list or set. @@ -233,10 +233,9 @@ cdef list _deserialize_list_or_set(Buffer *buf, int protocol_version, _unpack_len(buf, 0, &numelements) offset = sizeof(int32_t) - protocol_version = max(3, protocol_version) for _ in range(numelements): subelem(buf, &elem_buf, &offset) - result.append(from_binary(deserializer, &elem_buf, protocol_version)) + result.append(from_binary(deserializer, &elem_buf)) return result @@ -277,20 +276,20 @@ cdef class DesMapType(_DesParameterizedType): self.key_deserializer = self.deserializers[0] self.val_deserializer = self.deserializers[1] - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): key_type, val_type = self.cqltype.subtypes result = _deserialize_map( - buf, protocol_version, + buf, self.key_deserializer, self.val_deserializer, - key_type, val_type) + key_type) return result -cdef _deserialize_map(Buffer *buf, int protocol_version, +cdef _deserialize_map(Buffer *buf, Deserializer key_deserializer, Deserializer val_deserializer, - key_type, val_type): + key_type): cdef Buffer key_buf, val_buf cdef Buffer itemlen_buf @@ -300,13 +299,12 @@ cdef _deserialize_map(Buffer *buf, int protocol_version, _unpack_len(buf, 0, &numelements) offset = sizeof(int32_t) - themap = util.OrderedMapSerializedKey(key_type, protocol_version) - protocol_version = max(3, protocol_version) + themap = util.OrderedMapSerializedKey(key_type) for _ in range(numelements): subelem(buf, &key_buf, &offset) subelem(buf, &val_buf, &offset) - key = from_binary(key_deserializer, &key_buf, protocol_version) - val = from_binary(val_deserializer, &val_buf, protocol_version) + key = from_binary(key_deserializer, &key_buf) + val = from_binary(val_deserializer, &val_buf) themap._insert_unchecked(key, to_bytes(&key_buf), val) return themap @@ -317,7 +315,7 @@ cdef class DesTupleType(_DesParameterizedType): # TODO: Use TupleRowParser to parse these tuples - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Py_ssize_t i, p cdef int32_t itemlen cdef tuple res = tuple_new(self.subtypes_len) @@ -325,10 +323,6 @@ cdef class DesTupleType(_DesParameterizedType): cdef Buffer itemlen_buf cdef Deserializer deserializer - # collections inside UDTs are always encoded with at least the - # version 3 format - protocol_version = max(3, protocol_version) - p = 0 values = [] for i in range(self.subtypes_len): @@ -342,7 +336,7 @@ cdef class DesTupleType(_DesParameterizedType): p += itemlen deserializer = self.deserializers[i] - item = from_binary(deserializer, &item_buf, protocol_version) + item = from_binary(deserializer, &item_buf) tuple_set(res, i, item) @@ -350,9 +344,9 @@ cdef class DesTupleType(_DesParameterizedType): cdef class DesUserType(DesTupleType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): typ = self.cqltype - values = DesTupleType.deserialize(self, buf, protocol_version) + values = DesTupleType.deserialize(self, buf) if typ.mapped_class: return typ.mapped_class(**dict(zip(typ.fieldnames, values))) elif typ.tuple_type: @@ -362,7 +356,7 @@ cdef class DesUserType(DesTupleType): cdef class DesCompositeType(_DesParameterizedType): - cdef deserialize(self, Buffer *buf, int protocol_version): + cdef deserialize(self, Buffer *buf): cdef Py_ssize_t i, idx, start cdef Buffer elem_buf cdef int16_t element_length @@ -387,7 +381,7 @@ cdef class DesCompositeType(_DesParameterizedType): slice_buffer(buf, &elem_buf, 2, element_length) deserializer = self.deserializers[i] - item = from_binary(deserializer, &elem_buf, protocol_version) + item = from_binary(deserializer, &elem_buf) tuple_set(res, i, item) # skip element length, element, and the EOC (one byte) @@ -401,13 +395,13 @@ DesDynamicCompositeType = DesCompositeType cdef class DesReversedType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return from_binary(self.deserializer, buf, protocol_version) + cdef deserialize(self, Buffer *buf): + return from_binary(self.deserializer, buf) cdef class DesFrozenType(_DesSingleParamType): - cdef deserialize(self, Buffer *buf, int protocol_version): - return from_binary(self.deserializer, buf, protocol_version) + cdef deserialize(self, Buffer *buf): + return from_binary(self.deserializer, buf) #-------------------------------------------------------------------------- @@ -431,8 +425,8 @@ cdef class GenericDeserializer(Deserializer): Wrap a generic datatype for deserialization """ - cdef deserialize(self, Buffer *buf, int protocol_version): - return self.cqltype.deserialize(to_bytes(buf), protocol_version) + cdef deserialize(self, Buffer *buf): + return self.cqltype.deserialize(to_bytes(buf)) def __repr__(self): return "GenericDeserializer(%s)" % (self.cqltype,) diff --git a/cassandra/marshal.py b/cassandra/marshal.py index 413e1831d4..8aebd60370 100644 --- a/cassandra/marshal.py +++ b/cassandra/marshal.py @@ -33,7 +33,6 @@ def _make_packer(format_string): float_pack, float_unpack = _make_packer('>f') double_pack, double_unpack = _make_packer('>d') -# in protocol version 3 and higher, the stream ID is two bytes v3_header_struct = struct.Struct('>BBhB') v3_header_pack = v3_header_struct.pack v3_header_unpack = v3_header_struct.unpack diff --git a/cassandra/metadata.py b/cassandra/metadata.py index b85308449e..79f1d08287 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2205,7 +2205,7 @@ def _build_aggregate(cls, aggregate_row): cass_state_type = types.lookup_casstype(aggregate_row['state_type']) initial_condition = aggregate_row['initcond'] if initial_condition is not None: - initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) + initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition)) state_type = _cql_from_cass_type(cass_state_type) return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], @@ -3495,7 +3495,7 @@ def group_keys_by_replica(session, keyspace, table, keys): distance = cluster._default_load_balancing_policy.distance for key in keys: - serialized_key = [serializer.serialize(pk, cluster.protocol_version) + serialized_key = [serializer.serialize(pk) for serializer, pk in zip(serializers, key)] if len(serialized_key) == 1: routing_key = serialized_key[0] diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 0ad34f66e2..8687305ae8 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -156,7 +156,7 @@ cdef inline int unpack_row( if arr.is_object: deserializer = desc.deserializers[i] - val = from_binary(deserializer, &buf, desc.protocol_version) + val = from_binary(deserializer, &buf) Py_INCREF(val) ( arr.buf_ptr)[0] = val elif buf.size >= 0: diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index cf43771dd7..819507b137 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -31,7 +31,10 @@ cdef class ListParser(ColumnParser): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] + if desc.column_encryption_policy: + return [rowparser.unpack_col_encrypted_row(reader, desc) for i in range(rowcount)] + else: + return [rowparser.unpack_plain_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): @@ -47,7 +50,10 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) + if desc.column_encryption_policy: + return (rowparser.unpack_col_encrypted_row(reader, desc) for i in range(rowcount)) + else: + return (rowparser.unpack_plain_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): @@ -55,9 +61,11 @@ cdef class TupleRowParser(RowParser): Parse a single returned row into a tuple of objects: (obj1, ..., objN) + If CE (Column encryption) policy is enabled - use unpack_col_encrypted_row(), + otherwsise use unpack_plain_row() """ - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_col_encrypted_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf @@ -67,28 +75,53 @@ cdef class TupleRowParser(RowParser): cdef tuple res = tuple_new(desc.rowsize) ce_policy = desc.column_encryption_policy - for i in range(rowsize): - # Read the next few bytes - get_buf(reader, &buf) - - # Deserialize bytes to python object - deserializer = desc.deserializers[i] - coldesc = desc.coldescs[i] - uses_ce = ce_policy and ce_policy.contains_column(coldesc) - try: + try: + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + coldesc = desc.coldescs[i] + uses_ce = ce_policy.contains_column(coldesc) if uses_ce: col_type = ce_policy.column_type(coldesc) decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) PyBytes_AsStringAndSize(decrypted_bytes, &newbuf.ptr, &newbuf.size) deserializer = find_deserializer(ce_policy.column_type(coldesc)) - val = from_binary(deserializer, &newbuf, desc.protocol_version) + val = from_binary(deserializer, &newbuf) else: - val = from_binary(deserializer, &buf, desc.protocol_version) - except Exception as e: - raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], - desc.coltypes[i].cql_parameterized_type(), - str(e))) - # Insert new object into tuple - tuple_set(res, i, val) + val = from_binary(deserializer, &buf) + # Insert new object into tuple + tuple_set(res, i, val) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) + + return res + + cpdef unpack_plain_row(self, BytesIOReader reader, ParseDesc desc): + assert desc.rowsize >= 0 + + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + try: + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + val = from_binary(deserializer, &buf) + # Insert new object into tuple + tuple_set(res, i, val) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) return res diff --git a/cassandra/parsing.pxd b/cassandra/parsing.pxd index 27dc368b07..8fcee2c45e 100644 --- a/cassandra/parsing.pxd +++ b/cassandra/parsing.pxd @@ -21,7 +21,6 @@ cdef class ParseDesc: cdef public object column_encryption_policy cdef public list coldescs cdef Deserializer[::1] deserializers - cdef public int protocol_version cdef Py_ssize_t rowsize cdef class ColumnParser: diff --git a/cassandra/parsing.pyx b/cassandra/parsing.pyx index 954767d227..c67f18fd88 100644 --- a/cassandra/parsing.pyx +++ b/cassandra/parsing.pyx @@ -19,13 +19,12 @@ Module containing the definitions and declarations (parsing.pxd) for parsers. cdef class ParseDesc: """Description of what structure to parse""" - def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers, protocol_version): + def __init__(self, colnames, coltypes, column_encryption_policy, coldescs, deserializers): self.colnames = colnames self.coltypes = coltypes self.column_encryption_policy = column_encryption_policy self.coldescs = coldescs self.deserializers = deserializers - self.protocol_version = protocol_version self.rowsize = len(colnames) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..b37095db78 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -457,15 +457,9 @@ def __init__(self, creds): self.creds = creds def send_body(self, f, protocol_version): - if protocol_version > 1: - raise UnsupportedOperation( - "Credentials-based authentication is not supported with " - "protocol version 2 or higher. Use the SASL authentication " - "mechanism instead.") - write_short(f, len(self.creds)) - for credkey, credval in self.creds.items(): - write_string(f, credkey) - write_string(f, credval) + raise UnsupportedOperation( + "Credentials-based authentication is not supported. " + "Use the SASL authentication mechanism instead.") class AuthChallengeMessage(_MessageType): @@ -695,7 +689,7 @@ def recv(self, f, protocol_version, protocol_features, user_type_map, result_met if self.kind == RESULT_KIND_VOID: return elif self.kind == RESULT_KIND_ROWS: - self.recv_results_rows(f, protocol_version, user_type_map, result_metadata, column_encryption_policy) + self.recv_results_rows(f, user_type_map, result_metadata, column_encryption_policy) elif self.kind == RESULT_KIND_SET_KEYSPACE: self.new_keyspace = read_string(f) elif self.kind == RESULT_KIND_PREPARED: @@ -712,7 +706,7 @@ def recv_body(cls, f, protocol_version, protocol_features, user_type_map, result msg.recv(f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy) return msg - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows(self, f, user_type_map, result_metadata, column_encryption_policy): self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) @@ -725,7 +719,7 @@ def decode_val(val, col_md, col_desc): uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) + return col_type.from_binary(raw_bytes) def decode_row(row): return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) @@ -790,10 +784,8 @@ def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_ty flags = read_int(f) self.is_lwt = protocol_features.lwt_info.get_lwt_flag(flags) if protocol_features.lwt_info is not None else False colcount = read_int(f) - pk_indexes = None - if protocol_version >= 4: - num_pk_indexes = read_int(f) - pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] + num_pk_indexes = read_int(f) + pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] glob_tblspec = bool(flags & self._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: @@ -817,7 +809,7 @@ def recv_prepared_metadata(self, f, protocol_version, protocol_features, user_ty self.pk_indexes = pk_indexes def recv_results_schema_change(self, f, protocol_version): - self.schema_change_event = EventMessage.recv_schema_change(f, protocol_version) + self.schema_change_event = EventMessage.recv_schema_change(f) @classmethod def read_type(cls, f, user_type_map): @@ -985,11 +977,11 @@ def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) - return cls(event_type=event_type, event_args=read_method(f, protocol_version)) + return cls(event_type=event_type, event_args=read_method(f)) raise NotSupportedError('Unknown event type %r' % event_type) @classmethod - def recv_client_routes_change(cls, f, protocol_version): + def recv_client_routes_change(cls, f): # "UPDATE_NODES" change_type = read_string(f) connection_ids = read_stringlist(f) @@ -997,21 +989,21 @@ def recv_client_routes_change(cls, f, protocol_version): return dict(change_type=change_type, connection_ids=connection_ids, host_ids=host_ids) @classmethod - def recv_topology_change(cls, f, protocol_version): + def recv_topology_change(cls, f): # "NEW_NODE" or "REMOVED_NODE" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod - def recv_status_change(cls, f, protocol_version): + def recv_status_change(cls, f): # "UP" or "DOWN" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod - def recv_schema_change(cls, f, protocol_version): + def recv_schema_change(cls, f): # "CREATED", "DROPPED", or "UPDATED" change_type = read_string(f) target = read_string(f) @@ -1086,8 +1078,6 @@ def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta """ flags = 0 if msg.custom_payload: - if protocol_version < 4: - raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG if msg.tracing: diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..d152738ca6 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -235,9 +235,6 @@ class Statement(object): How many rows will be fetched at a time. This overrides the default of :attr:`.Session.default_fetch_size` - This only takes effect when protocol version 2 or higher is used. - See :attr:`.Cluster.protocol_version` for details. - .. versionadded:: 2.0.0 """ @@ -448,7 +445,6 @@ class PreparedStatement(object): custom_payload = None fetch_size = FETCH_SIZE_UNSET keyspace = None # change to prepared_keyspace in major release - protocol_version = None query_id = None query_string = None result_metadata = None @@ -460,14 +456,13 @@ class PreparedStatement(object): _is_lwt = False def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata, result_metadata_id, + keyspace, result_metadata, result_metadata_id, is_lwt=False, column_encryption_policy=None): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace - self.protocol_version = protocol_version self.result_metadata = result_metadata self.result_metadata_id = result_metadata_id self.column_encryption_policy = column_encryption_policy @@ -476,11 +471,11 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata, + query, prepared_keyspace, result_metadata, result_metadata_id, is_lwt, column_encryption_policy=None): if not column_metadata: return PreparedStatement(column_metadata, query_id, None, - query, prepared_keyspace, protocol_version, result_metadata, + query, prepared_keyspace, result_metadata, result_metadata_id, is_lwt, column_encryption_policy) if pk_indexes: @@ -506,7 +501,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata, + query, prepared_keyspace, result_metadata, result_metadata_id, is_lwt, column_encryption_policy) def bind(self, values): @@ -597,7 +592,6 @@ def bind(self, values): """ if values is None: values = () - proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata ce_policy = self.prepared_statement.column_encryption_policy @@ -611,12 +605,7 @@ def bind(self, values): try: values.append(values_dict[col.name]) except KeyError: - if proto_version >= 4: - values.append(UNSET_VALUE) - else: - raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + values.append(UNSET_VALUE) value_len = len(values) col_meta_len = len(col_meta) @@ -626,30 +615,19 @@ def bind(self, values): "Too many arguments provided to bind() (got %d, expected %d)" % (len(values), len(col_meta))) - # this is fail-fast for clarity pre-v4. When v4 can be assumed, - # the error will be better reported when UNSET_VALUE is implicitly added. - if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ - value_len < len(self.prepared_statement.routing_key_indexes): - raise ValueError( - "Too few arguments provided to bind() (got %d, required %d for routing key)" % - (value_len, len(self.prepared_statement.routing_key_indexes))) - self.raw_values = values self.values = [] for value, col_spec in zip(values, col_meta): if value is None: self.values.append(None) elif value is UNSET_VALUE: - if proto_version >= 4: - self._append_unset_value() - else: - raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) + self._append_unset_value() else: try: col_desc = ColDesc(col_spec.keyspace_name, col_spec.table_name, col_spec.name) uses_ce = ce_policy and ce_policy.contains_column(col_desc) col_type = ce_policy.column_type(col_desc) if uses_ce else col_spec.type - col_bytes = col_type.serialize(value, proto_version) + col_bytes = col_type.serialize(value) if uses_ce: col_bytes = ce_policy.encrypt(col_desc, col_bytes) self.values.append(col_bytes) @@ -659,11 +637,10 @@ def bind(self, values): 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) raise TypeError(message) - if proto_version >= 4: - diff = col_meta_len - len(self.values) - if diff: - for _ in range(diff): - self._append_unset_value() + diff = col_meta_len - len(self.values) + if diff: + for _ in range(diff): + self._append_unset_value() return self diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..80a66c3291 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -20,7 +20,7 @@ from cassandra.deserializers import make_deserializers include "ioutils.pyx" def make_recv_results_rows(ColumnParser colparser): - def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows(self, f, user_type_map, result_metadata, column_encryption_policy): """ Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) This is used as the recv_results_rows method of (Fast)ResultMessage @@ -34,7 +34,7 @@ def make_recv_results_rows(ColumnParser colparser): desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, [ColDesc(md[0], md[1], md[2]) for md in column_metadata], - make_deserializers(self.column_types), protocol_version) + make_deserializers(self.column_types)) reader = BytesIOReader(f.read()) try: self.parsed_rows = colparser.parse_rows(reader, desc) diff --git a/cassandra/util.py b/cassandra/util.py index 12886d05ab..5cef7f61af 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -767,17 +767,16 @@ def _serialize_key(self, key): class OrderedMapSerializedKey(OrderedMap): - def __init__(self, cass_type, protocol_version): + def __init__(self, cass_type): super(OrderedMapSerializedKey, self).__init__() self.cass_key_type = cass_type - self.protocol_version = protocol_version def _insert_unchecked(self, key, flat_key, value): self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 def _serialize_key(self, key): - return self.cass_key_type.serialize(key, self.protocol_version) + return self.cass_key_type.serialize(key) @total_ordering diff --git a/docs/cqlengine/models.rst b/docs/cqlengine/models.rst index 719513f4a9..a9f6b657c7 100644 --- a/docs/cqlengine/models.rst +++ b/docs/cqlengine/models.rst @@ -212,7 +212,4 @@ synchronize any types contained in the table. Alternatively :func:`~.management. explicitly. Upon declaration, types are automatically registered with the driver, so query results return instances of your ``UserType`` -class*. - -***Note**: UDTs were not added to the native protocol until v3. When setting up the cqlengine connection, be sure to specify -``protocol_version=3``. If using an earlier version, UDT queries will still work, but the returned type will be a namedtuple. +class. diff --git a/docs/object-mapper.rst b/docs/object-mapper.rst index 5eb78f57b6..f62b6cb9d9 100644 --- a/docs/object-mapper.rst +++ b/docs/object-mapper.rst @@ -65,7 +65,7 @@ Getting Started #next, setup the connection to your cassandra server(s)... # see http://datastax.github.io/python-driver/api/cassandra/cluster.html for options # the list of hosts will be passed to create a Cluster() instance - connection.setup(['127.0.0.1'], "cqlengine", protocol_version=3) + connection.setup(['127.0.0.1'], "cqlengine") #...and create your CQL table >>> sync_table(ExampleModel) diff --git a/docs/user-defined-types.rst b/docs/user-defined-types.rst index 32c03e37e8..a6dd933b52 100644 --- a/docs/user-defined-types.rst +++ b/docs/user-defined-types.rst @@ -50,7 +50,7 @@ Map a dict to a UDT .. code-block:: python - cluster = Cluster(protocol_version=3) + cluster = Cluster() session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") @@ -80,7 +80,7 @@ for the UDT: .. code-block:: python - cluster = Cluster(protocol_version=3) + cluster = Cluster() session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index b4eab35875..3fdce16210 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -268,7 +268,6 @@ def xfail_scylla_version(filter: Callable[[Version], bool], reason: str, *args, local = local_decorator_creator() notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') -greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') @@ -295,8 +294,8 @@ def xfail_scylla_version(filter: Callable[[Version], bool], reason: str, *args, reason='Scylla does not support UDFs written in Java') requires_composite_type = pytest.mark.skipif(SCYLLA_VERSION is not None, reason='Scylla does not support composite types') -requires_custom_payload = pytest.mark.skipif(SCYLLA_VERSION is not None or PROTOCOL_VERSION < 4, - reason='Scylla does not support custom payloads. Cassandra requires native protocol v4.0+') +requires_custom_payload = pytest.mark.skipif(SCYLLA_VERSION is not None, + reason='Scylla does not support custom payloads') xfail_scylla = lambda reason, *args, **kwargs: pytest.mark.xfail(SCYLLA_VERSION is not None, reason=reason, *args, **kwargs) incorrect_test = lambda reason='This test seems to be incorrect and should be fixed', *args, **kwargs: pytest.mark.xfail(reason=reason, *args, **kwargs) @@ -508,7 +507,7 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, if 'graph' in workloads: jvm_args += ['-Xms1500M', '-Xmx1500M'] else: - if PROTOCOL_VERSION >= 4 and not SCYLLA_VERSION: + if not SCYLLA_VERSION: jvm_args = [" -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler"] if len(workloads) > 0: for node in CCM_CLUSTER.nodes.values(): diff --git a/tests/integration/cqlengine/columns/test_static_column.py b/tests/integration/cqlengine/columns/test_static_column.py index e5d63c29bb..8a8c0a91be 100644 --- a/tests/integration/cqlengine/columns/test_static_column.py +++ b/tests/integration/cqlengine/columns/test_static_column.py @@ -21,11 +21,7 @@ from cassandra.cqlengine.models import Model from tests.integration.cqlengine.base import BaseCassEngTestCase -from tests.integration import PROTOCOL_VERSION -# TODO: is this really a protocol limitation, or is it just C* version? -# good enough proxy for now -STATIC_SUPPORTED = PROTOCOL_VERSION >= 2 class TestStaticModel(Model): __test__ = False @@ -38,16 +34,13 @@ class TestStaticModel(Model): class TestStaticColumn(BaseCassEngTestCase): - def setUp(cls): - if not STATIC_SUPPORTED: - raise unittest.SkipTest("only runs against the cql3 protocol v2.0") - super(TestStaticColumn, cls).setUp() + def setUp(self): + super(TestStaticColumn, self).setUp() @classmethod def setUpClass(cls): drop_table(TestStaticModel) - if STATIC_SUPPORTED: # setup and teardown run regardless of skip - sync_table(TestStaticModel) + sync_table(TestStaticModel) @classmethod def tearDownClass(cls): diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index ebffc0666c..ddea4afad3 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -31,7 +31,7 @@ from cassandra.cqlengine.usertype import UserType from cassandra import util -from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 +from tests.integration import CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 from tests.integration.cqlengine.base import BaseCassEngTestCase import pytest @@ -193,7 +193,7 @@ def test_varint_io(self): class DataType(): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + if CASSANDRA_VERSION < Version("3.0"): return class DataTypeTest(Model): @@ -205,16 +205,16 @@ class DataTypeTest(Model): @classmethod def tearDownClass(cls): - if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + if CASSANDRA_VERSION < Version("3.0"): return drop_table(cls.model_class) def setUp(self): - if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + if CASSANDRA_VERSION < Version("3.0"): raise unittest.SkipTest("Protocol v4 datatypes " - "require native protocol 4+ and C* version >=3.0, " - "currently using protocol {0} and C* version {1}". - format(PROTOCOL_VERSION, CASSANDRA_VERSION)) + "require C* version >=3.0, " + "currently using C* version {0}". + format(CASSANDRA_VERSION)) def _check_value_is_correct_in_db(self, value): """ @@ -385,7 +385,7 @@ class UserModel(Model): class TestUDT(DataType, BaseCassEngTestCase): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"): + if CASSANDRA_VERSION < Version("3.0"): return cls.db_klass, cls.python_klass = UserDefinedType, User diff --git a/tests/integration/cqlengine/columns/test_value_io.py b/tests/integration/cqlengine/columns/test_value_io.py index 758ca714a6..98a3d746e4 100644 --- a/tests/integration/cqlengine/columns/test_value_io.py +++ b/tests/integration/cqlengine/columns/test_value_io.py @@ -24,7 +24,6 @@ from cassandra.util import Date, Time -from tests.integration import PROTOCOL_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase @@ -201,20 +200,15 @@ class ProtocolV4Test(BaseColumnIOTest): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION >= 4: - super(ProtocolV4Test, cls).setUpClass() + super(ProtocolV4Test, cls).setUpClass() @classmethod def tearDownClass(cls): - if PROTOCOL_VERSION >= 4: - super(ProtocolV4Test, cls).tearDownClass() + super(ProtocolV4Test, cls).tearDownClass() class TestDate(ProtocolV4Test): def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - super(TestDate, self).setUp() column = columns.Date @@ -227,9 +221,6 @@ def setUp(self): class TestTime(ProtocolV4Test): def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - super(TestTime, self).setUp() column = columns.Time @@ -241,9 +232,6 @@ def setUp(self): class TestSmallInt(ProtocolV4Test): def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - super(TestSmallInt, self).setUp() column = columns.SmallInt @@ -255,9 +243,6 @@ def setUp(self): class TestTinyInt(ProtocolV4Test): def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - super(TestTinyInt, self).setUp() column = columns.TinyInt diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py index 78d5133e63..9160d43fce 100644 --- a/tests/integration/cqlengine/connections/test_connection.py +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -23,7 +23,7 @@ from cassandra.policies import RoundRobinPolicy from cassandra.query import dict_factory -from tests.integration import CASSANDRA_IP, PROTOCOL_VERSION, execute_with_long_wait_retry, local, TestCluster +from tests.integration import CASSANDRA_IP, execute_with_long_wait_retry, local, TestCluster from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE, setup_connection diff --git a/tests/integration/cqlengine/management/test_management.py b/tests/integration/cqlengine/management/test_management.py index 1332680cef..2116ef4907 100644 --- a/tests/integration/cqlengine/management/test_management.py +++ b/tests/integration/cqlengine/management/test_management.py @@ -23,7 +23,7 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns -from tests.integration import PROTOCOL_VERSION, greaterthancass20, requires_collection_indexes, \ +from tests.integration import greaterthancass20, requires_collection_indexes, \ MockLoggingHandler, CASSANDRA_VERSION, SCYLLA_VERSION, xfail_scylla from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.query.test_queryset import TestModel @@ -464,9 +464,6 @@ def test_failure(self): class StaticColumnTests(BaseCassEngTestCase): def test_static_columns(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest("Native protocol 2+ required, currently using: {0}".format(PROTOCOL_VERSION)) - class StaticModel(Model): id = columns.Integer(primary_key=True) c = columns.Integer(primary_key=True) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index f55815310a..a1ce2d6984 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -30,7 +30,7 @@ from cassandra.cqlengine.statements import SelectStatement, DeleteStatement, WhereClause from cassandra.cqlengine.operators import EqualsOperator -from tests.integration import PROTOCOL_VERSION, greaterthanorequalcass3_10 +from tests.integration import greaterthanorequalcass3_10 from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE from tests.util import assertSetEqual @@ -248,9 +248,6 @@ def test_can_insert_model_with_all_protocol_v4_column_types(self): @test_category data_types:primitive """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - class v4DatatypesModel(Model): id = columns.Integer(primary_key=True) a = columns.Date() @@ -672,25 +669,15 @@ class TestQuerying(BaseCassEngTestCase): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION < 4: - return - super(TestQuerying, cls).setUpClass() drop_table(TestQueryModel) sync_table(TestQueryModel) @classmethod def tearDownClass(cls): - if PROTOCOL_VERSION < 4: - return - super(TestQuerying, cls).tearDownClass() drop_table(TestQueryModel) - def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Date query tests require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - def test_query_with_date(self): uid = uuid4() day = date(2013, 11, 26) @@ -770,7 +757,7 @@ def test_routing_key_is_ignored(self): """.format(DEFAULT_KEYSPACE)) bound = prepared.bind((1, 2)) - mrk = BasicModelNoRouting._routing_key_from_values([1], self.session.cluster.protocol_version) + mrk = BasicModelNoRouting._routing_key_from_values([1]) simple = SimpleStatement("") simple.routing_key = mrk assert bound.routing_key != simple.routing_key @@ -802,7 +789,7 @@ def test_routing_key_generation_basic(self): """.format(DEFAULT_KEYSPACE)) bound = prepared.bind((1, 2)) - mrk = BasicModel._routing_key_from_values([1], self.session.cluster.protocol_version) + mrk = BasicModel._routing_key_from_values([1]) simple = SimpleStatement("") simple.routing_key = mrk assert bound.routing_key == simple.routing_key @@ -823,7 +810,7 @@ def test_routing_key_generation_multi(self): INSERT INTO {0}.basic_model_routing_multi (k, v) VALUES (?, ?) """.format(DEFAULT_KEYSPACE)) bound = prepared.bind((1, 2)) - mrk = BasicModelMulti._routing_key_from_values([1, 2], self.session.cluster.protocol_version) + mrk = BasicModelMulti._routing_key_from_values([1, 2]) simple = SimpleStatement("") simple.routing_key = mrk assert bound.routing_key == simple.routing_key @@ -849,7 +836,7 @@ def test_routing_key_generation_complex(self): float = 1.2 text_2 = "text_2" bound = prepared.bind((partition, cluster, count, text, float, text_2)) - mrk = ComplexModelRouting._routing_key_from_values([partition, cluster, text, float], self.session.cluster.protocol_version) + mrk = ComplexModelRouting._routing_key_from_values([partition, cluster, text, float]) simple = SimpleStatement("") simple.routing_key = mrk assert bound.routing_key == simple.routing_key diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py index 80f1b9693f..1647856e42 100644 --- a/tests/integration/cqlengine/model/test_udts.py +++ b/tests/integration/cqlengine/model/test_udts.py @@ -25,7 +25,6 @@ from cassandra.cqlengine import ValidationError from cassandra.util import Date, Time -from tests.integration import PROTOCOL_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE import pytest @@ -65,10 +64,6 @@ class AllDatatypesModel(Model): class UserDefinedTypeTests(BaseCassEngTestCase): - def setUp(self): - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest("UDTs require native protocol 3+, currently using: {0}".format(PROTOCOL_VERSION)) - def test_can_create_udts(self): class User(UserType): age = columns.Integer() @@ -302,9 +297,6 @@ def test_can_insert_udts_protocol_v4_datatypes(self): @test_category data_types:udt """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol v4 datatypes in UDTs require native protocol 4+, currently using: {0}".format(PROTOCOL_VERSION)) - class Allv4Datatypes(UserType): a = columns.Date() b = columns.SmallInt() diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 34b4ab5964..2cae49eae5 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -38,7 +38,7 @@ from cassandra.cqlengine import operators from cassandra.util import uuid_from_time from cassandra.cqlengine.connection import get_session -from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ +from tests.integration import CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ greaterthanorequalcass30, TestCluster, requires_collection_indexes from tests.integration.cqlengine import execute_count, DEFAULT_KEYSPACE import pytest @@ -1097,9 +1097,6 @@ def test_objects_property_returns_fresh_queryset(self): class PageQueryTests(BaseCassEngTestCase): @execute_count(3) def test_paged_result_handling(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest("Paging requires native protocol 2+, currently using: {0}".format(PROTOCOL_VERSION)) - # addresses #225 class PagingTest(Model): id = columns.Integer(primary_key=True) diff --git a/tests/integration/cqlengine/test_ifexists.py b/tests/integration/cqlengine/test_ifexists.py index 6c2ff437ab..6313aa3000 100644 --- a/tests/integration/cqlengine/test_ifexists.py +++ b/tests/integration/cqlengine/test_ifexists.py @@ -21,7 +21,6 @@ from cassandra.cqlengine.query import BatchQuery, BatchType, LWTException, IfExistsWithCounterColumn from tests.integration.cqlengine.base import BaseCassEngTestCase -from tests.integration import PROTOCOL_VERSION import pytest @@ -78,7 +77,6 @@ def tearDownClass(cls): class IfExistsUpdateTests(BaseIfExistsTest): - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_update_if_exists(self): """ Tests that update with if_exists work as expected @@ -116,7 +114,6 @@ def test_update_if_exists(self): assert assertion.value.existing.get('[applied]') == False - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_update_if_exists_success(self): """ Tests that batch update with if_exists work as expected @@ -150,7 +147,6 @@ def test_batch_update_if_exists_success(self): assert tm.count == 8 assert tm.text == '111111111' - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_mixed_update_if_exists_success(self): """ Tests that batch update with with one bad query will still fail with LWTException @@ -172,7 +168,6 @@ def test_batch_mixed_update_if_exists_success(self): assert assertion.value.existing.get('[applied]') == False - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_delete_if_exists(self): """ Tests that delete with if_exists work, and throw proper LWT exception when they are are not applied @@ -203,7 +198,6 @@ def test_delete_if_exists(self): assert assertion.value.existing.get('[applied]') == False - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_if_exists_success(self): """ Tests that batch deletes with if_exists work, and throw proper LWTException when they are are not applied @@ -232,7 +226,6 @@ def test_batch_delete_if_exists_success(self): assert assertion.value.existing.get('[applied]') == False - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_mixed(self): """ Tests that batch deletes with multiple queries and throw proper LWTException when they are are not all applicable diff --git a/tests/integration/cqlengine/test_ifnotexists.py b/tests/integration/cqlengine/test_ifnotexists.py index 6a1dd9d4bc..4159a3f00c 100644 --- a/tests/integration/cqlengine/test_ifnotexists.py +++ b/tests/integration/cqlengine/test_ifnotexists.py @@ -21,7 +21,6 @@ from cassandra.cqlengine.query import BatchQuery, LWTException, IfNotExistsWithCounterColumn from tests.integration.cqlengine.base import BaseCassEngTestCase -from tests.integration import PROTOCOL_VERSION import pytest class TestIfNotExistsModel(Model): @@ -73,7 +72,6 @@ def tearDownClass(cls): class IfNotExistsInsertTests(BaseIfNotExistsTest): - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_insert_if_not_exists(self): """ tests that insertion with if_not_exists work as expected """ @@ -101,7 +99,6 @@ def test_insert_if_not_exists(self): assert tm.count == 8 assert tm.text == '123456789' - @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_insert_if_not_exists(self): """ tests that batch insertion with if_not_exists work as expected """ diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index beb10f02c0..8e590cb355 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -43,23 +43,18 @@ @local def setup_module(): """ - We need some custom setup for this module. All unit tests in this module - require protocol >=4. We won't bother going through the setup required unless that is the - protocol version we are using. + We need some custom setup for this module. """ - - # If we aren't at protocol v 4 or greater don't waste time setting anything up, all tests will be skipped - if PROTOCOL_VERSION >= 4: - use_singledc(start=False) - ccm_cluster = get_cluster() - ccm_cluster.stop() - config_options = { - 'tombstone_failure_threshold': 2000, - 'tombstone_warn_threshold': 1000, - } - ccm_cluster.set_configuration_options(config_options) - start_cluster_wait_for_up(ccm_cluster) - setup_keyspace() + use_singledc(start=False) + ccm_cluster = get_cluster() + ccm_cluster.stop() + config_options = { + 'tombstone_failure_threshold': 2000, + 'tombstone_warn_threshold': 1000, + } + ccm_cluster.set_configuration_options(config_options) + start_cluster_wait_for_up(ccm_cluster) + setup_keyspace() def teardown_module(): @@ -67,20 +62,12 @@ def teardown_module(): The rest of the tests don't need custom tombstones remove the cluster so as to not interfere with other tests. """ - if PROTOCOL_VERSION >= 4: - remove_cluster() + remove_cluster() class ClientExceptionTests(unittest.TestCase): def setUp(self): - """ - Test is skipped if run with native protocol version <4 - """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest( - "Native protocol 4,0+ is required for custom payloads, currently using %r" - % (PROTOCOL_VERSION,)) self.cluster = TestCluster() self.session = self.cluster.connect() self.nodes_currently_failing = [] diff --git a/tests/integration/long/test_large_data.py b/tests/integration/long/test_large_data.py index 0a1b368bf0..c6ddaea709 100644 --- a/tests/integration/long/test_large_data.py +++ b/tests/integration/long/test_large_data.py @@ -21,7 +21,7 @@ from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.query import dict_factory from cassandra.query import SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster +from tests.integration import use_singledc, TestCluster from tests.integration.long.utils import create_schema import unittest diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py index fd8edde14c..619e729b47 100644 --- a/tests/integration/long/test_loadbalancingpolicies.py +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -633,7 +633,7 @@ def test_token_aware_with_transient_replication(self): query = session.prepare("SELECT * FROM test_tr.users WHERE id = ?") for i in range(100): f = session.execute_async(query, (i,), trace=True) - full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type.serialize(i, cluster.protocol_version)) + full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type.serialize(i)) if h.datacenter == 'dc1'] assert len(full_dc1_replicas) == 2 diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index eb8019bf65..903c6de0d7 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -19,7 +19,7 @@ from cassandra.cluster import NoHostAvailable from cassandra.auth import PlainTextAuthProvider, SASLClient, SaslAuthProvider -from tests.integration import use_singledc, get_cluster, remove_cluster, PROTOCOL_VERSION, \ +from tests.integration import use_singledc, get_cluster, remove_cluster, \ CASSANDRA_IP, CASSANDRA_VERSION, USE_CASS_EXTERNAL, start_cluster_wait_for_up, TestCluster from tests.integration.util import assert_quiescent_pool_state @@ -62,19 +62,12 @@ class AuthenticationTests(unittest.TestCase): """ def get_authentication_provider(self, username, password): """ - Return correct authentication provider based on protocol version. - There is a difference in the semantics of authentication provider argument with protocol versions 1 and 2 - For protocol version 2 and higher it should be a PlainTextAuthProvider object. - For protocol version 1 it should be a function taking hostname as an argument and returning a dictionary - containing username and password. + Return authentication provider for connecting to the cluster. :param username: authentication username :param password: authentication password - :return: authentication object suitable for Cluster.connect() + :return: PlainTextAuthProvider for Cluster.connect() """ - if PROTOCOL_VERSION < 2: - return lambda hostname: dict(username=username, password=password) - else: - return PlainTextAuthProvider(username=username, password=password) + return PlainTextAuthProvider(username=username, password=password) def cluster_as(self, usr, pwd): # test we can connect at least once with creds @@ -161,8 +154,6 @@ class SaslAuthenticatorTests(AuthenticationTests): Test SaslAuthProvider as PlainText """ def setUp(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest('Sasl authentication not available for protocol v1') if SASLClient is None: raise unittest.SkipTest('pure-sasl is not installed') diff --git a/tests/integration/standard/test_client_warnings.py b/tests/integration/standard/test_client_warnings.py index 781b5b7860..83572af2df 100644 --- a/tests/integration/standard/test_client_warnings.py +++ b/tests/integration/standard/test_client_warnings.py @@ -17,7 +17,7 @@ from cassandra.query import BatchStatement -from tests.integration import (use_singledc, PROTOCOL_VERSION, local, TestCluster, +from tests.integration import (use_singledc, local, TestCluster, requires_custom_payload, xfail_scylla) from tests.util import assertRegex, assertDictEqual @@ -30,9 +30,6 @@ class ClientWarningTests(unittest.TestCase): @classmethod def setUpClass(cls): - if PROTOCOL_VERSION < 4: - return - cls.cluster = TestCluster() cls.session = cls.cluster.connect() @@ -47,17 +44,8 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - if PROTOCOL_VERSION < 4: - return - cls.cluster.shutdown() - def setUp(self): - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest( - "Native protocol 4,0+ is required for client warnings, currently using %r" - % (PROTOCOL_VERSION,)) - def test_warning_basic(self): """ Test to validate that client warnings can be surfaced diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 1208edb9d2..d9633f01dc 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -462,10 +462,6 @@ def test_refresh_schema_type(self): if get_server_versions()[0] < (2, 1, 0): raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest('UDTs are not specified in change events for protocol v2') - # We may want to refresh types on keyspace change events in that case(?) - cluster = TestCluster() session = cluster.connect() @@ -1093,21 +1089,21 @@ def test_stale_connections_after_shutdown(self): Originates from https://github.com/scylladb/python-driver/issues/120 """ for _ in range(10): - with TestCluster(protocol_version=3) as cluster: + with TestCluster(protocol_version=4) as cluster: cluster.connect().execute("SELECT * FROM system_schema.keyspaces") time.sleep(1) - with TestCluster(protocol_version=3) as cluster: + with TestCluster(protocol_version=4) as cluster: session = cluster.connect() for _ in range(5): session.execute("SELECT * FROM system_schema.keyspaces") for _ in range(10): - with TestCluster(protocol_version=3) as cluster: + with TestCluster(protocol_version=4) as cluster: cluster.connect().execute("SELECT * FROM system_schema.keyspaces") for _ in range(10): - with TestCluster(protocol_version=3) as cluster: + with TestCluster(protocol_version=4) as cluster: cluster.connect() result = subprocess.run(["lsof -nP | awk '$3 ~ \":9042\" {print $0}' | grep ''"], shell=True, capture_output=True) diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index 5e6b1ffd59..5d464b9b23 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -22,7 +22,7 @@ from cassandra.policies import HostDistance from cassandra.query import dict_factory, tuple_factory, SimpleStatement -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster +from tests.integration import use_singledc, TestCluster import unittest import pytest @@ -178,11 +178,6 @@ def test_execute_concurrent_with_args_generator(self): next(results) def test_execute_concurrent_paged_result(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2+ is required for Paging, currently testing against %r" - % (PROTOCOL_VERSION,)) - num_statements = 201 statement = SimpleStatement( "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", @@ -221,11 +216,6 @@ def test_execute_concurrent_paged_result_generator(self): @test_category paging """ - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2+ is required for Paging, currently testing against %r" - % (PROTOCOL_VERSION,)) - num_statements = 201 statement = SimpleStatement( "INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index 2788a1d837..52d3f08cc4 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -23,7 +23,7 @@ from cassandra.protocol import ConfigurationException -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster, greaterthanorequalcass40, \ +from tests.integration import use_singledc, TestCluster, greaterthanorequalcass40, \ xfail_scylla_version_lt from tests.integration.datatype_utils import update_datatypes @@ -35,10 +35,6 @@ def setup_module(): class ControlConnectionTests(unittest.TestCase): def setUp(self): - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest( - "Native protocol 3,0+ is required for UDTs using %r" - % (PROTOCOL_VERSION,)) self.cluster = TestCluster() def tearDown(self): diff --git a/tests/integration/standard/test_custom_payload.py b/tests/integration/standard/test_custom_payload.py index fc58081070..2179c4225d 100644 --- a/tests/integration/standard/test_custom_payload.py +++ b/tests/integration/standard/test_custom_payload.py @@ -17,7 +17,7 @@ from cassandra.query import (SimpleStatement, BatchStatement, BatchType) -from tests.integration import (use_singledc, PROTOCOL_VERSION, local, TestCluster, +from tests.integration import (use_singledc, local, TestCluster, requires_custom_payload) import pytest diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 239f7e7336..1273f834ad 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -203,7 +203,7 @@ class CustomResultMessageRaw(ResultMessage): my_type_codes[0xc] = UUIDType type_codes = my_type_codes - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows(self, f, user_type_map, result_metadata, column_encryption_policy): self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) @@ -232,7 +232,7 @@ class CustomResultMessageTracked(ResultMessage): type_codes = my_type_codes checked_rev_row_set = set() - def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, column_encryption_policy): + def recv_results_rows(self, f, user_type_map, result_metadata, column_encryption_policy): self.recv_results_metadata(f, user_type_map) column_metadata = self.column_metadata or result_metadata rowcount = read_int(f) @@ -241,7 +241,7 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, self.column_types = [c[3] for c in column_metadata] self.checked_rev_row_set.update(self.column_types) self.parsed_rows = [ - tuple(ctype.from_binary(val, protocol_version) + tuple(ctype.from_binary(val) for ctype, val in zip(self.column_types, row)) for row in rows] diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index c30e369d83..46e70c500a 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -36,7 +36,7 @@ from cassandra.protocol import QueryMessage, ProtocolHandler from cassandra.util import SortedSet -from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, execute_until_pass, +from tests.integration import (get_cluster, use_singledc, execute_until_pass, BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, greaterthanorequalcass30, lessthancass30, local, @@ -623,34 +623,32 @@ def test_refresh_schema_metadata(self): cluster2.refresh_schema_metadata() assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns - if PROTOCOL_VERSION >= 3: - # UDT metadata modification - self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) - assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} - cluster2.refresh_schema_metadata() - assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types + # UDT metadata modification + self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) + assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} + cluster2.refresh_schema_metadata() + assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types - if PROTOCOL_VERSION >= 4: - # UDF metadata modification - self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) - RETURNS NULL ON NULL INPUT - RETURNS int - LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) + # UDF metadata modification + self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) + RETURNS NULL ON NULL INPUT + RETURNS int + LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) - assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} - cluster2.refresh_schema_metadata() - assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions + assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} + cluster2.refresh_schema_metadata() + assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions - # UDA metadata modification - self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) - SFUNC sum_int - STYPE int - INITCOND 0""" - .format(self.keyspace_name)) + # UDA metadata modification + self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) + SFUNC sum_int + STYPE int + INITCOND 0""" + .format(self.keyspace_name)) - assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} - cluster2.refresh_schema_metadata() - assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates + assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} + cluster2.refresh_schema_metadata() + assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates # Cluster metadata modification self.session.execute("DROP KEYSPACE new_keyspace") @@ -799,9 +797,6 @@ def test_refresh_user_type_metadata(self): @test_category metadata """ - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest("Protocol 3+ is required for UDTs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() @@ -868,9 +863,6 @@ def test_refresh_user_function_metadata(self): @test_category metadata """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDFs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() @@ -905,9 +897,6 @@ def test_refresh_user_aggregate_metadata(self): @test_category metadata """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Protocol 4+ is required for UDAs, currently testing against {0}".format(PROTOCOL_VERSION)) - cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() @@ -1125,11 +1114,6 @@ def test_export_keyspace_schema_udts(self): Test udt exports """ - if PROTOCOL_VERSION < 3: - raise unittest.SkipTest( - "Protocol 3.0+ is required for UDT change events, currently testing against %r" - % (PROTOCOL_VERSION,)) - if sys.version_info[0:2] != (2, 7): raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.') @@ -1523,12 +1507,7 @@ class FunctionTest(unittest.TestCase): """ def setUp(self): - """ - Tests are skipped if run with native protocol version < 4 - """ - - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Function metadata requires native protocol version 4+") + pass @property def function_name(self): @@ -1536,20 +1515,18 @@ def function_name(self): @classmethod def setup_class(cls): - if PROTOCOL_VERSION >= 4: - cls.cluster = TestCluster() - cls.keyspace_name = cls.__name__.lower() - cls.session = cls.cluster.connect() - cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) - cls.session.set_keyspace(cls.keyspace_name) - cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions - cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates + cls.cluster = TestCluster() + cls.keyspace_name = cls.__name__.lower() + cls.session = cls.cluster.connect() + cls.session.execute("CREATE KEYSPACE IF NOT EXISTS %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}" % cls.keyspace_name) + cls.session.set_keyspace(cls.keyspace_name) + cls.keyspace_function_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].functions + cls.keyspace_aggregate_meta = cls.cluster.metadata.keyspaces[cls.keyspace_name].aggregates @classmethod def teardown_class(cls): - if PROTOCOL_VERSION >= 4: - cls.session.execute("DROP KEYSPACE IF EXISTS %s" % cls.keyspace_name) - cls.cluster.shutdown() + cls.session.execute("DROP KEYSPACE IF EXISTS %s" % cls.keyspace_name) + cls.cluster.shutdown() class Verified(object): @@ -1749,34 +1726,33 @@ class AggregateMetadata(FunctionTest): @classmethod def setup_class(cls): - if PROTOCOL_VERSION >= 4: - super(AggregateMetadata, cls).setup_class() + super(AggregateMetadata, cls).setup_class() - cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int(s int, i int) + cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int(s int, i int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS 'return s + i;';""") - cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int_two(s int, i int, j int) + cls.session.execute("""CREATE OR REPLACE FUNCTION sum_int_two(s int, i int, j int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS 'return s + i + j;';""") - cls.session.execute("""CREATE OR REPLACE FUNCTION "List_As_String"(l list) + cls.session.execute("""CREATE OR REPLACE FUNCTION "List_As_String"(l list) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS 'return l.size();';""") - cls.session.execute("""CREATE OR REPLACE FUNCTION extend_list(s list, i int) + cls.session.execute("""CREATE OR REPLACE FUNCTION extend_list(s list, i int) CALLED ON NULL INPUT RETURNS list LANGUAGE java AS 'if (i != null) s.add(i.toString()); return s;';""") - cls.session.execute("""CREATE OR REPLACE FUNCTION update_map(s map, i int) + cls.session.execute("""CREATE OR REPLACE FUNCTION update_map(s map, i int) RETURNS NULL ON NULL INPUT RETURNS map LANGUAGE java AS 's.put(new Integer(i), new Integer(i)); return s;';""") - cls.session.execute("""CREATE TABLE IF NOT EXISTS t + cls.session.execute("""CREATE TABLE IF NOT EXISTS t (k int PRIMARY KEY, v int)""") - for x in range(4): - cls.session.execute("INSERT INTO t (k,v) VALUES (%s, %s)", (x, x)) - cls.session.execute("INSERT INTO t (k) VALUES (%s)", (4,)) + for x in range(4): + cls.session.execute("INSERT INTO t (k,v) VALUES (%s, %s)", (x, x)) + cls.session.execute("INSERT INTO t (k) VALUES (%s)", (4,)) def make_aggregate_kwargs(self, state_func, state_type, final_func=None, init_cond=None): return {'keyspace': self.keyspace_name, diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index 3f63b881ef..c825b1c6bc 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -199,23 +199,14 @@ def test_imprecise_bind_values_dicts(self): prepared.bind({'k': 1, 'v': 2, 'v2': 3}) # right number, but one does not belong - if PROTOCOL_VERSION < 4: - # pre v4, the driver bails with key error when 'v' is found missing - with pytest.raises(KeyError): - prepared.bind({'k': 1, 'v2': 3}) - else: - # post v4, the driver uses UNSET_VALUE for 'v' and 'v2' is ignored - prepared.bind({'k': 1, 'v2': 3}) + # the driver uses UNSET_VALUE for 'v' and 'v2' is ignored + prepared.bind({'k': 1, 'v2': 3}) # also catch too few variables with dicts assert isinstance(prepared, PreparedStatement) - if PROTOCOL_VERSION < 4: - with pytest.raises(KeyError): - prepared.bind({}) - else: - # post v4, the driver attempts to use UNSET_VALUE for unspecified keys - with pytest.raises(ValueError): - prepared.bind({}) + # the driver attempts to use UNSET_VALUE for unspecified keys + with pytest.raises(ValueError): + prepared.bind({}) def test_none_values(self): """ @@ -255,8 +246,6 @@ def test_unset_values(self): @test_category prepared_statements:binding """ - if PROTOCOL_VERSION < 4: - raise unittest.SkipTest("Binding UNSET values is not supported in protocol version < 4") # table with at least two values so one can be used as a marker self.session.execute("CREATE TABLE IF NOT EXISTS test1rf.test_unset_values (k int PRIMARY KEY, v0 int, v1 int)") diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 9cebc22b05..049adf2a58 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -24,8 +24,8 @@ BatchStatement, BatchType, dict_factory, TraceUnavailable) from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, Cluster from cassandra.policies import HostDistance, RoundRobinPolicy, WhiteListRoundRobinPolicy -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, \ - greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ +from tests.integration import use_singledc, BasicSharedKeyspaceUnitTestCase, \ + MockLoggingHandler, get_supported_protocol_versions, local, get_cluster, setup_keyspace, \ USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node @@ -136,7 +136,6 @@ def test_trace_ignores_row_factory(self): str(event) @local - @greaterthanprotocolv3 def test_client_ip_in_trace(self): """ Test to validate that client trace contains client ip information. @@ -659,11 +658,6 @@ def test_prepared_statement(self): class BatchStatementTests(BasicSharedKeyspaceUnitTestCase): def setUp(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2.0+ is required for BATCH operations, currently testing against %r" - % (PROTOCOL_VERSION,)) - self.cluster = TestCluster() self.session = self.cluster.connect(wait_for_all_pools=True) @@ -793,11 +787,6 @@ def test_too_many_statements(self): class SerialConsistencyTests(unittest.TestCase): def setUp(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2.0+ is required for Serial Consistency, currently testing against %r" - % (PROTOCOL_VERSION,)) - self.cluster = TestCluster() self.session = self.cluster.connect() @@ -880,15 +869,6 @@ def test_bad_consistency_level(self): class LightweightTransactionTests(unittest.TestCase): def setUp(self): - """ - Test is skipped if run with cql version < 2 - - """ - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2.0+ is required for Lightweight transactions, currently testing against %r" - % (PROTOCOL_VERSION,)) - serial_profile = ExecutionProfile(consistency_level=ConsistencyLevel.SERIAL) self.cluster = TestCluster(execution_profiles={'serial': serial_profile}) self.session = self.cluster.connect() @@ -1072,10 +1052,6 @@ class BatchStatementDefaultRoutingKeyTests(unittest.TestCase): # Test for PYTHON-126: BatchStatement.add() should set the routing key of the first added prepared statement def setUp(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2.0+ is required for BATCH operations, currently testing against %r" - % (PROTOCOL_VERSION,)) self.cluster = TestCluster() self.session = self.cluster.connect() query = """ diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index e0c67cd309..6694c0c284 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster +from tests.integration import use_singledc, TestCluster import logging log = logging.getLogger(__name__) @@ -38,11 +38,6 @@ def setup_module(): class QueryPagingTests(unittest.TestCase): def setUp(self): - if PROTOCOL_VERSION < 2: - raise unittest.SkipTest( - "Protocol 2.0+ is required for Paging state, currently testing against %r" - % (PROTOCOL_VERSION,)) - self.cluster = TestCluster( execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(consistency_level=ConsistencyLevel.LOCAL_QUORUM)} ) diff --git a/tests/integration/standard/test_single_interface.py b/tests/integration/standard/test_single_interface.py index 5fd9ef45d3..22776f4232 100644 --- a/tests/integration/standard/test_single_interface.py +++ b/tests/integration/standard/test_single_interface.py @@ -19,7 +19,7 @@ from cassandra.query import SimpleStatement from packaging.version import Version -from tests.integration import use_singledc, PROTOCOL_VERSION, \ +from tests.integration import use_singledc, \ remove_cluster, greaterthanorequalcass40, \ CASSANDRA_VERSION, TestCluster, DEFAULT_SINGLE_INTERFACE_PORT diff --git a/tests/integration/util.py b/tests/integration/util.py index 7cbdfdb22d..5f091d9508 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -14,7 +14,6 @@ from itertools import chain -from tests.integration import PROTOCOL_VERSION import time @@ -53,5 +52,3 @@ def assert_quiescent_pool_state(cluster, wait=None): assert len(req_ids) == len(set(req_ids)) assert connection.highest_request_id == len(req_ids) + len(orphan_ids) - 1 assert connection.highest_request_id == max(chain(req_ids, orphan_ids)) - if PROTOCOL_VERSION < 3: - assert connection.highest_request_id == connection.max_request_id diff --git a/tests/unit/advanced/test_geometry.py b/tests/unit/advanced/test_geometry.py index 1927b51da7..b1b6efc918 100644 --- a/tests/unit/advanced/test_geometry.py +++ b/tests/unit/advanced/test_geometry.py @@ -33,15 +33,13 @@ class GeoTypes(unittest.TestCase): samples = (Point(1, 2), LineString(((1, 2), (3, 4), (5, 6))), Polygon([(10.1, 10.0), (110.0, 10.0), (110., 110.0), (10., 110.0), (10., 10.0)], [[(20., 20.0), (20., 30.0), (30., 30.0), (30., 20.0), (20., 20.0)], [(40., 20.0), (40., 30.0), (50., 30.0), (50., 20.0), (40., 20.0)]])) def test_marshal_platform(self): - for proto_ver in protocol_versions: - for geo in self.samples: - cql_type = lookup_casstype(geo.__class__.__name__ + 'Type') - assert cql_type.from_binary(cql_type.to_binary(geo, proto_ver), proto_ver) == geo + for geo in self.samples: + cql_type = lookup_casstype(geo.__class__.__name__ + 'Type') + assert cql_type.from_binary(cql_type.to_binary(geo)) == geo def _verify_both_endian(self, typ, body_fmt, params, expected): - for proto_ver in protocol_versions: - assert typ.from_binary(struct.pack(">BI" + body_fmt, wkb_be, *params), proto_ver) == expected - assert typ.from_binary(struct.pack("BI" + body_fmt, wkb_be, *params)) == expected + assert typ.from_binary(struct.pack("o\xff\x00', b'\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00'] for serialized in vals: - assert serialized == DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) + assert serialized == DateRangeType.serialize(DateRangeType.deserialize(serialized)) def test_serialize_zero_datetime(self): """ @@ -734,7 +734,7 @@ def test_serialize_zero_datetime(self): DateRangeType.serialize(util.DateRange( lower_bound=(datetime.datetime(1970, 1, 1), 'YEAR'), upper_bound=(datetime.datetime(1970, 1, 1), 'YEAR') - ), 5) + )) def test_deserialize_zero_datetime(self): """ @@ -751,8 +751,7 @@ def test_deserialize_zero_datetime(self): DateRangeType.deserialize( (int8_pack(1) + int64_pack(0) + int8_pack(0) + - int64_pack(0) + int8_pack(0)), - 5 + int64_pack(0) + int8_pack(0)) )