from ibd2sql.innodb_page.page import PAGE
from ibd2sql.innodb_page.lob import FIRST_BLOB
import struct
import zlib
REC_STATUS_ORDINARY = 0 # leaf
REC_STATUS_NODE_PTR = 1 # non-leaf
REC_STATUS_INFIMUM  = 2 # INFIMUM
REC_STATUS_SUPREMUM = 3 # SUPREMUM
REC_N_FIELDS_ONE_BYTE_MAX = 0x7F

class INDEX(PAGE):
	"""
		get_sql: return sql list
		read_all_rows: return all rows (with trxid...)
		new_data: init new data 
			
	"""
	def __init__(self,*args,**kwargs):
		super().__init__()

	def init_index(self,*args,**kwargs):
		self.init(kwargs['data'])
		self.table = kwargs['table']
		self.idxid = kwargs['idxid'] # its 0 when pk_leaf or pk_non_leaf
		self.colid_list = self.table.index[self.idxid]['colid_list']
		self.null_count = self.table.index[self.idxid]['null_count']
		self.pg = kwargs['pg'] # page reader (for overflow page)
		self.pgt = kwargs['pgt'] # page type
		self.dep = kwargs['dep'] if 'dep' in kwargs else False # disable-extra-pages
		self.decode = kwargs['decode'] if 'decode' in kwargs else True # will run col[decode](data)
		self.row_format = self.table.row_format
		self.foffset = 99
		self.f_offset = 99
		if self.row_format == "REDUNDANT":
			self.read_extra_column = self._read_extra_column_with_768
			self.read_rec_header = self._read_rec_header_old
			self.read_nullbitmask_varsize = self._read_nullbitmask_varsize_old
			self.foffset = 101
			self.f_offset = 101
		elif self.row_format == "COMPACT":
			self.read_extra_column = self._read_extra_column_with_768
			self.read_rec_header = self._read_rec_header_new
			self.read_nullbitmask_varsize = self._read_nullbitmask_varsize_new
		else: # DYNAMIC, COMPRESSED
			self.read_extra_column = self._read_extra_column
			self.read_rec_header = self._read_rec_header_new
			self.read_nullbitmask_varsize = self._read_nullbitmask_varsize_new

		if self.row_format == "COMPRESSED":
			self.read_all_rows = self.read_all_rows_compressed
			self._read_extra_column_20_new = self._read_extra_column_20_compressed

		if self.pgt == "PK_LEAF":
			self.read_row = self._read_row_pk_leaf
		elif self.pgt == "PK_NON_LEAF":
			self.read_row = self._read_row_pk_non_leaf
		elif self.pgt == "KEY_LEAF":
			self.read_row = self._read_row_key_leaf
		elif self.pgt == "KEY_NON_LEAF":
			self.read_row = self._read_row_key_non_leaf
		else:
			return "" # exception

		self.offset = self.foffset
		self._offset = self.f_offset
		self.rec_header = {}

		# SQL_PRE
		self.sqlpre = ""
		if 'replace' in kwargs and kwargs['replace']:
			self.sqlpre = "REPLACE"
		else:
			self.sqlpre = "INSERT"
		if 'complete' in kwargs and kwargs['complete']:
			self.sqlpre = "INSERT"
		self.sqlpre += f" INTO {self.table._enclosed}{self.table.schema}{self.table._enclosed}.{self.table._enclosed}{self.table.name}{self.table._enclosed}"	
		if 'complete' in kwargs and kwargs['complete']:
			self.sqlpre += f"(','.join([ self.table._enclosed+colname+self.table._enclosed for colname,coldefault in self.table.column_order ]))"
			self.sqlpre += ")"
		self.sqlpre += " VALUES"
		
		if "multi" in kwargs and kwargs['multi']:
			self.get_sql = self._get_sql_multi
		else:
			self.get_sql = self._get_sql

	def new_data(self,data):
		self.data = data
		self.offset = self.foffset
		self._offset = self.f_offset


	def read_all_rows(self):
		all_row = []
		while True:
			self.read_rec_header()
			if self.rec_header['REC_TYPE'] == REC_STATUS_SUPREMUM:
				break
			elif self.rec_header['REC_TYPE'] == REC_STATUS_INFIMUM:
				self.offset = self._offset = self.rec_header['REC_NEXT']
				continue
			row,extra = self.read_row()
			self.offset = self._offset = self.rec_header['REC_NEXT']
			all_row.append({'data':row,'extra':extra})
		return all_row

	def read_all_rows_compressed(self):
		all_row = []
		n_dense = struct.unpack('>H',self.data[42:44])[0] & 32767
		n_recs = struct.unpack('>H',self.data[54:56])[0]
		d = zlib.decompressobj()
		c = d.decompress(self.data[94:])
		toffset = c.find(b'\x01') + 1
		data = c[toffset:]
		compressed_offset = len(data) + 120
		data += d.unused_data
		self.data = data
		self.offset = 120
		self.end_offset = 120
		self._offset = len(data)
		page_dir = []
		for i in range(n_recs):
			slot = struct.unpack('>H',self.read_reverse(2))[0] & 16383 # ignore owned
			page_dir.append([slot,False])
		for j in range(n_dense-n_recs-2): # user record deleted
			slot = struct.unpack('>H',self.read_reverse(2))[0] & 16383
			page_dir.append([slot,True])
		_ = page_dir.sort()
		if self.pgt == "PK_LEAF":
			trxid_rollptr = [ self.read_reverse(13) for x in range(n_dense-2) ]
		self.end_offset = self._offset
		self.last_offset = self.offset
		self.c_offset = 0 # compressed offset
		for x in range(n_dense-2):
			r = page_dir[x]
			self.last_offset = self.offset
			self.c_offset += 13 if self.c_offset > 0 else 0
			self.c_offset += 5
			self.offset = r[0] - self.c_offset
			is_compressed = True
			if self.offset > compressed_offset:
				self.offset += 1 if x < 64 else 2
				self.c_offset += 1 if x < 64 else 2
				self.last_offset += 1 if x < 64 else 2
				is_compressed = False
			self._offset = self.offset
			self.rec_header = {
				"REC_INFO_INSTANT":False,
				"REC_INFO_VERSION":False,
				"REC_INFO_DELETED":r[1],
				"REC_INFO_MIN_REC":True if self.c_offset == 5 else False,
				"REC_N_OWNED":False,
				"REC_HEAP_NO":0,
				"REC_TYPE": 0 if self.pgt in ["PK_LEAF","KEY_LEAF"] else 1,
				"REC_NEXT":self.offset,
				"is_compressed":is_compressed
			}
			print("self.offset,self.c_offset,self.last_offset,self._offset",self.offset,self.c_offset,self.last_offset,self._offset,compressed_offset)
			row,extra = self.read_row()
			all_row.append({'data':row,'extra':extra})
		return all_row
			

	def _get_sql(self):
		sql_list = []
		for data in self.read_all_rows():
			data = data['data']
			v = ''
			for colname,coldefault in self.table.column_order:
				v += f"{coldefault if colname not in data else data[colname]['data']},"
			sql_list.append(f"{self.sqlpre}({v[:-1]})")
		return sql_list

	def _get_sql_multi(self):
		sql = f"{self.sqlpre}"
		for data in self.read_all_rows():
			data = data['data']
			v = ''
			for colname,coldefault in self.table.column_order:
				v += f"{coldefault if colname not in data else data[colname]['data']},"
			sql += f"{v[:-1]},"
		return sql[:-1]

	def get_sql(self):
		pass

	def read_row(self):
		return None,None

	#def _read_row(self,null_list,size_list):
	def _read_row(self,colid_list,null_count):
		null_list,size_list = self.read_nullbitmask_varsize(colid_list,null_count)
		print(self.offset,self._offset,null_list,size_list)
		row = {}
		for colid in colid_list:
			col = self.table.column[colid]
			colname = col['name']
			vsize = size_list.pop(0)
			nullable = null_list.pop(0)
			if colname in ['DB_TRX_ID','DB_ROLL_PTR']:
				continue
			start_offset = self.offset
			data = None
			if nullable:
				data = "null"
			elif vsize == 16384 and not self.dep:
				data = self.read_extra_column()
			else:
				data = self.read(vsize)
			row[colname] = {
				'data':col['decode'](data,*col['args']) if self.decode and not nullable else data,
				'start_offset':start_offset,
				'size':vsize
			}
		return row

	def _read_row_pk_leaf(self):
		#self.read_rec_header()
		row_version = self.read_row_version()
		colid_list = self.table.pk + self.table.pkmr[row_version]['colid']
		null_count = self.table.pkmr[row_version]['null_count']
		row = self._read_row(colid_list, null_count)
		return row,0

	def _read_row_pk_non_leaf(self):
		row = self._read_row(self.table.pk,0)
		return row,struct.unpack(self.read(4))[0]
		

	def _read_row_key_leaf(self):
		row = self._read_row(self.colid_list+self.pk,self.null_count)
		return row,0

	def _read_row_key_non_leaf(self):
		row = self._read_row(self.colid_list,self.null_count)
		return row,struct.unpack(self.read(4))[0]

	def read_row_version(self):
		return struct.unpack('>B',self.read_reverse(1))[0] if self.rec_header['REC_INFO_INSTANT'] or self.rec_header['REC_INFO_VERSION'] else 0

	def _read_nullbitmask_varsize_old(self,colid_list,null_count=0):
		null_list = []
		size_list = [] # size_list.append(),  size_list.pop(0)
		size_null_format = '>H'
		size_null_size = 2
		nmask = 32768
		if self.rec_header['REC_SHORT']:
			size_null_format = '>B'
			size_null_size = 1
			nmask = 128
		lastoffset = 0
		for colid in colid_list:
			#col = self.table.column[colid]
			size_null = struct.unpack(size_null_format,self.read_reverse(size_null_size))[0]
			isnull = True if nmask&size_null else False
			vsize = (nmask-1)&size_null
			t = vsize
			vsize -= lastoffset
			lastoffset = t
			size_list.append(vsize)
			null_list.append(isnull)
		return null_list,size_list	

	def _read_nullbitmask_varsize_new(self,colid_list,null_count):
		null_list = []
		size_list = []
		if 'is_compressed' in self.rec_header and not self.rec_header['is_compressed']:
			self.data[self.last_offset:self.last_offset+(null_count+7)//8]
			nullvalue = int.from_bytes(self.data[self.last_offset:self.last_offset+(null_count+7)//8],'big') if null_count > 0 else 0
		else:
			nullvalue = int.from_bytes(self.read_reverse((null_count+7)//8),'big') if null_count > 0 else 0
		n = 0
		for colid in colid_list:
			col = self.table.column[colid]
			if col['is_nullable']:
				null_list.append(True if nullvalue&(1<<n) else False)
				n += 1
			else:
				null_list.append(False)
			vsize = col['size']
			if col['is_var']:
				if col['is_big']: # maxsize ge 255
					tsize = struct.unpack('>B',self.read_reverse(1))[0]
					if tsize > REC_N_FIELDS_ONE_BYTE_MAX:
						vsize = struct.unpack('>B',self.read_reverse(1))[0] + (tsize-128)*256
					else:
						vsize = tsize
				else:
					vsize = struct.unpack('>B',self.read_reverse(1))[0]
			size_list.append(vsize)
		return null_list,size_list	

	def read_nullbitmask_varsize(self,row_version,rec_header):
		pass

	def _read_rec_header_old(self):
		data = self.read_reverse(6)
		rec,rec_next = struct.unpack('>LH',data)
		REC_TYPE = REC_STATUS_ORDINARY if self.data[64:66] == b'\x00\x00' else REC_STATUS_NODE_PTR
		if self.offset == 101:
			REC_TYPE = REC_STATUS_INFIMUM
		if rec_next == 0:
			REC_TYPE = REC_STATUS_SUPREMUM
		self.rec_header = {
			"REC_INFO_INSTANT": True if rec&2147483648 > 0 else False,
			"REC_INFO_VERSION": True if rec&1073741824 > 0 else False,
			"REC_INFO_DELETED": True if rec&536870912  > 0 else False,
			"REC_INFO_MIN_REC": True if rec&268435456  > 0 else False,
			"REC_N_OWNED" : (rec&251658240)>>24,
			"REC_HEAP_NO" : (rec&16775168)>>11,
			"REC_N_FIELDS": (rec&2046)>>1,
			"REC_SHORT"   : True if rec&1 == 1 else False,
			"REC_TYPE"    : REC_TYPE,
			"REC_NEXT"    : rec_next
		}

	def _read_rec_header_new(self):
		data = self.read_reverse(5)
		rec1,rec2,rec_next = struct.unpack('>HBh',data)
		rec = (rec1<<8)+rec2
		self.rec_header = {
			"REC_INFO_INSTANT": True if rec&8388608 > 0 else False,
			"REC_INFO_VERSION": True if rec&4194304 > 0 else False,
			"REC_INFO_DELETED": True if rec&2097152 > 0 else False,
			"REC_INFO_MIN_REC": True if rec&1048576 > 0 else False,
			"REC_N_OWNED" : (rec&983040)>>16,
			"REC_HEAP_NO" : (rec&65528)>>3,
			"REC_TYPE"    : rec&7,
			"REC_NEXT"    : rec_next + self.offset
		}

	def read_rec_header(self):
		pass

	def _read_extra_column_with_768(self,):
		return self.read(768)+self._read_extra_column()

	def _read_extra_column_20_new(self):
		return self.read(20)

	def _read_extra_column_20_compressed(self,):
		data = self.data[self.end_offset-20:self.end_offset]
		self.end_offset -= 20
		return data

	def _read_extra_column(self,):
		#SPACE_ID,PAGENO,BLOB_HEADER,REAL_SIZE = struct.unpack('>3LQ',self.read(20))
		SPACE_ID,PAGENO,BLOB_HEADER,REAL_SIZE = struct.unpack('>3LQ',self._read_extra_column_20_new())
		data = b''
		if self.table.mysqld_version_id > 50744:
			data = FIRST_BLOB(self.f,PAGENO)
		else:
			while True:
				_ndata = self.pg.read(PAGENO)
				REAL_SIZE,PAGENO = struct.unpack('>LL',_ndata[38:46])
				data += _ndata[46:46+REAL_SIZE]
				if PAGENO == 4294967295:
					break
		return data

	def read_extra_column(self,):
		pass

