import socket
from collections import defaultdict, deque
from datetime import datetime as Datetime
from decimal import Decimal
from distutils.version import LooseVersion
from hashlib import md5
from itertools import count
from struct import Struct
from pg8000 import converters
from pg8000.exceptions import DatabaseError, InterfaceError
import scramp
def pack_funcs(fmt):
struc = Struct('!' + fmt)
return struc.pack, struc.unpack_from
i_pack, i_unpack = pack_funcs('i')
h_pack, h_unpack = pack_funcs('h')
ii_pack, ii_unpack = pack_funcs('ii')
ihihih_pack, ihihih_unpack = pack_funcs('ihihih')
ci_pack, ci_unpack = pack_funcs('ci')
bh_pack, bh_unpack = pack_funcs('bh')
cccc_pack, cccc_unpack = pack_funcs('cccc')
# Copyright (c) 2007-2009, Mathieu Fenniak
# Copyright (c) The Contributors
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * The name of the author may not be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
__author__ = "Mathieu Fenniak"
min_int2, max_int2 = -2 ** 15, 2 ** 15
min_int4, max_int4 = -2 ** 31, 2 ** 31
min_int8, max_int8 = -2 ** 63, 2 ** 63
NULL_BYTE = b'\x00'
# Message codes
NOTICE_RESPONSE = b"N"
AUTHENTICATION_REQUEST = b"R"
PARAMETER_STATUS = b"S"
BACKEND_KEY_DATA = b"K"
READY_FOR_QUERY = b"Z"
ROW_DESCRIPTION = b"T"
ERROR_RESPONSE = b"E"
DATA_ROW = b"D"
COMMAND_COMPLETE = b"C"
PARSE_COMPLETE = b"1"
BIND_COMPLETE = b"2"
CLOSE_COMPLETE = b"3"
PORTAL_SUSPENDED = b"s"
NO_DATA = b"n"
PARAMETER_DESCRIPTION = b"t"
NOTIFICATION_RESPONSE = b"A"
COPY_DONE = b"c"
COPY_DATA = b"d"
COPY_IN_RESPONSE = b"G"
COPY_OUT_RESPONSE = b"H"
EMPTY_QUERY_RESPONSE = b"I"
BIND = b"B"
PARSE = b"P"
QUERY = b"Q"
EXECUTE = b"E"
FLUSH = b'H'
SYNC = b'S'
PASSWORD = b'p'
DESCRIBE = b'D'
TERMINATE = b'X'
CLOSE = b'C'
def create_message(code, data=b''):
return code + i_pack(len(data) + 4) + data
FLUSH_MSG = create_message(FLUSH)
SYNC_MSG = create_message(SYNC)
TERMINATE_MSG = create_message(TERMINATE)
COPY_DONE_MSG = create_message(COPY_DONE)
EXECUTE_MSG = create_message(EXECUTE, NULL_BYTE + i_pack(0))
# DESCRIBE constants
STATEMENT = b'S'
PORTAL = b'P'
# ErrorResponse codes
RESPONSE_SEVERITY = "S" # always present
RESPONSE_SEVERITY = "V" # always present
RESPONSE_CODE = "C" # always present
RESPONSE_MSG = "M" # always present
RESPONSE_DETAIL = "D"
RESPONSE_HINT = "H"
RESPONSE_POSITION = "P"
RESPONSE__POSITION = "p"
RESPONSE__QUERY = "q"
RESPONSE_WHERE = "W"
RESPONSE_FILE = "F"
RESPONSE_LINE = "L"
RESPONSE_ROUTINE = "R"
IDLE = b"I"
IDLE_IN_TRANSACTION = b"T"
IDLE_IN_FAILED_TRANSACTION = b"E"
class CoreConnection():
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __init__(
self, user, host='localhost', database=None, port=5432,
password=None, source_address=None, unix_sock=None,
ssl_context=None, timeout=None, tcp_keepalive=True,
application_name=None, replication=None):
self._client_encoding = "utf8"
self._commands_with_count = (
b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH", b"COPY",
b"SELECT")
self.notifications = deque(maxlen=100)
self.notices = deque(maxlen=100)
self.parameter_statuses = deque(maxlen=100)
if user is None:
raise InterfaceError(
"The 'user' connection parameter cannot be None")
init_params = {
'user': user,
'database': database,
'application_name': application_name,
'replication': replication
}
for k, v in tuple(init_params.items()):
if isinstance(v, str):
init_params[k] = v.encode('utf8')
elif v is None:
del init_params[k]
elif not isinstance(v, (bytes, bytearray)):
raise InterfaceError(
"The parameter " + k + " can't be of type " +
str(type(v)) + ".")
self.user = init_params['user']
if isinstance(password, str):
self.password = password.encode('utf8')
else:
self.password = password
self.autocommit = False
self._xid = None
self._statement_nums = set()
self._caches = {}
if unix_sock is None and host is not None:
try:
self._usock = socket.create_connection(
(host, port), timeout, source_address)
except socket.error as e:
raise InterfaceError(
"Can't create a connection to host " + str(host) +
" and port " + str(port) + " (timeout is " + str(timeout) +
" and source_address is " +
str(source_address) + ").") from e
elif unix_sock is not None:
try:
if not hasattr(socket, "AF_UNIX"):
raise InterfaceError(
"attempt to connect to unix socket on unsupported "
"platform")
self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._usock.settimeout(timeout)
self._usock.connect(unix_sock)
except socket.error as e:
if self._usock is not None:
self._usock.close()
raise InterfaceError("communication error") from e
else:
raise InterfaceError("one of host or unix_sock must be provided")
if tcp_keepalive:
self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.channel_binding = None
if ssl_context is not None:
try:
import ssl
if ssl_context is True:
ssl_context = ssl.create_default_context()
request_ssl = getattr(ssl_context, 'request_ssl', True)
if request_ssl:
# Int32(8) - Message length, including self.
# Int32(80877103) - The SSL request code.
self._usock.sendall(ii_pack(8, 80877103))
resp = self._usock.recv(1)
if resp != b'S':
raise InterfaceError("Server refuses SSL")
self._usock = ssl_context.wrap_socket(
self._usock, server_hostname=host)
if request_ssl:
self.channel_binding = scramp.make_channel_binding(
'tls-server-end-point', self._usock)
except ImportError:
raise InterfaceError(
"SSL required but ssl module not available in this python "
"installation.")
self._sock = self._usock.makefile(mode="rwb")
self._flush = self._sock.flush
self._read = self._sock.read
self._write = self._sock.write
self._backend_key_data = None
self.pg_types = defaultdict(
lambda: converters.text_in, converters.PG_TYPES)
self.py_types = dict(converters.PY_TYPES)
self.inspect_funcs = {
Datetime: self.inspect_datetime,
list: self.array_inspect,
tuple: self.array_inspect,
int: self.inspect_int
}
self.message_types = {
NOTICE_RESPONSE: self.handle_NOTICE_RESPONSE,
AUTHENTICATION_REQUEST: self.handle_AUTHENTICATION_REQUEST,
PARAMETER_STATUS: self.handle_PARAMETER_STATUS,
BACKEND_KEY_DATA: self.handle_BACKEND_KEY_DATA,
READY_FOR_QUERY: self.handle_READY_FOR_QUERY,
ROW_DESCRIPTION: self.handle_ROW_DESCRIPTION,
ERROR_RESPONSE: self.handle_ERROR_RESPONSE,
EMPTY_QUERY_RESPONSE: self.handle_EMPTY_QUERY_RESPONSE,
DATA_ROW: self.handle_DATA_ROW,
COMMAND_COMPLETE: self.handle_COMMAND_COMPLETE,
PARSE_COMPLETE: self.handle_PARSE_COMPLETE,
BIND_COMPLETE: self.handle_BIND_COMPLETE,
CLOSE_COMPLETE: self.handle_CLOSE_COMPLETE,
PORTAL_SUSPENDED: self.handle_PORTAL_SUSPENDED,
NO_DATA: self.handle_NO_DATA,
PARAMETER_DESCRIPTION: self.handle_PARAMETER_DESCRIPTION,
NOTIFICATION_RESPONSE: self.handle_NOTIFICATION_RESPONSE,
COPY_DONE: self.handle_COPY_DONE,
COPY_DATA: self.handle_COPY_DATA,
COPY_IN_RESPONSE: self.handle_COPY_IN_RESPONSE,
COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}
# Int32 - Message length, including self.
# Int32(196608) - Protocol version number. Version 3.0.
# Any number of key/value pairs, terminated by a zero byte:
# String - A parameter name (user, database, or options)
# String - Parameter value
protocol = 196608
val = bytearray(i_pack(protocol))
for k, v in init_params.items():
val.extend(k.encode('ascii') + NULL_BYTE + v + NULL_BYTE)
val.append(0)
self._write(i_pack(len(val) + 4))
self._write(val)
self._flush()
code = self.error = None
while code not in (READY_FOR_QUERY, ERROR_RESPONSE):
code, data_len = ci_unpack(self._read(5))
self.message_types[code](self._read(data_len - 4), None)
if self.error is not None:
raise self.error
self.in_transaction = False
def register_out_adapter(self, typ, oid, out_func):
self.py_types[typ] = (oid, out_func)
def register_in_adapter(self, oid, in_func):
self.pg_types[oid] = in_func
def handle_ERROR_RESPONSE(self, data, ps):
msg = dict(
(
s[:1].decode(self._client_encoding),
s[1:].decode(self._client_encoding)) for s in
data.split(NULL_BYTE) if s != b'')
self.error = DatabaseError(msg)
def handle_EMPTY_QUERY_RESPONSE(self, data, ps):
self.error = DatabaseError("query was empty")
def handle_CLOSE_COMPLETE(self, data, ps):
pass
def handle_PARSE_COMPLETE(self, data, ps):
# Byte1('1') - Identifier.
# Int32(4) - Message length, including self.
pass
def handle_BIND_COMPLETE(self, data, ps):
pass
def handle_PORTAL_SUSPENDED(self, data, context):
pass
def handle_PARAMETER_DESCRIPTION(self, data, context):
'''
ParameterDescription (B)
Byte1('t')
Identifies the message as a parameter description.
Int32
Length of message contents in bytes, including self.
Int16
The number of parameters used by the statement (can be zero).
Then, for each parameter, there is the following:
Int32
Specifies the object ID of the parameter data type.
'''
# count = h_unpack(data)[0]
# context.parameter_oids = unpack_from("!" + "i" * count, data, 2)
def handle_COPY_DONE(self, data, ps):
self._copy_done = True
def handle_COPY_OUT_RESPONSE(self, data, ps):
# Int8(1) - 0 textual, 1 binary
# Int16(2) - Number of columns
# Int16(N) - Format codes for each column (0 text, 1 binary)
is_binary, num_cols = bh_unpack(data)
# column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
if ps.stream is None:
raise InterfaceError(
"An output stream is required for the COPY OUT response.")
def handle_COPY_DATA(self, data, results):
results.stream.write(data)
def handle_COPY_IN_RESPONSE(self, data, results):
# Int16(2) - Number of columns
# Int16(N) - Format codes for each column (0 text, 1 binary)
is_binary, num_cols = bh_unpack(data)
# column_formats = unpack_from('!' + 'h' * num_cols, data, 3)
if results.stream is None:
raise InterfaceError(
"An input stream is required for the COPY IN response.")
bffr = bytearray(8192)
while True:
bytes_read = results.stream.readinto(bffr)
if bytes_read == 0:
break
self._write(COPY_DATA + i_pack(bytes_read + 4))
self._write(bffr[:bytes_read])
self._flush()
# Send CopyDone
# Byte1('c') - Identifier.
# Int32(4) - Message length, including self.
self._write(COPY_DONE_MSG)
self._write(SYNC_MSG)
self._flush()
def handle_NOTIFICATION_RESPONSE(self, data, results):
##
# A message sent if this connection receives a NOTIFY that it was
# LISTENing for.
#
# Stability: Added in pg8000 v1.03. When limited to accessing
# properties from a notification event dispatch, stability is
# guaranteed for v1.xx.
backend_pid = i_unpack(data)[0]
idx = 4
null_idx = data.find(NULL_BYTE, idx)
channel = data[idx:null_idx].decode("ascii")
payload = data[null_idx+1:-1].decode('ascii')
self.notifications.append((backend_pid, channel, payload))
def close(self):
"""Closes the database connection.
This function is part of the `DBAPI 2.0 specification
`_.
"""
try:
# Byte1('X') - Identifies the message as a terminate message.
# Int32(4) - Message length, including self.
self._write(TERMINATE_MSG)
self._flush()
self._sock.close()
except AttributeError:
raise InterfaceError("connection is closed")
except ValueError:
raise InterfaceError("connection is closed")
except socket.error:
pass
finally:
self._usock.close()
self._sock = None
def handle_AUTHENTICATION_REQUEST(self, data, context):
# Int32 - An authentication code that represents different
# authentication messages:
# 0 = AuthenticationOk
# 5 = MD5 pwd
# 2 = Kerberos v5 (not supported by pg8000)
# 3 = Cleartext pwd
# 4 = crypt() pwd (not supported by pg8000)
# 6 = SCM credential (not supported by pg8000)
# 7 = GSSAPI (not supported by pg8000)
# 8 = GSSAPI data (not supported by pg8000)
# 9 = SSPI (not supported by pg8000)
# Some authentication messages have additional data following the
# authentication code. That data is documented in the appropriate
# class.
auth_code = i_unpack(data)[0]
if auth_code == 0:
pass
elif auth_code == 3:
if self.password is None:
raise InterfaceError(
"server requesting password authentication, but no "
"password was provided")
self._send_message(PASSWORD, self.password + NULL_BYTE)
self._flush()
elif auth_code == 5:
##
# A message representing the backend requesting an MD5 hashed
# password response. The response will be sent as
# md5(md5(pwd + login) + salt).
# Additional message data:
# Byte4 - Hash salt.
salt = b"".join(cccc_unpack(data, 4))
if self.password is None:
raise InterfaceError(
"server requesting MD5 password authentication, but no "
"password was provided")
pwd = b"md5" + md5(
md5(self.password + self.user).hexdigest().encode("ascii") +
salt).hexdigest().encode("ascii")
# Byte1('p') - Identifies the message as a password message.
# Int32 - Message length including self.
# String - The password. Password may be encrypted.
self._send_message(PASSWORD, pwd + NULL_BYTE)
self._flush()
elif auth_code == 10:
# AuthenticationSASL
mechanisms = [
m.decode('ascii') for m in data[4:-2].split(NULL_BYTE)]
self.auth = scramp.ScramClient(
mechanisms, self.user.decode('utf8'),
self.password.decode('utf8'),
channel_binding=self.channel_binding)
init = self.auth.get_client_first().encode('utf8')
mech = self.auth.mechanism_name.encode('ascii') + NULL_BYTE
# SASLInitialResponse
self._write(
create_message(PASSWORD, mech + i_pack(len(init)) + init))
self._flush()
elif auth_code == 11:
# AuthenticationSASLContinue
self.auth.set_server_first(data[4:].decode('utf8'))
# SASLResponse
msg = self.auth.get_client_final().encode('utf8')
self._write(create_message(PASSWORD, msg))
self._flush()
elif auth_code == 12:
# AuthenticationSASLFinal
self.auth.set_server_final(data[4:].decode('utf8'))
elif auth_code in (2, 4, 6, 7, 8, 9):
raise InterfaceError(
f"Authentication method {auth_code} not supported by pg8000.")
else:
raise InterfaceError(
f"Authentication method {auth_code} not recognized by pg8000.")
def handle_READY_FOR_QUERY(self, data, ps):
# Byte1 - Status indicator.
self.in_transaction = data != IDLE
def handle_BACKEND_KEY_DATA(self, data, ps):
self._backend_key_data = data
def array_inspect(self, value):
# Check if array has any values. If empty, we can just assume it's an
# array of strings
first_element = converters.array_find_first_element(value)
if first_element is None:
oid = 25
# Use binary ARRAY format to avoid having to properly
# escape text in the array literals
array_oid = converters.PG_ARRAY_TYPES[oid]
else:
# supported array output
typ = type(first_element)
if typ == bool:
oid = 16
array_oid = converters.PG_ARRAY_TYPES[oid]
elif issubclass(typ, int):
# special int array support -- send as smallest possible array
# type
typ = int
int2_ok, int4_ok, int8_ok = True, True, True
for v in converters.array_flatten(value):
if v is None:
continue
if min_int2 < v < max_int2:
continue
int2_ok = False
if min_int4 < v < max_int4:
continue
int4_ok = False
if min_int8 < v < max_int8:
continue
int8_ok = False
if int2_ok:
array_oid = 1005 # INT2[]
elif int4_ok:
array_oid = 1007 # INT4[]
elif int8_ok:
array_oid = 1016 # INT8[]
else:
raise InterfaceError(
"numeric not supported as array contents")
else:
try:
oid, _ = self.make_param(first_element)
# If unknown or string, assume it's a string array
if oid in (705, 1043, 25):
oid = 25
# Use binary ARRAY format to avoid having to properly
# escape text in the array literals
array_oid = converters.PG_ARRAY_TYPES[oid]
except KeyError:
raise InterfaceError(
"oid " + str(oid) + " not supported as array contents")
return (array_oid, self.array_out)
def array_out(self, ar):
result = []
for v in ar:
if isinstance(v, list):
val = self.array_out(v)
elif v is None:
val = 'NULL'
elif isinstance(v, str):
val = converters.array_string_escape(v)
else:
_, val = self.make_param(v)
result.append(val)
return '{' + ','.join(result) + '}'
def make_param(self, value):
typ = type(value)
try:
oid, func = self.py_types[typ]
except KeyError:
try:
oid, func = self.inspect_funcs[typ](value)
except KeyError as e:
oid, func = None, None
for k, v in self.py_types.items():
try:
if isinstance(value, k):
oid, func = v
break
except TypeError:
pass
if oid is None:
for k, v in self.inspect_funcs.items():
try:
if isinstance(value, k):
oid, func = v(value)
break
except TypeError:
pass
except KeyError:
pass
if oid is None:
raise InterfaceError(
"type " + str(e) + " not mapped to pg type")
return oid, func(value)
def make_params(self, values):
if len(values) == 0:
return (), ()
else:
oids, params = zip(*map(self.make_param, values))
return tuple(oids), tuple(params)
def inspect_datetime(self, value):
if value.tzinfo is None:
return self.py_types[1114] # timestamp
else:
return self.py_types[1184] # send as timestamptz
def inspect_int(self, value):
if min_int2 < value < max_int2:
return self.py_types[21]
if min_int4 < value < max_int4:
return self.py_types[23]
if min_int8 < value < max_int8:
return self.py_types[20]
return self.py_types[Decimal]
def handle_ROW_DESCRIPTION(self, data, context):
count = h_unpack(data)[0]
idx = 2
columns = []
input_funcs = []
for i in range(count):
name = data[idx:data.find(NULL_BYTE, idx)]
idx += len(name) + 1
field = dict(
zip((
"table_oid", "column_attrnum", "type_oid", "type_size",
"type_modifier", "format"), ihihih_unpack(data, idx)))
field['name'] = name.decode(self._client_encoding)
idx += 18
columns.append(field)
input_funcs.append(self.pg_types[field['type_oid']])
context.columns = columns
context.input_funcs = input_funcs
def send_PARSE(self, statement_name_bin, statement, oids):
val = bytearray(statement_name_bin)
val.extend(statement.encode(self._client_encoding) + NULL_BYTE)
val.extend(h_pack(len(oids)))
for oid in oids:
val.extend(i_pack(0 if oid == -1 else oid))
self._send_message(PARSE, val)
def send_DESCRIBE_STATEMENT(self, statement_name_bin):
self._send_message(DESCRIBE, STATEMENT + statement_name_bin)
def send_QUERY(self, sql):
data = sql.encode(self._client_encoding) + NULL_BYTE
try:
self._write(QUERY)
self._write(i_pack(len(data) + 4))
self._write(data)
except ValueError as e:
if str(e) == "write to closed file":
raise InterfaceError("connection is closed")
else:
raise e
except AttributeError:
raise InterfaceError("connection is closed")
def execute_unnamed(
self, statement, vals=(), input_oids=None, stream=None):
context = Context(stream=stream)
if len(vals) == 0 and stream is None:
self.send_QUERY(statement)
self._flush()
self.handle_messages(context)
else:
param_oids, params = self.make_params(vals)
if input_oids is None:
oids = param_oids
else:
oids = [
(p if i is None else i) for p, i in zip(
param_oids, input_oids)]
self.send_PARSE(NULL_BYTE, statement, oids)
self._write(SYNC_MSG)
self._flush()
self.handle_messages(context)
self.send_DESCRIBE_STATEMENT(NULL_BYTE)
self._write(SYNC_MSG)
try:
self._flush()
except AttributeError as e:
if self._sock is None:
raise InterfaceError("connection is closed")
else:
raise e
self.send_BIND(NULL_BYTE, params)
self.handle_messages(context)
self.send_EXECUTE()
self._write(SYNC_MSG)
self._flush()
self.handle_messages(context)
return context
def prepare_statement(self, statement, oids):
for i in count():
statement_name = '_'.join(("pg8000", "statement", str(i)))
statement_name_bin = statement_name.encode('ascii') + NULL_BYTE
if statement_name_bin not in self._statement_nums:
self._statement_nums.add(statement_name_bin)
break
self.send_PARSE(statement_name_bin, statement, oids)
self.send_DESCRIBE_STATEMENT(statement_name_bin)
self._write(SYNC_MSG)
try:
self._flush()
except AttributeError as e:
if self._sock is None:
raise InterfaceError("connection is closed")
else:
raise e
context = Context()
self.handle_messages(context)
return statement_name_bin, context.columns, context.input_funcs
def execute_named(self, statement_name_bin, params, columns, input_funcs):
context = Context(columns=columns, input_funcs=input_funcs)
self.send_BIND(statement_name_bin, params)
self.send_EXECUTE()
self._write(SYNC_MSG)
self._flush()
self.handle_messages(context)
return context
def _send_message(self, code, data):
try:
self._write(code)
self._write(i_pack(len(data) + 4))
self._write(data)
self._write(FLUSH_MSG)
except ValueError as e:
if str(e) == "write to closed file":
raise InterfaceError("connection is closed")
else:
raise e
except AttributeError:
raise InterfaceError("connection is closed")
def send_BIND(self, statement_name_bin, params):
# Byte1('B') - Identifies the Bind command.
# Int32 - Message length, including self.
# String - Name of the destination portal.
# String - Name of the source prepared statement.
# Int16 - Number of parameter format codes.
# For each parameter format code:
# Int16 - The parameter format code.
# Int16 - Number of parameter values.
# For each parameter value:
# Int32 - The length of the parameter value, in bytes, not
# including this length. -1 indicates a NULL parameter
# value, in which no value bytes follow.
# Byte[n] - Value of the parameter.
# Int16 - The number of result-column format codes.
# For each result-column format code:
# Int16 - The format code.
retval = bytearray(
NULL_BYTE + statement_name_bin + h_pack(0) + h_pack(len(params)))
for value in params:
if value is None:
retval.extend(i_pack(-1))
else:
val = value.encode(self._client_encoding)
retval.extend(i_pack(len(val)))
retval.extend(val)
retval.extend(h_pack(0))
self._send_message(BIND, retval)
def send_EXECUTE(self):
# Byte1('E') - Identifies the message as an execute message.
# Int32 - Message length, including self.
# String - The name of the portal to execute.
# Int32 - Maximum number of rows to return, if portal
# contains a query # that returns rows.
# 0 = no limit.
self._write(EXECUTE_MSG)
self._write(FLUSH_MSG)
def handle_NO_DATA(self, msg, context):
context.columns = []
def handle_COMMAND_COMPLETE(self, data, context):
values = data[:-1].split(b' ')
command = values[0]
if command in self._commands_with_count:
row_count = int(values[-1])
if context.row_count == -1:
context.row_count = row_count
else:
context.row_count += row_count
def handle_DATA_ROW(self, data, results):
idx = 2
row = []
for func in results.input_funcs:
vlen = i_unpack(data, idx)[0]
idx += 4
if vlen == -1:
v = None
else:
v = func(
str(data[idx:idx + vlen], encoding=self._client_encoding))
idx += vlen
row.append(v)
results.rows.append(row)
def handle_messages(self, context):
code = self.error = None
while code != READY_FOR_QUERY:
code, data_len = ci_unpack(self._read(5))
self.message_types[code](self._read(data_len - 4), context)
if self.error is not None:
raise self.error
# Byte1('C') - Identifies the message as a close command.
# Int32 - Message length, including self.
# Byte1 - 'S' for prepared statement, 'P' for portal.
# String - The name of the item to close.
def close_prepared_statement(self, statement_name_bin):
self._send_message(CLOSE, STATEMENT + statement_name_bin)
self._write(SYNC_MSG)
self._flush()
context = Context()
self.handle_messages(context)
self._statement_nums.remove(statement_name_bin)
# Byte1('N') - Identifier
# Int32 - Message length
# Any number of these, followed by a zero byte:
# Byte1 - code identifying the field type (see responseKeys)
# String - field value
def handle_NOTICE_RESPONSE(self, data, ps):
self.notices.append(
dict((s[0:1], s[1:]) for s in data.split(NULL_BYTE)))
def handle_PARAMETER_STATUS(self, data, ps):
pos = data.find(NULL_BYTE)
key, value = data[:pos], data[pos + 1:-1]
self.parameter_statuses.append((key, value))
if key == b"client_encoding":
encoding = value.decode("ascii").lower()
self._client_encoding = converters.pg_to_py_encodings.get(
encoding, encoding)
elif key == b"integer_datetimes":
if value == b'on':
pass
else:
pass
elif key == b"server_version":
self._server_version = LooseVersion(value.decode('ascii'))
if self._server_version < LooseVersion('8.2.0'):
self._commands_with_count = (
b"INSERT", b"DELETE", b"UPDATE", b"MOVE")
elif self._server_version < LooseVersion('9.0.0'):
self._commands_with_count = (
b"INSERT", b"DELETE", b"UPDATE", b"MOVE", b"FETCH",
b"COPY")
class Context():
def __init__(self, stream=None, columns=None, input_funcs=None):
self.rows = []
self.row_count = -1
self.columns = [] if columns is None else columns
self.stream = stream
self.input_funcs = [] if input_funcs is None else input_funcs