diff --git a/TAP/datadictionary.py b/TAP/datadictionary.py index a9471ac..07de80f 100644 --- a/TAP/datadictionary.py +++ b/TAP/datadictionary.py @@ -90,15 +90,25 @@ def __init__(self, conn, table, **kwargs): logging.debug('-------------------------------------------------') - sql = "select * from TAP_SCHEMA.columns where lower(table_name) = " + \ - "'" + self.dbtable + "'" + # Detect placeholder style from connection type + conn_type = type(self.conn).__module__ + if 'cx_Oracle' in conn_type or 'oracledb' in conn_type: + placeholder = ':1' + elif 'psycopg2' in conn_type: + placeholder = '%s' + else: + placeholder = '?' + + sql = "select * from TAP_SCHEMA.columns where lower(table_name) = " \ + + placeholder if self.debug: logging.debug('') logging.debug(f'TAP_SCHEMA sql = {sql:s}') + logging.debug(f' param = {self.dbtable:s}') try: - cursor.execute(sql) + cursor.execute(sql, (self.dbtable,)) except Exception as e: diff --git a/TAP/propfilter.py b/TAP/propfilter.py index 8bce2b8..5da9519 100644 --- a/TAP/propfilter.py +++ b/TAP/propfilter.py @@ -11,6 +11,7 @@ from TAP.writeresult import writeResult from TAP.datadictionary import dataDictionary from TAP.tablenames import TableNames +from TAP.tablevalidator import TableValidator class propFilter: @@ -460,6 +461,23 @@ def __init__(self, **kwargs): logging.debug('') logging.debug(f'dbtable = [{self.dbtable:s}]') + # + # Defense-in-depth: validate tables against TAP_SCHEMA + # + + try: + validator = TableValidator(self.conn, debug=self.debug) + validator.validate(tables) + + except Exception as e: + + if self.debug: + logging.debug('') + logging.debug(f'Table validation exception: {str(e):s}') + + self.msg = str(e) + raise Exception(self.msg) + # # Parse query: to extract query pieces for propfilter # diff --git a/TAP/runquery.py b/TAP/runquery.py index 50af633..a75e06a 100644 --- a/TAP/runquery.py +++ b/TAP/runquery.py @@ -15,6 +15,7 @@ from TAP.datadictionary import dataDictionary from TAP.writeresult import writeResult from TAP.tablenames import TableNames +from TAP.tablevalidator import TableValidator class runQuery: @@ -317,6 +318,26 @@ def __init__(self, **kwargs): raise Exception(self.msg) + # + # Defense-in-depth: validate tables against TAP_SCHEMA + # + + try: + tn = TableNames() + query_tables = tn.extract_tables(self.sql) + + validator = TableValidator(self.conn, debug=self.debug) + validator.validate(query_tables) + + except Exception as e: + + if self.debug: + logging.debug('') + logging.debug(f'Table validation exception: {str(e):s}') + + self.msg = str(e) + raise Exception(self.msg) + # # Retrieve dd table # diff --git a/TAP/tablenames.py b/TAP/tablenames.py index 6a9f9ca..6576d0b 100644 --- a/TAP/tablenames.py +++ b/TAP/tablenames.py @@ -45,7 +45,8 @@ def extract_from_part(self, parsed): for x in self.extract_from_part(item): yield x elif item.ttype is Keyword and item.value.upper() in \ - ['ORDER', 'GROUP', 'BY', 'HAVING', 'GROUP BY']: + ['ORDER', 'ORDER BY', 'GROUP', 'GROUP BY', + 'BY', 'HAVING', 'LIMIT', 'OFFSET']: from_seen = False StopIteration else: diff --git a/TAP/tablevalidator.py b/TAP/tablevalidator.py new file mode 100644 index 0000000..59c6fbb --- /dev/null +++ b/TAP/tablevalidator.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020, Caltech IPAC. +# This code is released with a BSD 3-clause license. License information is at +# https://github.com/Caltech-IPAC/nexsciTAP/blob/master/LICENSE + + +import logging + + +class TableValidator: + """ + Validates that table names in an ADQL query are registered in + TAP_SCHEMA.tables, preventing access to unauthorized database objects. + """ + + def __init__(self, conn, debug=0): + + self.conn = conn + self.debug = debug + + self.allowed_tables = set() + self.allowed_bare = set() + self.allowed_schemas = set() + + self._load_allowed_tables() + + def _load_allowed_tables(self): + + cursor = self.conn.cursor() + cursor.execute('SELECT table_name FROM TAP_SCHEMA.tables') + rows = cursor.fetchall() + cursor.close() + + for row in rows: + full_name = row[0].strip().lower() + self.allowed_tables.add(full_name) + + if '.' in full_name: + schema, bare = full_name.split('.', 1) + self.allowed_bare.add(bare) + self.allowed_schemas.add(schema) + else: + self.allowed_bare.add(full_name) + + if self.debug: + logging.debug('') + logging.debug( + f'TableValidator: loaded {len(self.allowed_tables)} ' + f'allowed tables: {self.allowed_tables}') + + def validate(self, table_names): + + if not table_names: + raise Exception('No table names to validate.') + + for tname in table_names: + tname_lower = tname.strip().lower() + + # Exact match against full table names (e.g. "tap_schema.columns") + if tname_lower in self.allowed_tables: + continue + + if '.' in tname_lower: + schema, bare = tname_lower.split('.', 1) + + # Schema-qualified query table: only match if the schema + # is one we know about AND the bare name is allowed. + # This prevents "information_schema.tables" from matching + # just because "tables" is a bare name in TAP_SCHEMA. + if schema in self.allowed_schemas and \ + bare in self.allowed_bare: + continue + else: + # Unqualified query table: bare-name match is fine + # (e.g. query says "columns", whitelist has + # "tap_schema.columns") + if tname_lower in self.allowed_bare: + continue + + raise Exception( + f'Table \'{tname}\' is not available for querying. ' + f'Use TAP_SCHEMA.tables to see available tables.') + + if self.debug: + logging.debug('') + logging.debug( + f'TableValidator: all tables validated: {table_names}') diff --git a/TAP/tap.py b/TAP/tap.py index 3a65c95..cc20103 100755 --- a/TAP/tap.py +++ b/TAP/tap.py @@ -1533,9 +1533,9 @@ def __init__(self, **kwargs): if len(self.dbtable) == 0: self.msg = 'No table name found in ADQL query.' - + if(self.tapcontext == 'async'): - + self.__writeAsyncError__(self.msg, self.statuspath, self.statdict, self.param) else: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tablevalidator.py b/tests/test_tablevalidator.py new file mode 100644 index 0000000..dd2ac36 --- /dev/null +++ b/tests/test_tablevalidator.py @@ -0,0 +1,110 @@ +import os +import sqlite3 +import tempfile +import unittest + +from TAP.tablevalidator import TableValidator + + +def _make_db(table_names): + """Create SQLite DBs with TAP_SCHEMA attached, mimicking real setup. + + Returns (conn, tap_schema_path) — caller is responsible for cleanup. + """ + fd, tap_schema_path = tempfile.mkstemp(suffix='.db') + os.close(fd) + + schema_conn = sqlite3.connect(tap_schema_path) + schema_conn.execute('CREATE TABLE tables (table_name TEXT)') + for name in table_names: + schema_conn.execute('INSERT INTO tables VALUES (?)', (name,)) + schema_conn.commit() + schema_conn.close() + + conn = sqlite3.connect(':memory:') + conn.execute('ATTACH DATABASE ? AS TAP_SCHEMA', (tap_schema_path,)) + + return conn, tap_schema_path + + +class TestTableValidator(unittest.TestCase): + + def setUp(self): + self.conn, self._tap_schema_path = _make_db([ + 'ps', + 'pscomppars', + 'stellarhosts', + 'TAP_SCHEMA.tables', + 'TAP_SCHEMA.columns', + 'TAP_SCHEMA.schemas', + 'cumulative', + ]) + + def tearDown(self): + self.conn.close() + if os.path.exists(self._tap_schema_path): + os.unlink(self._tap_schema_path) + + def test_exact_match(self): + v = TableValidator(self.conn) + v.validate(['ps']) # should not raise + + def test_case_insensitive(self): + v = TableValidator(self.conn) + v.validate(['PS']) + v.validate(['Ps']) + v.validate(['TAP_SCHEMA.Tables']) + + def test_schema_prefix_in_whitelist_bare_in_query(self): + """TAP_SCHEMA.columns is whitelisted; query says just 'columns'.""" + v = TableValidator(self.conn) + v.validate(['columns']) + + def test_bare_in_whitelist_unknown_schema_in_query(self): + """'ps' is whitelisted bare; 'public.ps' has unknown schema — rejected.""" + v = TableValidator(self.conn) + with self.assertRaises(Exception): + v.validate(['public.ps']) + + def test_known_schema_bare_table_in_query(self): + """'ps' is whitelisted bare; 'tap_schema.ps' uses known schema — allowed.""" + v = TableValidator(self.conn) + v.validate(['tap_schema.ps']) + + def test_disallowed_table_raises(self): + v = TableValidator(self.conn) + with self.assertRaises(Exception) as ctx: + v.validate(['pg_catalog.pg_tables']) + self.assertIn('not available', str(ctx.exception)) + + def test_multi_table_all_valid(self): + v = TableValidator(self.conn) + v.validate(['ps', 'pscomppars', 'stellarhosts']) + + def test_multi_table_one_invalid(self): + v = TableValidator(self.conn) + with self.assertRaises(Exception) as ctx: + v.validate(['ps', 'information_schema.tables', 'stellarhosts']) + self.assertIn('information_schema.tables', str(ctx.exception)) + + def test_empty_table_list_raises(self): + v = TableValidator(self.conn) + with self.assertRaises(Exception): + v.validate([]) + + def test_system_catalog_blocked(self): + v = TableValidator(self.conn) + for bad_table in [ + 'ALL_TABLES', + 'DBA_USERS', + 'V$SESSION', + 'information_schema.tables', + 'pg_catalog.pg_class', + 'EXOFOP.FILES', + ]: + with self.assertRaises(Exception, msg=f'{bad_table} should be blocked'): + v.validate([bad_table]) + + +if __name__ == '__main__': + unittest.main()