diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index c974d83b5..8366f89a0 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -315,28 +315,29 @@ class LlamaBenchData: class LlamaBenchDataSQLite3(LlamaBenchData): - connection: sqlite3.Connection + connection: Optional[sqlite3.Connection] = None cursor: sqlite3.Cursor table_name: str def __init__(self, tool: str = "llama-bench"): super().__init__(tool) - self.connection = sqlite3.connect(":memory:") - self.cursor = self.connection.cursor() + if self.connection is None: + self.connection = sqlite3.connect(":memory:") + self.cursor = self.connection.cursor() - # Set table name and schema based on tool - if self.tool == "llama-bench": - self.table_name = "llama_bench" - db_fields = LLAMA_BENCH_DB_FIELDS - db_types = LLAMA_BENCH_DB_TYPES - elif self.tool == "test-backend-ops": - self.table_name = "test_backend_ops" - db_fields = TEST_BACKEND_OPS_DB_FIELDS - db_types = TEST_BACKEND_OPS_DB_TYPES - else: - assert False + # Set table name and schema based on tool + if self.tool == "llama-bench": + self.table_name = "llama_bench" + db_fields = LLAMA_BENCH_DB_FIELDS + db_types = LLAMA_BENCH_DB_TYPES + elif self.tool == "test-backend-ops": + self.table_name = "test_backend_ops" + db_fields = TEST_BACKEND_OPS_DB_FIELDS + db_types = TEST_BACKEND_OPS_DB_TYPES + else: + assert False - self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});") + self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});") def _builds_init(self): if self.connection: @@ -397,9 +398,6 @@ class LlamaBenchDataSQLite3(LlamaBenchData): class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): def __init__(self, data_file: str, tool: Any): - super().__init__(tool) - - self.connection.close() self.connection = sqlite3.connect(data_file) self.cursor = self.connection.cursor() @@ -411,27 +409,28 @@ class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): if tool is None: if "llama_bench" in table_names: self.table_name = "llama_bench" - self.tool = "llama-bench" + tool = "llama-bench" elif "test_backend_ops" in table_names: self.table_name = "test_backend_ops" - self.tool = "test-backend-ops" + tool = "test-backend-ops" else: raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}") elif tool == "llama-bench": if "llama_bench" in table_names: self.table_name = "llama_bench" - self.tool = "llama-bench" + tool = "llama-bench" else: raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}") elif tool == "test-backend-ops": if "test_backend_ops" in table_names: self.table_name = "test_backend_ops" - self.tool = "test-backend-ops" + tool = "test-backend-ops" else: raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}") else: raise RuntimeError(f"Unknown tool: {tool}") + super().__init__(tool) self._builds_init() @staticmethod @@ -653,6 +652,8 @@ if not bench_data: if not bench_data.builds: raise RuntimeError(f"{input_file} does not contain any builds.") +tool = bench_data.tool # May have chosen a default if tool was None. + hexsha8_baseline = name_baseline = None