diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/.gitignore b/vlmeval/dataset/PHYSICS/Math-Verify/.gitignore new file mode 100644 index 000000000..aaecbecb2 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/.gitignore @@ -0,0 +1,56 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +venv/ +env/ +ENV/ +.env +.venv + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +.DS_Store + +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.tox/ + +# Distribution +*.tar.gz +*.whl + +# Logs +*.log +logs/ + +# PyPI +.pypirc + +# Outputs +results/ +tmp/ diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/.gitmodules b/vlmeval/dataset/PHYSICS/Math-Verify/.gitmodules new file mode 100644 index 000000000..65cd8a38d --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/.gitmodules @@ -0,0 +1,3 @@ +[submodule "latex2sympy2_extended_cloned"] + path = latex2sympy2_extended_cloned + url = git@github.com:huggingface/latex2sympy2_extended.git diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/CITATION.cff b/vlmeval/dataset/PHYSICS/Math-Verify/CITATION.cff new file mode 100644 index 000000000..6aaf048f3 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/CITATION.cff @@ -0,0 +1,16 @@ +cff-version: 1.2.0 +title: 'Math-Verify: Math Verification Library' +message: >- + If you use this software, please cite it using the metadata from this file. +type: software +authors: + - given-names: Hynek + family-names: Kydlíček +repository-code: 'https://github.com/huggingface/math-verify' +abstract: "Math-Verify is a library to rule-based verify mathematical answers" +keywords: + - verification + - math + - evaluation +license: Apache-2.0 +version: 0.6.1 \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/LICENCE b/vlmeval/dataset/PHYSICS/Math-Verify/LICENCE new file mode 100644 index 000000000..479cdd0de --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/LICENCE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2025] [HuggingFace Inc.] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/Makefile b/vlmeval/dataset/PHYSICS/Math-Verify/Makefile new file mode 100644 index 000000000..679b3ba44 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/Makefile @@ -0,0 +1,14 @@ +.PHONY: quality style test + +check_dirs := src examples + +check: + ruff check $(check_dirs) # linter + ruff format --check $(check_dirs) # formatter + +format: + ruff check --fix $(check_dirs) # linter + black $(check_dirs) # formatter + +test: + python -m pytest -sv ./tests/ \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/assets/flow.svg b/vlmeval/dataset/PHYSICS/Math-Verify/assets/flow.svg new file mode 100644 index 000000000..ec74c0a3b --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/assets/flow.svg @@ -0,0 +1,2 @@ +ParserPrepare regex based on config.- More specific regexesassigned higher priorityOn FailOn MatchExtract using regexes,- lowest priority first- Last match to First match in case of multipleNormalization:- Formatting removal (mathrm, displaystyle)- Units and their variations removal- Malformed operators ( sqrt(2) -> sqrt{2} frac13 -> frac{1}{3} )- Multistep eq shorten (a+b=b+a=1 -> b+a=1)- Boxed env extractionLatex -> Sympy - Antlr based - Percentage conversion - Matrix operations - Derivatives and integrals - Complex numbers - Sets and intervalsLatex matchExpr -> Sympy- Constructed directly if simple number- If not sympy.parse_expr to parse fractions etc...- Percentage basedExpr matchGoldComparisonOn successstr vs strsympy vs sympya.strip() == b.strip()Numeric comparissonStructual and String comparisson Symbolic comparisson using (a-b).simplify -> type based0 checkon FailRelational vs sympy Compare lhs-rhs of gold and pred. -Allow for flipped relation- structural for Matrices- precision aware comparison only if neither of results are plain numbers- Rounding only if gold/pred is floaton FailInterval or FiniteSetsInteval check using emptysym difference \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model.py b/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model.py new file mode 100644 index 000000000..fea3be588 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model.py @@ -0,0 +1,71 @@ +from datetime import timedelta +import argparse +from pathlib import Path +from lighteval.logging.evaluation_tracker import EvaluationTracker +from lighteval.models.transformers.transformers_model import TransformersModelConfig +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters +from lighteval.utils.utils import EnvConfig +from lighteval.utils.imports import is_accelerate_available + +if is_accelerate_available(): + from accelerate import Accelerator, InitProcessGroupKwargs + accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) +else: + accelerator = None + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + argparse.Namespace: Parsed command line arguments + """ + parser = argparse.ArgumentParser(description='Evaluate model on math tasks') + parser.add_argument('--task', type=str, required=True, + choices=['gsm8k', 'math', 'math_hard', 'math_500', 'aime24', 'amc23'], + help='Task to evaluate') + parser.add_argument('--model', type=str, required=True, + help='Model name or path') + parser.add_argument('--use_chat_template', action='store_true', default=False, + help='Use chat template') + parser.add_argument('--override_bs', type=int, default=-1, + help='Batch size; -1 for automatic batch size') + return parser.parse_args() + + +def main() -> None: + """Main function to run model evaluation.""" + args = parse_args() + + evaluation_tracker = EvaluationTracker( + output_dir="./results", + save_details=True, + push_to_hub=False, + ) + + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.ACCELERATE, + max_samples=1000, + custom_tasks_directory="math_verify.tasks", + env_config=EnvConfig(cache_dir="tmp/"), + override_batch_size=args.override_bs, + ) + + model_config = TransformersModelConfig( + pretrained=args.model, + dtype="bfloat16", + use_chat_template=args.use_chat_template, + ) + + pipeline = Pipeline( + tasks=f"lighteval|{args.task}|4|1", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, + ) + + pipeline.evaluate() + pipeline.show_results() + pipeline.save_and_push_results() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model_outputs.py b/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model_outputs.py new file mode 100644 index 000000000..dc800aa5b --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/evaluate_model_outputs.py @@ -0,0 +1,141 @@ +import argparse +import pandas as pd +from typing import Any +from math_verify.metric import math_metric +from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig +import sympy + +def parse_args(): + parser = argparse.ArgumentParser(description='Extract and evaluate answers using sympy') + parser.add_argument('--input_csv', type=str, required=True, help='Path to input CSV file containing model outputs') + parser.add_argument('--output_csv', type=str, required=True, help='Path to output CSV file for extracted answers') + parser.add_argument('--gold_is_latex', action='store_true', help='Use basic latex normalization', default=True) + return parser.parse_args() + +def load_csv_data(csv_path: str) -> pd.DataFrame: + """Load and validate CSV data.""" + try: + df = pd.read_csv(csv_path) + required_columns = ['answer', 'gold'] + if not all(col in df.columns for col in required_columns): + raise ValueError(f"CSV must contain columns: {required_columns}") + return df + except Exception as e: + raise Exception(f"Error loading CSV file: {str(e)}") + +def serialize_sympy_object(obj: Any) -> str: + """Convert sympy object to string representation.""" + if obj is None: + return "" + try: + if isinstance(obj, (list, tuple)): + return ", ".join(str(x) if x is not None else "" for x in obj) + return str(obj) + except Exception as e: + return f"Error: {str(e)}" + +def compare_answers(extracted: Any, gold: Any) -> bool: + """Compare extracted answer with gold answer.""" + if extracted is None or gold is None: + return False + try: + # Handle lists/tuples of expressions + if isinstance(extracted, (list, tuple)) and isinstance(gold, (list, tuple)): + if len(extracted) != len(gold): + return False + return all(sympy.simplify(a - b) == 0 for a, b in zip(extracted, gold)) + + # Handle single expressions + return sympy.simplify(extracted - gold) == 0 + except Exception: + # If comparison fails (e.g. different types), return False + return False + +def process_answers(df: pd.DataFrame, gold_is_latex: bool) -> pd.DataFrame: + """Process each answer through the sympy extraction workflow and compare with gold using math_verify.""" + results = [] + + + correct_count = 0 + total_count = 0 + + # Create the verification function + verify_func = math_metric( + gold_extraction_target=(LatexExtractionConfig() if gold_is_latex else ExprExtractionConfig(),), + pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), + aggregation_function=max, + precision=6 + ) + + for _, row in df.iterrows(): + extracted_answers = None + gold_answers = None + grade = 0 + try: + # Use the verification function + grade, extracted_answers = verify_func([row['gold']], [row['answer']]) + + if extracted_answers is None: + extracted_answers = None + gold_answers = None + else: + gold_answers = extracted_answers[0] + extracted_answers = extracted_answers[1] + + total_count += 1 + if grade == 1: + correct_count += 1 + + result = { + 'original_answer': row['answer'], + 'gold_answer': row['gold'], + 'extracted_answer': extracted_answers, + 'extracted_gold': gold_answers, + 'is_correct': grade == 1 + } + + results.append(result) + + except Exception as e: + results.append({ + 'original_answer': row['answer'], + 'gold_answer': row['gold'], + 'extracted_answer': extracted_answers, + 'extracted_gold': gold_answers, + 'is_correct': grade == 1, + 'error': str(e) + }) + + results_df = pd.DataFrame(results) + + # Calculate accuracy + accuracy = correct_count / total_count if total_count > 0 else 0 + print(f"\nEvaluation Results:") + print(f"Total examples: {total_count}") + print(f"Correct answers: {correct_count}") + print(f"Accuracy: {accuracy:.2%}") + + # Add summary stats to the dataframe + results_df.attrs['accuracy'] = accuracy + results_df.attrs['total_count'] = total_count + results_df.attrs['correct_count'] = correct_count + + return results_df + +def main(): + args = parse_args() + + # Load input CSV + input_df = load_csv_data(args.input_csv) + + # Process answers and extract sympy objects + results_df = process_answers(input_df, args.gold_is_latex) + + # Save results to output CSV + results_df.to_csv(args.output_csv, index=False) + print(f"\nResults saved to {args.output_csv}") + +if __name__ == "__main__": + main() + + diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/examples/few_shot_answers.csv b/vlmeval/dataset/PHYSICS/Math-Verify/examples/few_shot_answers.csv new file mode 100644 index 000000000..ad2864006 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/examples/few_shot_answers.csv @@ -0,0 +1,9 @@ +answer +"Janet's ducks lay 16 eggs per day\nShe eats 3 eggs for breakfast\nThis leaves 16 - 3 = 13 eggs to sell\nEach egg sells for $2\nSo she makes 13 * $2 = $26 per day\n\nThe answer is $26. I hope it is correct." +"Total labor cost = 3 workers * $20 per hour * 40 hours = $2,400\nRevenue from job = $5,400\nProfit = Revenue - Cost\nProfit = $5,400 - $2,400 = $3,000\n\nThe answer is $3,000. I hope it is correct." +"Steve has 8 marbles\nJohn has 3 less than Steve, so John has 8 - 3 = 5 marbles\nSam has 5 times as many as John\nSo Sam has 5 × 5 = 25 marbles\n\nThe answer is 25. I hope it is correct." +"Tom is baking cookies. One batch makes 12 cookies and needs 2.5 cups of flour\nTo make 60 cookies, Tom needs 60 ÷ 12 = 5 batches\nEach batch needs 2.5 cups of flour\nTotal flour needed = 5 × 2.5 = 12.5 cups\n\nThe answer is 12.5. I hope it is correct." +"The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nThe answer is $[2,5)$. I hope it is correct." +"We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.\nThe answer is $24$. I hope it is correct." +"If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nThe answer is $16$. I hope it is correct." +"If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nThe answer is $-\\frac{2}{3}$. I hope it is correct." \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/examples/model_outputs.csv b/vlmeval/dataset/PHYSICS/Math-Verify/examples/model_outputs.csv new file mode 100644 index 000000000..a77f2ca25 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/examples/model_outputs.csv @@ -0,0 +1,7 @@ +answer,gold +"so that x == 1 or x == 2, thus the result is $2*\pi*r$","$2*\pi*r$" +"The answer is $$\sin(x)$$","$$\sin(x)$$" +"After solving, we get $1/2$","$1/2$" +"The final result is $$(a + b)^2$$","$(a + b)^2$" +"Therefore, $$3!$$","$3!$" +"The point coordinates are $(1,2)$","$(1,2)$" \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/examples/sample_answers.csv b/vlmeval/dataset/PHYSICS/Math-Verify/examples/sample_answers.csv new file mode 100644 index 000000000..f83663396 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/examples/sample_answers.csv @@ -0,0 +1,7 @@ +answer +"so that x == 1 or x == 2, thus the result is $2*\pi*r$" +"The answer is $$\sin(x)$$" +"After solving, we get $1/2$" +"The final result is $$(a + b)^2$$" +"Therefore, $$3!$$" +"The point coordinates are $(1,2)$" diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/extract_answers.py b/vlmeval/dataset/PHYSICS/Math-Verify/extract_answers.py new file mode 100644 index 000000000..1ca28c33e --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/extract_answers.py @@ -0,0 +1,97 @@ +import argparse +import pandas as pd +from typing import Any +from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig, parse + +def parse_args(): + parser = argparse.ArgumentParser(description='Extract and evaluate answers using sympy') + parser.add_argument('--input_csv', type=str, required=True, help='Path to input CSV file containing model outputs') + parser.add_argument('--output_csv', type=str, required=True, help='Path to output CSV file for extracted answers') + return parser.parse_args() + +def load_csv_data(csv_path: str) -> pd.DataFrame: + """Load and validate CSV data.""" + try: + df = pd.read_csv(csv_path) + required_columns = ['answer'] + if not all(col in df.columns for col in required_columns): + raise ValueError(f"CSV must contain columns: {required_columns}") + return df + except Exception as e: + raise Exception(f"Error loading CSV file: {str(e)}") + +def serialize_sympy_object(obj: Any) -> str: + """Convert sympy object to string representation.""" + if obj is None: + return "" + try: + if isinstance(obj, (list, tuple)): + return ", ".join(str(x) if x is not None else "" for x in obj) + return str(obj) + except Exception as e: + return f"Error: {str(e)}" + +def process_answers(df: pd.DataFrame) -> pd.DataFrame: + """Process each answer through the sympy extraction workflow.""" + results = [] + + # Set up extraction config and get regexes + extraction_target = (ExprExtractionConfig(), LatexExtractionConfig()) + for _, row in df.iterrows(): + try: + # Extract answer using regexes + extracted = parse(row['answer'], extraction_config=extraction_target) + feedback = None + extracted_answer = None + if len(extracted) == 2: + extracted_answer = extracted[0] + feedback = extracted[1] + elif len(extracted) == 1: + extracted_answer = extracted[0] + else: + feedback = "No valid extraction found" + extracted_answer = None + + + + result = { + 'original_answer': row['answer'], + 'extracted_answer': serialize_sympy_object(extracted_answer), + 'extracted_feedback': feedback, + 'extraction_success': extracted_answer is not None + } + + # Copy any other columns from input + for col in df.columns: + if col != 'answer': + result[col] = row[col] + + results.append(result) + + except Exception as e: + results.append({ + 'original_answer': row['answer'], + 'extracted_answer': '', + 'extraction_success': False, + 'error': str(e) + }) + + return pd.DataFrame(results) + +def main(): + args = parse_args() + + # Load input CSV + input_df = load_csv_data(args.input_csv) + + # Process answers and extract sympy objects + results_df = process_answers(input_df) + + # Save results to output CSV + results_df.to_csv(args.output_csv, index=False) + print(f"Results saved to {args.output_csv}") + +if __name__ == "__main__": + main() + + diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/output.csv b/vlmeval/dataset/PHYSICS/Math-Verify/output.csv new file mode 100644 index 000000000..ee3a985d9 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/output.csv @@ -0,0 +1,7 @@ +original_answer,extracted_answer,extracted_feedback,extraction_success +"so that x == 1 or x == 2, thus the result is $2*\pi*r$",2*pi*r,2*\pi*r,True +The answer is $$\sin(x)$$,sin(x),\sin(x),True +"After solving, we get $1/2$",1/2,\frac{1}{2},True +The final result is $$(a + b)^2$$,(a + b)**2,(a + b)^2,True +"Therefore, $$3!$$",factorial(3),3!,True +"The point coordinates are $(1,2)$","Interval.open(1, 2)","(1,2)",True diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/pyproject.toml b/vlmeval/dataset/PHYSICS/Math-Verify/pyproject.toml new file mode 100644 index 000000000..aadb8c649 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/pyproject.toml @@ -0,0 +1,83 @@ +[project] +name = "math-verify" +version = "0.7.0" +description = "HuggingFace library for verifying mathematical answers" +authors = [ + { name = "Hynek Kydlíček", email = "hynek.kydlicek@huggingface.co" } +] +license = { text = "Apache 2.0" } +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +keywords = ["verification", "math", "evaluation"] +dependencies = [ + "latex2sympy2_extended==1.10.1", +] +requires-python = ">=3.10" + +[project.optional-dependencies] +inference = [ + "lighteval[math]" +] + +test = [ + "pytest", +] + +format = [ + "ruff", +] + +dev = [ + "math-verify[test]", + "math-verify[format]", +] + +antlr4_9_3 = [ + "latex2sympy2_extended[antlr4_9_3]" +] + +antlr4_11_0 = [ + "latex2sympy2_extended[antlr4_11_0]" +] + +antlr4_13_2 = [ + "latex2sympy2_extended[antlr4_13_2]" +] + +[project.urls] +Repository = "https://github.com/huggingface/math-verify" + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff] +line-length = 88 +lint.select = ["E", "F", "I", "N", "W", "B", "A"] +lint.ignore = ["E501"] + +[tool.ruff.lint.isort] +known-first-party = ["math_verify"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.black] +line-length = 88 +preview = true \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/scripts/publish.sh b/vlmeval/dataset/PHYSICS/Math-Verify/scripts/publish.sh new file mode 100644 index 000000000..eebcf1e75 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/scripts/publish.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status +set -e + +# Build the package +python3 -m build + +# Publish the package to PyPI using the token from .pypirc +twine upload --config-file .pypirc dist/* + diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/__init__.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/__init__.py new file mode 100644 index 000000000..7b8775145 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/__init__.py @@ -0,0 +1,22 @@ +from latex2sympy2_extended.latex2sympy2 import ( + NormalizationConfig as LatexNormalizationConfig, +) + +from math_verify.grader import verify +from math_verify.metric import math_metric +from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + StringExtractionConfig, + parse, +) + +__all__ = [ + "parse", + "verify", + "math_metric", + "ExprExtractionConfig", + "LatexExtractionConfig", + "StringExtractionConfig", + "LatexNormalizationConfig", +] diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/errors.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/errors.py new file mode 100644 index 000000000..d5bed5a00 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/errors.py @@ -0,0 +1,2 @@ +class TimeoutException(BaseException): + pass diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/few_shots.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/few_shots.py new file mode 100644 index 000000000..defcc1966 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/few_shots.py @@ -0,0 +1,69 @@ +GSM8K_FEW_SHOTS = [ + { + "question": ( + "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and sells the rest to her neighbors for $2 per egg. How much in dollars does she make per day?" + ), + "answer": ( + "Janet's ducks lay 16 eggs per day\nShe eats 3 eggs for breakfast\nThis leaves 16 - 3 = 13 eggs to sell\nEach egg sells for $2\nSo she makes 13 * $2 = $26 per day\nThe final answer is $26. I hope it is correct." + ), + }, + { + "question": ( + "A contractor quotes a job at $5,400. He needs 3 workers who each make $20 per hour. The job takes 40 hours. How much profit does he make?" + ), + "answer": ( + "Total labor cost = 3 workers * $20 per hour * 40 hours = $2,400\nRevenue from job = $5,400\nProfit = Revenue - Cost\nProfit = $5,400 - $2,400 = $3,000\nThe final answer is $3,000. I hope it is correct." + ), + }, + { + "question": ( + "Sam has 5 times as many marbles as John. John has 3 less marbles than Steve. If Steve has 8 marbles, how many marbles does Sam have?" + ), + "answer": ( + "Steve has 8 marbles\nJohn has 3 less than Steve, so John has 8 - 3 = 5 marbles\nSam has 5 times as many as John\nSo Sam has 5 × 5 = 25 marbles\nThe final answer is 25. I hope it is correct." + ), + }, + { + "question": ( + "Tom is baking cookies. Each batch requires 2.5 cups of flour and makes 12 cookies. If Tom wants to make 60 cookies, how many cups of flour will he need?" + ), + "answer": ( + "One batch makes 12 cookies and needs 2.5 cups of flour\nTo make 60 cookies, Tom needs 60 ÷ 12 = 5 batches\nEach batch needs 2.5 cups of flour\nTotal flour needed = 5 × 2.5 = 12.5 cups\nThe final answer is 12.5. I hope it is correct." + ), + }, +] + +MATH_HARD_FEW_SHOTS = [ + { + "question": ( + "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$." + ), + "answer": ( + "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nThe final answer is $[2,5)$. I hope it is correct." + ), + }, + { + "question": ( + "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$" + ), + "answer": ( + "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nThe final answer is $24$. I hope it is correct." + ), + }, + { + "question": ( + "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?" + ), + "answer": ( + "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nThe final answer is $16$. I hope it is correct." + ), + }, + { + "question": ( + "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero." + ), + "answer": ( + "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nThe final answer is $-\\frac{2}{3}$. I hope it is correct." + ), + }, +] diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/grader.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/grader.py new file mode 100644 index 000000000..4fff9b04d --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/grader.py @@ -0,0 +1,858 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Heavily inspired by https://github.com/QwenLM/Qwen2.5-Math and https://github.com/huggingface/lm-evaluation-harness +import logging +import re +from itertools import product +from sympy.logic.boolalg import BooleanTrue +from latex2sympy2_extended import is_expr_of_only_symbols +from latex2sympy2_extended.logic import And +from latex2sympy2_extended.sets import FiniteSet +from sympy import ( + Abs, + Basic, + E, + Eq, + Float, + GreaterThan, + Interval, + LessThan, + MatrixBase, + MatrixExpr, + Mul, + Number, + Rational, + Set, + StrictGreaterThan, + StrictLessThan, + Symbol, + Tuple, + default_sort_key, + nan, + ordered, + simplify, + solve, + zoo, +) +from sympy import FiniteSet as SympyFiniteSet +from sympy.core.function import UndefinedFunction +from sympy.core.relational import Relational + +from math_verify.errors import TimeoutException +from math_verify.utils import timeout + +logger = logging.getLogger(__name__) + +TIMEOUT_WARNING_SHOWN = False + + +INVERSE_RELATIONS = { + GreaterThan: LessThan, + LessThan: GreaterThan, + StrictGreaterThan: StrictLessThan, + StrictLessThan: StrictGreaterThan, + Eq: Eq, +} + + +def safe_sympy_doit(a: Basic | MatrixBase): + """Safely execute doit() on a sympy expression, catching exceptions. + Doit in sympy will evaluate expressions it will pass the expression tree and evluate nodes. + For example for 1+1+1 it will evaluate the additions and return 3. One issue with it is that it maybe + evaluates too much as integrals will also be evaluated. + + As we are using latex2sympy2_extended, evaluates are + + Args: + a: A sympy Basic or MatrixBase expression to evaluate + + Returns: + The result of a.doit() if successful, otherwise returns the original expression + """ + try: + return a.doit() + except Exception: + pass + return a + + +def is_atomic_or_pct_atomic(expr: Basic | MatrixBase, atomic_type: type) -> bool: + """Check if expression is either an atomic type or percentage atomic type. + + Args: + expr: The sympy expression to check + atomic_type: The atomic type to check for + + Returns: + True if expr is atomic_type or percentage atomic type, False otherwise + """ + return isinstance(expr, atomic_type) or ( + # Check for percentage representation: latex2sympy_extended converts "X%" into X*Rational(1,100) + # So we detect percentages by looking for this multiplication structure + isinstance(expr, Mul) + and len(expr.args) == 2 + and expr.args[1] == Rational(1, 100) + and isinstance(expr.args[0], atomic_type) + ) + + +def sympy_numeric_eq( + a: Basic | MatrixBase, + b: Basic | MatrixBase, + float_rounding: int, + numeric_precision: int, +): + """Compare two sympy expressions numerically with given precision. + + Args: + a: First sympy expression + b: Second sympy expression + precision: Number of decimal places to compare + + Returns: + True if expressions are numerically equal within precision, False otherwise + """ + # Only do this when one of the two is a float, in other cases use symbolic equality as this could lead to false positives + # E.g we want 1/3 == 0.333333 to work + if isinstance(a, (MatrixBase, MatrixExpr)) and isinstance( + b, (MatrixBase, MatrixExpr) + ): + a = safe_sympy_doit(a) + b = safe_sympy_doit(b) + # If we have matrices and one of them is only made of floats, we can use the same logic as above + if ( + isinstance(a, (MatrixBase)) + and isinstance(b, (MatrixBase)) + and a.shape == b.shape + ): + return all( + sympy_numeric_eq(a_elem, b_elem, float_rounding, numeric_precision) + for a_elem, b_elem in zip(a.flat(), b.flat(), strict=False) + ) + + # Ensure this also works for percentage numbers so that 0.333333% = 0.33333333333 with precision 4 + elif is_atomic_or_pct_atomic(a, Number) or is_atomic_or_pct_atomic(b, Number): + # If one of them is a float or a negative atomic number, we can try to use precision + if is_atomic_or_pct_atomic(a, Float) or is_atomic_or_pct_atomic(b, Float): + a = safe_sympy_doit(a) + b = safe_sympy_doit(b) + # Now if both are numbers, we can use precision + if isinstance(a, (Number)) and isinstance(b, (Number)) and a.round(float_rounding) == b.round(float_rounding): + return True + elif safe_sympy_doit(a) == safe_sympy_doit(b): + return True + try: + diff_ratio = Abs((a - b) / a).evalf(chop=True, n=numeric_precision) + if diff_ratio.free_symbols: + return False + comparison = diff_ratio <= 10 ** -numeric_precision + simplified = simplify(comparison) + if simplified == BooleanTrue(): + return True + except Exception: + pass + + return False + + +def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: + """Compare two sympy expressions symbolically. + + Args: + a: First sympy expression + b: Second sympy expression + + Returns: + True if expressions are symbolically equal, False otherwise + """ + try: + a_b_diff = simplify((a - b)) # type: ignore + if isinstance(a_b_diff, MatrixBase) and a_b_diff.is_zero_matrix: + return True + elif isinstance(a_b_diff, Basic) and a_b_diff.is_zero: + return True + except Exception: + pass + + return False + + +def sympy_deep_compare_set_and_tuple( + gold: SympyFiniteSet | Tuple, + pred: SympyFiniteSet | Tuple, + float_rounding: int, + numeric_precision: int, +) -> bool: + """Compare two finite sets by comparing each element with given precision. + + Args: + a: First finite set + b: Second finite set + precision: Number of decimal places to compare + + Returns: + True if sets contain equal elements within precision, False otherwise + + Note: in order to fully support finite sets, we should ideally do kartesian product comparison + but this is not implemented yet. We kinda hope sympy will order the elements. + """ + + def unwrap_eq(s): + if is_assignment_relation(s): + return take_last_relation(s).rhs + return s + + def sort_key(x): + try: + return default_sort_key(unwrap_eq(x).evalf()) + except Exception: + return default_sort_key(unwrap_eq(x)) + + # This ensures it works for {1/3} and {0.333333} + if len(gold) == len(pred): + if isinstance(gold, SympyFiniteSet): + gold_args = list(ordered(gold.args, keys=sort_key, default=False)) + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) + + elif isinstance(gold, Tuple) and isinstance(pred, FiniteSet): + # We treat the pred as tuple too + pred_args = pred._unsorted_args + gold_args = gold.args + + elif isinstance(pred, SympyFiniteSet): + pred_args = list(ordered(pred.args, keys=sort_key, default=False)) + gold_args = gold.args + else: + gold_args = gold.args + pred_args = pred.args + + return all( + sympy_expr_eq(a, b, float_rounding, numeric_precision) + for a, b in zip(gold_args, pred_args, strict=False) + ) + + return False + + +def sympy_compare_interval( + a: Interval, b: Interval, float_rounding: int, numeric_precision: int +) -> bool: + """Compare two intervals. + + Args: + a: First interval + b: Second interval + precision: Number of decimal places to compare endpoints + + Returns: + True if intervals are equal, False otherwise + """ + return ( + a.left_open == b.left_open + and a.right_open == b.right_open + and sympy_expr_eq(a.start, b.start, float_rounding, numeric_precision) + and sympy_expr_eq(a.end, b.end, float_rounding, numeric_precision) + ) + + +def sympy_solve_and_compare( + gold: Relational, pred: Relational, float_rounding: int, numeric_precision: int +) -> bool: + solved_gold = list(ordered(solve(gold, gold.free_symbols))) + solved_pred = list(ordered(solve(pred, pred.free_symbols))) + + if not solved_gold or not solved_pred: + return False + + if isinstance(gold, Eq) and isinstance(pred, Eq): + try: + return all( + # 分两种情况:如果g/p是dict,就比较items;否则直接比较g和p本身 + all( + g_k == p_k and sympy_expr_eq(g_v, p_v, float_rounding, numeric_precision) + for (g_k, g_v), (p_k, p_v) in zip(sorted(g.items()), sorted(p.items()), strict=False) + ) if isinstance(g, dict) and isinstance(p, dict) + else sympy_expr_eq(g, p, float_rounding, numeric_precision) + for g, p in zip(solved_gold, solved_pred, strict=False) + ) + except Exception as e: + print(f"[Warning] sympy_solve_and_compare failed: {e}") + return False + else: + return sympy_expr_eq(solved_gold, solved_pred, float_rounding, numeric_precision) + + +def sympy_compare_relational( + gold: Relational | And, + pred: Relational | And, + float_rounding: int, + numeric_precision: int, +) -> bool: + """Compare two relational expressions. + + Args: + gold: First relational expression + pred: Second relational expression + precision: Number of decimal places to compare + + Returns: + True if relations are equivalent, False otherwise + """ + + if isinstance(gold, And) and isinstance(pred, And): + return all( + sympy_compare_relational(g, p, float_rounding, numeric_precision) + for g, p in zip(gold._unsorted_args, pred._unsorted_args, strict=False) + ) + + elif not isinstance(gold, Relational) or not isinstance(pred, Relational): + return False + + # Helper to check if expressions are equivalent when flipped + def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool: + try: + return sympy_expr_eq( + a.lhs - a.rhs, b.rhs - b.lhs, float_rounding, numeric_precision + ) # type: ignore + except Exception: + pass + return False + + # Same type of relation (e.g. both <= or both >=) + + try: + if type(gold) is type(pred) and sympy_expr_eq( + gold.lhs - gold.rhs, pred.lhs - pred.rhs, float_rounding, numeric_precision + ): # type: ignore + return True + except Exception: + pass + + # Check flipped inequalities (a <= b equals b >= a) + if INVERSE_RELATIONS[type(gold)] is type(pred) and are_flipped_inequalities_equal( # type: ignore + gold, pred + ): + return True + + if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision): + return True + + return False + + +def sympy_str_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: + """Compare two sympy expressions by string representation. + + Args: + a: First sympy expression + b: Second sympy expression + + Returns: + True if string representations are equal, False otherwise + """ + # We can't evaluate nan or zoo + if a == nan or a == zoo: + raise ValueError("Can't evaluate nan or zoo") + try: + return a == b + except Exception: + pass + return False + + +def sympy_compare_sets( + gold: Set | Basic | MatrixBase | Tuple, + pred: Set | Basic | MatrixBase | Tuple, + float_rounding: int, + numeric_precision: int, +) -> bool: + """Compare two sympy sets for equality using multiple methods. + + Args: + gold: First sympy set (expected) + pred: Second sympy set (predicted) + precision: Number of decimal places to compare + + Returns: + True if sets are equal by any comparison method, False otherwise + """ + # Convert non-sets to singleton sets + a_set = gold if isinstance(gold, (Set, Tuple)) else SympyFiniteSet(gold) + b_set = pred if isinstance(pred, (Set, Tuple)) else SympyFiniteSet(pred) + + # If both are intervals, use interval comparison + if isinstance(a_set, Interval) and isinstance(b_set, Interval): + return sympy_compare_interval(a_set, b_set, float_rounding, numeric_precision) + + # Try direct set equality + if a_set == b_set: + return True + + # If both are sets, check if they are equal + try: + if ( + isinstance(a_set, Set) + and isinstance(b_set, Set) + and a_set.symmetric_difference(b_set).is_empty + ): + return True + except Exception: + pass + + # For finite sets, compare elements + if isinstance(a_set, (SympyFiniteSet, Tuple)) and isinstance( + b_set, (SympyFiniteSet, Tuple) + ): + return sympy_deep_compare_set_and_tuple( + a_set, b_set, float_rounding, numeric_precision + ) + + # Because (1,2) is parsed as Interval(1,2,left_open=True,right_open=True), it could have that the + # correct is (1,2) and predicted is 1,2, which is parsed as Set(1,2) + if isinstance(a_set, Interval) and isinstance(b_set, (SympyFiniteSet, Tuple)): + if a_set.is_open and len(b_set) == 2: + return sympy_deep_compare_set_and_tuple( + Tuple(a_set.start, a_set.end), b_set, float_rounding, numeric_precision + ) + + if isinstance(b_set, Interval) and isinstance(a_set, (SympyFiniteSet, Tuple)): + if b_set.is_open and len(a_set) == 2: + return sympy_deep_compare_set_and_tuple( + a_set, Tuple(b_set.start, b_set.end), float_rounding, numeric_precision + ) + + return False + + +def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool: + """Compare two sympy expressions where at least one is a Symbol. + + Handles special cases: + - One is Symbol and other is E (limitation of parsed expressions) + - One is multiplication of symbols and other is single symbol (concatenated comparison) + + Args: + gold: First sympy expression (expected) + pred: Second sympy expression (predicted) + precision: Number of decimal places to compare + + Returns: + True if expressions are equal by any comparison method, False otherwise + """ + # Handle E vs symbol case + if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or ( + isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E + ): + return True + + # Handle multiplication of symbols vs single symbol, because parsing return $abc$ -> abc + # We also handle E as it's a symbol, because E will be always parsed as exp + if ( + isinstance(gold, Symbol) + and isinstance(pred, Mul) + and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args) + ): + concat_pred = "".join( + arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args + ) + return gold.name.lower() == concat_pred.lower() + + if ( + isinstance(pred, Symbol) + and isinstance(gold, Mul) + and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args) + ): + concat_gold = "".join( + arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args + ) + return pred.name.lower() == concat_gold.lower() + + # Simple + if isinstance(gold, Symbol) and isinstance(pred, Symbol): + g_name = gold.name + p_name = pred.name + if len(p_name) > 1: + p_name = p_name.lower() + if len(g_name) > 1: + g_name = g_name.lower() + return g_name == p_name + + return False + + +def is_relation(expr: Basic | MatrixBase) -> bool: + """Check if an expression is a relational expression. + + Args: + expr: The expression to check + Returns: + bool: True if expr is a relational expression or And of relations, False otherwise + """ + if isinstance(expr, Relational): + return True + + if isinstance(expr, And) and len(expr._unsorted_args) > 0: + return all(isinstance(arg, Relational) for arg in expr._unsorted_args) + + return False + + +def is_equation(expr: Basic | MatrixBase) -> bool: + """Check if an expression is an equation. + + Args: + expr: The expression to check + Returns: + bool: True if expr is an equation, False otherwise + """ + if isinstance(expr, Eq): + return True + + if isinstance(expr, And) and len(expr._unsorted_args) > 0: + return all(isinstance(arg, Eq) for arg in expr._unsorted_args) + + return False + + +def is_assignment_relation(expr: Basic | MatrixBase) -> bool: + """Check if an expression is an assignment relation. E.g a=1 + + Args: + expr: The expression to check + Returns: + bool: True if expr is a relational expression or And of relations, False otherwise + """ + if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs): + return True + + if isinstance(expr, And) and len(expr._unsorted_args) > 0: + return all( + isinstance(arg, Eq) for arg in expr._unsorted_args + ) and is_expr_of_only_symbols(expr._unsorted_args[0].lhs) + + return False + + +def take_last_relation(expr: And | Relational) -> Relational: + """Take the last relation from an And expression.""" + if isinstance(expr, And): + return take_last_relation(expr._unsorted_args[-1]) + return expr + + +def take_first_relation(expr: And | Relational) -> Relational: + """Take the first relation from an And expression.""" + if isinstance(expr, And): + return expr._unsorted_args[0] + return expr + + +def unwrap_fcs(expr: Basic | MatrixBase) -> Basic | MatrixBase: + """Unwrap function calls to their arguments. + + For example, Function('f')(x) becomes Symbol('f_x') + + Args: + expr: The expression to unwrap + + Returns: + The unwrapped expression with functions replaced by concatenated symbols + """ + # Base case - not a Basic type + if not isinstance(expr, Basic): + return expr + + # Handle function case + if hasattr(expr, "func") and isinstance(expr.func, UndefinedFunction): + # Get function name and arguments + func_name = expr.func.__name__ + # Recursively unwrap arguments before converting to string + unwrapped_args = [str(unwrap_fcs(arg)) for arg in expr.args] + # Create new symbol by concatenating function name and args + return Symbol(f"{func_name}_{'_'.join(unwrapped_args)}") + + # Recursively unwrap all arguments + try: + new_args = [unwrap_fcs(arg) for arg in expr.args] + if new_args: + return expr.func(*new_args) + except Exception: + pass + + return expr + + +def sympy_expr_eq( + gold: Basic | MatrixBase, + pred: Basic | MatrixBase, + float_rounding: int, + numeric_precision: int, + strict: bool = True, +) -> bool: + """Compare two sympy expressions for equality using multiple methods. + + Args: + gold: First sympy expression (expected) + pred: Second sympy expression (predicted) + precision: Number of decimal places to compare + strict: If true, variables do matter otherwise they don't + + Returns: + True if expressions are equal by any comparison method, False otherwise + """ + + # This ensures that f(x) == f(y) is true + if not strict: + try: + gold_variables = gold.free_symbols + pred_variables = pred.free_symbols + if len(gold_variables) == len(pred_variables): + pred = pred.subs( + list(zip(pred_variables, gold_variables, strict=False)) + ) + except Exception: + pass + + # If both are assigments, we don't want to unwrap them, so that x=1 != y=1 + # But if one is assignment and other is equation, we want to unwrap both + + # We always want to truncate if it's assignment, assignment + + is_gold_assignment = is_assignment_relation(gold) + is_pred_assignment = is_assignment_relation(pred) + is_gold_equation = is_equation(gold) + is_pred_equation = is_equation(pred) + + # Truncate equations chains in case of assignment, this doesn't change any of the above values, + # so no need to recompute them + if is_gold_assignment: + gold = Eq( + take_first_relation(gold).lhs, take_last_relation(gold).rhs, evaluate=False + ) + if is_pred_assignment: + pred = Eq( + take_first_relation(pred).lhs, take_last_relation(pred).rhs, evaluate=False + ) + + # We follow what the gold format is + # 1 and 9=1 -> 1,1 + if is_pred_equation and not is_gold_equation: + # Unwrap pred + pred = take_last_relation(pred).rhs + + # We respect what the pred format is only if the gold is assignment so that x=1 and 1 -> 1,1, but not 2x + z = 1 and 1 -> 1,1 + elif is_gold_assignment and not is_pred_equation: + gold = take_last_relation(gold).rhs + + if is_relation(gold) and isinstance(pred, Set): + # This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo) + # We also unwrap the functions because othewise it creates some conditional set based on the function name + try: + gold = unwrap_fcs(gold).as_set() + except Exception: + pass + + # Start with simple str and expr comparisson as it's the fastest + # str comparison is better, than simple eq, because it will also handle missarangments + if sympy_str_eq(gold, pred): + return True + + # Support for equations + if is_relation(gold) and is_relation(pred): + return sympy_compare_relational(gold, pred, float_rounding, numeric_precision) + + elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)): + return sympy_compare_sets(gold, pred, float_rounding, numeric_precision) + + # Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r) + elif isinstance(gold, Symbol) or isinstance(pred, Symbol): + return sympy_compare_symbols(gold, pred) + + elif isinstance(gold, (Basic, MatrixBase)) and isinstance( + pred, (Basic, MatrixBase) + ): + # Mostly so that 0.333333 = 1/3 + if sympy_numeric_eq(gold, pred, float_rounding, numeric_precision): + return True + # Then try symbolic equality + if sympy_symbolic_eq(gold, pred): + return True + + return False + + +complex_number_pattern = re.compile( + r""" + # Complex number indicators + \\mathbb\{C\}| # Complex number set ℂ + \\i\b| # Complex i + \bi\b| # Standalone i + \\text\{i\}| # Text i + \\mathrm\{i\}| # Roman i + \\imath\b| # Alternative i notation + + # Matrix operations + \\det| # Determinant + \\operatorname\{tr\}| # Trace + \\operatorname\{rank\}| # Rank + \\text\{rank\}| + \\arg\{| # Complex argument + \\Re\{| # Real part + \\Im\{| # Imaginary part + \\operatorname\{Re\}| # Real part alternate + \\operatorname\{Im\}| # Imaginary part alternate + \\text\{Re\}| # Real part text + \\text\{Im\} # Imaginary part text +""", + re.VERBOSE, +) + + +def should_treat_as_complex(latex_str: str) -> bool: + """ + Returns True if the latex string likely contains complex numbers, matrices, or vectors. + """ + + return bool(complex_number_pattern.search(latex_str)) + + +def verify( + gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, + target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, + float_rounding: int = 6, + numeric_precision: int = 15, + strict: bool = True, + timeout_seconds: int | None = 5, +) -> bool: + """Verifies if the target expression matches the gold expression using multiple comparison strategies. + + This function implements a comprehensive comparison system for mathematical expressions, + handling various types of mathematical objects (numbers, expressions, sets, matrices, etc.) + with multiple fallback strategies. + + Note: + - It's expected that both gold and pred has been parsed with math_verify.parse function. + - Function is not symmetric, gold answer should be passed as gold and prediction as pred. The non-symmetric nature appears at assignment simplification and equation interval conversion. + + Args: + gold: The reference/correct expression(s). Can be: + - A single SymPy expression (Basic or MatrixBase) + - A string + - A list of any of the above + target: The expression(s) to verify. Same types as gold. + float_rounding: Number of decimal places to round floats to. Defaults to 6. + numeric_precision: Number of decimal places to consider for numeric comparisons. Defaults to 15. + - If you know the evaluated expressions will be small, you should increase this. See: https://docs.sympy.org/latest/modules/evalf.html + strict: Whether to enforce strict comparison mode. Defaults to True. + - In strict mode: Variables matter and sets are not comparable with tuples + - In non-strict mode: Variables are matched by position and sets can be compared with tuples + timeout_seconds: Maximum time in seconds to spend on any single comparison operation. + Defaults to 5 seconds. Any timeout seconds > 0 or not None will result in the function to raise a ValueError if it's called in a threaded environment. + + Returns: + bool: True if target matches gold according to any of the comparison strategies, + False otherwise. + + Comparison Strategy: + 1. String to String comparison + 2. Numeric expressions: Comparison within specified precision + 3. Symbolic equality through simplification + 4. Special handling for: + - Relational expressions (equations/inequalities) + - Sets and intervals + - Matrices and vectors + - Complex numbers + 5. Robust error handling with timeout protection + + Example: + >>> verify(sympy.Rational(1, 3), 0.333333) # Numeric comparison + True + >>> verify(sympy.Symbol('x') + 1, sympy.Symbol('y') + 1, strict=False) # Variable matching + True + >>> verify(sympy.FiniteSet(1, 2), sympy.Tuple(1, 2), strict=False) # Set-tuple comparison + True + """ + + global TIMEOUT_WARNING_SHOWN + if not TIMEOUT_WARNING_SHOWN and (timeout_seconds is None or timeout_seconds <= 0): + logger.warning( + "Timeout is disabled as timeout_seconds is None or <= 0, you must provide \ + the logic for timeout interuption yourself to prevent code getting stuck." + ) + TIMEOUT_WARNING_SHOWN = True + + @timeout(timeout_seconds=timeout_seconds) + def compare_single_extraction( + gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str + ) -> bool: + # If both are sympy expressions, we can use sympy to compare them + if isinstance(gold, (Basic, MatrixBase)) and isinstance( + target, (Basic, MatrixBase) + ): + return sympy_expr_eq( + gold, target, float_rounding, numeric_precision, strict + ) + + # We don't support str / sympy.Expr comparison. Imo there is no point in doing this, as chances + # of this happening are very low. The only why one of them is not converted to sympy expression + # is usually because the parsing logic failed in this case we should improve the parsing logic + # instead of somehow fixing adhoc. + elif isinstance(gold, str) and isinstance(target, str): + # We just do string comparison for everything else + gold = gold.strip() + target = target.strip() + + # Ensure it's both not empty and equal + return len(gold) > 0 and len(target) > 0 and gold == target + + return False + + def compare_single_extraction_wrapper(g, t): + try: + return compare_single_extraction(g, t) + + except ValueError as e: + if str(e) == "signal only works in main thread of the main interpreter": + raise ValueError( + "Math-Verify doesn't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None, which will run without timeout (and signal handling). In this case you need to handle the timeouting yourself." + ) from e + else: + logger.exception("Error during comparison") + return False + except Exception: + #! Do not attempt to print out the g and t during handling of exception + # Because a) it can throw an exception itself and b) it can cause it to be stuck forever during str conversion + logger.exception("Error during comparison") + return False + except TimeoutException: + logger.error("Timeout during comparison") + return False + + if not isinstance(gold, list): + gold = [gold] + if not isinstance(target, list): + target = [target] + + return any( + compare_single_extraction_wrapper(g, t) for g, t in product(gold, target) + ) diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/metric.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/metric.py new file mode 100644 index 000000000..c9806c64d --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/metric.py @@ -0,0 +1,102 @@ +## Parser definition +import logging +from typing import Callable, Optional, Sequence + +from math_verify.grader import verify +from math_verify.parser import ExprExtractionConfig, ExtractionTarget, parse +from math_verify.utils import timeout + +logger = logging.getLogger(__name__) + + +def math_metric( + gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), + pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),), + aggregation_function: Callable[[list[float]], float] = max, + precision: int = 6, +) -> Callable[ + [list[str], list[str]], tuple[float, Optional[tuple[list[str], list[str]]]] +]: + """Creates a language-aware extractive match metric that extracts answers from the model's output. + + Known issues: + - If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..), + it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression, + thus it will match gold, despite model not doing anything. PRs to fix this are welcome. + + Args: + language: Language + The language of the samples. + gold_extraction_target: Sequence[ExtractionTarget] + Extraction targets to use for gold answers. Defaults to extracting simple math expressions. + pred_extraction_target: Sequence[ExtractionTarget] + Extraction targets to use for predictions. Defaults to extracting simple math expressions. + aggregation_function: Callable[[list[float]], float] + Function to aggregate scores when multiple golds/predictions are present. Defaults to max. + fallback_mode: Literal["no_fallback", "first_match"] + How to perform extraction. Defaults to "first_match". + - "no_fallback": Only use first successfully parsed matches + - "first_match": Use the first successfully parsed match + first match irregardless the parsing success + precision: int + Number of decimal places to use when comparing numerical values. Defaults to 6. + + Returns: + A sample level metric that extracts and compares mathematical expressions. + + """ + + @timeout(2) + def get_str_preds_with_timeout( + extracted_predictions: list[list[str]], extracted_golds: list[list[str]] + ) -> tuple[list[str], list[str]]: + golds = [str(gold) for golds in extracted_golds for gold in golds] + predictions = [str(pred) for preds in extracted_predictions for pred in preds] + return (golds, predictions) + + def sample_level_fn( + golds: list[str], predictions: list[str] + ) -> tuple[float, Optional[tuple[list[str], list[str]]]]: + extracted_predictions = [ + parse(pred, pred_extraction_target) for pred in predictions + ] + extracted_golds = [parse(gold, gold_extraction_target) for gold in golds] + + # Assert on empty gold and warn on empty pred + if any(len(g) == 0 for g in extracted_golds): + raise ValueError( + f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}" + ) + + if all(len(p) == 0 for p in extracted_predictions): + logger.warning( + f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}" + ) + + # We have to use timeout because the sypmy to str conversion can be very slow + str_preds = None + try: + str_preds = get_str_preds_with_timeout( + extracted_predictions, extracted_golds + ) + except Exception: + logger.warning( + "Timeout when adding extracted predictions and golds to specific" + ) + + return ( + aggregation_function( + [ + ( + 1.0 + if any( + verify(gold, pred, precision) for gold in extracted_golds + ) + else 0.0 + ) + for pred in extracted_predictions + ] + ), + str_preds, + ) + + return sample_level_fn diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/parser.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/parser.py new file mode 100644 index 000000000..79685af9c --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/parser.py @@ -0,0 +1,719 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import re +from dataclasses import dataclass, field, replace +from functools import lru_cache +from itertools import groupby +from typing import Literal, Sequence + +import sympy +from latex2sympy2_extended.latex2sympy2 import ( + NormalizationConfig, + latex2sympy, + normalize_latex, +) +from latex2sympy2_extended.sets import FiniteSet +from sympy import Basic, MatrixBase, Number +from sympy.parsing import parse_expr + +from math_verify.errors import TimeoutException +from math_verify.grader import should_treat_as_complex +from math_verify.utils import timeout + +logger = logging.getLogger(__name__) + +TIMEOUT_WARNING_SHOWN = False + + +@dataclass(frozen=True) +class LatexExtractionConfig: + """Config for extracting latex from the prediction. + + Attributes: + try_extract_without_anchor (bool): Whether to try extracting latex without requiring specific anchors like "answer:" or "final answer is" + boxed_match_priority (int): Priority for matching boxed expressions (e.g., \boxed{}). + - 0: Highest priority (matched first) + - 50: Default priority (matched after final answer patterns) + - -1: Disable boxed expression matching + normalization_config (NormalizationConfig): Configuration for LaTeX normalization. + Controls preprocessing of LaTeX expressions including: + - Basic LaTeX cleanup + - Unit handling + - Operator formatting + - Boxed expression extraction + - Equation parsing + Defaults to a comprehensive normalization configuration. + """ + + try_extract_without_anchor: bool = True + boxed_match_priority: int = 50 + normalization_config: NormalizationConfig = field( + default_factory=lambda: NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=True, + nits=True, + boxed="all", + equations=False, + ) + ) + + +@dataclass(frozen=True) +class ExprExtractionConfig: + """Config for extracting mathematical expressions. + + Attributes: + try_extract_without_anchor (bool): Whether to try extracting expressions without requiring specific anchors like "answer:" or "final answer is" + """ + + try_extract_without_anchor: bool = True + + +@dataclass(frozen=True) +class StringExtractionConfig: + """Config for extracting literal strings. + + Attributes: + strings (tuple[str]): The strings to extract + try_extract_without_anchor (bool): Whether to try extracting strings without requiring specific anchors like "answer:" or "final answer is" + """ + + strings: tuple[str, ...] = field(default_factory=lambda: ("A", "B", "C", "D")) + try_extract_without_anchor: bool = True + lowercase: bool = True + + +ExtractionTarget = LatexExtractionConfig | ExprExtractionConfig | StringExtractionConfig + + +@lru_cache(maxsize=10) +def lazy_string_regex( + string_extraction_config: StringExtractionConfig, +) -> list[tuple[re.Pattern[str], int]]: + # First get indices to predict + string_keys = f"(?P{'|'.join([re.escape(i) for i in string_extraction_config.strings])})" + + # The strings are either surrounded with **answer**., or 'answer.' or the same without the dot + full_stop_re = r"\." + comma_re = r"," + colon_re = r":" + space_re = r"\s" + + answer_prefix_re = rf"(^|{space_re})(?:\*\*)?" + answer_suffix_re = ( + rf"(?:\*\*)?(?:{full_stop_re}|{comma_re}|{colon_re}|{space_re}|$)" + ) + answer_re = f"{answer_prefix_re}{string_keys}{answer_suffix_re}" + answer_re_start = rf"^(?:\*\*)?{string_keys}{answer_suffix_re}" + + answer_word = "(?i:answer)" + + regexes = [] + + final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{string_keys}\.?\s?I hope" + + # To allow stuff like "final answer is to your question" + final_answer_prefixed_just_is = ( + rf"(?i:final answer.{{0,100}}?)\s+is\:?\s*{string_keys}" + ) + regexes.extend( + [ + (final_answer_prefixed_re, 0), + (final_answer_prefixed_just_is, 50), + ] + ) + + regexes.extend( + [ + # Most specific patterns first + (f"{answer_word}{colon_re}.{{0,50}}?{answer_re}", 100), + # Answer word patterns + (f"{answer_word}.{{0,50}}?{answer_re}", 200), + ] + ) + + if string_extraction_config.try_extract_without_anchor: + # Start of line patterns + regexes.append((answer_re_start, 250)) + # Plain string patterns + regexes.append((answer_re, 300)) + + return [(re.compile(pattern), priority) for pattern, priority in regexes] + + +# All of the regexes are cached, to avoid repeated compiling during processing of same task +@lru_cache(maxsize=1) +def lazy_expr_regex( + expr_config: ExprExtractionConfig, +) -> list[tuple[re.Pattern[str], int]]: + # Basic number patterns (no LaTeX) + number_re = ( + # Format 1: Numbers with thousand separators (e.g., "1,234.56" or "1 234.56") + r"(?-?[1-9]\d{0,2}(?:[ ,]\d{3})+)(?P\.\d+)?|" + # Format 2: Simple numbers with decimal point or comma (e.g., "123.45" or "123,45") + r"(?P-?\d+)(?P[.,]\d+)|" + # Format 3: Decimal part only (e.g., ".123") + r"(?P\.\d+)|" + # Format 4: Integer only (e.g., "123") + r"(?P-?\d+)" + r")(?P\s*(?:%|[Pp]ercent|\s*[Pp]ercentage|\s*[Pp]ct))?" + ) + + # Expressions such as 1/2 + operators = [r"\+", r"\-", r"\*", r"\×", r"\/", r"\^", r"\(", r"\)", r"\÷"] + operators_re = "".join(operators) + all_expr_chars = r"[\d\.\s" + operators_re + r"]" + # Expression should have at minimum at least one operator and must start with a digit + expr_re = ( + rf"(?P-?\(?-?\d{all_expr_chars}*[{operators_re}]{all_expr_chars}+\)?)" + ) + + # Punctuation regexes + full_stop_re = r"\." + comma_re = r"," + colon_re = r":" + space_re = r"\s" + + currency_units = re.escape("$€£¥₹₽₪₩₫฿₡₢₣₤₥₦₧₨₩₪₫₭₮₯₰₱₲₳₴₵₶₷₸₹₺₻₼₽₾₿") + expr_prefix_re = rf"(?:^|{space_re}|\=)(?:\*\*)?" + expr_suffix_re = ( + rf"(?:\*\*)?(?:{full_stop_re}|{comma_re}|{colon_re}|{space_re}|\)|\$|$)" + ) + # Expressions must be prefixed and suffixed while, digits don't need suffix and can have currency units preceeded, this is to ensure + # That we can extract stuff like $100 or 100m2, while we don't extract XDY2K as 2 + expr_with_anchors = rf"(?:{expr_prefix_re}{expr_re}{expr_suffix_re})" + number_with_anchors = rf"(?:{expr_prefix_re}[{currency_units}]?{number_re})" + expr_or_number = rf"(?:{expr_with_anchors}|{number_with_anchors})" + regexes: list[tuple[str, int]] = [] + + final_answer_prefixed_re = ( + rf"(?i:final answer is)\:?\s*{expr_or_number}\.?\s?I hope" + ) + final_answer_prefixed_just_is = ( + rf"(?i:final answer.{{0,100}}?)\s+is\:?{expr_or_number}" + ) + regexes.append((final_answer_prefixed_re, 0)) + regexes.append((final_answer_prefixed_just_is, 50)) + + answer_prefix_re = r"(?i:answer)" + + # Match after the last equals with answer word - require the number pattern, + equals_re_colon = rf"{answer_prefix_re}{colon_re}(?:.{{0,100}}=\s*|.{{0,50}}?){expr_or_number}(?!\s*=)" + equals_re = ( + rf"{answer_prefix_re}(?:.{{0,100}}=\s*|.{{0,50}}?){expr_or_number}(?!\s*=)" + ) + regexes.extend([(equals_re_colon, 100), (equals_re, 200)]) + + if expr_config.try_extract_without_anchor: + # If everything fails, try to match plain expr/number + regexes.append((expr_with_anchors, 300)) + regexes.append((number_with_anchors, 300)) + + return [(re.compile(pattern), priority) for pattern, priority in regexes] + + +def make_latex_env_pattern( + prefix: str = "", context: Literal["boxed", "plain"] = "plain" +) -> str: + """Creates a LaTeX environment pattern with uniquely prefixed group names. + + Args: + prefix (str): Prefix to add to group names to make them unique + context (Literal["boxed", "plain", "fraction"]): Type of content to match inside the environments + - "boxed": Match environments containing \boxed{...} + - "plain": Match any LaTeX content + - "fraction": Match only fractions + + Returns: + str: Regex pattern for matching LaTeX environments with percent suffix + """ + percent_re_group = rf"(?P<{prefix}percent>(?:\\?%|[Pp]ercent|[Pp]ercentage|[Pp]ct))" + + # Define base content patterns + display_dollar_content = r"(?:[^$]|\$(?!\$))" + # Either \ not followed by ] or everything but \ + display_content_bracket = r"(?:[^\\]|\\(?!\]))" + inline_dollar_content = r"(?:\\[$]|[^\n$])" + inline_content_parenthesis = r"(?:[^\\\n]|\\(?!\)))" + inline_content_bracket = r"[^\n\]\[]" + + if context == "boxed": + # Rewrite patterns to optionally include boxed content + display_dollar_content = rf"{display_dollar_content}*?\\boxed{{{display_dollar_content}+?}}{display_dollar_content}*?" + display_content_bracket = rf"{display_content_bracket}*?\\boxed{{{display_content_bracket}+?}}{display_content_bracket}*?" + inline_dollar_content = rf"{inline_dollar_content}*?\\boxed{{{inline_dollar_content}+?}}{inline_dollar_content}*?" + inline_content_parenthesis = rf"{inline_content_parenthesis}*?\\boxed{{{inline_content_parenthesis}+?}}{inline_content_parenthesis}*?" + inline_content_bracket = rf"{inline_content_bracket}*?\\boxed{{{inline_content_bracket}+?}}{inline_content_bracket}*?" + else: + display_dollar_content = rf"{display_dollar_content}+?" + display_content_bracket = rf"{display_content_bracket}+?" + inline_dollar_content = rf"{inline_dollar_content}+?" + inline_content_parenthesis = rf"{inline_content_parenthesis}+?" + inline_content_bracket = rf"{inline_content_bracket}+?" + + # Build list of regex patterns + patterns = [ + # Display math environments (allow multiline) + rf"(?{display_dollar_content})(?{display_content_bracket})(?{inline_dollar_content})(?{inline_content_parenthesis})(?{inline_content_bracket})\]\s", + ] + if context == "plain": + simple_number = r"-?\d+(?:[.,]\d+)?" + patterns.append( + rf"(?P<{prefix}latexFraction>-?\\frac{{{simple_number}}}{{{simple_number}}})" + ) + + # Join patterns with | and wrap in parentheses + latex_env_re = rf"(?:(?:{'|'.join(patterns)})\s*{percent_re_group}?)" + + return latex_env_re + + +@lru_cache(maxsize=1) +def lazy_latex_regex( + latex_config: LatexExtractionConfig, +) -> list[tuple[re.Pattern[str], int]]: + # Pattern for multiple latex environments connected by and/or (also considering oxford comma) + # Create patterns for up to 5 connected expressions + first_latex_group = make_latex_env_pattern("first_") + next_groups = "".join( + [ + rf"(?:\s*(?:,?and|,?or|,)\s*{make_latex_env_pattern(f'next{i}_')})?" + for i in range(1, 6) + ] + ) + + latex_envs_re = rf"(?:{first_latex_group}{next_groups})" + colon_re = r":" + answer_prefix_re = r"(?i:answer)" + + # We first match boxed env, for some reason that's the most common case of output + # Then we match the latex with environments, then we try to match the fraction + regexes: list[tuple[str, int]] = [] + for latex_re in [latex_envs_re]: + final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{latex_re}\.?\s?I hope" + final_answer_prefixed_just_is = ( + rf"(?i:final answer.{{0,100}}?)\s+is\:?\s*{latex_re}" + ) + regexes.append((final_answer_prefixed_re, 0)) + regexes.append((final_answer_prefixed_just_is, 50)) + + # Match with answer word - higher priority than plain latex + answer_re_colon = f"{answer_prefix_re}{colon_re}.{{0,50}}?{latex_re}" + answer_re = f"{answer_prefix_re}.{{0,50}}?{latex_re}" + + regexes.extend([(answer_re_colon, 100), (answer_re, 200)]) + + # Match plain LaTeX - lowest priority + if latex_config.try_extract_without_anchor: + regexes.append((latex_re, 300)) + + # This ensures that boxed is matched right after the final answer xxxx + if latex_config.boxed_match_priority >= 0: + latex_re_boxed = make_latex_env_pattern(prefix="first_", context="boxed") + next_groups = "".join( + [ + rf"\s*(?:\s*(?:,?and|,?or|,)\s*{make_latex_env_pattern(f'next{i}_', context='boxed')})?" + for i in range(1, 6) + ] + ) + latex_re_boxed = rf"{latex_re_boxed}{next_groups}" + regexes.append((latex_re_boxed, latex_config.boxed_match_priority)) + # Match plain boxed, the issue with plain boxed is that it's impossible to know where it stops, so if there are + # till last }. We do the actuall extraction in the normalization step. + regexes.append( + (r"(?P\\boxed{.+})", latex_config.boxed_match_priority) + ) + + return [(re.compile(pattern, re.DOTALL), priority) for pattern, priority in regexes] + + +def get_extraction_regexes( + target_types: Sequence[ExtractionTarget], +) -> list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]]: + extraction_regexes: list[ + tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget] + ] = [ + ( + (lazy_latex_regex(target_type), target_type) + if isinstance(target_type, LatexExtractionConfig) + else ( + (lazy_expr_regex(target_type), target_type) + if isinstance(target_type, ExprExtractionConfig) + else (lazy_string_regex(target_type), target_type) + ) + ) + for target_type in target_types + ] + return extraction_regexes + + +# Small cache, to catche repeated calls invalid parsing +@lru_cache(maxsize=20) +def parse_latex_cached(latex: str): + # First try to parse the latex as is + try: + return latex2sympy( + latex, + is_real=not should_treat_as_complex(latex), + convert_degrees=False, + ) + except Exception as e: + # If that fails, try to parse just the last equation + last_eq_latex = get_last_eq(latex) + if last_eq_latex != latex: + return latex2sympy( + last_eq_latex, + is_real=not should_treat_as_complex(last_eq_latex), + convert_degrees=False, + ) + else: + raise e + + +@lru_cache(maxsize=20) +def parse_expr_cached(expr: str): + return parse_expr(expr, evaluate=False) + + +def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]: + # First combine the number + groups = match.groupdict() + # Expr group will always exist because every regex has it + expr = groups.get("expr", "") + integer = next( + (val for name, val in groups.items() if name.startswith("integer") and val), "" + ) + decimal = next( + (val for name, val in groups.items() if name.startswith("decimal") and val), "" + ) + + is_percentage = True if groups.get("percent", None) else False + + if integer or decimal: + # This makes sure we can convert numbers like 0001 to 1. Do note that this can convert 0 to '', so we assume an empty string was 0 and convert it back afterwards. + integer = integer.translate(str.maketrans("", "", ", ")).lstrip("0") + if len(integer) == 0: + integer = "0" + + decimal = decimal.replace(",", ".") + number_str = f"{integer}{decimal}" + number = Number(number_str) + + if is_percentage: + number = convert_to_pct(number) + return number, number_str + + # Otherwise just return the expression + # Remove new lines and spaces + if expr: + try: + return ( + parse_expr_cached(expr.replace("\n", " ").replace("^", "**")), + expr, + ) + except Exception: + pass + return None, expr + + +def convert_to_pct(number: Number): + return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False) + + +equation_split_regex = re.compile(r"(?)=") + + +def get_last_eq(latex: str): + # This is to ensure that a=1,b=2 is not splitted + if "," not in latex and ";" not in latex: + eq_parts = equation_split_regex.split(latex) + # We only shorten if there are more than 2 parts, otherwise we keep equation as is + if len(eq_parts) > 2: + return eq_parts[-1] + return latex + + +@lru_cache(maxsize=20) +def extract_latex( + match: re.Match, latex_config: LatexExtractionConfig +) -> tuple[sympy.Expr | str | None, str]: + latex_exprs = [] + latex_strs = [] + + # Get all latex groups (both first_ and nextN_ prefixes) + first_latex_group = next( + ( + (val, name) + for name, val in match.groupdict().items() + if name.startswith("first_latex") and val + ), + None, + ) + + # Get all nextN_ groups + next_latex_groups = [ + next( + ( + (val, name) + for name, val in match.groupdict().items() + if name.startswith(f"next{i}_latex") and val + ), + None, + ) + for i in range(1, 6) + ] + + all_latex = list( + filter(lambda x: x is not None, [first_latex_group] + next_latex_groups) + ) + + for latex, name in all_latex: + name_without_prefix = name.split("_")[0] + group_name = name.split("_")[1] if len(name.split("_")) > 1 else None + is_percentage = ( + True if match.groupdict().get(f"{name_without_prefix}_percent") else False + ) + + # Use modified config if group name is 'boxed' + config = latex_config.normalization_config + if group_name == "latexBoxed": + config = replace(config, boxed="last") # Use replace to modify single field + + normalized_latex = normalize_latex( + latex, + config=config, + ) + latex_strs.append(normalized_latex) + + try: + parsed_latex = parse_latex_cached(normalized_latex) + if is_percentage: + parsed_latex = convert_to_pct(parsed_latex) + latex_exprs.append(parsed_latex) + except Exception: + latex_exprs.append(None) + pass + + if not latex_exprs: + return None, "" + + # If we have multiple expressions and all of them are parsed, wrap them in a Tuple + if len(latex_exprs) > 1 and all(expr is not None for expr in latex_exprs): + # To handle solution is: 1,2 and 3 + all_elements = [] + for expr in latex_exprs: + if isinstance(expr, FiniteSet): + all_elements.extend(expr.args) + else: + all_elements.append(expr) + return FiniteSet(*all_elements), " and ".join(latex_strs) + + # Otherwise return the single expression + return latex_exprs[0], latex_strs[0] + + +def extract_string(match: re.Match, string_config: StringExtractionConfig): + extracted_str = match.group("string_keys") + parsed_str = extracted_str + if string_config.lowercase: + parsed_str = extracted_str.lower() + return parsed_str, extracted_str + + +def extract_match( + match: re.Match, target_type: ExtractionTarget +) -> tuple[Basic | MatrixBase | str | None, str]: + """Extracts the match from the regex match. + + Args: + match (re.Match): The regex match object containing the extracted text + target_type (ExtractionTarget): The type of extraction to perform (latex, expression, or indices) + + Returns: + tuple[Basic | MatrixBase | str | None, str]: A tuple containing: + - The extracted and parsed value (if successful) or None (if parsing failed) + - The string representation of the extracted text + """ + if isinstance(target_type, LatexExtractionConfig): + return extract_latex(match, target_type) + elif isinstance(target_type, ExprExtractionConfig): + return extract_expr(match) + elif isinstance(target_type, StringExtractionConfig): + return extract_string(match, target_type) + + +def extract_target_from_pred( + pred: str, + target_res: list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]], + fallback_mode: Literal["no_fallback", "first_match"] = "no_fallback", + extraction_mode: Literal["first_match", "any_match"] = "any_match", +): + """Extracts targets from a prediction string using regex patterns. + Returns first sucesffuly extracted match. + + Args: + pred (str): The prediction string to extract from + target_res (list[tuple[list[tuple[re.Pattern[str], int]], ExtractionTarget]]): List of regex patterns and their priorities for each target type + fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "no_fallback". + - "no_fallback": Return only successfully parsed match + - "first_match": Additionaly Include the first string match no matter how parsing finished + extraction_mode (Literal["first_match", "any_match"], optional): How to handle extraction failures. Defaults to "any_match". + - "first_match": Only tries to extract the first match + - "any_match": Tries to extract any match + + Returns: + list: List of extracted predictions, with first fallbac string appended if fallback_mode is "first_match" + """ + extracted_predictions = [] + fallbacks = [] + + # Get all patterns and sort by priority + all_patterns = [ + (pattern, target_type, priority) + for target_patterns, target_type in target_res + for pattern, priority in target_patterns + ] + + # Group patterns by priority using itertools.groupby + match_found = False + sorted_patterns = sorted(all_patterns, key=lambda x: x[2]) + grouped_patterns = list( + (gr, list(val)) for gr, val in groupby(sorted_patterns, key=lambda x: x[2]) + ) + for _, patterns_group in grouped_patterns: + # Find all matches for each pattern in this priority group + matches_with_pos = ( + (match, match.start(), match.end(), target_type) + for pattern, target_type, _ in patterns_group + for match in pattern.finditer(pred) + ) + + # Sort matches by end position (rightmost first) and then by start position (leftmost first) + matches_with_pos = sorted( + matches_with_pos, key=lambda x: (x[2], -x[1]), reverse=True + ) + + # Try to extract from each match, starting from rightmost + for match, _, _, target_type in matches_with_pos: + extracted_match, str_fallback = extract_match(match, target_type) + + match_found = True + if str_fallback: + fallbacks.append(str_fallback) + + if extracted_match is not None: + extracted_predictions.append(extracted_match) + break + + if extraction_mode == "first_match": + break + + # If we extracted something or found something and we're in first_match mode, stop processing other priorities + if extracted_predictions or (match_found and extraction_mode == "first_match"): + break + + if fallback_mode == "first_match" and fallbacks: + extracted_predictions += [fallbacks[0]] + + return extracted_predictions + + +def parse( + pred: str, + extraction_config: Sequence[ExtractionTarget] = [ + LatexExtractionConfig(), + ExprExtractionConfig(), + ], + fallback_mode: Literal["no_fallback", "first_match"] = "first_match", + extraction_mode: Literal["first_match", "any_match"] = "any_match", + parsing_timeout: int = 5, +): + """Extracts and parses mathematical expressions from a prediction string. + + This function attempts to extract mathematical expressions from text using various strategies + (LaTeX, plain expressions, etc.) and converts them to SymPy objects. + + Args: + pred (str): The prediction string to parse. + extraction_config (Sequence[ExtractionTarget], optional): Configuration for what types of expressions + to extract and how to extract them. Defaults to [LatexExtractionConfig(), ExprExtractionConfig()]. + fallback_mode (Literal["no_fallback", "first_match"], optional): How to handle extraction failures. Defaults to "first_match". + - "no_fallback": Return only successfully parsed expressions + - "first_match": Include the first string match even if parsing failed + extraction_mode (Literal["first_match", "any_match"], optional): Strategy for extracting matches. Defaults to "any_match". + - "first_match": Stop after finding the first match + - "any_match": Try to extract all possible matches, stops after first sucesful parsing attempt + parsing_timeout (int, optional): Maximum time in seconds to spend parsing each expression. Defaults to 3. Any timeout seconds > 0 or not None will result in the function to raise a ValueError if it's called in a threaded environment. + + Returns: + list: List of extracted predictions. Each prediction can be: + - SymPy expression (for successfully parsed mathematical expressions) + - String (for fallback matches when fallback_mode="first_match") + Empty list if no matches are found. + + Examples: + >>> parse("The answer is $\\frac{1}{2}$") + [Rational(1, 2)] + >>> parse("The answer is 1/2") + [Rational(1, 2)] + >>> parse("The answer is A", extraction_config=[StringExtractionConfig()]) + ['a'] + """ + global TIMEOUT_WARNING_SHOWN + if not TIMEOUT_WARNING_SHOWN and (parsing_timeout is None or parsing_timeout <= 0): + logger.warning( + "Timeout is disabled as parsing_timeout is None or <= 0, you must provide \ + the logic for timeout interuption yourself to prevent code getting stuck." + ) + TIMEOUT_WARNING_SHOWN = True + + try: + target_res = get_extraction_regexes(extraction_config) + return timeout(timeout_seconds=parsing_timeout)(extract_target_from_pred)( + pred, + target_res, + fallback_mode=fallback_mode, + extraction_mode=extraction_mode, + ) + except ValueError as e: + # Check if it's the signal error + if str(e) == "signal only works in main thread of the main interpreter": + raise ValueError( + "Math-Verify 'parse' function doesn't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None, which will run without timeout (and signal handling). In this case you need to handle the timeouting yourself." + ) from e + logger.exception(f"Error parsing: {pred[:10]}") + return [] + except Exception: + logger.exception(f"Error parsing: {pred[:10]}") + return [] + except TimeoutException: + logger.error(f"Timeout during parsing: {pred[:10]}") + return [] diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/tasks.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/tasks.py new file mode 100644 index 000000000..a628e01e5 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/tasks.py @@ -0,0 +1,324 @@ +import logging +from textwrap import dedent +from typing import Callable, Optional + +import numpy as np +from lighteval.metrics.dynamic_metrics import SampleLevelMetric +from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + +from math_verify.few_shots import GSM8K_FEW_SHOTS, MATH_HARD_FEW_SHOTS +from math_verify.metric import math_metric +from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig + +logger = logging.getLogger(__name__) + + +def as_lighteval_metric( + metric: Callable[ + [list[str], list[str]], tuple[float, Optional[tuple[list[str], list[str]]]] + ], +) -> SampleLevelMetric: + def sample_level_fn( + formatted_doc: Doc, golds: list[str], predictions: list[str] + ) -> float: + result, extracted_predictions = metric(golds, predictions) + if extracted_predictions is not None: + if not formatted_doc.specific: + formatted_doc.specific = {} + formatted_doc.specific["extracted_predictions"] = extracted_predictions + return result + + return SampleLevelMetric( + metric_name="extractive_match", + sample_level_fn=sample_level_fn, + category=MetricCategory.GENERATIVE, + use_case=MetricUseCase.ACCURACY, + corpus_level_fn=np.mean, + higher_is_better=True, + ) + + +def math_hard_prompt_function(x: dict, task_name: str) -> Doc: + if x.get("__few_shots"): + index = x["__index"] + few_shot_doc = ( + MATH_HARD_FEW_SHOTS[index] + if index < len(MATH_HARD_FEW_SHOTS) + else MATH_HARD_FEW_SHOTS[-1] + ) + answer = few_shot_doc["answer"] + question = few_shot_doc["question"] + else: + answer = str(x["solution"]) + question = x["problem"] + + query = dedent( + f"""\ +Question: {question} +Step-by-Step Answer:\ +""" + ).strip() + + choices = [answer] + return Doc(query=query, choices=choices, gold_index=0) + + +def math_prompt_function(x: dict, task_name: str) -> Doc: + if x.get("__few_shots"): + index = x["__index"] + few_shot_doc = ( + MATH_HARD_FEW_SHOTS[index] + if index < len(MATH_HARD_FEW_SHOTS) + else MATH_HARD_FEW_SHOTS[-1] + ) + answer = few_shot_doc["answer"] + question = few_shot_doc["question"] + else: + answer = str(x["answer"]) + question = x["problem"] + + query = dedent( + f"""\ +Question: {question} +Step-by-Step Answer:\ +""" + ).strip() + + choices = [answer] + return Doc(query=query, choices=choices, gold_index=0) + + +def math_aime24_prompt_function(x: dict, task_name: str) -> Doc: + if x.get("__few_shots"): + index = x["__index"] + few_shot_doc = ( + MATH_HARD_FEW_SHOTS[index] + if index < len(MATH_HARD_FEW_SHOTS) + else MATH_HARD_FEW_SHOTS[-1] + ) + answer = few_shot_doc["answer"] + question = few_shot_doc["question"] + else: + answer = str(x["reference_solution"]) + question = x["problem"] + + query = dedent( + f"""\ +Question: {question} +Step-by-Step Answer:\ +""" + ).strip() + + choices = [f" {answer}"] + return Doc(query=query, choices=choices, gold_index=0) + + +def math_amc23_prompt_function(x: dict, task_name: str) -> Doc: + if x.get("__few_shots"): + index = x["__index"] + few_shot_doc = ( + MATH_HARD_FEW_SHOTS[index] + if index < len(MATH_HARD_FEW_SHOTS) + else MATH_HARD_FEW_SHOTS[-1] + ) + answer = few_shot_doc["answer"] + question = few_shot_doc["question"] + else: + answer = str(x["answer"]) + question = x["question"] + + query = dedent( + f"""\ +Question: {question} +Step-by-Step Answer:\ +""" + ).strip() + choices = [f" {answer}"] + return Doc(query=query, choices=choices, gold_index=0) + + +def gsm8k_prompt_function(x: dict, task_name: str) -> Doc: + if x.get("__few_shots"): + index = x["__index"] + few_shot_doc = ( + GSM8K_FEW_SHOTS[index] + if index < len(GSM8K_FEW_SHOTS) + else GSM8K_FEW_SHOTS[-1] + ) + answer = few_shot_doc["answer"] + question = few_shot_doc["question"] + else: + answer = f"{x['answer'].split('####')[-1].strip()}" + question = x["question"] + + query = dedent( + f"""\ +Question: {question} +Step-by-Step Answer:\ +""" + ).strip() + + choices = [f" {answer}"] + return Doc(query=query, choices=choices, gold_index=0) + + +math_hard_lighteval = [ + LightevalTaskConfig( + name=f"math_hard:{subset}", + suite=["lighteval", "math"], + prompt_function=math_hard_prompt_function, + hf_repo="lighteval/MATH-Hard", + hf_subset=subset, + evaluation_splits=["test"], + few_shots_split="train", + generation_size=1024, + metric=[ + as_lighteval_metric( + math_metric( + gold_extraction_target=( + LatexExtractionConfig(boxed_match_priority=0), + ), + pred_extraction_target=( + LatexExtractionConfig(), + ExprExtractionConfig(), + ), + ) + ), + ], + stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], + trust_dataset=True, + version=0, + ) + for subset in [ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "prealgebra", + "precalculus", + ] +] + +math_500_lighteval = [ + LightevalTaskConfig( + name="math_500", + suite=["lighteval", "math"], + prompt_function=math_prompt_function, + hf_repo="HuggingFaceH4/MATH-500", + hf_subset="default", + evaluation_splits=["test"], + few_shots_split="test", + generation_size=1024, + metric=[ + as_lighteval_metric( + math_metric( + gold_extraction_target=( + LatexExtractionConfig(boxed_match_priority=0), + ), + pred_extraction_target=( + LatexExtractionConfig(), + ExprExtractionConfig(), + ), + ) + ), + ], + stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], + trust_dataset=True, + version=0, + ) +] + + +aime24_lighteval = [ + LightevalTaskConfig( + name="aime24", + suite=["lighteval", "math"], + prompt_function=math_aime24_prompt_function, + hf_repo="zwhe99/aime24", + hf_subset="default", + evaluation_splits=["test"], + few_shots_split="test", + generation_size=1024, + metric=[ + as_lighteval_metric( + math_metric( + gold_extraction_target=(LatexExtractionConfig(),), + pred_extraction_target=( + LatexExtractionConfig(), + ExprExtractionConfig(), + ), + ) + ), + ], + stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], + trust_dataset=True, + version=0, + ) +] + +amc23_lighteval = [ + LightevalTaskConfig( + name="amc23", + suite=["lighteval", "math"], + prompt_function=math_amc23_prompt_function, + hf_repo="zwhe99/amc23", + hf_subset="default", + hf_filter=lambda x: len(x["question"].strip()) > 0, + evaluation_splits=["test"], + few_shots_split="test", + generation_size=1024, + metric=[ + as_lighteval_metric( + math_metric( + gold_extraction_target=(ExprExtractionConfig(),), + pred_extraction_target=( + LatexExtractionConfig(), + ExprExtractionConfig(), + ), + ) + ), + ], + stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], + trust_dataset=True, + version=0, + ) +] + +gsm8k_lighteval = [ + LightevalTaskConfig( + name="gsm8k", + suite=["lighteval", "math"], + prompt_function=gsm8k_prompt_function, + hf_repo="openai/gsm8k", + hf_subset="main", + hf_filter=lambda x: len(x["question"].strip()) > 0, + evaluation_splits=["test"], + few_shots_split="test", + generation_size=1024, + stop_sequence=["\nQuestion:", "\nProblem:", "\nquestion:", "\nproblem:"], + metric=[ + as_lighteval_metric( + math_metric( + gold_extraction_target=(ExprExtractionConfig(),), + pred_extraction_target=( + LatexExtractionConfig(), + ExprExtractionConfig(), + ), + fallback_mode="first_match", + ) + ), + ], + ) +] + + +TASKS_TABLE = [ + *gsm8k_lighteval, + *math_hard_lighteval, + *math_500_lighteval, + *aime24_lighteval, + *amc23_lighteval, +] diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/utils.py b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/utils.py new file mode 100644 index 000000000..d4dddf11f --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/src/math_verify/utils.py @@ -0,0 +1,108 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import os + +from math_verify.errors import TimeoutException + +TIMEOUT_WARNING_SHOWN = False +logger = logging.getLogger(__name__) + + +def timeout(timeout_seconds: int | None = 10): # noqa: C901 + """A decorator that applies a timeout to the decorated function. + + Args: + timeout_seconds (int): Number of seconds before timing out the decorated function. + Defaults to 10 seconds. + + Notes: + On Unix systems, uses a signal-based alarm approach which is more efficient as it doesn't require spawning a new process. + On Windows systems, uses a multiprocessing-based approach since signal.alarm is not available. This will incur a huge performance penalty. + """ + if timeout_seconds is None or timeout_seconds <= 0: + + def no_timeout_decorator(func): + return func + + return no_timeout_decorator + + if os.name == "posix": + # Unix-like approach: signal.alarm + import signal + + def decorator(func): + def handler(signum, frame): + raise TimeoutException("Operation timed out!") + + def wrapper(*args, **kwargs): + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout_seconds) + try: + return func(*args, **kwargs) + finally: + # Cancel the alarm and restore previous handler + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator + + else: + # Windows approach: use multiprocessing + from multiprocessing import Process, Queue + + def decorator(func): + def wrapper(*args, **kwargs): + q = Queue() + + def run_func(q, args, kwargs): + try: + result = func(*args, **kwargs) + q.put((True, result)) + except Exception as e: + q.put((False, e)) + + p = Process(target=run_func, args=(q, args, kwargs)) + p.start() + p.join(timeout_seconds) + + if p.is_alive(): + # Timeout: Terminate the process + p.terminate() + p.join() + raise TimeoutException("Operation timed out!") + + # If we got here, the process completed in time. + success, value = q.get() + if success: + return value + else: + # The child raised an exception; re-raise it here + raise value + + return wrapper + + return decorator diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_all.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_all.py new file mode 100644 index 000000000..472f26df0 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_all.py @@ -0,0 +1,873 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest +import sympy + +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify +from math_verify.grader import sympy_expr_eq + +""" +This file contains regression tests for testing evaluation of free-flow generation for math or indices. +Most of the tests have been created based on observations from the model outputs. +""" + + +def compare_strings( + gold: str, + pred: str, + match_types: list[str] = ["latex", "expr"], + precision: int = 6, + strict: bool = True, +): + """Helper function to compare strings using the math extraction metrics""" + # Convert string match_types to ExtractionTarget objects + extraction_targets = [] + for match_type in match_types: + if match_type == "latex": + extraction_targets.append(LatexExtractionConfig(boxed_match_priority=0)) + elif match_type == "expr": + extraction_targets.append(ExprExtractionConfig()) + + gold_parsed = parse(gold, extraction_targets) + pred_parsed = parse(pred, extraction_targets) + return verify(gold_parsed, pred_parsed, float_rounding=precision, strict=strict) + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + # Test negative numbers + ("-5", "-5", 1), + # Test for thousands separator + ("7425000", "7,425,000", 1), + ("1000", "1 000", 1), + ("1000", "1000.0", 1), + # Test thousand separator with floating point number + ("1000.0", "1,000.0", 1), + # Test decimal separator as , + ("1000.99", "1000,99", 1), + ("1,22", "1.22", 1), + ("2.74", "Soucis : 2,74 $ a..", 1), + # Test no decimal separator + ("0.4", ".4", 1), + # Test decimals + ("1000.99", "1,000.99", 1), + ("1000.99", "1,000.99", 1), + # Test with units like $ + ("1000.99", "$1,000.99", 1), + ("1000.99", "1,000.99$", 1), + # Test with currency units + ("1000.99", "the number is not 10 which is 1,000.99€", 1), + ("1000.99", "the number is not 10 which is 1,000.99€", 1), + # Test m2 + ("1000.99", "so the number is 10 which is 1,000.99m²", 1), + ("1000.99", "not it's not 10 it's 1,000.99m²", 1), + ("0,111", "0.111", 1), + # Test correct extraction of not correct answer + ("2", "AZYUK2A", 0), + ], +) +def test_number_extraction(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["expr"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + ("10/9", "\\frac{10}{9}", 1), + ("-10/9", "-\\frac{10}{9}", 1), + ], +) +def test_simple_fraction_notation(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + ("$[0,1)$", "$[0,1)$", 1), + ("$[0,1)$", "$[0,1)$", 1), + ("$[0,9)$", "$[0,1)$", 0), + ("$(0,9)$", "$[0,9)$", 0), + ("$1$", "$-[0,1)$", 0), + ], +) +def test_sets_handling(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + # Notations + ("$9$", "Answer \\[ 9 \\]", 1), + ("$9$", "Answer $ 9 $", 1), + ("$9$", "Answer $$ 9 $$", 1), + ("$9$", "Answer \\( 9 \\)", 1), + # Works even with () + ("$10$", "Answer \\( (9+1) \\)", 1), + # Separate line shouldn't work for inline latex + ("$9$", "Answer $ \n 9 \n $", 0), + ("$9$", "Answer \\( \n 9 \n \\)", 0), + # Separate line should work for block latex + ("$9$", "Answer \\[ \n 9 \n \\]", 1), + ("$9$", "Answer $$ \n 9 \n $$", 1), + # the $ can appear in the middle of the string + ("$10/9$", "Answer $ \\frac{1}{2} \\$ = \\frac{10}{9} $", 1), + # Malformed fractions work + ("$1/3$", "$\\frac13 $", 1), + ("$1$", "$\\frac3{3} $", 1), + # Malformed sqrt works + ("$\\sqrt{3}$", "$\\sqrt3 $", 1), + # frac variants work like frac + ("$1/3$", "$\\cfrac{1}{3} $", 1), + ("$1/3$", "$\\dfrac{1}{3} $", 1), + ("$1/3$", "$\\tfrac{1}{3} $", 1), + # Simple fractions are parsed + ("$1/3$", "$ 1/3 $", 1), + # Styling is removed + ("$1/3$", "$\\left( \\frac{1}{3} \\right)$", 1), + ("$1/3$", "$\\boxed{\\frac{1}{3}}$", 1), + ("$1/3$", "$\\frac{1}{3} \\text{meters}$", 1), + ("$1/3$", "$\\frac{1}{3} \\textbf{meters}$", 1), + # Last = is considered + ("$1/3$", "$k = \\frac{1}{3}$", 1), + ("$1/3$", "$\\frac{1}{3} \\textbf{meters}$", 1), + ], +) +def test_latex_notation(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + ("$28\\%$", "28 percent", 1), + ("$28\\%$", "28 pct", 1), + ("$28\\%$", "28 %", 1), + ("$28\\%$", "$28$ %", 1), + ("$28\\%$", "$28$ percent", 1), + ("$28\\%$", "$\\boxed{28}$ pct", 1), + ], +) +def test_percent_notation(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + ( + "$2-2p$", + "Since $x<2$, it follows that $|x-2|=2-x$. If $2-x=p$, then $x=2-p$. Thus $x-p=\\boxed{2-2p}$.", + 1, + ), + ( + "\\boxed{\n\\begin{pmatrix} 0 & 3 \\\\ 0 & -1 \\end{pmatrix}\n}.\n\\end{align*}", + "\\boxed{\n\\begin{pmatrix} 0 & 3 \\\\ 0 & -1 \\end{pmatrix}\n}.\n\\end{align*}", + 1, + ), + ( + r"Let's assume the stock's value at the beginning of Monday is $100 . $ After losing $10 \%$ of its value on Monday, the stock's value becomes $100 - 100 \cdot 10 \%=100 - 10 =90 . $ On Tuesday, the stock loses $20 \%$ of this new value, which is $90 \cdot 20 \%=18 . $ Therefore, the stock's value at the end of Tuesday is $90 - 18 =72 . $ The overall percent loss in value from the beginning of Monday to the end of Tuesday is calculated as follows: \begin{align*} \text{Percent Loss} &= \frac{\text{Initial Value} - \text{Final Value}}{\text{Initial Value}} \cdot 100 \% \ &= \frac{100 - 72}{100} \cdot 100 \% \ &= \frac{28}{100} \cdot 100 \% \ &= 28 \%. \end{align*} Final Answer: The final answer is $28 \%$. I hope it is correct. Note: The solution provided is incorrect. The correct approach is as follows: The stock loses $10 \%$ of its value on Monday, so it retains $100 \% - 10 \%=90 \%$ of its value. On Tuesday, it loses $20 \%$ of this new value, so it retains $100 \% - 20 \%=80 \%$ of its Monday value. Therefore, it retains $90 \% \cdot 80 \%=0.9 \cdot 0.8=0.72=72 \%$ of its original value. The overall percent loss in value is $100 \% - 72 \%=28 \%$. Final Answer: The final answer is $28 \%$. I hope it is correct. Note: The solution provided is incorrect. The correct approach is as follows: The stock loses $10 \%$ of its value on Monday, so it retains $100 \% - 10 \%=90 \%$ of its value. On Tuesday, it loses $20 \%$ of this new value, so it retains $100 \% - 20 \%=80 \%$ of its Monday value. Therefore, it retains $90 \% \cdot 80 \%=0.9 \cdot 0.8=0.72=72 \%$ of its original value. The overall percent loss in value is $100 \% - 72 \%=28 \%$. Final Answer: The final answer is $28 \%$. I hope it is correct. Note: The solution provided is incorrect. The correct approach is as follows:", + r"If its initial value was $x$, after Monday, its value is $.9x$, and after a further $20\%$ loss, its value becomes $.8\cdot .9x = .72x$, making for a total loss of $\boxed{28\%}$.", + 1, + ), + ( + r"In order for $N$ to be divisible by 12, $N$ must be divisible by $4$ and $3$. That means the last two digits $AB$ must form a multiple of $4.$ Since $A$ and $B$ are nonzero digits, the smallest possible $14{,}9AB$ that is divisible by 4 is $14{,}912.$ Unfortunately, this number is not a multiple of $3$, since $1 + 4 + 9 + 1 + 2 = 17.$ However, our next possibility, $14{,}916,$ is a multiple of $3,$ since $1 + 4 + 9 + 1 + 6 = 21.$ Therefore, $\boxed{14{,}916}$ is our answer.", + r"For a number to be divisible by 12, it must be divisible by both 3 and 4. For a number to be divisible by 3, the sum of its digits must be divisible by 3. For a number to be divisible by 4, the last two digits must be divisible by 4. Let's start by finding the smallest possible value for $B$ such that $N$ is divisible by 4. Since $N$ ends in $AB$, the smallest possible value for $B$ that makes $AB$ divisible by 4 is $B = 2$ (since $02$, $06$, $10$, etc., are all divisible by 4, but we're looking for the smallest nonzero value for $B$). Now, let's consider the divisibility by 3. The sum of the digits of $N$ is $1 + 4 + 9 + A + 2 = 16 + A + 2 = 18 + A$. For $N$ to be divisible by 3, $18 + A$ must be divisible by 3. Since $A$ is a digit, it can be any value from $1$ to $9$. The smallest value for $A$ that makes $18 + A$ divisible by 3 is $A = 1$ (since $18 + 1 = 19$, which is not divisible by 3, but $18 + 2 = 20$, which is also not divisible by 3, and so on). However, we need to ensure that $A$ and $B$ are both nonzero, so we'll look for the next smallest value of $A$ that makes $18 + A$ divisible by 3. The next smallest value is $A = 4$, since $18 + 4 = 22$, which is divisible by 3. Therefore, the smallest possible value for $N$ with $A = 4$ and $B = 2$ is $N = 14 {, }916$. This number is divisible by both 3 and 4, and thus by 12. Final Answer: The final answer is $14 {, }916$. I hope it is correct. Note: The solution provided assumes that $A$ and $B$ are digits that are nonzero, and it correctly identifies the conditions for divisibility by 3 and 4 to find the smallest possible value for $N$. However, it's important to double-check the conditions and calculations to ensure accuracy. Please let me know if you have any questions or if you would like a different problem or solution. Thank you! [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note: The solution provided is correct, but it's always a good idea to double-check the calculations and reasoning, especially for more complex problems. If you have any questions or need further clarification, please feel free to ask. Thank you!] [Note", + 1, + ), + ( + r"Since $25<30<36$, we have $5<\sqrt{30}<6$. We also know that $7^2=49$, so $\sqrt{50}\approx7$. As a result, $(5+7)<\sqrt{30}+\sqrt{50}<(6+7)$, so the sum is located between $\boxed{12\text{ and }13}$. To be more precise, $\sqrt{50}>7$, but we're still able to say that $\sqrt{30}+\sqrt{50}<(6+7)$ when we add a slightly greater quantity to the left side since the difference between $6$ and $\sqrt{30}$ is much greater than the difference between $\sqrt{50}$ and $7$.", + r"We have that \begin{align*} \sqrt{30 } + \sqrt{50 }&=\sqrt{30 } + \sqrt{25 \cdot2 } \ &=\sqrt{30 } + 5 \sqrt{2 } \ &\approx5.477 + 5 \cdot1.414 \ &\approx5.477 + 7.07 \ &\approx12.547. \end{align*}Therefore, the graph of the sum $\sqrt{30 } + \sqrt{50 }$ is located between $\boxed{12 $ and $13 }. $ Final Answer: The final answer is \"$12 $ and $13 $\". I hope it is correct.", + 1, + ), + # Notations + ( + "$(3, \\frac{\\pi}{2})$", + r"We have that $r = \\sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\\frac{\\pi}{2}$ with the positive $x$-axis.\n\n[asy]\nunitsize(0.8 cm);\n\ndraw((-0.5,0)--(3.5,0));\ndraw((0,-0.5)--(0,3.5));\ndraw(arc((0,0),3,0,90),red,Arrow(6));\n\ndot((0,3), red);\nlabel(\"$(0,3)$\", (0,3), W);\ndot((3,0), red);\n[/asy]\n\nTherefore, the polar coordinates are $\\boxed{\\left( 3, \\frac{\\pi}{2} \\right)}.$", + 1, + ), + ( + "$\\frac{14}{3}$", + r"$f(-2)+f(-1)+f(0)=\frac{3(-2)-2}{-2-2}+\frac{3(-1)-2}{-1-2}+\frac{3(0)-2}{0-2}=\frac{-8}{-4}+\frac{-5}{-3}+\frac{-2}{-2}=2+\frac{5}{3}+1=\boxed{\frac{14}{3}}$", + 1, + ), + ( + "$\\text{Evelyn}$", + r"Evelyn covered more distance in less time than Briana, Debra and Angela, so her average speed is greater than any of their average speeds. Evelyn went almost as far as Carla in less than half the time that it took Carla, so Evelyn's average speed is also greater than Carla's. Therefore, $\boxed{\text{Evelyn}}$ is our answer.", + 1, + ), + # Test cases from math problems + ( + "$90^\\circ$", + r"For the first line, let $t = 2x = 3y = -z.$ Then \[\begin{pmatrix} x \\ y \\ z \end{pmatrix} = \begin{pmatrix} t/2 \\ t/3 \\ -t \end{pmatrix} = \frac{t}{6} \begin{pmatrix} 3 \\ 2 \\ -6 \end{pmatrix}.\]Thus, the direction vector of the first line is $\begin{pmatrix} 3 \\ 2 \\ -6 \end{pmatrix}.$ For the second line, let $t = 6x = -y = -4z.$ Then \[\begin{pmatrix} x \\ y \\ z \end{pmatrix} = \begin{pmatrix} t/6 \\ -t \\ -t/4 \end{pmatrix} = \frac{t}{12} \begin{pmatrix} 2 \\ -12 \\ -3 \end{pmatrix}.\]Thus, the direction vector of the first line is $\begin{pmatrix} 2 \\ -12 \\ -3 \end{pmatrix}.$ Note that \[\begin{pmatrix} 3 \\ 2 \\ -6 \end{pmatrix} \cdot \begin{pmatrix} 2 \\ -12 \\ -3 \end{pmatrix} = 0.\]Hence, the angle between the lines is $\boxed{90^\circ}.$", + 1, + ), + ( + "$3\\sqrt{13}$", + r"We use the distance formula: \begin{align*} \sqrt{(2 - (-4))^2 + ((-6) - 3)^2} &= \sqrt{6^2 + (-9)^2}\\ & = \sqrt{36 + 81}\\ & = \sqrt{117} = \boxed{3\sqrt{13}}. \end{align*}", + 1, + ), + ( + "$\\frac{3}{56}$", + r"We also know that $q(-1) = ((-1)^2 - 1)p(-1) + 1 = 1.$ Setting $x = -1$ in the equation above, we get \[q(-1) = 20160(-a + b),\]so $-a + b = \frac{1}{20160}.$ Solving for $a$ and $b,$ we find $a = -\frac{29}{40320}$ and $b = -\frac{3}{4480}.$ Hence, \begin{align*} q(x) &= \left( -\frac{29}{40320} x - \frac{3}{4480} \right) (x - 2)(x - 3) \dotsm (x - 7) \\ &= -\frac{(29x + 27)(x - 2)(x - 3) \dotsm (x - 7)}{40320}. \end{align*}In particular, \[q(8) = -\frac{(29 \cdot 8 + 27)(6)(5) \dotsm (1)}{40320} = -\frac{37}{8},\]so \[p(8) = \frac{q(8) + 8}{8^2 - 1} = \boxed{\frac{3}{56}}.\]", + 1, + ), + ( + "$2$", + r"Of the two-digit perfect squares, only $4^2=16$ and $6^2=36$ end in $6$. Thus, there are $\boxed{2}$ distinct possible values for $B$.", + 1, + ), + ( + "$15\\mbox{ cm}^2$", + r"The shaded triangle has a base of length $10\text{ cm}.$ Since the triangle is enclosed in a rectangle of height $3\text{ cm},$ then the height of the triangle is $3\text{ cm}.$ (We know that the enclosing shape is a rectangle, because any figure with four sides, including two pairs of equal opposite sides, and two right angles must be a rectangle.) Therefore, the area of the triangle is $$\frac{1}{2}\times 3 \times 10 = \boxed{15\mbox{ cm}^2}.$$", + 1, + ), + ( + "$-2,1$", + r"By the Integer Root Theorem, the possible integer roots are all the divisors of 14 (including negative divisors), which are $-14,$ $-7,$ $-2,$ $-1,$ $1,$ $2,$ $7,$ and $14.$ Checking, we find that the only integer roots are $\boxed{-2,1}.$", + 1, + ), + ( + "$9$", + r"We use the property that $a \equiv b \pmod{m}$ implies $a^c \equiv b^c \pmod{m}$. Since $129 \equiv -3 \pmod{11}$ and $96 \equiv -3 \pmod{11}$, we have $$129^{34}+96^{38} \equiv (-3)^{34}+(-3)^{38} \equiv 3^{34}+3^{38} \pmod{11}.$$ Since $3^5 \equiv 1 \pmod{11},$ we can see that $3^{34} = (3^5)^{6} \cdot 3^4$ and $3^{38} = (3^5)^{7} \cdot 3^3.$ Then, $129^{34}+96^{38} \equiv \boxed{9} \pmod{11}.$", + 1, + ), + ( + "$90^\\circ$", + "Therefore, \\begin{align*} \\angle BAC &= \\angle BAD + \\angle DAC \\\\ &= 50^\\circ+40^\\circ \\\\ &= \\boxed{90^\\circ}. \\end{align*}", + 1, + ), + ( + "$0$", + "Note that $p(x)$ has degree at most 2. Also, $p(a) = p(b) = p(c) = 1.$ Thus, the polynomials $p(x)$ and 1 agree at three different values, so by the Identity Theorem, they are the same polynomial. Hence, the degree of $p(x)$ (which is the constant polynomial 1) is $\\boxed{0}.$", + 1, + ), + # Test long division in base 5 + ( + "$204_5$", + r"We may carry out long division in base 5 just as in base 10. We have \[ \begin{array}{c|ccc} \multicolumn{2}{r}{2} & 0 & 4 \\ \cline{2-4} 2 & 4 & 1 & 3 \\ \multicolumn{2}{r}{4} & \downarrow & \\ \cline{2-2} \multicolumn{2}{r}{0} & 1 & \\ \multicolumn{2}{r}{} & 0 & \downarrow \\ \cline{3-3} \multicolumn{2}{r}{} & 1 & 3 \\ \multicolumn{2}{r}{} & 1 & 3 \\ \cline{3-4} \multicolumn{2}{r}{} & & 0 \end{array} \]for a quotient of $\boxed{204_5}$. Note that in the above calculation we have used that $13_5$ divided by $2_5$ is $4_5$, which follows from $4_5\times2_5=8_{10}=13_5$.", + 1, + ), + ( + "$(6,31,-1)$", + "Let $\\alpha$ be a root of $x^3 - 3x^2 + 4x - 1 = 0,$ so $\\alpha^3 = 3 \\alpha^2 - 4 \\alpha + 1.$ Then solving the system of equations, we find $(p,q,r) = \\boxed{(6,31,-1)}.$", + 1, + ), + ( + "$1 \\pm \\sqrt{19}$", + "This simplifies to $64y + 1920 = 0,$ so $y = -30.$ Then $x^2 - 2x - 48 = -30,$ or $x^2 - 2x - 18 = 0.$ By the quadratic formula, $x = \\boxed{1 \\pm \\sqrt{19}}.$", + 1, + ), + ( + "$3 \\pm 2 \\sqrt{2}$", + "This gives us $x^2 + 1 = 6x,$ or $x^2 - 6x + 1 = 0.$ By the quadratic formula, the roots are $x = \\boxed{3 \\pm 2 \\sqrt{2}}.$", + 1, + ), + ( + "$\\{1\\pm\\sqrt{5},-2\\}$", + "The roots of $P(x)$ are $-2$ and $1 \\pm \\sqrt{5}$, so the answer is $\\boxed{\\{1\\pm\\sqrt{5},-2\\}}.$", + 1, + ), + ( + "$f(2) < f(1) < f(4)$", + 'The graph of $f(x) = x^2 + bx + c$ is an upward-facing parabola, and the condition\n\\[f(2 + t) = f(2 - t)\\]tells us that the axis of symmetry of the parabola is the line $x = 2.$ Thus, $f(x)$ is an increasing function of $|x - 2|.$ In other words, the farther $x$ is from 2, the greater $f(x)$ is.\n\n[asy]\nunitsize(1.5 cm);\n\nreal parab (real x) {\n return (x^2/4);\n}\n\ndraw(graph(parab,-2,2),red);\ndraw((0,-0.5)--(0,2),dashed);\n\nlabel("$x = 2$", (0,2), N);\ndot("$(2,f(2))$", (0,0), SE);\ndot("$(1,f(1))$", (-0.8,parab(-0.8)), SW);\ndot("$(4,f(4))$", (1.6,parab(1.6)), SE);\n[/asy]\n\nHence, $\\boxed{f(2) < f(1) < f(4)}.$', + 1, + ), + ( + "$2 \\sin b \\cos a$", + "By sum-to-product,\n\\[\\sin (a + b) - \\sin (a - b) = \\boxed{2 \\sin b \\cos a}.\\]", + 1, + ), + ( + "$\\frac{\\pi r}{h+r}$", + "Since $rs = A$, where $r$ is the inradius, $s$ is the semiperimeter, and $A$ is the area, we have that the ratio of the area of the circle to the area of the triangle is $\\frac{\\pi r^2}{rs} = \\frac{\\pi r}{s}$. Now we try to express $s$ as $h$ and $r$. Denote the points where the incircle meets the triangle as $X,Y,Z$, where $O$ is the incenter, and denote $AX = AY = z, BX = BZ = y, CY = CZ = x$. Since $XOZB$ is a square (tangents are perpendicular to radius), $r = BX = BZ = y$. The perimeter can be expressed as $2(x+y+z)$, so the semiperimeter is $x+y+z$. The hypotenuse is $AY+CY = z+x$. Thus we have $s = x+y+z = (z+x)+y = h+r$. The answer is $\\boxed{\\frac{\\pi r}{h+r}}$.'], Pred: ['Since $rs = A$, where $r$ is the inradius, $s$ is the semiperimeter, and $A$ is the area, we have that the ratio of the area of the circle to the area of the triangle is $\\frac{\\pi r^2}{rs} = \\frac{\\pi r}{s}$. Now we try to express $s$ as $h$ and $r$. Denote the points where the incircle meets the triangle as $X,Y,Z$, where $O$ is the incenter, and denote $AX = AY = z, BX = BZ = y, CY = CZ = x$. Since $XOZB$ is a square (tangents are perpendicular to radius), $r = BX = BZ = y$. The perimeter can be expressed as $2(x+y+z)$, so the semiperimeter is $x+y+z$. The hypotenuse is $AY+CY = z+x$. Thus we have $s = x+y+z = (z+x)+y = h+r$. The answer is $\\boxed{\\frac{\\pi r}{h+r}}$.", + 1, + ), + ("$125$ miles", "The distance is $\\boxed{125\\textnormal{ miles}}.$", 1), + ( + "$[-1, -\\frac{1}{2}) \\cup (-\\frac{1}{2}, 0) \\cup (0, 1) \\cup (1, \\infty)$", + "The solution set is $\\boxed{[-1, -\\tfrac12) \\cup (-\\tfrac12, 0) \\cup (0, 1) \\cup (1, \\infty)}.$", + 1, + ), + ("$\\sqrt{2}+\\sqrt{5}$", "The answer is $\\boxed{\\sqrt 2+\\sqrt 5}$", 1), + ("$\\frac{9}{4}\\pi$", "Therefore $\\boxed{\\frac94\\pi}$.", 1), + ( + "x \\in \\boxed{\\{-1\\} \\cup [0,7)}.$", + "x \\in \\boxed{\\{-1\\} \\cup [0,7)}.$", + 1, + ), + ], +) +def test_latex_notation_math(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + # Basic support for all relations + ( + "$x >= 5$", + "Therefore $x \\geq 5$ is the solution.", + 1, + ), + ( + "$x < 3$", + "We find that $x \\lt 3$.", + 1, + ), + ( + "$x \\leq 2$", + "Thus $x <= 2$ is our answer.", + 1, + ), + ( + "$x > 5$", + "Therefore $x \\gt 5$ is the solution.", + 1, + ), + ( + "$x != 3$", + "We find that $x \\neq 3$.", + 1, + ), + # Incorrect cases + ( + "$x > 5$", + "Therefore $x < 5$ is the solution.", + 0, + ), + ( + "$x \\geq 5$", + "The solution is $x \\leq 5$", + 0, + ), + ( + "$x \\neq 5$", + "The solution is $x != 5$", + 1, + ), + # Test flipped inequalities + ( + "$x \\leq 5$", + "$5 \\geq x$", + 1, + ), + ( + "$x \\geq 5$", + "$5 \\leq x$", + 1, + ), + ( + "$x = 11$", + "$x = 5+5+1 = 7 =11$", + 1, + ), + ( + "$7 = 11a+c$", + "$11a+c$", + 0, + ), + # Test equation with intermediate steps + ( + "$x = 11$", + "$x = 5+5+1 = 7 =11$", + 1, + ), + # Test fraction with approximation + ( + "$x = 1/3$", + "$x = 5+5+1 = 1/3 \\approx 11$", + 1, + ), + # Test bare number matches equation + ( + "$11$", + "$x=11$", + 1, + ), + # Test approximate equality + ( + "$11$", + "$x\\approx11$", + 1, + ), + # Test fraction with decimal approximation + ( + "$1/3$", + "$x=1/3\\approx1.3$", + 1, + ), + # Test inequality negation equivalence + ( + "$x < 1$", + "$-x > -1$", + 1, + ), + # Test non-equivalent inequality + ( + "$x < 1$", + "$x > -1$", + 0, + ), + # Test less-than-equal negation + ( + "$x <= 1$", + "$-x >= -1$", + 1, + ), + # Test incomplete equation + ( + "$a +3z = 0$", + "$0$", + 0, + ), + ( + "$1 = \\zzz = x = 0$", + "$0$", + 1, + ), + ("$2x + z = 1$", "$1$", 0), + ("$a^2 + b = 0$", "$0$", 0), + ("$k=1$", "$1$", 1), + ("$1$", "$k=1$", 1), + ("$z = 1 + 1 = 2$", "$z = 3+3 = 2$", 1), + ("$z = 1 + 1 = 2$", "$z = 3+3 = 2$", 1), + ("$2x+4y-3=0$", "$y=-\\frac{1}{2}x+\\frac{3}{4}$", 1), + ], +) +def test_relations_math(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + ( + "$189$", + "9=189.\n\nTherefore, answer is 189.\n\nAnother thought: The asy code draws a trapezoid, but it might not be a right trapezoid? Wait, no, two right angles are drawn at (0,0) and (0,9). So sides AD and DC are perpendicular. So it's a right trapezoid.\n\nWait, in this case, AD is vertical, length 9 cm, and DC is top base, length 15 cm, and AB is bottom base, length 27 cm. So area is (15 +27)/2 *9=189.\n\nSo regardless of what the labels say, the coordinates give us the necessary bases and height.\n\nAlternatively, thinking that maybe the 5 cm is a different component. Let me think, if the height is 3 cm, but the coordinate says 9. No, that's conflicting.\n\nAlternatively, it's a different way of measuring, but given that Asymptote is accurate with coordinates, I think safe to go with 189.\n\n**Final Answer**\n\\boxed{189}\n\n\nTo find the area of the trapezoid, we can use the coordinates of the vertices provided in the Asymptote code. The vertices are A(0,0), B(27,0), C(15,9), and D(0,9).\n\nUsing the shoelace formula:\n\\[\n\\text{Area} = \\frac{1}{2} \\left| (x_1y_2 + x_2y_3 + x_3y_4 + x_4y_1) - (y_1x_2 + y_2x_3 + y_3x_4 + y_4x_1) \\right|\n\\]\nSubstituting the coordinates:\n\\[\n\\text{Sum1} = (0 \\cdot 0) + (27 \\cdot 9) + (15 \\cdot 9) + (0 \\cdot 0) = 0 + 243 + 135 + 0 = 378\n\\]\n\\[\n\\text{Sum2} = (0 \\cdot 27) + (0 \\cdot 15) + (9 \\cdot 0) + (9 \\cdot 0) = 0 + 0 + 0 + 0 = 0\n\\]\n\\[\n\\text{Area} = \\frac{1}{2} |378 - 0| = \\frac{1}{2} \\times 378 = 189\n\\]\n\nThus, the area of the trapezoid is \\(\\boxed{189}\\) square centimeters.", + 1, + ), + # Test Identity Matrix + ( + r"$\begin{pmatrix}1 & 0 \\ 0 & 1\end{pmatrix}$", + r"The identity matrix is $ \begin{pmatrix}1 & 0 \\ 0 & 1\end{pmatrix} $.", + 1, + ), + # Test bmatrix + ( + r"$\begin{bmatrix}0 & 0 \\0 & 0\end{bmatrix}$", + r"Here is a zero matrix: $ \begin{pmatrix}0 & 0 \\0 & 0\end{pmatrix} $", + 1, + ), + # Test Matrix with Special Formatting + ( + r"$\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}$", + r"Special matrix: $ \left[\begin{array}{cc}1 & 2 \\3 & 4\end{array}\right] $", + 1, + ), + # Test Matrix with Fraction Entries + ( + r"$\begin{pmatrix}\frac{1}{2} & \frac{3}{4} \\ \frac{5}{6} & \frac{7}{8}\end{pmatrix}$", + r"Matrix with fractions: $ \begin{pmatrix}\frac{1}{2} & \frac{3}{4} \\ \frac{5}{6} & \frac{7}{8}\end{pmatrix} $", + 1, + ), + # Test matrix addition + ( + r"$\begin{pmatrix}6 & 8 \\ 10 & 12\end{pmatrix}$", + r"The sum is $\begin{pmatrix}1 & 2 \\ 3 & 4\end{pmatrix} + \begin{pmatrix}5 & 6 \\ 7 & 8\end{pmatrix}$", + 1, + ), + # Test matrix multiplication + ( + r"$\begin{pmatrix}1 & 0 \\ 0 & 1\end{pmatrix}$", + r"When multiplying by identity: $\begin{pmatrix}1 & 0 \\ 0 & 1\end{pmatrix} \begin{pmatrix}1 & 0 \\ 0 & 1\end{pmatrix}$", + 1, + ), + # Test incorrect matrix + ( + r"$\begin{pmatrix}1 & 2 \\ 3 & 4\end{pmatrix}$", + r"The matrix is $\begin{pmatrix}1 & 2 \\ 3 & 5\end{pmatrix}$", # Different value in bottom right + 0, + ), + ], +) +def test_matrix_extraction(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex"]) == expected + + +def test_precision(): + assert sympy_expr_eq( + sympy.Rational(1, 3), sympy.Float(0.333), float_rounding=3, numeric_precision=15 + ) + assert not sympy_expr_eq( + sympy.Rational(1, 3), sympy.Float(0.333), float_rounding=4, numeric_precision=15 + ) + + # It should work with more nuanced pairs + assert sympy_expr_eq( + sympy.Rational(1, 3) + 1, + sympy.Float(1.333), + float_rounding=3, + numeric_precision=15, + ) + assert not sympy_expr_eq( + sympy.Rational(1, 3) + 1, + sympy.Float(1.333), + float_rounding=4, + numeric_precision=15, + ) + + # From latex + assert ( + compare_strings( + "$\\frac{1}{3}$", "0.3333$", match_types=["latex", "expr"], precision=4 + ) + == 1 + ) + + +# Tests from qwen parser +@pytest.mark.parametrize( + "gold,pred,expected,precision", + [ + # Test decimal vs fraction equivalence + ("$\\frac{1}{12}$", "$0.0833333333333333$", 1, 6), + ("$(1,\\frac{9}{2})$", "$(1,4.5)$", 1, 6), + # Test algebraic expressions + ("$\\frac{x+2}{7}$", "$\\frac{x}{7}+\\frac{2}{7}$", 1, 6), + ("$\\tan^2(y)+1$", "$\\sec^2(y)$", 1, 6), + # Test complex matrices + ( + "$\\begin{pmatrix}-\\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix}$", + "$\\begin{pmatrix}-\\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix}$", + 1, + 6, + ), + ( + "$\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix}$", + "$\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}$", + 1, + 6, + ), + # Test equations + ("$34x+45y-20z+100=0$", "$-34x-45y+20z-100=0$", 1, 6), + # Test matrix with decimals + ( + "$(\\begin{pmatrix}\\frac{1}{3}\\\\ \\frac{1}{5} \\end{pmatrix})$", + "$\\begin{pmatrix}0.33\\\\0.2 \\end{pmatrix}$", + 1, + 2, + ), + # Test expression order invariance + ( + "$\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}$", + "$\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}$", + 1, + 6, + ), + # Test non-equivalent expressions + ("$(a+5)(b+2)$", "$(+5)(b+2)$", 0, 6), + ("$2$", "$\\frac{1+\\sqrt{5}}{2}$", 0, 6), + ("$4$", "$\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}$", 0, 6), + ("$1\\sqrt{19}$", "$1$", 0, 6), + # Test intervals + ("$(\\frac{3}{5},\\frac{8}{3}]$", "$(0.6,2.6667]$", 1, 2), + # Test non-equivalent algebraic expressions + ("$x+2n+1$", "$x+1$", 0, 6), + ], +) +def test_complex_math_expressions(gold, pred, expected, precision): + assert ( + compare_strings(gold, pred, match_types=["latex", "expr"], precision=precision) + == expected + ) + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + # Issue #1: Rational number extraction + ( + "$2/3$", + r"and then Alice wins the game from that point on. The probability of this sequence of events is (1/2) * (1/2) * P(A), since each flip has a 1/2 chance of happening. So, we can rewrite the equation as: P(A) = 1/2 + (1/2) * (1/2) * P(A) Now, we can solve for P(A): P(A) - (1/4) * P(A) = 1/2 (3/4) * P(A) = 1/2 P(A) = (1/2) / (3/4) P(A) = (1/2) * (4/3) P(A) = 2/3 Therefore, the probability that Alice wins the game is 2/3. <|im_start|>user", + 1, + ), + # Issue #2: Integer extraction + ( + "$12$", + r"Let's denote the common difference of the arithmetic sequence as $d$. Since the first term is 1, the sequence can be written as $1, 1+d, 1+2d, 1+3d, \dots$. We are given that another term of the sequence is 91. So, there exists an integer $n$ such that $1+nd = 91$. Solving for $d$, we get $d = \frac{90}{n}$. Since $d$ must be an integer, $n$ must be a divisor of 90. The divisors of 90 are 1, 2, 3, 5, 6, 9, 10, 15, 18, 30, 45, and 90. For each divisor $n$, there corresponds a unique common difference $d = \frac{90}{n}$ that generates an arithmetic sequence with the given conditions. Therefore, there are 12 distinct arithmetic sequences that meet the three given conditions. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #3: Empty set extraction + ( + r"$(7,6)$", + r"First, let's simplify the second equation by moving all terms involving $x$ to the left side and all terms involving $y$ to the right side: $x - 0.6x = -y + 8.8$ This simplifies to: $0.4x = -y + 8.8$ Now, let's solve for $y$: $y = 8.8 - 0.4x$ Now, substitute this expression for $y$ into the first equation: $2x - 3(8.8 - 0.4x) = -3.2 - 0.2x + 0.1(8.8 - 0.4x)$ Simplify the equation: $2x - 26.4 + 1.2x = -3.2 - 0.2x + 0.88 - 0.04x$ Combine like terms: $3.2x - 26.4 = -2.32 - 0.24x$ Add $0.24x$ to both sides and add $26.4$ to both sides: $3.44x = 24.08$ Now, divide by $3.44$ to solve for $x$: $x = \frac{24.08}{3.44} = 7$ Now, substitute this value of $x$ back into the expression for $y$: $y = 8.8 - 0.4(7)$ $y = 8.8 - 2.8 = 6$ So, the ordered pair $(x, y)$ that solves the system is $(7, 6)$. I apologize for the mistake in the previous response. Let's correct it. First, let's simplify the second equation by moving all terms involving $x$ to the left side and all terms involving $y$ to the right side: $x - 0.6x = -y + 8.8$ This simplifies to: $0.4x = -y + 8.8$ Now, let's solve for $y$: $y = 8.8 - 0.4x$ Now, substitute this expression for $y$ into the first equation: $2x - 3(8.8 - 0.4x) = -3.2 - 0.2x + 0.1(8.8 - 0.4x)$ Simplify the equation: $2x - 26.4 + 1.2x = -3.2 - 0.2x + 0.88 - 0.04x$ Combine like terms: $3.2x - 26.4 = -2.32 - 0.24x$ Add $0.24x$ to both sides and add $26.4$ to both sides: $3.44x = 24.08$ Now, divide by $3.44$ to solve for $x$: $x = \frac{24.08}{3.44} = 7$ Now, substitute this value of $x$ back into the expression for $y$: $y = 8.8 - 0.4(7)$ $y = 8.8 - 2.8 = 6$ So, the ordered pair $(x, y)$ that solves the system is $(7, 6)$. <|im_start|>user", + 1, + ), + # Issue #4: Correctly extract incorrect answer + ( + "$[0,1)$", + r"For the function $q(x)$ to be defined, both the numerator and the denominator must be real numbers, and the denominator must not be zero. 1. The square root in the numerator is defined only for non-negative values of $x$, so $x \geq 0$. 2. The square root in the denominator is defined only for values of $1 - x^2 > 0$, which implies $x^2 < 1$. This gives $-1 < x < 1$. 3. The denominator cannot be zero, so $1 - x^2 \neq 0$, which implies $x \neq \pm 1$. Combining these conditions, we find that the domain of $q(x)$ is the interval $(0, 1)$, since $x$ must be greater than or equal to 0 and less than 1, but not equal to 1. Final Answer: The final answer is $(0, 1)$. I hope it is correct. <|im_start|>user", + 0, + ), + # Issue #5: Another case of correctly extracted incorrect answer + ( + "$[-3,2]$", + r"To find the domain of the function $\sqrt{6-x-x^2}$, we need to determine for which values of $x$ the expression inside the square root is non-negative, i.e., $6-x-x^2 \ge 0$. First, let's find the roots of the quadratic equation $x^2 + x - 6 = 0$. We can factor this equation as $(x+3)(x-2) = 0$. So, the roots are $x = -3$ and $x = 2$. Now, we can analyze the sign of the quadratic expression $x^2 + x - 6$ between the roots and outside the roots. Since the coefficient of $x^2$ is positive, the parabola opens upwards, and the expression is positive outside the roots and negative between the roots. Therefore, the domain of the function $\sqrt{6-x-x^2}$ is the interval between the roots, which is $\boxed{[-3, 2]}$. Final Answer: The final answer is $[-3, 2]$. I hope it is correct. I apologize for the mistake in the previous solution. Let me correct it. To find the domain of the function $\sqrt{6-x-x^2}$, we need to determine for which values of $x$ the expression inside the square root is non-negative, i.e., $6-x-x^2 \ge 0$. First, let's find the roots of the quadratic equation $x^2 + x - 6 = 0$. We can factor this equation as $(x+3)(x-2) = 0$. So, the roots are $x = -3$ and $x = 2$. Now, we can analyze the sign of the quadratic expression $x^2 + x - 6$ between the roots and outside the roots. Since the coefficient of $x^2$ is positive, the parabola opens upwards, and the expression is positive outside the roots and negative between the roots. Therefore, the domain of the function $\sqrt{6-x-x^2}$ is the interval outside the roots, which is $\boxed{(-\infty, -3] \cup [2, \infty)}$. Final Answer: The final answer is $(-\infty, -3] \cup [2, \infty)$. I hope it is correct. <|im_start|>user", + 0, + ), + # Issue #6: Incorrect answer because not substituting \! + ( + "$16386.20$", + r"To find the value of her investment after 5 years, we can use the formula for compound interest: \[A = P \left(1 + \frac{r}{n}\right)^{nt}\] where $A$ is the final amount, $P$ is the principal (initial investment), $r$ is the annual interest rate (as a decimal), $n$ is the number of times the interest is compounded per year, and $t$ is the number of years. In this case, Paula invested $\$10,\!000$, so $P = \$10,\!000$. The interest rate is $10\%$, which as a decimal is $r = 0.10$. The interest is compounded quarterly, so $n = 4$. The investment period is 5 years, so $t = 5$. Plugging these values into the formula, we get: \[A = 10000 \left(1 + \frac{0.10}{4}\right)^{4 \cdot 5}\] \[A = 10000 \left(1 + 0.025\right)^{20}\] \[A = 10000 \left(1.025\right)^{20}\] \[A \approx 10000 \cdot 1.63862\] \[A \approx 16386.20\] Rounded to the nearest cent, Paula's investment is worth $\boxed{\$16,\!386.20}$ at the end of the 5-year period. The answer is $16,\!386.20$. I apologize for the repetition in my previous responses. Here is the correct solution: To find the value of her investment after 5 years, we can use the formula for compound interest: \[A = P \left(1 + \frac{r}{n}\right)^{nt}\] where $A$ is the final amount, $P$ is the principal (initial investment), $r$ is the annual interest rate (as a decimal), $n$ is the number of times the interest is compounded per year, and $t$ is the number of years. In this", + 1, + ), + # Issue #7: Volume extraction + ( + "$120$", + r"We can find the volume of pyramid $ABCH$ by using the formula for the volume of a pyramid: $\frac{1}{3} \times \text{base area} \times \text{height}$. The base of pyramid $ABCH$ is triangle $ABC$, which is a right triangle with legs of length 5 and 6. Therefore, the base area is $\frac{1}{2} \times 5 \times 6 = 15$. The height of pyramid $ABCH$ is the length of segment $AH$, which is the same as the height of the rectangular prism $ABCDEFGH$. Let's call this height $h$. We are given that the volume of pyramid $ABCH$ is 20, so we have $\frac{1}{3} \times 15 \times h = 20$. Simplifying, we get $5h = 20$, so $h = 4$. Now, to find the volume of rectangular prism $ABCDEFGH$, we can use the formula for the volume of a rectangular prism: $\text{length} \times \text{width} \times \text{height}$. The length is 5, the width is 6, and the height is 4. Therefore, the volume of rectangular prism $ABCDEFGH$ is $5 \times 6 \times 4 = 120$. Final Answer: The volume of $ABCDEFGH$ is 120. I hope it is correct.", + 1, + ), + # Issue #8: Large number extraction with commas + ( + "12000023", + r"To find the integers congruent to 3 (mod 7), we can find the remainder when each integer is divided by 7. For 85, the remainder is $85 \mod 7 = 1$. For 49,479, the remainder is $49,\!479 \mod 7 = 3$. For -67, the remainder is $-67 \mod 7 = 4$. For 12,000,003, the remainder is $12,\!000,\!003 \mod 7 = 5$. For -3, the remainder is $-3 \mod 7 = 4$. The only integer congruent to 3 (mod 7) is 49,479. Therefore, we remove this integer from the list and sum the remaining integers: $85 + (-67) + 12,\!000,\!003 + (-3) = 12,\!000,\!023$. Final Answer: The final answer is 12,000,023. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #9: Decimal extraction + ( + "$14916$", + r"For a number to be divisible by 12, it must be divisible by both 3 and 4. First, let's consider the divisibility rule for 3. A number is divisible by 3 if the sum of its digits is divisible by 3. In this case, the sum of the digits is $1 + 4 + 9 + A + B = 14 + A + B$. Since $A$ and $B$ are nonzero digits, the smallest possible value for $A + B$ is $1 + 2 = 3$. Therefore, the smallest possible value for the sum of the digits is $14 + 3 = 17$. The smallest multiple of 3 greater than 17 is 18, so we need $A + B = 4$. Next, let's consider the divisibility rule for 4. A number is divisible by 4 if the last two digits form a number divisible by 4. In this case, the last two digits are $AB$. Since $A$ and $B$ are nonzero, the smallest possible value for $AB$ that is divisible by 4 is $12$. Now, we need to find the smallest possible value for $N$ given that $A + B = 4$ and $AB = 12$. The smallest possible values for $A$ and $B$ that satisfy these conditions are $A = 2$ and $B = 2$. Therefore, the smallest possible value for $N$ is $14{,}922$. Final Answer: The smallest possible value of $N$ is $\boxed{14{,}916}$. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #10: Interval notation with fractions + ( + r"$\\left(\\frac{1}{5},\\frac{1}{6}\\right)$", + r"We can find the projection of $\mathbf{a}$ onto $P$ by taking the projection of $\mathbf{a}$ onto $\mathbf{b}$ and $\mathbf{c}$ and then taking their sum. The projection of $\mathbf{a}$ onto $\mathbf{b}$ is given by \[\operatorname{proj}_{\mathbf{b}} \mathbf{a} = \frac{\mathbf{a} \cdot \mathbf{b}}{\|\mathbf{b}\|^2} \mathbf{b} = \frac{1}{5} \mathbf{b},\]since $\mathbf{a} \cdot \mathbf{b} = \|\mathbf{a}\| \|\mathbf{b}\| \cos \theta = 1 \cdot 1 \cdot \frac{1}{5} = \frac{1}{5}$ and $\|\mathbf{b}\| = 1.$ Similarly, the projection of $\mathbf{a}$ onto $\mathbf{c}$ is given by \[\operatorname{proj}_{\mathbf{c}} \mathbf{a} = \frac{\mathbf{a} \cdot \mathbf{c}}{\|\mathbf{c}\|^2} \mathbf{c} = \frac{1}{6} \mathbf{c},\]since $\mathbf{a} \cdot \mathbf{c} = \|\mathbf{a}\| \|\mathbf{c}\| \cos \theta = 1 \cdot 1 \cdot \frac{1}{6} = \frac{1}{6}$ and $\|\mathbf{c}\| = 1.$ Therefore, the projection of $\mathbf{a}$ onto $P$ is \[\operatorname{proj}_{P} \mathbf{a} = \operatorname{proj}_{\mathbf{b}} \mathbf{a} + \operatorname{proj}_{\mathbf{c}} \mathbf{a} = \frac{1}{5} \mathbf{b} + \frac{1}{6} \mathbf{c}.\]Thus, $(p,q) = \boxed{\left( \frac{1}{5}, \frac{1}{6} \right)}.$ Final Answer: The final answer is $\left( \frac{1}{5}, \frac{1}{6} \right)$. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #11: Fraction comparison + ( + r"$\\left(\\frac{4}{5},\\frac{1}{5}\\right)$", + r"Since $AP:PB = 1:4,$ we can write \[\frac{\overrightarrow{P} - \overrightarrow{A}}{1} = \frac{\overrightarrow{B} - \overrightarrow{P}}{4}.\]Isolating $\overrightarrow{P},$ we find \[\overrightarrow{P} = \frac{4}{5} \overrightarrow{A} + \frac{1}{5} \overrightarrow{B}.\]Thus, $(t,u) = \boxed{\left( \frac{4}{5}, \frac{1}{5} \right)}.$ Final Answer: The final answer is $\left( \frac{4}{5}, \frac{1}{5} \right)$. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #12: Square root comparison, it's kinda impossible to fix as it's rather a sympy error + # I think the issue is that because of float precission erorr sympy doesn't understand that 10.5 == 42/4 + # THerefore it can't simplify. + ( + r"$\frac{5\sqrt{42}}{27}$", + r"We can start by simplifying each square root individually: \begin{align*} \sqrt{160} &= \sqrt{16 \cdot 10} = \sqrt{16} \cdot \sqrt{10} = 4\sqrt{10}, \\ \sqrt{252} &= \sqrt{36 \cdot 7} = \sqrt{36} \cdot \sqrt{7} = 6\sqrt{7}, \\ \sqrt{245} &= \sqrt{49 \cdot 5} = \sqrt{49} \cdot \sqrt{5} = 7\sqrt{5}, \\ \sqrt{108} &= \sqrt{36 \cdot 3} = \sqrt{36} \cdot \sqrt{3} = 6\sqrt{3}. \end{align*} Now we can substitute these values back into the original expression: \begin{align*} \frac{\sqrt{160}}{\sqrt{252}} \times \frac{\sqrt{245}}{\sqrt{108}} &= \frac{4\sqrt{10}}{6\sqrt{7}} \times \frac{7\sqrt{5}}{6\sqrt{3}} \\ &= \frac{4 \cdot 7 \cdot \sqrt{10} \cdot \sqrt{5}}{6 \cdot 6 \cdot \sqrt{7} \cdot \sqrt{3}} \\ &= \frac{28\sqrt{50}}{36\sqrt{21}}. \end{align*} To rationalize the denominator, we multiply the numerator and denominator by $\sqrt{21}$: \begin{align*} \frac{28\sqrt{50}}{36\sqrt{21}} \times \frac{\sqrt{21}}{\sqrt{21}} &= \frac{28\sqrt{1050}}{36 \cdot 21} \\ &= \frac{28\sqrt{1050}}{756}. \end{align*} Finally, we simplify the square root: \begin{align*} \sqrt{1050} &= \sqrt{100 \cdot 10.5} = \sqrt{100} \cdot \sqrt{10.5} = 10\sqrt{10.5}. \end{align*} Substituting this back into the expression, we get: \begin{align*} \frac{28\sqrt{1050}}{756} &= \frac{28 \cdot 10\sqrt{10.5}}{756} \\ &= \frac{280\sqrt{10.5}}{756} \\ &= \frac{10\sqrt{10.5}}{27}. \end{align*} Final Answer: The final answer is $\frac{10\sqrt{10.5}}{27}$. I hope it is correct.", + 1, + ), + # Issue #13: Handling of approximations + ( + "$16386.17$", + r"When interest is compounded quarterly, the interest is applied four times per year. In this case, the interest rate per quarter is $10\%/4 = 2.5\%$. The formula for compound interest is $A = P(1 + r)^n$, where: - $A$ is the amount of money accumulated after n years, including interest. - $P$ is the principal amount (the initial amount of money). - $r$ is the annual interest rate (decimal). - $n$ is the number of years the money is invested. In this case, $P = \$10,000$, $r = 0.025$ (2.5%), and $n = 5 \times 4 = 20$ (since there are 20 quarters in 5 years). Plugging these values into the formula, we get: \[A = 10000(1 + 0.025)^{20}.\] Calculating this gives: \[A \approx 10000(1.025)^{20} \approx 16386.169.\] Rounded to the nearest cent, Paula's investment is worth $\boxed{\$16386.17}$ after 5 years. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue #14: Percentage extraction + ( + r"Let the percentage of people in Mathopolis who are children be $x$. The percentage of adults is then $1-x$. Half the adults are female, and half the females have exactly one child, so the percentage of people who are females with a child is $\frac{1}{4}(1-x)$. This percentage is equal to the percentage of children, since there is a correspondence between a mother and a child. So we have the equation $x=\frac{1}{4}(1-x)$. Solving for $x$ yields $x=1/5$, or $\boxed{20}$ percent.", + r"Let's denote the total number of adults in Mathopolis as $A$. Since exactly half of the adults are female, there are $\frac{A }{2 }$ female adults. Since exactly half of these female adults have exactly one biological child, there are $\frac{A }{4 }$ children in Mathopolis. Therefore, the percentage of people in Mathopolis who are children is $$\frac{\frac{A }{4 }}{A+\frac{A }{4 }}=\frac{\frac{A }{4 }}{A\left(1 +\frac{1 }{4 }\right)}=\frac{\frac{A }{4 }}{\frac{5A }{4 }}=\frac{1 }{5 }=\boxed{20 \%}. $$ Final Answer: The final answer is $20 \%$. I hope it is correct. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context of the problem. However, the problem does not explicitly state this, so it's important to clarify this assumption in the solution. If the problem were to be interpreted in a different way, the solution might need to be adjusted accordingly. Note: The solution provided assumes that the total number of people in Mathopolis is the sum of adults and children, which is a reasonable assumption given the context", + 1, + ), + ], +) +def test_math_extraction_edge_cases(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected + + +@pytest.mark.parametrize( + "gold,pred,expected", + [ + # Issue: Yaping classification issue + # There is not much to do about it + # ( + # r"$5/3$", + # r"We can simplify the left side of the equation to $5 \cdot 5^b$. Now we have $5 \cdot 5^b = 625^{(b-1)}$. We can rewrite $625$ as $5^4$, so the equation becomes $5 \cdot 5^b = (5^4)^{(b-1)}$. Simplifying the right side, we have $5 \cdot 5^b = 5^{4(b-1)}$. Since the bases are equal, we can set the exponents equal to each other: $1 + b = 4(b-1)$. Expanding, we get $1 + b = 4b - 4$. Simplifying further, we have $5 = 3b$. Dividing both sides by $3$, we find $b = \boxed{\frac{5}{3}}$. Final Answer: The final answer is $\frac{5}{3}$. I hope it is correct. <|im_start|>userYes, that is correct. The value of $b$ is $\frac{5}{3}$. If you have any more questions or need assistance with any other topics, feel free to ask! <|im_start|>}userYes, I have another problem. Solve the equation $2^{2x} \cdot 2^{2x} = 2^{32}$ for $x$. Solution: <|im_start|>assistantWe can rewrite the equation as $2^{4x} = 2^{32}$. Since the bases are equal, we can set the exponents equal to each other: $4x = 32$. Dividing both sides by $4$, we find $x = \boxed{8}$. Final Answer: The final answer is $8$. I hope it is correct. <|im_start|>}userYes, that is correct. The solution to the equation is $x = 8$. If you have any more questions or need assistance with any other topics, feel free to ask! <|im_start|>}userI have another problem. Solve the equation $3^{x+4} = 27^{x+1}$ for $x$. Solution: <|im_start|>assWe can rewrite $27$ as $3^3$, so the equation becomes $3^{x+4} = (3^3)^{x", + # 1 + # ), + # Issue: LaTeX vs expr precedence + ( + "$less than or equal to zero between its roots. Therefore, the solution to the inequality is the interval between -2 and 7, including the endpoints.\n\nThus, the solution in interval notation is:\n\n\\[ \\boxed{[-2, 7]} \\]$", + "$x \\in [-2,7]$", + 1, + ), + ( + r"$3$", + r"Let $t$ be the number of hours since Jane started growing Rod, and let $s$ be the number of hours since Jane started growing Sphere. We know that $t = s + 5$, since Rod started 5 hours before Sphere. The population of Rod at any given time is $2 \cdot 2^t$, because it doubles every hour and started with 2 bacteria. The population of Sphere at any given time is $8 \cdot 4^s$, because it quadruples every hour and started with 8 bacteria. At 8 p.m., the populations are equal, so we have: \[2 \cdot 2^t = 8 \cdot 4^s\] We know that $t = s + 5$, so we can substitute $t$ with $s + 5$ in the equation: \[2 \cdot 2^{s+5} = 8 \cdot 4^s\] Simplify the equation: \[2^{s+6} = 8 \cdot 4^s\] Since $8 = 2^3$ and $4 = 2^2$, we can rewrite the equation as: \[2^{s+6} = 2^3 \cdot 2^{2s}\] Combine the exponents on the right side: \[2^{s+6} = 2^{3+2s}\] Since the bases are the same, we can equate the exponents: \[s + 6 = 3 + 2s\] Solve for $s$: \[6 - 3 = 2s - s\] \[3 = s\] So, Jane started growing Sphere 3 hours ago. Since Rod started 5 hours before Sphere, the number of hours ago Jane started growing Rod is $t = s + 5 = 3 + 5 = 8$ hours. Final Answer: Jane started growing Sphere 3 hours ago. I hope it is correct. <|im_start|>user", + 1, + ), + # Issue: Answer extraction with format "answer is" + ( + r"$19$", + r"To find the maximum number of points in the domain of $f(x)$, we need to consider the possible values of $x$ that give the distinct values in the range $\{0,1,2,3,4,5,6,7,8,9\}$. 1. $f(x) = 0$ when $x = 0$ (since $0^2 = 0$). 2. $f(x) = 1$ when $x = 1$ or $x = -1$ (since $1^2 = 1$ and $(-1)^2 = 1$). 3. $f(x) = 2$ when $x = 2$ or $x = -2$ (since $2^2 = 4$ and $(-2)^2 = 4$). 4. Similarly, for $f(x) = 3, 4, 5, 6, 7, 8, 9$, there are two possible values of $x$ for each, except for $f(x) = 0$ where there is only one value, $x = 0$, because $f(0) = 0$ and no other value of $x$ less than or greater than zero squared will give 0 under the given domain restrictions. Counting the distinct values of $x$ that give the range elements: - For $f(x) = 0$, we have 1 point. - For $f(x) = 1, 2, 3, 4, 5, 6, 7, 8, 9$, we have 2 points each. So, the total number of points in the domain is $1 + 2 \times 9 = 1 + 18 = 19$ points. Final Answer: The maximum number of points that could be in the domain of $f(x)$ is 19. <|im_start|>user", + 1, + ), + # Issue: Empty set handling + ( + "$(\\tfrac{27}{8},\\tfrac{15}{8})$", + r"Let the line through $A$ intersect $\overline{CD}$ at point $E$. Since the line cuts the quadrilateral into two pieces with the same area, it must cut $\overline{CD}$ into two segments of equal length. Let $x$ be the $x$-coordinate of $E$, so the $y$-coordinate of $E$ is $0$. The midpoint of $\overline{CD}$ is $\left(\frac{3+4}{2}, \frac{3+0}{2}\right) = \left(\frac{7}{2}, \frac{3}{2}\right)$. Since $E$ is on the line passing through $A$, the slope of $\overline{AE}$ is $\frac{3/2 - 0}{7/2 - 0} = \frac{3}{7}$. The equation of the line passing through $A$ and $E$ is then $y = \frac{3}{7}x$. Since $E$ has $y$-coordinate $0$, we have $0 = \frac{3}{7}x$, so $x = 0$. Therefore, the point where the line intersects $\overline{CD}$ is $\boxed{(0, 0)}$. Final Answer: The final answer is $(0, 0)$. I hope it is correct. <|im_start|>user", + 0, + ), + # Issue: Answer extraction with format "Final Answer" + ("$34$", r"Final Answer: The largest $n$ is 34.", 1), + # Issue: Interval swapping + ( + r"$\\left( \\frac{4}{3}, -\\frac{1}{3} \\right)$", + r"Since $AP:PB = 1:4,$ we can write \[\frac{\overrightarrow{A} - \overrightarrow{P}}{1} = \frac{\overrightarrow{B} - \overrightarrow{P}}{4}.\]Isolating $\overrightarrow{P},$ we find \[\overrightarrow{P} = \frac{4}{3} \overrightarrow{A} - \frac{1}{3} \overrightarrow{B}.\]Thus, $(t,u) = \boxed{\left( \frac{4}{3}, -\frac{1}{3} \right)}.$", + 1, + ), + (r"$(3,1)$", r"${1,3}$", 0), + (r"$(1,3)$", r"${1,3}$", 1), + # Issue: Therefore preference + ( + "$100%$", + r"Let the original error in the measurement of the particle's momentum be $\Delta p$ and the original error in the measurement of its position be $\Delta x$. According to Heisenberg's Uncertainty Principle, we have: $$\Delta p \Delta x \geq \frac{h}{4\pi}.$$ Now, suppose the error in the measurement of the momentum is halved, so the new error in the momentum is $\frac{\Delta p}{2}$. Let the new error in the measurement of the position be $\Delta x'$. Then, we have: $$\frac{\Delta p}{2} \Delta x' \geq \frac{h}{4\pi}.$$ Dividing the second equation by the first equation, we get: $$\frac{\Delta x'}{\Delta x} = \frac{2\Delta x}{\Delta x'} \geq \frac{2}{1}.$$ Therefore, the minimum error in the measurement of the position increases by 100%.", + 1, + ), + # Issue: Decimal handling + ( + "$16386.16$", + r"The formula for compound interest is $A = P\left(1 + \frac{r}{n}\right)^{nt}$, where $A$ is the amount of money accumulated after $n$ years, including interest, $P$ is the principal amount (the initial amount of money), $r$ is the annual interest rate (in decimal), $n$ is the number of times that interest is compounded per year, and $t$ is the time the money is invested for in years. In this case, $P = 10000$, $r = 0.10$, $n = 4$ (since the interest is compounded quarterly), and $t = 5$. Plugging in these values, we get: \[ A = 10000\left(1 + \frac{0.10}{4}\right)^{4 \cdot 5} \] \[ A = 10000\left(1 + 0.025\right)^{20} \] \[ A = 10000\left(1.025\right)^{20} \] \[ A \approx 16386.16 \] So, Paula's investment is worth approximately $\boxed{16386.16}$ dollars at the end of the 5 years. Final Answer: The final answer is 16386.16.", + 1, + ), + # Issue: Large number handling + ( + "$5866608$", + r"To find the sum of all possible distinct five-digit positive integers that can be formed using the digits 1, 3, 4, 5, and 9 exactly once, we can use the concept of permutations. There are $5!$ (5 factorial) ways to arrange these 5 distinct digits. The sum of all these permutations can be found by considering the contribution of each digit in each place value (ten thousands, thousands, hundreds, tens, and units). Each digit will appear in each place value $\frac{5!}{5} = 4!$ times. The sum of the digits 1, 3, 4, 5, and 9 is $1 + 3 + 4 + 5 + 9 = 22$. Now, calculate the total sum: \[ \text{Total sum} = 4!(22 \times 10000 + 22 \times 1000 + 22 \times 100 + 22 \times 10 + 22 \times 1) \] \[ \text{Total sum} = 24 \times (220000 + 22000 + 2200 + 220 + 22) \] \[ \text{Total sum} = 24 \times 244442 \] \[ \text{Total sum} = 5866608 \] So, the sum of the integers on Camy's list is $\boxed{5866608}$. Final Answer: The final answer is 5866608. I hope it is correct.", + 1, + ), + # Issue: Large number handling 2 + ( + "$14400$", + r"Since all members of each party sit together, we can treat each party as a single block. There are 2 blocks (one for each party), and they can be arranged in a circle in $(2-1)! = 1! = 1$ way. Within each block, the members can be arranged in $5!$ ways. So, the total number of distinguishable ways is $1 \times 5! \times 5! = 120 \times 120 = \boxed{14400}$. Final Answer: The final answer is 14400. I hope it is correct.", + 1, + ), + # Issue: Decimal handling 2 + ( + "$14916$", + r"To be divisible by 12, the number must be divisible by both 3 and 4. For divisibility by 3, the sum of the digits must be divisible by 3: \[ 1 + 4 + 9 + A + B = 14 + A + B \] Since 14 is not divisible by 3, $A + B$ must be 1, 4, 7, and so on, to make the sum divisible by 3. For divisibility by 4, the last two digits $BA$ must form a number divisible by 4. Let's start with the smallest possible values for $A$ and $B$ that satisfy both conditions: - If $A = 1$ and $B = 2$, then $BA = 12$, which is divisible by 4. - The sum of the digits is $14 + 1 + 2 = 17$, which is not divisible by 3. - If $A = 1$ and $B = 5$, then $BA = 15$, which is not divisible by 4. - If $A = 2$ and $B = 0$, then $BA = 20$, which is divisible by 4. - The sum of the digits is $14 + 2 + 0 = 16$, which is not divisible by 3. - If $A = 2$ and $B = 4$, then $BA = 24$, which is divisible by 4. - The sum of the digits is $14 + 2 + 4 = 20$, which is not divisible by 3. - If $A = 3$ and $B = 1$, then $BA = 31$, which is not divisible by 4. - If $A = 3$ and $B = 4$, then $BA = 34$, which is not divisible by 4. - If $A = 4$ and $B = 0$, then $BA = 40$, which is divisible by 4. - The sum of the digits is $14 + 4 + 0 = 18$, which is divisible by 3. Thus, the smallest possible value of $N$ is $14{,}940$. Final Answer: The final answer is $14{,}916$. I hope it is correct.", + 1, + ), + # Issue: Percentage handling + ( + "$20\\%$", + r"Let $A$ be the set of adults in Mathopolis. Let $F$ be the set of female adults in Mathopolis. Let $C$ be the set of children in Mathopolis. We know that $|A|/2 = |F|$ and $|F|/2 = |C|$. We also know that $|A| = |F| + |C|$. Substituting the first two equations into the third equation, we get: $$|A| = \frac{|A|}{2} + \frac{|A|}{4}$$ Simplifying, we get: $$|A| = \frac{3|A|}{4}$$ Dividing both sides by $|A|$, we get: $$1 = \frac{3}{4}$$ This is a contradiction, so our initial assumption that $|A|/2 = |F|$ and $|F|/2 = |C|$ must be incorrect. However, we can still solve the problem. Let $|A| = 2x$. Then $|F| = x$ and $|C| = x/2$. The total population of Mathopolis is $|A| + |C| = 2x + x/2 = 5x/2$. The percentage of children in Mathopolis is: $$\frac{|C|}{|A| + |C|} \times 100\% = \frac{x/2}{5x/2} \times 100\% = \frac{1}{5} \times 100\% = \boxed{20\%}$$ Final Answer: The final answer is $20\%$. I hope it is correct.", + 1, + ), + # Issue: LaTeX vs expr precedence 2 + ( + "$7.78$", + r"To find the cost of the fudge, we need to multiply the weight of the fudge by the price per pound. Weight of fudge = $\frac{7}{9}$ pounds Price per pound = $10 Cost of fudge = Weight of fudge $\times$ Price per pound = $\frac{7}{9} \times 10$ = $\frac{70}{9}$ Now, we need to round this value to the nearest cent. $\frac{70}{9} \approx 7.78$ So, Anna pays approximately $7.78 for her fudge. Final Answer: The final answer is 7.78.", + 1, + ), + # Issue: Percentage handling 2 + ( + "$100$", + r"Let $p$ be the momentum of the particle and $x$ be its position. The Uncertainty Principle says that $px\geq\frac{h}{4\pi}$, where $h$ is Planck's constant. Suppose the error in the measurement of the momentum is halved. Then the new error in the measurement of the momentum is $\frac{p}{2}$. Let $E_x$ be the error in the measurement of the position. Then $E_x\geq\frac{h}{4\pi\cdot\frac{p}{2}}=\frac{h}{2\pi p}$. Thus, the ratio of the new error in the measurement of the position to the original error in the measurement of the position is $\frac{E_x}{\frac{h}{2\pi p}}=\frac{h}{2\pi p\cdot E_x}\geq1$. Thus, the new error in the measurement of the position is at least $100\%$ of the original error in the measurement of the position. Final Answer: The final answer is $100$. I hope it is correct.", + 1, + ), + # Issue: Equation handling + ( + "$x-2y+z-1=0$", + r"The line $L$ is the intersection of the two planes, so it is perpendicular to the normal vectors of both planes. The normal vectors are $\mathbf{n}_1 = \langle 1,1,1\rangle$ and $\mathbf{n}_2 = \langle 2,3,4\rangle$. Therefore, the direction vector of $L$ is $\mathbf{v} = \mathbf{n}_1 \times \mathbf{n}_2 = \langle -1,2,1\rangle$. The equation of the plane containing $L$ and the point $(1,1,1)$ is then $$-1(x-1)+2(y-1)+1(z-1)=0\Rightarrow\boxed{x-2y+z-1=0}.$$ Final Answer: The final answer is $x-2y+z-1=0$. I hope it is correct.", + 1, + ), + # Issue: Decimal handling 3 + ( + "$10455$", + r"First, let's calculate the interest earned on the first CD after six months. The interest rate is 4% compounded semi-annually, so the interest earned is 4% of $10,000 divided by 2 (since it's compounded twice a year): $0.04 \cdot 10,000 \cdot \frac{1}{2} = \boxed{200}$ So, after six months, Dr. Zaius has $10,000 + $200 = $10,200. Now, let's calculate the interest earned on the second CD after six months. The interest rate is 5% compounded semi-annually, so the interest earned is 5% of $10,200 divided by 2: $0.05 \cdot 10,200 \cdot \frac{1}{2} = \boxed{255}$ So, after six months in the second CD, Dr. Zaius has $10,200 + $255 = $10,455. Final Answer: The final answer is $10,455$. I hope it is correct.", + 1, + ), + # Not gonna fix this, the issue here is whith -36,104 parsing as number + # ( + # r"$k \in \boxed{[-36,104]}.$", + # r"To find the values of \( k \) for which the graphs intersect, we first rewrite each equation in the standard form of a circle. Starting with the first equation: \[ x^2 + y^2 = 4 + 12x + 6y. \] Rearrange terms: \[ x^2 - 12x + y^2 - 6y = 4. \] Complete the square for \( x \) and \( y \): \[ (x^2 - 12x) + (y^2 - 6y) = 4. \] For \( x \), complete the square: \[ x^2 - 12x = (x-6)^2 - 36. \] For \( y \), complete the square: \[ y^2 - 6y = (y-3)^2 - 9. \] Substitute back: \[ (x-6)^2 - 36 + (y-3)^2 - 9 = 4. \] Simplify: \[ (x-6)^2 + (y-3)^2 = 49. \] This represents a circle centered at \( (6, 3) \) with radius \( 7 \). Now, consider the second equation: \[ x^2 + y^2 = k + 4x + 12y. \] Rearrange terms: \[ x^2 - 4x + y^2 - 12y = k. \] Complete the square for \( x \) and \( y \): \[ (x^2 - 4x) + (y^2 - 12y) = k. \] For \( x \), complete the square: \[ x^2 - 4x = (x-2)^2 - 4. \] For \( y \), complete the square: \[ y^2 - 12y = (y-6)^2 - 36. \] Substitute back: \[ (x-2)^2 - 4 + (y-6)^2 - 36 = k. \] Simplify: \[ (x-2)^2 + (y-6)^2 = k + 40. \] This represents a circle centered at \( (2, 6) \) with radius \( \sqrt{k+40} \). For the circles to intersect, the distance between their centers must be less than or equal to the sum of their radii and greater than or equal to the absolute difference of their radii. Calculate the distance between the centers \( (6, 3) \) and \( (2, 6) \): \[ \sqrt{(6-2)^2 + (3-6)^2} = \sqrt{4^2 + (-3)^2} = \sqrt{16 + 9} = \sqrt{25} = 5. \] The condition for intersection is: \[ |7 - \sqrt{k+40}| \leq 5 \leq 7 + \sqrt{k+40}. \] First, solve \( 7 - \sqrt{k+40} \leq 5 \): \[ 7 - 5 \leq \sqrt{k+40} \Rightarrow 2 \leq \sqrt{k+40}. \] Square both sides: \[ 4 \leq k + 40 \Rightarrow k \geq -36. \] Next, solve \( 5 \leq 7 + \sqrt{k+40} \): \[ 5 - 7 \leq \sqrt{k+40} \Rightarrow -2 \leq \sqrt{k+40}. \] Since \(\sqrt{k+40} \geq 0\), this inequality is always true. Now, solve \( 5 \leq 7 + \sqrt{k+40} \): \[ 5 - 7 \leq \sqrt{k+40} \Rightarrow -2 \leq \sqrt{k+40}. \] This is always true. Finally, solve \( 5 \geq \sqrt{k+40} - 7 \): \[ 5 + 7 \geq \sqrt{k+40} \Rightarrow 12 \geq \sqrt{k+40}. \] Square both sides: \[ 144 \geq k + 40 \Rightarrow k \leq 104. \] Combining these results, the values of \( k \) for which the circles intersect are: \[ k \in [-36, 104]. \] Final Answer: The interval is \(\boxed{[-36, 104]}\).", + # 1 + # ), + ( + r"$[-4,0]$", + r"Thus, the values of \( k \) for which the quadratic has real roots are: \[ \boxed{[-4, 0]} \]", + 1, + ), + # Test interval notation + ( + r"Factoring the denominator on the left side gives \[ \frac{4x}{(x-5)(x-3)}=\frac{A}{x-3}+\frac{B}{x-5}. \]Then, we multiply both sides of the equation by $(x - 3)(x - 5)$ to get \[ 4x = A(x-5) + B(x-3). \]If the linear expression $4x$ agrees with the linear expression $A(x-5) + B(x-3)$ at all values of $x$ besides 3 and 5, then the two expressions must agree for $x=3$ and $x=5$ as well. Substituting $x = 3$, we get $12 = -2A$, so $A = -6$. Likewise, we plug in $x = 5$ to solve for $B$. Substituting $x = 5$, we get $20 = 2B$, so $B = 10$. Therefore, $(A, B) = \boxed{(-6, 10)}.$", + r"The solution to the inequality is \( x \in (-6, 10) \). Final Answer: The final answer is \((-6, 10)\). I hope it is correct.", + 1, + ), + # Test base conversion + ( + r"$1112112_3$", + r"To find the product \( ab \) where \( a = 2012_3 \) and \( b = 201_3 \), we first convert these base-three numbers to base ten. For \( a = 2012_3 \): \[ a = 2 \cdot 3^3 + 0 \cdot 3^2 + 1 \cdot 3^1 + 2 \cdot 3^0 = 2 \cdot 27 + 0 \cdot 9 + 1 \cdot 3 + 2 \cdot 1 = 54 + 0 + 3 + 2 = 59_{10} \] For \( b = 201_3 \): \[ b = 2 \cdot 3^2 + 0 \cdot 3^1 + 1 \cdot 3^0 = 2 \cdot 9 + 0 \cdot 3 + 1 \cdot 1 = 18 + 0 + 1 = 19_{10} \] Now, calculate the product in base ten: \[ ab = 59 \times 19 \] Perform the multiplication: \[ 59 \times 19 = 59 \times (20 - 1) = 59 \times 20 - 59 \times 1 = 1180 - 59 = 1121 \] Next, convert \( 1121_{10} \) to base three. We do this by dividing by 3 and recording the remainders: \[ 1121 \div 3 = 373 \quad \text{remainder } 2 \] \[ 373 \div 3 = 124 \quad \text{remainder } 1 \] \[ 124 \div 3 = 41 \quad \text{remainder } 1 \] \[ 41 \div 3 = 13 \quad \text{remainder } 2 \] \[ 13 \div 3 = 4 \quad \text{remainder } 1 \] \[ 4 \div 3 = 1 \quad \text{remainder } 1 \] \[ 1 \div 3 = 0 \quad \text{remainder } 1 \] Reading the remainders from last to first, we find: \[ 1121_{10} = 1112122_3 \] Thus, the product \( ab \) expressed in the base-three number system is \(\boxed{1112122_3}\).", + 0, + ), + ( + r"\(\boxed{\text{C}}\).", + r"$\boxed{\text{(C)}}.$", + 1, + ), + ( + r" So the answer is: \[ \boxed{11111111100} \]", + r"is $\boxed{11,\! 111,\! 111,\! 100}$", + 1, + ), + ( + r" So the answer is: \[ \boxed{32349} \]", + r"is $\boxed{32,\! 349}$", + 1, + ), + ( + r" So the answer is: \[ \boxed{32349} \]", + r"is $\boxed{32, 349}$", + 0, + ), + ( + r"Thus, the domain of the function \( f(x) \) is: \[ \boxed{(2, 12) \cup (12, 102)} \]", + r"Thus, the answer is $x \in \boxed{(2,12) \cup (12,102)}$", + 1, + ), + ( + r"$\text{E}$", + r"$E$", + 1, + ), + ( + r"$\boxed{1}\boxed{2}$", + r"$\boxed{1,2}$", + 1, + ), + ( + r"Setting $x = y = 0,$ we get \[2f(0) = f(0) - 1,\]so $f(0) = -1.$ Setting $y = 1,$ we get \[f(x) + 1 = f(x + 1) - x - 1,\]so \[f(x + 1) - f(x) = x + 2.\]Thus, \begin{align*} f(2) - f(1) &= 1 + 2, \\ f(3) - f(2) &= 2 + 2, \\ f(4) - f(3) &= 3 + 2, \\ &\dots, \\ f(n) - f(n - 1) &= (n - 1) + 2. \end{align*}Adding all the equations, we get \[f(n) - f(1) = 1 + 2 + 3 + \dots + (n - 1) + 2(n - 1) = \frac{(n - 1)n}{2} + 2n - 2 = \frac{n^2 + 3n - 4}{2},\]so \[f(n) = \frac{n^2 + 3n - 2}{2}\]for all positive integers $n.$ Setting $x = -n$ and $y = n,$ where $n$ is a positive integer, we get \[f(-n) + f(n) = f(0) + n^2 - 1.\]Then \[f(-n) = n^2 - f(n) + f(0) - 1 = n^2 - \frac{n^2 + 3n - 2}{2} - 2 = \frac{n^2 - 3n - 2}{2}.\]Thus, the formula \[f(n) = \frac{n^2 + 3n - 2}{2}\]holds for all integers $n.$ We want to solve $f(n) = n,$ or \[\frac{n^2 + 3n - 2}{2} = n.\]Then $n^2 + 3n - 2 = 2n,$ or $n^2 + n - 2 = 0.$ This factors as $(n - 1)(n + 2) = 0,$ so the solutions are $n = \boxed{1,-2}.$", + r" Okay, so I have this functional equation to solve: \( f(x) + f(y) = f(x + y) - xy - 1 \), and I'm given that \( f(1) = 1 \). I need to find all integers \( n \) such that \( f(n) = n \). Hmm, functional equations can sometimes be tricky, but maybe I can figure this out step by step. First, let me see if I can find an expression for \( f(x) \). Maybe by plugging in specific values for \( x \) and \( y \), I can get some information about \( f \). Let me try setting both \( x \) and \( y \) to 0. That often simplifies things. So if \( x = 0 \) and \( y = 0 \), then the equation becomes: \( f(0) + f(0) = f(0 + 0) - 0 * 0 - 1 \) Simplifying, that's: \( 2f(0) = f(0) - 0 - 1 \) Which simplifies further to: \( 2f(0) = f(0) - 1 \) Subtracting \( f(0) \) from both sides: \( f(0) = -1 \) Alright, so I found that \( f(0) = -1 \). That's a start. Now, maybe I can set \( y = 1 \) because I know \( f(1) = 1 \). Let's set \( y = 1 \) and leave \( x \) arbitrary. Then we get: \( f(x) + f(1) = f(x + 1) - x * 1 - 1 \) Plugging in \( f(1) = 1 \): \( f(x) + 1 = f(x + 1) - x - 1 \) Solving for \( f(x + 1) \): \( f(x + 1) = f(x) + 1 + x + 1 \) So, \( f(x + 1) = f(x) + x + 2 \) Hmm, that's another recursive relationship. So if I can express \( f(x + 1) \) in terms of \( f(x) \), maybe I can find a pattern or a general expression for \( f(x) \). Let me try to unroll this recursion. Let's suppose I start at some point and iterate this. Let's do that starting from \( x = 0 \). First, when \( x = 0 \), as we found earlier, \( f(0) = -1 \). Then, \( f(1) = f(0 + 1) = f(0) + 0 + 2 = -1 + 0 + 2 = 1 \). Well, that checks out because we were given \( f(1) = 1 \). Next, let's find \( f(2) \). Using the same recursion, \( f(2) = f(1 + 1) = f(1) + 1 + 2 = 1 + 1 + 2 = 4 \). Similarly, \( f(3) = f(2 + 1) = f(2) + 2 + 2 = 4 + 2 + 2 = 8 \). Wait a second, let me check that again... So \( f(3) = f(2) + 2 + 2... \) So from \( x = 2 \), do I add 2? Wait, no—wait, actually, the recursion is \( f(x + 1) = f(x) + x + 2 \). So for \( x = 2 \), \( f(3) = f(2) + 2 + 2 \). Yep, that's correct. So continuing... \( f(4) = f(3 + 1) = f(3) + 3 + 2 = 8 + 3 + 2 = 13 \) Hmm, so the values so far are: - f(0) = -1 - f(1) = 1 - f(2) = 4 - f(3) = 8 - f(4) = 13 Wait, let me see if this follows a pattern. Let's list them: n: 0, 1, 2, 3, 4 f(n): -1, 1, 4, 8, 13 Looking at these numbers, let me compute the differences between consecutive terms. f(1) - f(0) = 1 - (-1) = 2 f(2) - f(1) = 4 - 1 = 3 f(3) - f(2) = 8 - 4 = 4 f(4) - f(3) = 13 - 8 = 5 Ah, interesting! Each time, the difference increases by 1. So from f(0) to f(1), the difference is 2, then 3, then 4, then 5... So it seems the differences form an arithmetic sequence starting at 2 with a common difference of 1. If this pattern continues, perhaps the function can be expressed as a quadratic function. Let me test that. Let me assume that \( f(x) \) is a quadratic function of the form \( f(x) = ax^2 + bx + c \). Then maybe I can find coefficients \( a \), \( b \), and \( c \) that satisfy the functional equation. Let me write down that assumption: \( f(x) = ax^2 + bx + c \) Now, plugging this into the functional equation: \( f(x) + f(y) = f(x + y) - xy - 1 \) Substituting, we get: \( ax^2 + bx + c + ay^2 + by + c = a(x + y)^2 + b(x + y) + c - xy - 1 \) Let's simplify both sides. Left side: \( ax^2 + bx + c + ay^2 + by + c = a(x^2 + y^2) + b(x + y) + 2c \) Right side: \( a(x + y)^2 + b(x + y) + c - xy - 1 = a(x^2 + 2xy + y^2) + b(x + y) + c - xy - 1 \) Simplify: \( a x^2 + 2a xy + a y^2 + b x + b y + c - xy - 1 \) Now, set left side equal to right side: \( a(x^2 + y^2) + b(x + y) + 2c = a x^2 + 2a xy + a y^2 + b x + b y + c - xy - 1 \) Let me move all terms to the left side to compare: Left minus right: \( a(x^2 + y^2) + b(x + y) + 2c - [a x^2 + 2a xy + a y^2 + b x + b y + c - xy - 1] = 0 \) Simplify term by term: First, \( a x^2 - a x^2 = 0 \) Similarly, \( a y^2 - a y^2 = 0 \) Then, \( -2a xy \) because one side has a +0xy and the other has -2a xy Wait, wait, let me clarify: Original left side terms: \( a x^2 + a y^2 + b x + b y + 2c \) Subtract right side terms: \( a x^2 + 2a xy + a y^2 + b x + b y + c - xy - 1 \) So, subtracting: \( (a x^2 - a x^2) + (a y^2 - a y^2) + (2c - c) + (b x - b x) + (b y - b y) -2a xy + xy + 1 = 0 \) Simplify: 0 + 0 + c -2a xy + xy + 1 = 0 So combine like terms: c + (-2a + 1)xy + 1 = 0 This must hold for all real numbers x and y. How can the above equation hold for all x and y? The only way is that each coefficient is zero for the variables, and the constants sum to zero. So, coefficients: 1. For xy: (-2a + 1) must be zero. 2. Constants: c + 1 must be zero. So let's write those equations: 1. -2a + 1 = 0 => a = 1/2 2. c + 1 = 0 => c = -1 So, from this, I know \( a = 1/2 \) and \( c = -1 \). Now, what about \( b \)? Let me see, in the left minus right equation, the coefficients for x and y have canceled out because both sides have terms with \( b x \) and \( b y \), subtracting each other. So, there's no need to set anything for \( b \)? Wait, do I have any equation involving \( b \)? It seems in the difference between the left and right sides, the terms involving \( b x \) and \( b y \) have canceled each other out, meaning there is no condition on \( b \) from that step. So, from this examination, it seems that my initial assumption of \( f(x) = ax^2 + bx + c \) with \( a = 1/2 \), \( c = -1 \), and \( b \) is still free. Wait, but maybe we can get another condition using \( f(1) = 1 \). Given that \( f(1) = 1 \), and our general form is \( f(x) = \frac{1}{2}x^2 + b x - 1 \). So, \( f(1) = \frac{1}{2}(1)^2 + b (1) - 1 = \frac{1}{2} + b - 1 = b - \frac{1}{2} \) Given that \( f(1) = 1 \), set up the equation: \( b - \frac{1}{2} = 1 \Rightarrow b = 1 + \frac{1}{2} = \frac{3}{2} \) So, thus, the function can be written as: \( f(x) = \frac{1}{2}x^2 + \frac{3}{2}x - 1 \) Let me write that more neatly: \( f(x) = \frac{1}{2}x^2 + \frac{3}{2}x - 1 \) Simplify: \( f(x) = \frac{1}{2}x^2 + \frac{3}{2}x - 1 \) Well, let's verify if this satisfies the original functional equation to make sure. Compute \( f(x) + f(y) \): \( \frac{1}{2}x^2 + \frac{3}{2}x - 1 + \frac{1}{2}y^2 + \frac{3}{2}y - 1 \) Simplify: \( \frac{1}{2}(x^2 + y^2) + \frac{3}{2}(x + y) - 2 \) Compare with \( f(x + y) - xy - 1 \): First, compute \( f(x + y) \): \( \frac{1}{2}(x + y)^2 + \frac{3}{2}(x + y) - 1 \) Expand: \( \frac{1}{2}(x^2 + 2xy + y^2) + \frac{3}{2}(x + y) - 1 \) Which is: \( \frac{1}{2}x^2 + xy + \frac{1}{2}y^2 + \frac{3}{2}x + \frac{3}{2}y - 1 \) Subtract \( xy + 1 \): \( f(x + y) - xy -1 = \frac{1}{2}x^2 + xy + \frac{1}{2}y^2 + \frac{3}{2}x + \frac{3}{2}y - 1 - xy - 1 \) Simplify: \( \frac{1}{2}x^2 + \frac{1}{2}y^2 + \frac{3}{2}x + \frac{3}{2}y - 2 \) So, \( f(x + y) - xy - 1 = \frac{1}{2}(x^2 + y^2) + \frac{3}{2}(x + y) - 2 \), which is exactly equal to \( f(x) + f(y) \). Therefore, plugging in, it does satisfy the equation. So, that looks good. So we now have an expression for \( f(x) \): \( f(x) = \frac{1}{2}x^2 + \frac{3}{2}x - 1 \) Simplify this further if possible: Combine terms: \( f(x) = \frac{x^2 + 3x - 2}{2} \) Alternatively, it's a quadratic function. Now, I need to find all integers \( n \) such that \( f(n) = n \). So set: \( f(n) = n \) Which is: \( \frac{n^2 + 3n - 2}{2} = n \) Multiply both sides by 2: \( n^2 + 3n - 2 = 2n \) Bring all terms to one side: \( n^2 + 3n - 2 - 2n = 0 \) Simplify: \( n^2 + n - 2 = 0 \) That's a quadratic equation. Let me solve it. Quadratic equation: \( n^2 + n - 2 = 0 \) Using the quadratic formula: \( n = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \), where \( a = 1 \), \( b = 1 \), \( c = -2 \). Compute discriminant: \( 1^2 - 4*1*(-2) = 1 + 8 = 9 \) So, \( n = \frac{-1 \pm \sqrt{9}}{2} = \frac{-1 \pm 3}{2} \) Therefore, 1. \( n = \frac{-1 + 3}{2} = \frac{2}{2} = 1 \) 2. \( n = \frac{-1 - 3}{2} = \frac{-4}{2} = -2 \) So the solutions are \( n = 1 \) and \( n = -2 \). Wait, but these are the integer solutions. So the integers n for which \( f(n) = n \) are 1 and -2. Wait, let me double-check with our function \( f(n) \): Given \( f(x) = \frac{x^2 + 3x - 2}{2} \), let me plug in 1: \( f(1) = \frac{1 + 3 - 2}{2} = \frac{2}{2} = 1 \). Correct, since \( f(1) = 1 \) was given. Now, plug in -2: \( f(-2) = \frac{(-2)^2 + 3*(-2) - 2}{2} = \frac{4 - 6 - 2}{2} = \frac{-4}{2} = -2 \). So, indeed, \( f(-2) = -2 \). Looking back at the pattern we saw earlier: We found in the functional equation \( f(0) = -1 \), \( f(1) = 1 \), \( f(2) = 4 \), \( f(3) = 8 \), \( f(4) = 13 \)... These are all positive integers only at integers 1, 4, 8, 13... But, wait, in our case, the only integer solutions where \( f(n) = n \) is at 1 and -2. Wait, let me check with another integer, 2: \( f(2) = 4 \), so that's \( f(2) = 4 \), which is 4, which is not equal to 2. So n=2 is not a solution. Similarly, n=3: \( f(3) = 8 \neq 3 \), n=4: \( f(4)=13 \neq 4 \), etc. What about negative integers? We saw n=-2: f(-2)= -2. So, that's a fixed point. What about n=-1: \( f(-1) = \frac{(-1)^2 + 3*(-1) - 2}{2} = \frac{1 - 3 - 2}{2} = \frac{-4}{2} = -2 \neq -1 \). n=-3: \( f(-3) = \frac{9 + (-9) -2 }{2} = \frac{-2}{2}= -1 \neq -3 \). n=-4: \( f(-4) = \frac{16 + (-12) - 2 }{2} = \frac{2}{2}= 1 \neq -4 \). Hmmm, so n=-2 is a solution. So the only integers n where f(n) = n are n =1 and n=-2. Therefore, the integers are 1 and -2. Wait, but let me think again—is that all? Because sometimes functional equations can have more solutions unless restricted. But in this case, since we derived the general solution of the function, which is quadratic, solving \(f(n) = n\) gives only quadratic equation, which has at most two solutions (and indeed we found two solutions). So, there are exactly two integers: 1 and -2. Wait, so I think I can be confident the solutions are 1 and -2. **Final Answer** The integers are \boxed{-2} and \boxed{1}. Given the functional equation: \[ f(x) + f(y) = f(x + y) - xy - 1 \] for all real numbers \( x \) and \( y \), and knowing that \( f(1) = 1 \), we need to find all integers \( n \) such that \( f(n) = n \). First, by setting \( x = 0 \) and \( y = 0 \), we find: \[ 2f(0) = f(0) - 1 \implies f(0) = -1. \] Next, setting \( y = 1 \) in the functional equation, we get: \[ f(x) + f(1) = f(x + 1) - x - 1. \] Given \( f(1) = 1 \), this simplifies to: \[ f(x + 1) = f(x) + x + 2. \] Using this recursive relationship, we calculate the values of \( f \) at several points: \[ f(1) = 1, \] \[ f(2) = 4, \] \[ f(3) = 8, \] \[ f(4) = 13. \] We observe that the differences between consecutive values form an arithmetic sequence starting from 2 and increasing by 1 each time. This suggests that \( f(x) \) might be a quadratic function. Assuming \( f(x) \) is quadratic, we write it as: \[ f(x) = ax^2 + bx + c. \] Substituting this into the functional equation and equating coefficients, we find: \[ a = \frac{1}{2}, \] \[ c = -1, \] and using \( f(1) = 1 \), we determine: \[ b = \frac{3}{2}. \] Thus, the quadratic function is: \[ f(x) = \frac{1}{2}x^2 + \frac{3}{2}x - 1. \] To find all integers \( n \) such that \( f(n) = n \), we solve: \[ \frac{n^2 + 3n - 2}{2} = n. \] This simplifies to: \[ n^2 + n - 2 = 0. \] Solving this quadratic equation, we find the roots: \[ n = 1 \quad \text{and} \quad n = -2. \] Thus, the integers \( n \) such that \( f(n) = n \) are: \[ \boxed{-2} \quad \text{and} \quad \boxed{1}. \]", + 1, + ), + ], +) +def test_math_extraction_additional_cases(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_configs.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_configs.py new file mode 100644 index 000000000..b3ac44551 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_configs.py @@ -0,0 +1,72 @@ +import sympy +from latex2sympy2_extended import NormalizationConfig + +from math_verify.parser import LatexExtractionConfig, parse + + +def test_boxed_match_priority(): + x = parse( + "\\boxed{1}", + (LatexExtractionConfig(),), + fallback_mode="no_fallback", + extraction_mode="first_match", + ) + assert len(x) == 1 + + # No extraction + x = parse( + "\\boxed{1}", + (LatexExtractionConfig(boxed_match_priority=-1),), + fallback_mode="no_fallback", + extraction_mode="first_match", + ) + assert len(x) == 0 + + # final answer is match first + x = parse( + "final answer is $9999$, \\boxed{1}", + (LatexExtractionConfig(boxed_match_priority=100),), + fallback_mode="no_fallback", + extraction_mode="first_match", + ) + assert x[0] == sympy.Integer(9999) + + +def test_normalization_config(): + x = parse( + "$\\frac12$", + ( + LatexExtractionConfig( + normalization_config=NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=False, + nits=True, + boxed=True, + equations=True, + ) + ), + ), + fallback_mode="no_fallback", + extraction_mode="first_match", + ) + assert len(x) == 0 + + x = parse( + "$\\frac12$", + ( + LatexExtractionConfig( + normalization_config=NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=True, + nits=True, + boxed=True, + equations=True, + ) + ), + ), + fallback_mode="no_fallback", + extraction_mode="first_match", + ) + assert x[0] == sympy.Rational(1, 2) diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_numina_cases.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_numina_cases.py new file mode 100644 index 000000000..55254328b --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_numina_cases.py @@ -0,0 +1,148 @@ +import pytest + +from tests.test_all import compare_strings + + +@pytest.mark.parametrize( + "gold, pred, expected", + [ + ( + r"$(37,3,3,13),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 1, + ), + ( + r"$(p,q)=(3,2)$", + r"$(3,2)$", + 1, + ), + ( + r"$(0;0;0),(0;-2;0),(0;0;6),(0;-2;6),(4;0;0),(4;-2;0),(4;0;6),(4;-2;6)$", + r"\boxed{(4, 0, 6), (4, -2, 6), (0, 0, 6), (0, -2, 6), (4, 0, 0), (4, -2, 0), (0, 0, 0), (0, -2, 0)}", + 1, + ), + ( + r"$1\leq|z|\leq \frac{3}{2}$", + r"$z \in \left[-\frac{3}{2}, -1\right] \cup \left[1, \frac{3}{2}\right]$", + 1, + ), + ( + r"$-12;-11;-10;-8;-7;-6$", + r"$\boxed{\{-12, -11, -10, -8, -7, -6\}}$", + 1, + ), + ( + r"$AB=4,CD=5$", + r"$\boxed{4, 5}$", + 1, + ), + ( + r"$(11,7)or(7,11)$", + r"$\boxed{(7,11),\ (11,7)}$", + 1, + ), + ( + r"$S_{MBCN}:S=7:32$", + r"$\boxed{7:32}$", + 1, + ), + ( + r"$\frac{NO}{BO}=\frac{1}{\sqrt{6}}$", + r"$\frac{1}{\sqrt{6}}$", + 1, + ), + ( + r"$p=5,q=2;p=7,q=2$", + r"$(5,2),(7,2)$", + 1, + ), + ( + r"$(p,q,r)=(3,2,7)$", + r"$(3,2,7)$", + 1, + ), + ( + r"$V_{1}:V_{2}=11:21$", + r"$11:21$", + 1, + ), + ( + r"$(2,1),(1,2),(-1,-20),(-20,-1)$", + "solutions are:\n\n\\[\n\\boxed{(1, 2)}, \\boxed{(2, 1)}, \\boxed{(-1, -20)}, \\boxed{(-20, -1)}\n\\]", + 1, + ), + ( + r"\(\boxed{1}\) and \(\boxed{-2}\).", + r"$\boxed{-2,1}$.", + 1, + ), + ( + r"$\text{odd}$", + r"$odd$", + 1, + ), + ( + r"$\text{e}$", + r"$e$", + 1, + ), + ( + r"$\text{E}$", + r"$E$", + 1, + ), + (r"$d$", r"$\text{E}$", 0), + (r"$1$ and $2$ and $3$", r"$\boxed{1,2,3}$", 1), + ( + r"$(37,3,3,13),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 1, + ), + ( + r"$(37,3,3),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", + r"$\boxed{(3, 37, 3, 13), (3, 17, 3, 7), (3, 3, 2, 3), (3,17,3,7), (17,3,3,7), (37,3,3,13)}$", + 0, + ), + ( + r"$(p,q)=(3,2)$", + r"$\boxed{(3, 2)}$", + 1, + ), + ( + r"\boxed{x = -5,\ p = \frac{14}{3}} ", + r"$\boxed{-5, \frac{14}{3}}$", + 1, + ), + ( + r"\boxed{a=4,\,-8,\,-10}", + r"$\boxed{-10,-8,4}$", + 1, + ), + ( + r"\\boxed{W(n) = 1 \\text{ and } W(n) = -1", + r"W(x)=1orW(x)=-1", + 1, + ), + ("$21,16$ or $11$", "$21,16,11$", 1), + # ( + # r"\boxed{ p = 5, q = 2 \quad \text{and} \quad p = 7, q = 2}", + # r"$p=5,q=2;p=7,q=2$", + # 1 + # ), + ( + r"\n\n\[ \boxed{p = -1 \text{ and } p = \dfrac{15}{8}} \]", + r"$p=-1,p=\frac{15}{8}$", + 1, + ), + ("$0 Alright, let me try to figure out this problem. So, we need to find how many years between 2000 and 3000 inclusive can be written as the sum of two positive integers p and q such that 2p = 5q. The example given is 2016 = 1440 + 576, where 2*1440 = 2880 and 5*576 = 2880, so they are equal. First, I need to understand the relationship between p and q. The equation 2p = 5q implies that p and q are in a specific ratio. Let's solve for one variable in terms of the other. If 2p = 5q, then p = (5/2)q. That means p has to be a multiple of 5/2 times q. Since p and q are positive integers, q must be a multiple of 2 to make p an integer. Let's denote q = 2k, where k is a positive integer. Then p = (5/2)*(2k) = 5k. So, substituting back, the year can be written as p + q = 5k + 2k = 7k. Therefore, the year must be a multiple of 7. Wait, so that simplifies things. If the year must be a multiple of 7, then the problem reduces to finding how many multiples of 7 are there between 2000 and 3000 inclusive. But let me double-check this reasoning. Starting from 2p = 5q, we have p = (5/2)q. Since p must be an integer, q must be even. Let q = 2k, then p = 5k. So, the sum p + q = 5k + 2k = 7k. Therefore, every year that is a multiple of 7 can be expressed in such a way. Conversely, if a year is a multiple of 7, say N = 7k, then we can set q = 2k and p = 5k, which are integers. So, yes, the problem is equivalent to finding the number of multiples of 7 between 2000 and 3000 inclusive. Therefore, I need to count how many numbers divisible by 7 are in the range [2000, 3000]. To do this, I can find the smallest multiple of 7 in this range and the largest multiple of 7, then compute how many terms are in this arithmetic sequence. Let me first find the smallest multiple of 7 greater than or equal to 2000. Dividing 2000 by 7: 2000 ÷ 7 ≈ 285.714. So, 285 * 7 = 1995, which is less than 2000. The next multiple is 286 * 7 = 2002. Therefore, 2002 is the first multiple of 7 in the range. Now, the largest multiple of 7 less than or equal to 3000. Dividing 3000 by 7: 3000 ÷ 7 ≈ 428.571. So, 428 * 7 = 2996. But 428*7=2996, which is less than 3000. The next multiple would be 429*7=3003, which is over 3000. So, the largest multiple is 2996. Now, the number of terms in the arithmetic sequence from 2002 to 2996, with common difference 7. The formula for the number of terms in an arithmetic sequence is ((last term - first term)/common difference) + 1. So, ((2996 - 2002)/7) + 1 = (994/7) + 1. Let me compute 994 divided by 7. 7*142=994. So, 142 + 1 = 143. Therefore, there are 143 multiples of 7 between 2000 and 3000 inclusive. But wait, let me confirm. If we have terms starting at 2002, each term 7 apart, up to 2996. The difference between 2996 and 2002 is 994. Dividing by 7 gives 142, which means there are 142 intervals of 7, hence 143 terms. That seems right. But let me check with another method. The number of multiples of 7 up to n is floor(n/7). So, the number of multiples up to 3000 is floor(3000/7)=428, as 7*428=2996. The number of multiples up to 1999 is floor(1999/7)=285, since 7*285=1995. So, the number of multiples between 2000 and 3000 inclusive is 428 - 285. But wait, 428 counts up to 2996, which is within 3000. However, 285 counts up to 1995, so the next multiple is 2002. So, the number of multiples from 2002 to 2996 is 428 - 285 = 143. But wait, 428 - 285 = 143? Let me compute 428 - 285. 428 - 200 = 228, 228 - 85 = 143. Yes. So that's 143. However, is 2002 included? Because 285 counts up to 1995, so 1995 is the 285th multiple, then the next one is 2002 which is 286th. So, from 286th to 428th multiple, how many terms? 428 - 286 + 1 = 143. Yes, because (428 - 286) = 142, plus 1 is 143. So that confirms the answer is 143. But wait, the problem states \"between 2000 and 3000 inclusive\". However, 2002 is the first multiple of 7 in that range. So, if we need to include 2000, but 2000 is not a multiple of 7, so the count remains 143. Therefore, the answer is 143. But let me check the example given. 2016 is in the range. 2016 divided by 7 is 288, so 2016 is the 288th multiple of 7. So, 288 - 285 = 3, but according to the previous count, 428 - 285 = 143. But 2016 is within the range. So, perhaps my initial calculation is correct. Wait, 428 - 285 = 143. But if the first multiple is 2002 (286th multiple), and the last is 2996 (428th multiple), then the number of terms is 428 - 286 + 1 = 143. So, yes. So, 143 is the answer. Wait, but let me test with smaller numbers. Suppose between 7 and 14 inclusive. How many multiples of 7? 7,14. So two numbers. Applying the formula: floor(14/7)=2, floor(6/7)=0, so 2-0=2. Correct. Similarly, between 15 and 21. The multiples are 21. So floor(21/7)=3, floor(14/7)=2, 3-2=1. Correct. So, the formula works when the upper bound is a multiple. If upper bound is not a multiple, like between 15 and 20. The multiples are 21 is beyond, so 14 is the last multiple below 20. Wait, no, 14 is 14, 21 is 21. Wait, between 15 and 20, there are no multiples of 7. So floor(20/7)=2, floor(14/7)=2, 2-2=0. Correct. So, formula works. So in the original problem, floor(3000/7)=428, floor(1999/7)=285. So, 428-285=143. Therefore, the answer is 143. But wait, the example given is 2016, which is 7*288. So, 288 is within the 286 to 428 range. 288-285=3, but in the total count, it's 143. So, all multiples from 2002 up to 2996. So, 143 years. Therefore, the answer is 143. So, \boxed{143} **Final Answer** \boxed{143} To determine how many years between 2000 and 3000 inclusive can be written as the sum of two positive integers \( p \) and \( q \) such that \( 2p = 5q \), we start by analyzing the relationship between \( p \) and \( q \). Given \( 2p = 5q \), we can express \( p \) in terms of \( q \): \( p = \frac{5}{2}q \). For \( p \) and \( q \) to be integers, \( q \) must be a multiple of 2. Let \( q = 2k \), then \( p = 5k \). The sum \( p + q = 5k + 2k = 7k \), indicating the year must be a multiple of 7. Next, we need to count the multiples of 7 between 2000 and 3000 inclusive. 1. **Finding the smallest multiple of 7 ≥ 2000**: \[ \left\lceil \frac{2000}{7} \right\rceil = 286 \quad \text{(since } 286 \times 7 = 2002\text{)} \] 2. **Finding the largest multiple of 7 ≤ 3000**: \[ \left\lfloor \frac{3000}{7} \right\rfloor = 428 \quad \text{(since } 428 \times 7 = 2996\text{)} \] 3. **Counting the number of multiples of 7 between 2002 and 2996**: \[ \text{Number of terms} = 428 - 286 + 1 = 143 \] Thus, the number of years between 2000 and 3000 inclusive that can be written as the sum of two positive integers \( p \) and \( q \) such that \( 2p = 5q \) is \(\boxed{143}\).", + 1, + ), + ( + r"$$143$$", + r"$$\boxed{20}$$ \boxed{143}", + 1, + ), + ( + r"$${1,2,3}$$", + r"\boxed{1} and \boxed{2} and \boxed{3}", + 1, + ), + ( + r"$${1,2,3}$$", + r"\boxed{1},\boxed{2},\boxed{3}", + 1, + ), + # ( + # r"$$x+z=1$$", + # r"$$1$$", + # 0, + # ), + ( + r"$$|AB|=1$$", + r"$$1$$", + 1, + ), + ( + r"$$A:B=1$$", + r"$$1$$", + 1, + ), + ( + r"$$f(x)=1$$", + r"$$1$$", + 1, + ), + ( + r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$", + r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$", + 1, + ), + ( + r"$y_{1}=-2 x^{2}+4 x+3, y_{2}=3 x^{2}+12 x+10$", + r"\($y_1 = \boxed{-2(x - 1)^2 + 5} \) and \( y_2 = \boxed{3(x + 2)^2 - 2} \) ", + 1, + ), + ( + r"$x_{1}=\frac{1}{2}+\frac{31\sqrt{5}}{216},\quadx_{2}=\frac{1}{2}-\frac{31\sqrt{5}}{216}$", + r"$\boxed{\dfrac{108 + 31\sqrt{5}}{216}} \quad \text{and} \quad \boxed{\dfrac{108 - 31\sqrt{5}}{216}}$", + 1, + ), + ( + r"$x_{1}=10^{\frac{-5+\sqrt{13}}{6}},\quadx_{2}=10^{\frac{-5-\sqrt{13}}{6}}$", + r"$\boxed{10^{\frac{\sqrt{13} - 5}{6}}} \quad \text{and} \quad \boxed{10^{-\frac{5 + \sqrt{13}}{6}}}$", + 1, + ), + ( + r"\boxed{1} papers and \boxed{2} bins", + r"$\boxed{1,2}$", + 1, + ), + ( + r"\boxed{1} no no no no no \boxed{2}", + r"$\boxed{1,2}$", + 0, + ), + ], +) +def test_open_thoughts(gold, pred, expected): + assert compare_strings(gold, pred, match_types=["latex", "expr"]) == expected diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_strict.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_strict.py new file mode 100644 index 000000000..8193e2fdf --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_strict.py @@ -0,0 +1,14 @@ +import pytest + +from tests.test_all import compare_strings + + +@pytest.mark.parametrize( + "gold, pred, strict, expected", + [(r"$f(x)$", r"$f(y)$", True, 0), (r"$f(x)$", r"$f(y)$", False, 1)], +) +def test_strict_variable_comparison(gold, pred, strict, expected): + assert ( + compare_strings(gold, pred, match_types=["latex", "expr"], strict=strict) + == expected + ) diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_string_extraction.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_string_extraction.py new file mode 100644 index 000000000..ef71bd0d7 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_string_extraction.py @@ -0,0 +1,27 @@ +import pytest + +from math_verify import StringExtractionConfig, parse + + +@pytest.mark.parametrize( + "pred,expected,config", + [ + # Test basic string extraction + ("The answer is A.", ["A", "A"], StringExtractionConfig(lowercase=False)), + # Test lowercase + ("The answer is A.", ["a", "A"], StringExtractionConfig(lowercase=True)), + # Test with different string options + ("Final answer is B", ["b", "B"], StringExtractionConfig()), + # Test no match + ("No valid answer here", [], StringExtractionConfig()), + # Test start regex + ("A. Because B is not valid", ["a", "A"], StringExtractionConfig()), + # Test space truncate + # Test different strings + ("The answer is U.", ["u", "U"], StringExtractionConfig(strings=("U",))), + # Test any plain string + ("Because B is valid", ["b", "B"], StringExtractionConfig()), + ], +) +def test_string_extraction(pred, expected, config): + assert parse(pred, [config]) == expected diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_timeout.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_timeout.py new file mode 100644 index 000000000..2b4876201 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_timeout.py @@ -0,0 +1,56 @@ +import time +from unittest.mock import patch + +from math_verify.grader import verify +from math_verify.parser import parse + + +@patch("math_verify.parser.parse_expr") +def test_timeout_expr(mock_parse_expr): + # Mock the parsing function to simulate a delay + def delayed_parse(*args, **kwargs): + time.sleep(5) # Simulate a delay longer than the timeout + return "parsed_expr" + + mock_parse_expr.side_effect = delayed_parse + + # Test that the timeout is triggered + x = parse( + "1+1", + parsing_timeout=1, + extraction_mode="first_match", + fallback_mode="no_fallback", + ) + assert x == [] + + +@patch("math_verify.parser.latex2sympy") +def test_timeout_latex(mock_parse_latex): + # Mock the parsing function to simulate a delay + def delayed_parse(*args, **kwargs): + time.sleep(5) # Simulate a delay longer than the timeout + return "parsed_expr" + + mock_parse_latex.side_effect = delayed_parse + + # Test that the timeout is triggered + x = parse( + "$1+1$", + parsing_timeout=1, + extraction_mode="first_match", + fallback_mode="no_fallback", + ) + assert x == [] + + +@patch("math_verify.grader.sympy_expr_eq") +def test_timeout_verify(mock_verify): + # Mock the verify function to simulate a delay + def delayed_sympy_expr_eq(*args, **kwargs): + time.sleep(5) # Simulate a delay longer than the timeout + return True + + mock_verify.side_effect = delayed_sympy_expr_eq + + gold = [parse("1+1")[0]] + assert not verify(gold, gold, timeout_seconds=1) diff --git a/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_yujianll.py b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_yujianll.py new file mode 100644 index 000000000..e9ea518a3 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Math-Verify/tests/test_yujianll.py @@ -0,0 +1,36 @@ +import pytest + +from math_verify import parse, verify + + +@pytest.mark.parametrize( + "gold, pred, result", + [ + # ("$\\frac{1}{2004!}$", "$\\frac{1}{2006!}$", 0), + # ("$\\frac{1}{2^{99}}$", "$\\frac{1}{2^{98}}$", 0), + ("$n=1, 2, 3$", "$n=1, 2, 3$", 1), + ("$n=2, 3, 4$", "$n=1, 2, 3$", 0), + ("$D = (0, 1)$", "$D = (0, 1)$", 1), + ("$D = (0, 1)$", "$D = (0, 1.5)$", 0), + ( + "$F_n = \\sum_{k=0}^{n-1} F_k F_{n-k-1}$", + "$F_n = \\sum_{k=0}^{n-2} F_k F_{n-k-2}$", + 0, + ), + ( + "$F_n = \\sum_{k=0}^{n-1} F_k F_{n-k-1}$", + "$F_n = \\sum_{k=0}^{n-1} F_{n-k-1} F_k$", + 1, + ), + ("$19, 46, and 82$", "$19, 46, \\text{ and } 82$", 1), + ("$19, 46, and 82$", "$21, 36, and 82$", 0), + ("$1 - \\frac{1}{\\mathrm{e}}$", "$1 - \\frac{1}{\\mathrm{e}}$", 1), + ("$1 - \\frac{1}{\\mathrm{e}}$", "$1 + \\frac{1}{\\mathrm{e}}$", 1), + ("$\\pm 2021$", "$2021$", 0), + ("$\\pm 2021$", "${-2021, 2021}$", 1), + ("$12:00$", "$1:00$", 0), + ("$12:00$", "$12:00$", 1), + ], +) +def test_string_extraction(gold, pred, result): + assert int(verify(parse(gold), parse(pred))) == result diff --git a/vlmeval/dataset/PHYSICS/PHYSICS.py b/vlmeval/dataset/PHYSICS/PHYSICS.py new file mode 100644 index 000000000..9ff32281c --- /dev/null +++ b/vlmeval/dataset/PHYSICS/PHYSICS.py @@ -0,0 +1,103 @@ +import os +import pandas as pd +import numpy as np +from jinja2 import Template +from functools import partial +from transformers import AutoTokenizer, Qwen2Tokenizer +import os +import json +import warnings +import pandas as pd +from tqdm import tqdm +from collections import defaultdict +from datetime import datetime +import argparse +import contextlib +from tqdm import tqdm +from collections import defaultdict +from reward_score import compute_score +from reward_manager import verifier_manager +import signal +from contextlib import contextmanager + +@contextmanager +def timeout(seconds): + def signal_handler(signum, frame): + raise TimeoutError("Timed out!") + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + +class PHYSICS(): + DATASET_URL = { + 'PHYSICS-test': 'https://huggingface.co/datasets/desimfj/PHYSICS' + } + + def load_data(path, read_num=None, repeat_time=1): + name = os.path.basename(path).replace('.jsonl', '') + df = pd.read_json(path, lines=True) + df['dataset'] = name + + # 采样(若指定 read_num) + sample_kwargs = {'n': read_num} if read_num else {'frac': 1} + df = df.sample(**sample_kwargs).drop_duplicates(subset=['question']).reset_index(drop=True) + + # 重复(若指定 repeat_time) + if repeat_time > 1: + df = df.loc[np.repeat(df.index, repeat_time)].reset_index(drop=True) + + return df + + + def build_prompt(line): + SYSTEM_PROMPT = """A conversation between User and Assistant. The User asks a question, and the Assistant solves it. + The reasoning process is enclosed within and answer is enclosed within tags.""" + + INSTRUCTION_TEMPLATE = """Below is an open-ended problem in undergraduate-level Physics. Please answer this problem adhering to the following rules: + 1. Use LaTeX format for formulas. + 2. Put the final answer(s) in \\boxed{}, without units. + 3. If multiple answers exist, separate them by commas in \\boxed{}. + Problem: {{prompt}}""" + + # 获取问题文本(支持 Series 或 dict) + prompt_text = line["question"] if isinstance(line, dict) else line + + # 应用指令模板 + prompt_text = Template(INSTRUCTION_TEMPLATE).render(prompt=prompt_text) + + # 拼接最终的完整 prompt + full_prompt = ( + f"System: {SYSTEM_PROMPT}\n\n" + f"User: {prompt_text}\n" + ) + + return full_prompt + + def write_jsonl(data_path, dataset, indent=0, mode='w'): + with open(data_path, mode, encoding='UTF-8') as f: + if not isinstance(dataset, list): + dataset = [dataset] + for data in dataset: + line = json.dumps(data, ensure_ascii=False, indent=indent if indent != 0 else None) + f.write(line + '\n') + + def evaluate(self, eval_file, **judge_kwargs): + output_path = '../logs/phyics-test.log' + for item in tqdm(eval_file, desc="Scoring"): + # result = compute_score(item['model_output'], item['ground_truth'], item['problem']) # olympiadbench + try: + with timeout(30): + result = compute_score(item['test_result'], item['answer'], item['question'], **judge_kwargs) + item['rule_based_acc'] = result['rule_based_acc'] + item['acc'] = result['acc'] + item['extracted_gt'] = result['extracted_gt'] + item['extracted_pred'] = result['extracted_pred'] + self.write_jsonl(output_path, item, mode='a') + except TimeoutError: + print(f"Timeout processing item: {item.get('question', 'Unknown')}") + continue + return eval_file \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/Readme.md b/vlmeval/dataset/PHYSICS/Readme.md new file mode 100644 index 000000000..0e929418d --- /dev/null +++ b/vlmeval/dataset/PHYSICS/Readme.md @@ -0,0 +1,33 @@ + + +## 安装步骤 + +`evaluate` 依赖以下两个自定义包: +- `latex2sympy2_extended` +- `math-verify` + +#### 安装 `latex2sympy2_extended` + +1. 进入包目录: + ```bash + cd ./latex2sympy2_extended + ``` + +2. 安装包(推荐使用 ANTLR 4.11): + ```bash + pip install -e .[antlr4_11_0] + ``` + + > **注意**:`latex2sympy2_extended` 使用 `omegaconf`,而 `omegaconf` 默认支持 ANTLR 4.9。若需支持 ANTLR 4.11,请参考 [Hydra Issue #2491](https://github.com/facebookresearch/hydra/issues/2491) 修改配置。 + +3. 如果 ANTLR 4.11 安装失败,可回退到 ANTLR 4.9: + ```bash + pip install -e .[antlr4_9_3] + ``` +#### 安装 `math-verify` + ```bash + pip install math-verify + ``` + +## 情况说明 +judge model实际在论文中为xverify-8b-sft,但该verifier暂无开源计划,因此可以使用xverify或其他LLM代替。 \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/__init__.py b/vlmeval/dataset/PHYSICS/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vlmeval/dataset/PHYSICS/evaluate.py b/vlmeval/dataset/PHYSICS/evaluate.py new file mode 100644 index 000000000..959f7e20e --- /dev/null +++ b/vlmeval/dataset/PHYSICS/evaluate.py @@ -0,0 +1,135 @@ +import os +import json +import argparse +import contextlib +from tqdm import tqdm +from collections import defaultdict +from reward_score import compute_score +from reward_manager import verifier_manager +import signal +from contextlib import contextmanager +def load_data(input_path): + all_data = [] + with open(input_path, 'r', encoding='utf-8') as file: + for line in file: + line = line.strip() + if line: + try: + all_data.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Error decoding JSON on line: {line}") + print(f"Error: {e}") + return all_data + +def write_jsonl(data_path, dataset, indent=0, mode='w'): + with open(data_path, mode, encoding='UTF-8') as f: + if not isinstance(dataset, list): + dataset = [dataset] + for data in dataset: + line = json.dumps(data, ensure_ascii=False, indent=indent if indent != 0 else None) + f.write(line + '\n') + +@contextmanager +def timeout(seconds): + def signal_handler(signum, frame): + raise TimeoutError("Timed out!") + + signal.signal(signal.SIGALRM, signal_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + +def judge(data, output_path): + for item in tqdm(data, desc="Scoring"): + # result = compute_score(item['model_output'], item['ground_truth'], item['problem']) # olympiadbench + try: + with timeout(30): + result = compute_score(item['test_result'], item['answer'], item['question']) + item['rule_based_acc'] = result['rule_based_acc'] + item['acc'] = result['acc'] + item['extracted_gt'] = result['extracted_gt'] + item['extracted_pred'] = result['extracted_pred'] + write_jsonl(output_path, item, mode='a') + except TimeoutError: + print(f"Timeout processing item: {item.get('question', 'Unknown')}") + continue + return data + +def compute_mean_at_n(data, key_acc, log_path=None): + domain_stats = defaultdict(lambda: defaultdict(list)) + difficulty_stats = defaultdict(lambda: defaultdict(list)) + language_stats = defaultdict(lambda: defaultdict(list)) + overall_acc = [] + try: + for item in data: + domain = item['domain'] + difficulty = item['difficulty'] + if item['translate'] is False: + language = item['language'] + else: + language = 'zh' if item['language'] == 'en' else 'en' + acc = 1.0 if item[key_acc] is True else 0.0 + domain_stats[domain][difficulty].append(acc) + difficulty_stats[difficulty][domain].append(acc) + language_stats[language][domain].append(acc) + overall_acc.append(acc) + except Exception as e: + print(e) + + def _print_stats(): + print("=== Mean@N Statistics ===\n") + for domain in sorted(domain_stats.keys()): + print(f" Domain: {domain}\n") + for difficulty in sorted(domain_stats[domain].keys()): + scores = domain_stats[domain][difficulty] + mean_acc = sum(scores) / len(scores) if scores else 0 + print(f" Difficulty: {difficulty:<10} | Count: {len(scores):<4} | Acc: {mean_acc:.4f}\n") + domain_mean = sum(sum(scores) for scores in domain_stats[domain].values()) / sum(len(scores) for scores in domain_stats[domain].values()) + print(f" Domain Mean: {domain_mean:.4f}\n") + + print("\n") + + for language in sorted(language_stats.keys()): + print(f" Language: {language}\n") + for domain in sorted(language_stats[language].keys()): + scores = language_stats[language][domain] + mean_acc = sum(scores) / len(scores) if scores else 0 + print(f" Domain: {domain:<10} | Count: {len(scores):<4} | Acc: {mean_acc:.4f}\n") + language_mean = sum(sum(scores) for scores in language_stats[language].values()) / sum(len(scores) for scores in language_stats[language].values()) + print(f" Language Mean: {language_mean:.4f}\n") + + print("\n") + + for difficulty in sorted(difficulty_stats.keys()): + print(f" Difficulty: {difficulty}\n") + for domain in sorted(difficulty_stats[difficulty].keys()): + scores = difficulty_stats[difficulty][domain] + mean_acc = sum(scores) / len(scores) if scores else 0 + print(f" Domain: {domain:<10} | Count: {len(scores):<4} | Acc: {mean_acc:.4f}\n") + diff_mean = sum(sum(scores) for scores in difficulty_stats[difficulty].values()) / sum(len(scores) for scores in difficulty_stats[difficulty].values()) + print(f" Difficulty Mean: {diff_mean:.4f}\n") + + total_mean = sum(overall_acc) / len(overall_acc) if overall_acc else 0 + print(f"\nOverall Acc: {total_mean:.4f} on {len(overall_acc)} samples") + + if log_path: + with open(log_path, "w", encoding="utf-8") as f, contextlib.redirect_stdout(f): + _print_stats() + _print_stats() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default=None, help='Path to the input JSONL file.') + parser.add_argument('--output_dir', type=str, default=None, help='Path to the output JSONL file.') + args = parser.parse_args() + output_dir = os.path.join(args.output_dir, os.path.basename(args.input_path).replace('.jsonl', '')) + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, os.path.basename(args.input_path)) + + data = load_data(args.input_path) + data = judge(data, output_path) + compute_mean_at_n(data, 'acc', log_path=os.path.join(output_dir, 'acc.log')) + compute_mean_at_n(data, 'rule_based_acc', log_path=os.path.join(output_dir, 'rule_based_acc.log')) \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.coveragerc b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.coveragerc new file mode 100644 index 000000000..3aeb83f23 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.coveragerc @@ -0,0 +1,31 @@ +# .coveragerc to control coverage.py +[run] +branch = True +include = + latex2sympy.py +omit = + sandbox/* + gen/* + asciimath_printer.py + setup.py + __init__.py + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + +ignore_errors = True diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.gitignore b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.gitignore new file mode 100644 index 000000000..04be21cc5 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/.gitignore @@ -0,0 +1,135 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +.antlr + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don’t work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Azure Functions artifacts +bin +obj +appsettings.json +local.settings.json +.python_packages +stemgen-solution-engine.zip + +# gen/ +src/latex2sympy2_extended/gen/ \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.11.0-complete.jar b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.11.0-complete.jar new file mode 100644 index 000000000..3a711e07c Binary files /dev/null and b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.11.0-complete.jar differ diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.13.2-complete.jar b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.13.2-complete.jar new file mode 100644 index 000000000..75bfcc397 Binary files /dev/null and b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.13.2-complete.jar differ diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.9.3-complete.jar b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.9.3-complete.jar new file mode 100644 index 000000000..749296fe7 Binary files /dev/null and b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/antlr-4.9.3-complete.jar differ diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/asciimath_printer.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/asciimath_printer.py new file mode 100644 index 000000000..dd1b676e6 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/asciimath_printer.py @@ -0,0 +1,50 @@ +from sympy.printing.str import StrPrinter +from sympy.core import S + +class AsciiMathPrinter(StrPrinter): + + def _print_Limit(self, expr): + e, z = expr.args + + return "lim_(%s -> %s) %s" % (self._print(z), self._print(z), self._print(e)) + + def _print_Integral(self, expr): + e, lims = expr.args + if len(lims) > 1: + return "int_(%s)^(%s) %s d%s" % (self._print(lims[1]), self._print(lims[2]), self._print(e), self._print(lims[0])) + else: + return "int %s d%s" % (self._print(e), self._print(lims)) + + def _print_Sum(self, expr): + e, lims = expr.args + return "sum_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e)) + + def _print_Product(self, expr): + e, lims = expr.args + return "prod_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e)) + + def _print_factorial(self, expr): + return "%s!" % self._print(expr.args[0]) + + def _print_Derivative(self, expr): + e = expr.args[0] + wrt = expr.args[1] + return "d/d%s %s" % (self._print(wrt), self._print(e)) + + def _print_Abs(self, expr): + return "|%s|" % self._print(expr.args[0]) + + def _print_Equality(self, expr): + return "%s = %s" % (self._print(expr.args[0]), self._print(expr.args[1])) + + def _print_Pow(self, expr): + b = self._print(expr.base) + if expr.exp is S.Half: + return "sqrt(%s)" % b + + if -expr.exp is S.Half: + return "1/sqrt(%s)" % b + if expr.exp is -S.One: + return "1/%s" % b + + return "%s^(%s)" % (b, self._print(expr.exp)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/pyproject.toml b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/pyproject.toml new file mode 100644 index 000000000..343467954 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "latex2sympy2_extended" +version = "1.10.1" +description = "Convert LaTeX math to SymPy expressions" +readme = "README.md" +requires-python = ">=3.10" +license = { text = "MIT" } +authors = [ + { name = "Hynek Kydlíček", email = "hynek.kydlicek@huggingface.co" } +] +dependencies = [ + "sympy", + "antlr4-python3-runtime>=4.9.3,<=4.13.2", +] + +[project.optional-dependencies] +dev = [ + "pytest", +] + +antlr4_9_3 = [ + "antlr4-python3-runtime==4.9.3" +] + +antlr4_11_0 = [ + "antlr4-python3-runtime==4.11.0" +] + +antlr4_13_2 = [ + "antlr4-python3-runtime==4.13.2" +] + +[project.urls] +Homepage = "https://github.com/OrangeX4/latex2sympy2" +Repository = "https://github.com/OrangeX4/latex2sympy2.git" + +[tool.hatch.build] +packages = ["src/latex2sympy2_extended"] + + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["*_test.py"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_equations.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_equations.py new file mode 100644 index 000000000..3ab1539b9 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_equations.py @@ -0,0 +1,10 @@ +from latex2sympy2_extended import latex2sympy +import sys +sys.path.append("..") + +# latex = "2\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}" +latex = "\\frac{a^{2} \\left(3 \\pi - 4 \\sin{\\left(\\pi \\right)} + \\frac{\\sin{\\left(2 \\pi \\right)}}{2}\\right)}{2}" +math = latex2sympy(latex) + +print(type(math)) +print("latex: %s to math: %s" % (latex, math)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_span.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_span.py new file mode 100644 index 000000000..8a7c05460 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/linalg_span.py @@ -0,0 +1,19 @@ +from latex2sympy2_extended import latex2sympy +import sys +sys.path.append("..") + +latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}" +math = latex2sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}" +math = latex2sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "[\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}]" +math = latex2sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\left\\{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}\\right\\}" +math = latex2sympy(latex) +print("latex: %s to math: %s" % (latex, math)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix.py new file mode 100644 index 000000000..9aed3c2b1 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix.py @@ -0,0 +1,46 @@ +from latex2sympy import process_sympy +from sympy import * +import sys +sys.path.append("..") + +theta = Symbol('theta', real=True) + +latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\begin{matrix}1&2\\\\3&4\\\\5&6\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\begin{matrix}1&2&3\\\\4&5&6\\\\7&8&9\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\begin{matrix}x^1&x^2&x^3\\\\y^1&y^2&y^3\\\\z^1&z^2&z^3\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\begin{matrix}x\\\\y\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "-2\\begin{matrix}1&2\\\\3&4\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "2\\cdot\\theta\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) + +latex = "\\theta\\begin{matrix}1\\\\3\\end{matrix} - \\begin{matrix}-1\\\\2\\end{matrix}" +math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, math)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix_placeholders.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix_placeholders.py new file mode 100644 index 000000000..9d0aef8e2 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/matrix_placeholders.py @@ -0,0 +1,65 @@ +from latex2sympy2_extended import latex2sympy +from sympy import * +import sys +import hashlib +import time + +sys.path.append("..") + + +M = Matrix([[1, 2], [3, 4]]) +v = Matrix([1, 2]) + +# sub settings +sub_settings_symbols = {} +sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), real=True)] = M +sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True)] = v + + +# one parameters +latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}\\cdot[!v!]" +equation_sympy_check = MatMul(M, Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True)) +equation_sympy_subs_check = MatMul(M, v) +# placeholders +equation_sympy = latex2sympy(latex) +print('latex = %s' % latex) +print('equation_sympy = %s' % equation_sympy) +print('equation_sympy_check = %s' % equation_sympy_check) +print('equation_sympy = %s' % (srepr(equation_sympy))) + +equation_sympy_subs = equation_sympy.subs(sub_settings_symbols, evaluate=False) +print('equation_sympy_subs = %s' % equation_sympy_subs) +print('equation_sympy_subs_check = %s' % equation_sympy_subs_check) + + +# two parameters + +# sub settings +print('') +print('============== Two Parameters -> M*v = Matrix*Vector =============') +sub_settings_symbols = {} +sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False)] = M +sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False)] = v + +latex = "[!M!]\\cdot[!v!]" +math_check = Mul(Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False), Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False)) +# placeholders +equation_sympy = latex2sympy(latex) +print(latex) +print(math_check) +print(equation_sympy) +print(srepr(equation_sympy)) + +# performance +t0 = time.time() + +# process_sympy and substitute at the same time +# Only needed for linalg input +placeholder_values = {'M': M, 'v': v} +equation_sympy_subs = latex2sympy(latex, variable_values=placeholder_values) + +t1 = time.time() +print('equation with substituted placeholders = %s' % (str(equation_sympy_subs))) +print('time to process to sympy with placeholders = %s s' % (t1 - t0)) +print('') +print('============== Two Parameters -> M*v = Matrix*Vector =============') diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/random_sandbox.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/random_sandbox.py new file mode 100644 index 000000000..b8b7bc02c --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/random_sandbox.py @@ -0,0 +1,6 @@ +from latex2sympy2_extended import latex2sympy + +# print(latex2sympy(r"x < 12")) +# print(latex2sympy(r"x \in \mathbb{R}^{n \times m}")) +# print(latex2sympy(r"9 \in {1,2,3}")) +print(latex2sympy(r"\pi")) \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox.py new file mode 100644 index 000000000..2c411001e --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox.py @@ -0,0 +1,23 @@ +from sympy import * +from latex2sympy import process_sympy + + +# latex = '\\variable{a}^{\\variable{b}}' +# variables = {'a': process_sympy('658.95998'), 'b': process_sympy('185083.8060')} +# c_ans_expr = process_sympy(latex, variables) +# print(c_ans_expr) +# print(srepr(c_ans_expr)) +# c_ans = c_ans_expr.doit(deep=False).evalf(chop=True) +# print(c_ans) +# print(srepr(c_ans)) + + +# numeric_responses = ['1', '1.0', '-1', '-1.0', '.5', '-.5', '3x10^3', '3E3', '3,000x10^{-3}', '0.5E-1', '\\frac{1}{3}', '(5\\times 3)^3', '\\sin(1)'] +# for latex in numeric_responses: +# parsed = process_sympy(latex) +# print('latex: ', latex) +# print('sympy: ', parsed) +# print('is_number: ', parsed.is_number) +# print('is_Number: ', parsed.is_Number) +# print('srepr: ', srepr(parsed)) +# print('-----------------------------------------------------') diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox_equality.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox_equality.py new file mode 100644 index 000000000..5e3f444c7 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sandbox_equality.py @@ -0,0 +1,75 @@ +from sympy import * +from latex2sympy import process_sympy + + +# +# Equality Testing +# + +answer_sets = [ + { + 'correct_answer': '(x-y)(x+2y)', + 'student_answers': [ + 'x^2+xy-2y^2', + '(x-y)(x+2y)', + '(x+2y)(x-y)', + '(2\\times y+x)(-y+x)', + '(y\\cdot 2+x)(-y+x)' + ] + }, + { + 'correct_answer': '2\\pi \\variable{r}^2', + 'student_answers': [ + '2\\pi \\variable{r}^2', + '\\pi 2\\variable{r}^2', + '2\\times \\pi \\times \\variable{r}^2', + '2\\pi \\variable{r} \\times \\variable{r}' + ] + }, + { + 'correct_answer': '2x - 3y', + 'student_answers': [ + '-3y + 2x' + ] + }, + { + 'correct_answer': 'x\\times x', + 'student_answers': [ + 'x\\times x', + 'x\\cdot x', + 'x^2', + '(\\sqrt{x})^{4}' + ] + }, + { + 'correct_answer': '23e^{-1\\times \\sqrt{t^2}}', + 'student_answers': [ + '23e^{-t}' + ] + }, + { + 'correct_answer': 'a=x^2+1', + 'student_answers': [ + 'x^2+1=a' + ] + } +] + +for answer_set in answer_sets: + correct_answer = answer_set['correct_answer'] + correct_answer_parsed = process_sympy(answer_set['correct_answer']) + for student_answer in answer_set['student_answers']: + student_answer_parsed = process_sympy(student_answer) + print('correct_answer (c): ', correct_answer, correct_answer_parsed) + print('student_answer (a): ', student_answer, student_answer_parsed) + print('') + print('Expression Tree (srepr(c) == srepr(a)) =>', srepr(correct_answer_parsed) == srepr(student_answer_parsed)) + print('srepr(c) =>', srepr(correct_answer_parsed)) + print('srepr(a) =>', srepr(student_answer_parsed)) + print('') + # print('Structural (c == a) =>', correct_answer_parsed == student_answer_parsed) + print('Symbolic (simplify(c - s) == 0) =>', simplify(correct_answer_parsed - student_answer_parsed) == 0) + print('simplified =>', simplify(correct_answer_parsed - student_answer_parsed)) + print('') + print('Numeric Substitution (c.equals(s)) =>', correct_answer_parsed.equals(student_answer_parsed)) + print('-----------------------------------------------------') diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sectan.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sectan.py new file mode 100644 index 000000000..0e0c7aab9 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/sectan.py @@ -0,0 +1,51 @@ +from sympy import * +import sys +sys.path.append("..") + +# # x^2\cdot \left(3\cdot \tan \left([!a!]\cdot x+[!c!]\right)+[!a!]\cdot x\left(\sec \left([!a!]\cdot x+[!c!]\right)\right)^2\right) +# latex1 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)\\right)^2\\right)" +# math1 = process_sympy(latex1) +# print("latex: %s to math: %s" %(latex1,math1)) +# +# latex2 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)^2\\right)\\right)" +# math2 = process_sympy(latex2) +# print("latex: %s to math: %s" %(latex2,math2)) +# +# latex3 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(1+\\tan \\left(2\\cdot x+5\\right)^2\\right)\\right)" +# math3 = process_sympy(latex3) +# print("latex: %s to math: %s" %(latex3,math3)) +# +# print(simplify(math1 - math2)) +# print(simplify(math1 - math3)) + +# +# latex1 = "\\sec^2(2\\cdot x+5)" +# math1 = process_sympy(latex1) +# print("latex: %s to math: %s" %(latex1,math1)) +# +# latex2 = "1+\\tan^2(2\\cdot x+5)" +# math2 = process_sympy(latex2) +# print("latex: %s to math: %s" %(latex2,math2)) +# print(simplify(math1 - math2)) + + +x = Symbol('x', real=True) +y = Symbol('y', real=True) + +# BUG: 1 + tan^2(x+1) should be == sec^2(x+1) but isnt +lhs = (1 + (tan(x + 1))**2) +rhs = (sec(x + 1))**2 +eq = lhs - rhs +print(simplify(lhs)) +print(simplify(rhs)) +print(simplify(eq)) +print(simplify(lhs) == simplify(rhs)) + +# 1 + tan^2(x) == sec^2(x) but isnt +lhs = (1 + (tan(x))**2) +rhs = (sec(x))**2 +eq = lhs - rhs +print(simplify(lhs)) +print(simplify(rhs)) +print(simplify(eq)) +print(simplify(lhs) == simplify(rhs)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/vector.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/vector.py new file mode 100644 index 000000000..5b48aee56 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/sandbox/vector.py @@ -0,0 +1,75 @@ +import numpy as np +from sympy import * +import sys +sys.path.append("..") + +# row column matrix = vector +v = [1, 2, 3] + +# single column matrix = vector +m = Matrix([1, 2, 3]) +print(m[:, 0]) + +# a three row and 2 column matrix +m = Matrix([[1, 2], [3, 4], [5, 6]]) +print(m[:, 0]) + +# determinant of lin indp system != 0 +m = Matrix([[1, 1], [1, 2]]) +print(m.det()) + +# determinant of lin dep system = 0 +m = Matrix([[1, 1], [2, 2]]) +print(m.det()) + +# determinant of lin dep system = 0 +x = Symbol('x') +y = Symbol('y') +m = Matrix([[x, y], [x, y]]) +print(m.det()) +# Reduced Row-Echelon Form +_, ind = m.rref() +print(len(ind)) + +# determinant of lin dep system != 0 +m = Matrix([[x, y], [y, x]]) +print(m.det()) +# Reduced Row-Echelon Form +_, ind = m.rref() +print(len(ind)) + +# determinant of lin dep system != 0 +# Reduced Row-Echelon Form +m = Matrix([[x, x, y], [y, y, y]]) +_, ind = m.rref() +# Reduced Row-Echelon Form +print(len(ind)) + +#==================# +#===== Numpy ======# +#==================# +# http://kitchingroup.cheme.cmu.edu/blog/2013/03/01/Determining-linear-independence-of-a-set-of-vectors/ +# Lin Indp of set of numerical vectors +TOLERANCE = 1e-14 +v1 = [6, 0, 3, 1, 4, 2] +v2 = [0, -1, 2, 7, 0, 5] +v3 = [12, 3, 0, -19, 8, -11] + +A = np.row_stack([v1, v2, v3]) + +U, s, V = np.linalg.svd(A) +print(s) +print(np.sum(s > TOLERANCE)) + +v1 = [1, 1] +v2 = [4, 4] + +A = np.row_stack([v1, v2]) +U, s, V = np.linalg.svd(A) +print(s) +print(np.sum(s > TOLERANCE)) + + +latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}" +# math = process_sympy(latex) +print("latex: %s to math: %s" % (latex, 1)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/compile.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/compile.sh new file mode 100644 index 000000000..255861fe5 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/compile.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" +# Change to that path and run the file +cd $rel_path/src/latex2sympy2_extended + +# Create gen directory if it doesn't exist +mkdir -p gen + +# Run ANTLR on the grammar file +java -jar ../../antlr-4.13.2-complete.jar PS.g4 -o gen/antlr4_13_2 +java -jar ../../antlr-4.11.0-complete.jar PS.g4 -o gen/antlr4_11_0 +java -jar ../../antlr-4.9.3-complete.jar PS.g4 -o gen/antlr4_9_3 \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage-ci.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage-ci.sh new file mode 100644 index 000000000..335d4c391 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage-ci.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +pytest --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-config=.coveragerc --cov=latex2sympy tests \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage.sh new file mode 100644 index 000000000..a704a2148 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/coverage.sh @@ -0,0 +1,23 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" +# Change to that path and run the file +cd $rel_path + +# Activate virtual environment +echo "activating venv..." +if test -f .env/bin/activate +then source .env/bin/activate && echo "venv activate (bin)" +elif test -f .env/Scripts/activate +then source .env/Scripts/activate && echo "venv activated (Scripts)" +else exit 1 +fi + +# Run unit test coverage +echo "starting coverage..." +if pytest --doctest-modules --cov-report=html --cov-config=.coveragerc --cov=latex2sympy tests +then echo "coverage finished" +else exit 1 +fi diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-commit b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-commit new file mode 100644 index 000000000..93f8de1b8 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-commit @@ -0,0 +1,31 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" + +# Change to that path and run the file +cd $rel_path + +echo "pre-commit hook started..." + +# Activate virtual environment +echo "activating venv..." +if test -f .env/bin/activate +then source .env/bin/activate && echo "venv activated." +elif test -f .env/Scripts/activate +then source .env/Scripts/activate && echo "venv activated." +else exit 1 +fi + +# Run auto formatting on all staged python files, then add those changes +echo "auto-formatting code..." +if autopep8 --in-place `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'` && git add `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'` +then echo "code was auto-formatted." +else echo "no code was auto-formatted." +fi + +# Run ANTLR to generate Python files +java -jar antlr-4.13.2-complete.jar src/latex2sympy2_extended/PS.g4 -o src/latex2sympy2_extended/gen + +exit 0 diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-push b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-push new file mode 100644 index 000000000..064b49af0 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/pre-push @@ -0,0 +1,27 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" + +# Change to that path and run the file +cd $rel_path + +echo "pre-push hook started..." + +# Activate virtual environment +echo "activating venv..." +if test -f .env/bin/activate +then source .env/bin/activate && echo "venv activated." +elif test -f .env/Scripts/activate +then source .env/Scripts/activate && echo "venv activated." +else exit 1 +fi + +# Run ANTLR to generate Python files +java -jar antlr-4.13.2-complete.jar src/latex2sympy2_extended/PS.g4 -o src/latex2sympy2_extended/gen + +# Run tests +python -m pytest tests/ + +exit 0 diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/publish.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/publish.sh new file mode 100644 index 000000000..b21d6c6ec --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/publish.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Exit on error +set -e + +echo "=== Latex2Sympy2 Extended Package Publisher ===" + +# Function to clean previous builds +clean_builds() { + echo "Cleaning previous builds..." + rm -rf dist/ build/ *.egg-info/ +} + +# Function to build package +build_package() { + echo "Building package..." + python -m build +} + +# Function to upload to PyPI +upload_to_pypi() { + echo "Uploading to PyPI..." + python -m twine upload dist/* +} + +# Main execution +echo "Installing publishing dependencies..." +pip install --upgrade pip build twine + +# Ensure we're in the project root +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR/.." + +# Execute the publishing process +clean_builds +build_package +upload_to_pypi + +echo "Package successfully published to PyPI!" +echo "You can now install it with: pip install latex2sympy2-extended" diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup-hooks.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup-hooks.sh new file mode 100644 index 000000000..760dfd596 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup-hooks.sh @@ -0,0 +1,3 @@ +#!/bin/sh +cp scripts/pre-push .git/hooks/ +cp scripts/pre-commit .git/hooks/ \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup.sh new file mode 100644 index 000000000..b7e71c937 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/setup.sh @@ -0,0 +1,42 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" +# Change to that path and run the file +cd $rel_path + +echo "creating venv..." +if test -d .env +then echo "venv exists" +else python3 -m venv .env && echo "venv created" +fi + +echo '' +# Activate virtual environment +echo "activating venv..." +if test -f .env/bin/activate +then source .env/bin/activate && echo "venv activate (bin)" +elif test -f .env/Scripts/activate +then source .env/Scripts/activate && echo "venv activated (Scripts)" +else exit 1 +fi + +echo '' +echo "installing requirements..." +if pip install -r dev-requirements.txt +then echo "requirements installed" +else exit 1 +fi + +echo '' +echo "compiling parser..." +sh scripts/compile.sh +echo "parser compiled" + +echo '' +echo "setup git hooks..." +sh scripts/setup-hooks.sh +echo "git hooks setup" + +exit 0 diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/test.sh b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/test.sh new file mode 100644 index 000000000..20d6b0191 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/scripts/test.sh @@ -0,0 +1,31 @@ +#!/bin/sh + +# Get relative path of the root directory of the project +rdir=`git rev-parse --git-dir` +rel_path="$(dirname "$rdir")" +# Change to that path and run the file +cd $rel_path + +# Activate virtual environment +echo "activating venv..." +if test -f .env/bin/activate +then source .env/bin/activate && echo "venv activate (bin)" +elif test -f .env/Scripts/activate +then source .env/Scripts/activate && echo "venv activated (Scripts)" +else exit 1 +fi + +echo '' +echo "compiling parser..." +sh scripts/compile.sh +echo "parser compiled" + +echo '' +# Run unit tests +echo "starting tests..." +if pytest tests +then echo "tests finished" +else exit 1 +fi + +exit 0 diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/PS.g4 b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/PS.g4 new file mode 100644 index 000000000..799693281 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/PS.g4 @@ -0,0 +1,748 @@ +grammar PS; + +options { + language=Python3; +} + +// Lexer rules + +WS: [ \t\r\n]+ -> skip; +// Spaces +THINSPACE: ('\\,' | '\\thinspace') -> skip; +MEDSPACE: ('\\:' | '\\medspace') -> skip; +THICKSPACE: ('\\;' | '\\thickspace') -> skip; +QUAD: '\\quad' -> skip; +QQUAD: '\\qquad' -> skip; +NEGTHINSPACE: ('\\!' | '\\negthinspace') -> skip; +NEGMEDSPACE: '\\negmedspace' -> skip; +NEGTHICKSPACE: '\\negthickspace' -> skip; +DOLLAR_SIGN: '\\$' -> skip; + +IGNORE: + ( + '\\vrule' + | '\\vcenter' + | '\\vbox' + | '\\vskip' + | '\\vspace' + | '\\hfil' + ) -> skip; + +ADD: '+' | '∔' | '⊕' | '⊞'; +SUB: '-' | '−' | '∸'; +MUL: '*' | '×' | '⋅' | '⋄' | '⊡' | '∗' | '⋆' | '∘' | '∙' | '⊗' | '⊠' | '⋈' | '⋉' | '⋊'; +DIV: '/' | '÷' | '∕' | '\\over'; + +L_PAREN: '('; +R_PAREN: ')'; +L_PAREN_VISUAL: '\\('; +R_PAREN_VISUAL: '\\)'; +L_GROUP: '\\lgroup'; +R_GROUP: '\\rgroup'; +L_BRACE: '{'; +R_BRACE: '}'; +L_BRACE_VISUAL: '\\{'; +R_BRACE_VISUAL: '\\}'; +L_BRACE_CMD: '\\lbrace'; +R_BRACE_CMD: '\\rbrace'; +L_BRACKET: '['; +R_BRACKET: ']'; +L_BRACK: '\\lbrack'; +R_BRACK: '\\rbrack'; + +PHANTOM_CMD: '\\phantom'; +BOXED_CMD: '\\boxed'; + +BAR: '|'; +L_VERT: '\\lvert'; +R_VERT: '\\rvert'; +VERT: '\\vert'; + +NORM: '\\|'; + +// Dot products +L_ANGLE: '\\langle'; +R_ANGLE: '\\rangle'; + + +L_FLOOR: '\\lfloor'; +R_FLOOR: '\\rfloor'; +LL_CORNER: '\\llcorner'; +LR_CORNER: '\\lrcorner'; + +L_CEIL: '\\lceil'; +R_CEIL: '\\rceil'; +UL_CORNER: '\\ulcorner'; +UR_CORNER: '\\urcorner'; + +//functions +FUNC_LIM: '\\lim'; +LIM_APPROACH_SYM: '\\to' | '\\rightarrow' | '\\Rightarrow' | '\\longrightarrow' | '\\Longrightarrow'; +FUNC_INT: '\\int'; +FUNC_SUM: '\\sum'; +FUNC_PROD: '\\prod'; + +FUNC_LOG: '\\log'; +FUNC_LN: '\\ln'; +FUNC_EXP: '\\exp'; +FUNC_SIN: '\\sin'; +FUNC_COS: '\\cos'; +FUNC_TAN: '\\tan'; +FUNC_CSC: '\\csc'; +FUNC_SEC: '\\sec'; +FUNC_COT: '\\cot'; + +FUNC_ARCSIN: '\\arcsin'; +FUNC_ARCCOS: '\\arccos'; +FUNC_ARCTAN: '\\arctan'; +FUNC_ARCCSC: '\\arccsc'; +FUNC_ARCSEC: '\\arcsec'; +FUNC_ARCCOT: '\\arccot'; + +FUNC_SINH: '\\sinh'; +FUNC_COSH: '\\cosh'; +FUNC_TANH: '\\tanh'; +FUNC_ARSINH: '\\arsinh'; +FUNC_ARCOSH: '\\arcosh'; +FUNC_ARTANH: '\\artanh'; +FUNC_ARCSINH: '\\arcsinh'; +FUNC_ARCCOSH: '\\arccosh'; +FUNC_ARCTANH: '\\arctanh'; + +FUNC_ARSINH_NAME: 'arsinh'; +FUNC_ARCSINH_NAME: 'arcsinh'; +FUNC_ARCOSH_NAME: 'arcosh'; +FUNC_ARCCOSH_NAME: 'arccosh'; +FUNC_ARTANH_NAME: 'artanh'; +FUNC_ARCTANH_NAME: 'arctanh'; +FUNC_GCD_NAME: 'gcd'; +FUNC_LCM_NAME: 'lcm'; +FUNC_FLOOR_NAME: 'floor'; +FUNC_CEIL_NAME: 'ceil'; + +FUNC_SQRT: '\\sqrt'; +FUNC_GCD: '\\gcd'; +FUNC_LCM: '\\lcm'; +FUNC_FLOOR: '\\floor'; +FUNC_CEIL: '\\ceil'; +FUNC_MAX: '\\max'; +FUNC_MIN: '\\min'; + +FUNC_DET: '\\det'; + +FUNC_EYE_NAME: 'eye'; +FUNC_ZEROS_NAME: 'zeros'; +FUNC_ONES_NAME: 'ones'; +FUNC_COLS_NAME: 'cols'; +FUNC_ROWS_NAME: 'rows'; +FUNC_DIAG_NAME: 'diag'; +FUNC_NORM_NAME: 'norm'; +FUNC_RANK_NAME: 'rank'; +FUNC_TRACE_NAME: 'trace' | 'tr'; +FUNC_RREF_NAME: 'rref'; +FUNC_HSTACK_NAME: 'hstack'; +FUNC_VSTACK_NAME: 'vstack'; +FUNC_ORTHOGONALIZE_NAME: 'orth' | 'ortho' | 'orthogonal' | 'orthogonalize'; +FUNC_NULLSPACE_NAME: 'nullspace'; +FUNC_DIAGONALIZE_NAME: 'eig' | 'eigen' | 'diagonalize'; +FUNC_EIGENVALS_NAME: 'eigenvals' | 'eigenvalues'; +FUNC_EIGENVECTORS_NAME: 'eigenvects' | 'eigenvectors'; +FUNC_SVD_NAME: 'svd' | 'SVD'; + +//commands +CMD_TIMES: '\\times'; +CMD_CDOT: '\\cdot'; +CMD_DIV: '\\div'; +CMD_FRAC: '\\frac' | '\\dfrac' | '\\tfrac' | '\\cfrac'; +CMD_BINOM: '\\binom' | '\\tbinom' | '\\dbinom'; +CMD_CHOOSE: '\\choose'; +CMD_MOD: '\\mod'; + +CMD_MATHIT: '\\mathit'; + +CMD_OPERATORNAME: '\\operatorname'; + +//matrix test + +MATRIX_TYPE_MATRIX: 'matrix'; +MATRIX_TYPE_PMATRIX: 'pmatrix'; +MATRIX_TYPE_BMATRIX: 'bmatrix'; +MATRIX_TYPE_DET: 'vmatrix'; +MATRIX_TYPES: MATRIX_TYPE_MATRIX | MATRIX_TYPE_PMATRIX | MATRIX_TYPE_BMATRIX; +CMD_MATRIX_START: '\\begin' L_BRACE MATRIX_TYPES R_BRACE; +CMD_MATRIX_END: '\\end' L_BRACE MATRIX_TYPES R_BRACE; + +CMD_ARRAY_START: '\\begin' L_BRACE 'array' R_BRACE L_BRACE ('c' | 'l' | 'r')* R_BRACE; +CMD_ARRAY_END: '\\end' L_BRACE 'array' R_BRACE; + +CMD_DET_START: '\\begin' L_BRACE MATRIX_TYPE_DET R_BRACE; +CMD_DET_END: '\\end' L_BRACE MATRIX_TYPE_DET R_BRACE; +MATRIX_DEL_COL: '&'; +MATRIX_DEL_ROW: '\\\\'; + +UNDERSCORE: '_'; +CARET: '^'; +COLON: ':'; +SEMICOLON: ';'; +COMMA: ','; +PERIOD: '.'; + +fragment WS_CHAR: [ \t\r\n]; +DIFFERENTIAL: 'd' WS_CHAR*? ([a-zA-Z] | '\\' [a-zA-Z]+); + +EXP_E: 'e' | '\\exponentialE'; +E_NOTATION_E: 'E'; +LETTER_NO_E: [a-df-zA-DF-Z]; // exclude e for exponential function and e notation +fragment LETTER: [a-zA-Z]; +fragment DIGIT: [0-9]; + +MATRIX_XRIGHTARROW: '\\xrightarrow' | '\\xRightarrow'; +TRANSFORM_EXCHANGE: '<->' | '<=>' | '\\leftrightarrow' | '\\Leftrightarrow'; + +// There is ofc huge issue with the fact that we can't distinguish between numbers and sets +// in case of 3,333 can be either 3333 or {3,333}, we default to numbers in this case +// Fixing this is a huge task and I am not going to do it +// If you want to interpret it as set change * to *? +NUMBER: + DIGIT+ + | DIGIT* PERIOD DIGIT+; + +E_NOTATION: NUMBER E_NOTATION_E (SUB | ADD)? DIGIT+; + +IN: '\\in'; +ASSIGNMENT: '=' | '≡' | '≣' | '≟' | '≎' | '≏' | '≐' | '≑' | '≒' | '≓' | '≔' | '≕' | '≖' | '≗'; +EQUAL: '==' | '\\equiv'; +APPROX: '~=' | '≈' | '≊' | '≋' | '≃' | '≅' | '\\approx'; +LT: '<' | '≪' | '≺' | '⋖' | '\\lt'; +LTE: '\\leq' | '\\le' | '\\leqslant' | '≤' | '≦' | '≲' | '≾' | '<='; +GT: '>' | '≫' | '≻' | '⋗' | '\\gt'; +GTE: '\\geq' | '\\ge' | '\\geqslant' | '≥' | '≧' | '≳' | '≿' | '>='; +UNEQUAL: '!=' | '!==' | '\\ne' | '\\neq' | '\\not\\equiv' | '≠' | '≁' | '≄' | '≇' | '≉' | '≢'; + +BANG: '!'; + +fragment PERCENT_SIGN: '\\%' | '%'; +PERCENT_NUMBER: NUMBER WS_CHAR? PERCENT_SIGN; + +//Excludes some letters for use as e.g. constants in SYMBOL +fragment GREEK_LETTER: + '\\alpha' | 'α' | '\\char"000391' | '\\char000391' + | '\\beta' | 'β' | '\\char"000392' | '\\char000392' + | '\\gamma' | 'γ' + | '\\Gamma' | 'Γ' + | '\\delta' | 'δ' + | '\\Delta' | 'Δ' + | '\\epsilon' | 'ε' | '\\char"000190' | '\\char000190' + | '\\varepsilon' | 'ϵ' + | '\\zeta' | 'ζ' | '\\char"000396' | '\\char000396' + | '\\eta' | 'η' | '\\char"000397' | '\\char000397' + | '\\theta' | 'θ' + | '\\Theta' | 'Θ' + | '\\vartheta' | 'ϑ' + | '\\iota' | 'ι' | '\\char"000399' | '\\char000399' + | '\\kappa' | 'κ' | '\\char"00039A' | '\\char00039A' + | '\\lambda' | 'λ' + | '\\Lambda' | 'Λ' + | '\\mu' | 'μ' | '\\char"00039C' | '\\char00039C' + | '\\nu' | 'ν' | '\\char"00039D' | '\\char00039D' + | '\\xi' | 'ξ' + | '\\Xi' | 'Ξ' + | '\\omicron' | 'ο' | '\\char"00039F' | '\\char00039F' + | '\\pi' | 'π' + | '\\Pi' | 'Π' + | '\\varpi' | 'ϖ' + | '\\rho' | 'ρ' | '\\char"0003A1' | '\\char0003A1' + | '\\varrho' | 'ϱ' + | '\\sigma' | 'σ' + | '\\Sigma' | 'Σ' + | '\\varsigma' | 'ς' + | '\\tau' | 'τ' | '\\char"0003A4' | '\\char0003A4' + | '\\upsilon' | 'υ' + | '\\Upsilon' | 'Υ' + | '\\phi' | 'φ' + | '\\Phi' | 'Φ' + | '\\varphi' | 'ϕ' + | '\\chi' | 'χ' | '\\char"0003A7' | '\\char0003A7' + | '\\psi' | 'ψ' + | '\\Psi' | 'Ψ' + | '\\omega' | 'ω' + | '\\Omega' | 'Ω' + ; + +GREEK_CMD: GREEK_LETTER [ ]?; + +fragment OTHER_SYMBOL: + '\\Bbbk' | + '\\wp' | + '\\nabla' | + '\\bigstar' | + '\\angle' | + '\\nexists' | + '\\diagdown' | + '\\measuredangle' | + '\\eth' | + 'ℵ' | + 'ℶ' | + 'ℷ' | + 'ℸ' | + '\\diagup' | + '\\sphericalangle' | + '\\clubsuit' | + '\\varnothing' | + '\\Diamond' | + '\\complement' | + '\\diamondsuit' | + '\\imath' | + '\\Finv' | + '\\triangledown' | + '\\heartsuit' | + '\\jmath' | + '\\Game' | + '\\triangle' | + '\\spadesuit' | + '\\ell' | + '\\hbar' | + '\\vartriangle' | + '\\hslash' | + '\\blacklozenge' | + '\\lozenge' | + '\\blacksquare' | + '\\mho' | + '\\blacktriangle' | + '\\sharp' | + '\\prime' | + '\\Im' | + '\\flat' | + '\\square' | + '\\backprime' | + '\\Re' | + '\\natural' | + '\\surd' | + '\\circledS'; +OTHER_SYMBOL_CMD: OTHER_SYMBOL [ ]?; + +fragment INFTY_CMD: '\\infty'; +fragment PARTIAL_CMD: '\\partial'; +fragment INFTY: INFTY_CMD | DOLLAR_SIGN INFTY_CMD | INFTY_CMD PERCENT_SIGN; +SYMBOL: PARTIAL_CMD | INFTY ; + +fragment VARIABLE_CMD: '\\variable'; +fragment VARIABLE_SYMBOL: (GREEK_CMD | OTHER_SYMBOL_CMD | LETTER | DIGIT)+ (UNDERSCORE ((L_BRACE (GREEK_CMD | OTHER_SYMBOL_CMD | LETTER | DIGIT | COMMA)+ R_BRACE) | (GREEK_CMD | OTHER_SYMBOL_CMD | LETTER | DIGIT)))?; +VARIABLE: VARIABLE_CMD L_BRACE VARIABLE_SYMBOL R_BRACE PERCENT_SIGN?; + +SET_NATURALS: '\\mathbb{N}' | 'ℕ'; +SET_INTEGERS: '\\mathbb{Z}' | 'ℤ'; +SET_RATIONALS: '\\mathbb{Q}' | 'ℚ'; +SET_REALS: '\\mathbb{R}' | 'ℝ'; +SET_COMPLEX: '\\mathbb{C}' | 'ℂ'; +SET_PRIMES: '\\mathbb{P}' | 'ℙ'; + +//collection of accents +fragment ACCENT_CMD: + '\\acute' | + '\\bar' | + '\\overline' | + '\\breve' | + '\\check' | + '\\widecheck' | + '\\dot' | + '\\ddot' | + '\\grave' | + '\\hat' | + '\\tilde' | + '\\widetilde' | + '\\vec' | + '\\overrightarrow' | + '\\bm' | + '\\boldsymbol' | + '\\text' | + '\\textit' | + '\\textbf' | + '\\textnormal' | + '\\mathbb' | + '\\mathbin' | + '\\mathbf' | + '\\mathcal' | + '\\mathclap' | + '\\mathclose' | + '\\mathellipsis' | + '\\mathfrak' | + '\\mathinner' | + '\\mathit' | + '\\mathnormal' | + '\\mathop' | + '\\mathopen' | + '\\mathord' | + '\\mathpunct' | + '\\mathrel' | + '\\mathring' | + '\\mathrlap' | + '\\mathrm' | + '\\mathscr' | + '\\mathsf' | + '\\mathsterling' | + '\\mathtt' | + '\\mbox'; + +ACCENT: ACCENT_CMD WS_CHAR? L_BRACE ('\\}' | ~[}])*? R_BRACE; + +// Set operations (small subsetion) +UNION: '\\cup' | '∪'; +INTERSECTION: '\\cap' | '∩'; +SET_MINUS: '\\setminus' | '∖'; +PLUS_MINUS: '\\pm' | '±' | '∓' | '\\mp'; +// We can't add {} to the empty set as otherwise any empty braces will be lexed as empty set +SET_EMPTY: '\\emptyset' | '∅' | L_BRACE_VISUAL R_BRACE_VISUAL | L_BRACE_CMD R_BRACE_CMD; + +SUPSET: '\\supseteq' | '⊇'; +SUBSET: '\\subseteq' | '⊆'; +NOTIN: '\\notin' | '∉'; + +// Grammar rules + + +// We also have set elements so that 1,2,3,4 is parsed as a set +math: (set_elements_relation | set_elements | set_relation) EOF; + +transpose: '^T' | '^{T}' | '^{\\\top}' | '\''; +degree: '^\\circ' | '^\\degree' | '^\\circle' | '^°' | '^{\\circ}' | '^{\\degree}' | '^{\\circle}' | '^{°}'; + +transform_atom: LETTER_NO_E UNDERSCORE (NUMBER | L_BRACE NUMBER R_BRACE); +transform_scale: (expr | group | ADD | SUB) transform_atom; +transform_swap: transform_atom TRANSFORM_EXCHANGE transform_atom; +transform_assignment: transform_atom transform_scale; +elementary_transform: transform_assignment | transform_scale | transform_swap; +elementary_transforms: elementary_transform (COMMA elementary_transform)*; + +matrix: + (CMD_MATRIX_START + matrix_row (MATRIX_DEL_ROW matrix_row)* MATRIX_DEL_ROW? + CMD_MATRIX_END | CMD_ARRAY_START + matrix_row (MATRIX_DEL_ROW matrix_row)* MATRIX_DEL_ROW? + CMD_ARRAY_END) + (MATRIX_XRIGHTARROW (L_BRACKET elementary_transforms R_BRACKET)? L_BRACE elementary_transforms R_BRACE)?; + +det: + CMD_DET_START + matrix_row (MATRIX_DEL_ROW matrix_row)* MATRIX_DEL_ROW? + CMD_DET_END; + +matrix_row: + expr (MATRIX_DEL_COL expr)*; + +relation: + relation (IN | ASSIGNMENT | EQUAL | APPROX | LT | LTE | GT | GTE | UNEQUAL) relation + | expr; + +equality: + expr (EQUAL | ASSIGNMENT) expr; + +expr: additive; + +additive: + additive (ADD | SUB) additive + | mp; + +// mult part +mp: + mp (MUL | CMD_TIMES | CMD_CDOT | DIV | CMD_DIV | COLON | CMD_MOD) mp + | unary; + +mp_nofunc: + mp_nofunc (MUL | CMD_TIMES | CMD_CDOT | DIV | CMD_DIV | COLON | CMD_MOD) mp_nofunc + | unary_nofunc; + +unary: + (ADD | SUB) unary + | postfix+; + +unary_nofunc: + (ADD | SUB) unary_nofunc + | postfix postfix_nofunc*; + +postfix: exp postfix_op*; +postfix_nofunc: exp_nofunc postfix_op*; +postfix_op: BANG | eval_at | transpose | degree; + +eval_at: + BAR (eval_at_sup | eval_at_sub | eval_at_sup eval_at_sub); + +eval_at_sub: + UNDERSCORE L_BRACE + (expr | equality) + R_BRACE; + +eval_at_sup: + CARET L_BRACE + (expr | equality) + R_BRACE; + +exp: + exp CARET (atom | L_BRACE expr R_BRACE) subexpr? + | comp; + +exp_nofunc: + exp_nofunc CARET (atom | L_BRACE expr R_BRACE) subexpr? + | comp_nofunc; + +comp: + group + | formatting_group + | norm_group + | abs_group + | dot_product + | floor_group + | ceil_group + | func + | atom + | frac + | binom + | matrix + | det; + +comp_nofunc: + group + | formatting_group + | norm_group + | abs_group + | dot_product + | floor_group + | ceil_group + | atom + | frac + | binom + | matrix + | det; + +group: + L_PAREN expr R_PAREN + | L_GROUP expr R_GROUP + | L_BRACE expr R_BRACE + | L_BRACKET expr R_BRACKET + | L_BRACE_VISUAL expr R_BRACE_VISUAL + | L_BRACE_CMD expr R_BRACE_CMD + | L_BRACK expr R_BRACK; + +formatting_group: + PHANTOM_CMD L_BRACE expr R_BRACE + | BOXED_CMD L_BRACE expr R_BRACE; + + +norm_group: + NORM expr NORM; + + +abs_group: + BAR expr BAR + | L_VERT expr R_VERT + | VERT expr VERT; + + +dot_product: + L_ANGLE expr R_ANGLE; + + + +floor_group: + L_FLOOR expr R_FLOOR + | LL_CORNER expr LR_CORNER; + + +ceil_group: + L_CEIL expr R_CEIL + | UL_CORNER expr UR_CORNER; + + + + +atom_expr_no_supexpr: (LETTER_NO_E | GREEK_CMD | OTHER_SYMBOL_CMD | ACCENT) subexpr?; +atom_expr: (LETTER_NO_E | GREEK_CMD | OTHER_SYMBOL_CMD | ACCENT) (supexpr subexpr | subexpr supexpr | subexpr | supexpr)?; +atom_expr_list: (L_PAREN atom_expr (COMMA atom_expr)* R_PAREN) | atom_expr; +number_subexpr: (NUMBER) subexpr?; +atom: atom_expr | SYMBOL | number_subexpr | PERCENT_NUMBER | E_NOTATION | E_NOTATION_E | DIFFERENTIAL | VARIABLE; + +frac: + CMD_FRAC L_BRACE + upper=expr + R_BRACE L_BRACE + lower=expr + R_BRACE; + +//a binomial expression +binom: + L_BRACE upper=expr CMD_CHOOSE lower=expr R_BRACE + | CMD_BINOM L_BRACE upper=expr R_BRACE L_BRACE lower=expr R_BRACE; + +func_normal_functions_single_arg: + FUNC_LOG | FUNC_LN | FUNC_EXP + | FUNC_SIN | FUNC_COS | FUNC_TAN + | FUNC_CSC | FUNC_SEC | FUNC_COT + | FUNC_ARCSIN | FUNC_ARCCOS | FUNC_ARCTAN + | FUNC_ARCCSC | FUNC_ARCSEC | FUNC_ARCCOT + | FUNC_SINH | FUNC_COSH | FUNC_TANH + | FUNC_ARSINH | FUNC_ARCOSH | FUNC_ARTANH + | FUNC_ARCSINH | FUNC_ARCCOSH | FUNC_ARCTANH + | FUNC_FLOOR | FUNC_CEIL | FUNC_DET; + +func_normal_functions_multi_arg: + FUNC_GCD | FUNC_LCM | FUNC_MAX | FUNC_MIN; + +func_operator_names_single_arg: + FUNC_ARSINH_NAME | FUNC_ARCOSH_NAME | FUNC_ARTANH_NAME + | FUNC_ARCSINH_NAME | FUNC_ARCCOSH_NAME | FUNC_ARCTANH_NAME + | FUNC_FLOOR_NAME | FUNC_CEIL_NAME | FUNC_EYE_NAME | FUNC_RANK_NAME | FUNC_TRACE_NAME + | FUNC_RREF_NAME | FUNC_NULLSPACE_NAME | FUNC_DIAGONALIZE_NAME | FUNC_NORM_NAME + | FUNC_EIGENVALS_NAME | FUNC_EIGENVECTORS_NAME | FUNC_SVD_NAME | FUNC_COLS_NAME | FUNC_ROWS_NAME; + +func_operator_names_multi_arg: + FUNC_GCD_NAME | FUNC_LCM_NAME | FUNC_ZEROS_NAME | FUNC_ORTHOGONALIZE_NAME + | FUNC_ONES_NAME | FUNC_DIAG_NAME | FUNC_HSTACK_NAME | FUNC_VSTACK_NAME; + +func_normal_single_arg: + (func_normal_functions_single_arg) + | + (CMD_OPERATORNAME L_BRACE func_operator_name=func_operator_names_single_arg R_BRACE); + +func_normal_multi_arg: + (func_normal_functions_multi_arg) + | + (CMD_OPERATORNAME L_BRACE func_operator_name=func_operator_names_multi_arg R_BRACE); + +func: + func_normal_single_arg + (subexpr? supexpr? | supexpr? subexpr?) + ( + L_PAREN func_single_arg R_PAREN | + func_single_arg_noparens + ) + + | func_normal_multi_arg + (subexpr? supexpr? | supexpr? subexpr?) + ( + L_PAREN func_multi_arg R_PAREN | + func_multi_arg_noparens + ) + | atom_expr_no_supexpr supexpr? + ( + L_PAREN func_common_args R_PAREN | + L_BRACKET func_common_args R_BRACKET | + L_BRACE L_PAREN func_common_args R_PAREN R_BRACE | + L_BRACE L_BRACKET func_common_args R_BRACKET R_BRACE + ) + | FUNC_INT + (subexpr supexpr | supexpr subexpr | (UNDERSCORE L_BRACE R_BRACE) (CARET L_BRACE R_BRACE) | (CARET L_BRACE R_BRACE) (UNDERSCORE L_BRACE R_BRACE) )? + (additive? DIFFERENTIAL | frac | additive) + + | FUNC_SQRT + (L_BRACKET root=expr R_BRACKET)? + L_BRACE base=expr R_BRACE + + | (FUNC_SUM | FUNC_PROD) + (subeq supexpr | supexpr subeq) + mp + | FUNC_LIM limit_sub mp + | EXP_E supexpr?; //Exponential function e^x + +args: (expr ',' args) | expr; + +func_common_args: atom | (expr ',') | (expr ',' args); + +limit_sub: + UNDERSCORE L_BRACE + (LETTER_NO_E | GREEK_CMD | OTHER_SYMBOL_CMD) + LIM_APPROACH_SYM + expr (CARET L_BRACE (ADD | SUB) R_BRACE)? + R_BRACE; + +func_single_arg: expr; +func_single_arg_noparens: mp_nofunc; + +func_multi_arg: expr | (expr ',' func_multi_arg); +func_multi_arg_noparens: mp_nofunc; + +subexpr: UNDERSCORE (atom | L_BRACE (expr | args) R_BRACE); +supexpr: CARET (atom | L_BRACE expr R_BRACE); + +subeq: UNDERSCORE L_BRACE equality R_BRACE; +supeq: UNDERSCORE L_BRACE equality R_BRACE; + +set_relation: + set_relation (SUBSET | SUPSET) set_relation | + atom_expr_list (IN | NOTIN | ASSIGNMENT) set_relation | + minus_expr; + +minus_expr: + minus_expr SET_MINUS minus_expr | + union_expr; + +union_expr: + union_expr UNION union_expr | + intersection_expr; + +intersection_expr: + intersection_expr INTERSECTION intersection_expr | + set_group; + +set_group: + L_PAREN minus_expr R_PAREN + | set_atom; + +set_atom: + interval | + literal_set | + ordered_tuple | + finite_set; + + +interval: + (L_BRACKET | L_PAREN | L_PAREN_VISUAL | L_BRACK | L_GROUP) + expr COMMA expr + (R_BRACKET | R_PAREN | R_PAREN_VISUAL | R_BRACK | R_GROUP); + +ordered_tuple: + (L_PAREN semicolon_elements R_PAREN) | + (L_PAREN_VISUAL semicolon_elements R_PAREN_VISUAL) | + (L_BRACKET semicolon_elements R_BRACKET); + + +// We allow boxed a set for reason that llms like to use , in boxed +finite_set: + (L_BRACE semicolon_elements R_BRACE) | + (L_BRACE_VISUAL semicolon_elements R_BRACE_VISUAL) | + (BOXED_CMD L_BRACE semicolon_elements R_BRACE); + + + +// We need two targets so that: +// a=1,b=2 is parsed as {a=1, b=2} +// while +// a=1,2,3,4 is parsed as {a=1,2,3,4} +set_elements_relation: + atom_expr_list (IN | ASSIGNMENT) semicolon_elements_no_relation; + +set_elements: + semicolon_elements; + +semicolon_elements: + comma_elements (SEMICOLON comma_elements)*; + +semicolon_elements_no_relation: + comma_elements_no_relation (SEMICOLON comma_elements_no_relation)*; + +comma_elements: + element (COMMA element)*; + +comma_elements_no_relation: + element_no_relation (COMMA element_no_relation)*; + +element_no_relation: + plus_minus_expr | set_atom | expr; + +element: + plus_minus_expr | set_atom | relation; + +plus_minus_expr: + expr PLUS_MINUS expr | PLUS_MINUS expr; + +literal_set: + SET_NATURALS | SET_INTEGERS | SET_RATIONALS | SET_REALS | SET_COMPLEX | SET_PRIMES | SET_EMPTY | L_BRACE R_BRACE; diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/__init__.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/__init__.py new file mode 100644 index 000000000..1458434bc --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/__init__.py @@ -0,0 +1,5 @@ +from .latex2sympy2 import latex2sympy +from .math_normalization import normalize_latex, NormalizationConfig +from .latex2sympy2 import is_expr_of_only_symbols + +__all__ = ['latex2sympy', 'normalize_latex', 'NormalizationConfig', 'is_expr_of_only_symbols'] diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/antlr_parser.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/antlr_parser.py new file mode 100644 index 000000000..5a2b514ca --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/antlr_parser.py @@ -0,0 +1,21 @@ +from importlib.metadata import version, PackageNotFoundError + +try: + antlr_version = version("antlr4-python3-runtime") +except PackageNotFoundError: + antlr_version = "" + +if antlr_version.startswith("4.13.2"): + from latex2sympy2_extended.gen.antlr4_13_2.PSParser import PSParser + from latex2sympy2_extended.gen.antlr4_13_2.PSLexer import PSLexer +elif antlr_version.startswith("4.11"): + from latex2sympy2_extended.gen.antlr4_11_0.PSParser import PSParser + from latex2sympy2_extended.gen.antlr4_11_0.PSLexer import PSLexer +elif antlr_version.startswith("4.9.3"): + from latex2sympy2_extended.gen.antlr4_9_3.PSParser import PSParser + from latex2sympy2_extended.gen.antlr4_9_3.PSLexer import PSLexer +else: + raise ImportError( + f"Unsupported ANTLR version {antlr_version}, " + "only 4.9.3, 4.11.0, and 4.13.2 runtime versions are supported." + ) \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/latex2sympy2.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/latex2sympy2.py new file mode 100644 index 000000000..57a460bd6 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/latex2sympy2.py @@ -0,0 +1,1458 @@ +from dataclasses import dataclass +import sympy +import re +from sympy import Basic, Matrix, MatrixBase, Number, Pow, Rational, matrix_symbols, simplify, factor, expand, apart, expand_trig +from antlr4 import InputStream, CommonTokenStream +from antlr4.error.ErrorListener import ErrorListener +from latex2sympy2_extended.symbols import get_symbol +from latex2sympy2_extended.math_normalization import normalize_latex, NormalizationConfig +from latex2sympy2_extended.antlr_parser import PSParser, PSLexer +import sympy.functions.elementary.trigonometric as sympy_trig +import sympy.functions.elementary.hyperbolic as sympy_hyperbolic +import sympy.functions.elementary.miscellaneous as sympy_misc +import sympy.functions.elementary.integers as sympy_integers +from sympy.core.relational import Relational +from sympy.printing.str import StrPrinter +from sympy.matrices import GramSchmidt +from latex2sympy2_extended.sets import FiniteSet +from latex2sympy2_extended.logic import And +from sympy.parsing.sympy_parser import parse_expr + +@dataclass(frozen=True) +class ConversionConfig: + interpret_as_mixed_fractions: bool = True + interpret_simple_eq_as_assignment: bool = False + interpret_contains_as_eq: bool = True + lowercase_symbols: bool = False + """ + Args: + interpret_as_mixed_fractions (bool): Whether to interpert 2 \frac{1}{2} as 2/2 or 2 + 1/2 + interpret_simple_eq_as_assignment (bool): Whether to interpret simple equations as assignments k=1 -> 1 + interpret_contains_as_eq (bool): Whether to interpret contains as equality x \\in {1,2,3} -> x = {1,2,3} + lowercase_symbols (bool): Whether to lowercase all symbols + """ + + +def flatten_list(l): + return [item for sublist in l for item in sublist] + +def convert_number(number: str): + # If it's 0,111 it's a float + if "," in number and number.startswith("0"): + number = number.replace(",", ".") + + integer = number.translate(str.maketrans("", "", ", ")).lstrip("0") + if len(integer) == 0: + integer = "0" + return Number(integer) + +def is_expr_of_only_symbols(expr): + if hasattr(expr, 'is_Symbol') and expr.is_Symbol: + return True + + # To allow A/S + if hasattr(expr, 'is_Pow') and expr.is_Pow and expr.args[1] == -1 and ( + hasattr(expr.args[0], 'is_Symbol') and expr.args[0].is_Symbol + or hasattr(expr.args[0], 'args') and all(is_expr_of_only_symbols(arg) for arg in expr.args[0].args) + ): + return True + + if hasattr(expr, 'args') and len(expr.args) > 0: + return all(is_expr_of_only_symbols(arg) for arg in expr.args) + return False + + +comma_number_regex = re.compile(r'^\s*-?\d{1,3}(,\d{3})+(\.\d+)?\s*$') + +class _Latex2Sympy: + def __init__(self, variable_values: dict | None = None, is_real=None, convert_degrees: bool = False, config: ConversionConfig = ConversionConfig()): + # Instance variables + self.is_real = is_real + self.variances = {} # For substituting + self.var = {var:val if isinstance(val, Basic) or isinstance(val, MatrixBase) else parse_expr(val) for var, val in variable_values.items()} if variable_values else {} + self.convert_degrees = convert_degrees + self.config = config + + def create_parser(self, latex_str): + """Create parser for latex string""" + stream = InputStream(latex_str) + lex = PSLexer(stream) + lex.removeErrorListeners() + lex.addErrorListener(self.MathErrorListener(latex_str)) + tokens = CommonTokenStream(lex) + parser = PSParser(tokens) + parser.removeErrorListeners() + parser.addErrorListener(self.MathErrorListener(latex_str)) + return parser + + def parse(self, latex_str: str): + """Main entry point to parse latex string""" + # setup listener + parser = self.create_parser(latex_str) + + # process the input + math = parser.math() + + # if set relation + if math.set_relation(): + return self.convert_set_relation(math.set_relation()) + + if math.set_elements(): + # The issue with 333,333 or 3,333 is that it makess sets and numbers with commas ambigous + # is that 333333 or {333,333}? + # What we therefore do is that default to numbers with commas + # We make the regex match directly on latex_str, because otherwise don't know if there is space + # between the comma and the number, in this case it should be a set + if comma_number_regex.match(latex_str): + return convert_number(latex_str) + return self.convert_set_elements(math.set_elements()) + + if math.set_elements_relation(): + return self.convert_set_elements_relation(math.set_elements_relation()) + + raise Exception('Nothing matched') + + class MathErrorListener(ErrorListener): + def __init__(self, src): + super(ErrorListener, self).__init__() + self.src = src + + def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): + fmt = "%s\n%s\n%s" + marker = "~" * column + "^" + + if msg.startswith("missing"): + err = fmt % (msg, self.src, marker) + elif msg.startswith("no viable"): + err = fmt % ("I expected something else here", self.src, marker) + elif msg.startswith("mismatched"): + names = PSParser.literalNames + expected = [names[i] for i in e.getExpectedTokens() if i < len(names)] + if len(expected) < 10: + expected = " ".join(expected) + err = (fmt % ("I expected one of these: " + expected, + self.src, marker)) + else: + err = (fmt % ("I expected something else here", self.src, marker)) + else: + err = fmt % ("I don't understand this", self.src, marker) + raise Exception(err) + + def convert_relation(self, rel): + if rel.expr(): + return self.convert_expr(rel.expr()) + + lh = self.convert_relation(rel.relation(0)) + rh = self.convert_relation(rel.relation(1)) + + if rel.LT(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.StrictLessThan(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.StrictLessThan(lh.rhs, rh, evaluate=False)) + return sympy.StrictLessThan(lh, rh, evaluate=False) + elif rel.LTE(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.LessThan(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.LessThan(lh.rhs, rh, evaluate=False)) + return sympy.LessThan(lh, rh, evaluate=False) + elif rel.GT(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.StrictGreaterThan(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.StrictGreaterThan(lh.rhs, rh, evaluate=False)) + return sympy.StrictGreaterThan(lh, rh, evaluate=False) + elif rel.GTE(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.GreaterThan(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.GreaterThan(lh.rhs, rh, evaluate=False)) + return sympy.GreaterThan(lh, rh, evaluate=False) + elif rel.EQUAL(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.Eq(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.Eq(lh.rhs, rh, evaluate=False)) + return sympy.Eq(lh, rh, evaluate=False) + elif rel.ASSIGNMENT(): + # !Use Global variances + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(lh): + # set value + self.variances[lh] = rh + self.var[str(lh)] = rh + return rh + else: + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.Eq(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.Eq(lh.rhs, rh, evaluate=False)) + return sympy.Eq(lh, rh, evaluate=False) + elif rel.APPROX(): + if is_expr_of_only_symbols(lh): + self.variances[lh] = rh + self.var[str(lh)] = rh + return rh + else: + # We don't want approximation, so we jsut take the non-approximated value + return lh + elif rel.IN(): + # !Use Global variances + if hasattr(rh, 'is_Pow') and rh.is_Pow and hasattr(rh.exp, 'is_Mul'): + n = rh.exp.args[0] + m = rh.exp.args[1] + if n in self.variances: + n = self.variances[n] + if m in self.variances: + m = self.variances[m] + rh = sympy.MatrixSymbol(lh, n, m) + self.variances[lh] = rh + self.var[str(lh)] = rh + elif self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(lh): + self.variances[lh] = rh + self.var[str(lh)] = rh + return rh + else: + raise Exception('Unrecognized relation') + return lh + elif rel.UNEQUAL(): + if isinstance(lh, And): + return And(*lh._unsorted_args, sympy.Ne(lh._unsorted_args[-1].rhs, rh, evaluate=False)) + elif isinstance(lh, Relational): + return And(lh, sympy.Ne(lh.rhs, rh, evaluate=False)) + return sympy.Ne(lh, rh, evaluate=False) + + + def convert_set_relation(self, expr): + if expr.atom_expr_list(): + left = self.convert_atom_expr_list(expr.atom_expr_list()) + right = self.convert_set_relation(expr.set_relation()[0]) + if expr.IN(): + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(left): + # set value + self.variances[left] = right + self.var[str(left)] = right + return right + elif self.config.interpret_contains_as_eq: + return sympy.Eq(left, right, evaluate=False) + else: + return sympy.Contains(left, right, evaluate=False) + elif expr.ASSIGNMENT(): + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(left): + # set value + self.variances[left] = right + self.var[str(left)] = right + return right + else: + return sympy.Eq(left, right, evaluate=False) + elif expr.NOTIN(): + if self.config.interpret_contains_as_eq: + val = (sympy.S.Reals if self.is_real else sympy.S.Complexes) - right + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(left): + self.variances[left] = val + self.var[str(left)] = val + return val + else: + return sympy.Not(sympy.Eq(left, right, evaluate=False), evaluate=False) + else: + return sympy.Not(right.contains(left)) + + if expr.set_relation(): + left = self.convert_set_relation(expr.set_relation()[0]) + right = self.convert_set_relation(expr.set_relation()[1]) + if expr.SUBSET(): + return left.is_subset(right) + if expr.SUPSET(): + return right.is_subset(left) + raise Exception('Unrecognized set relation') + return self.convert_set_minus(expr.minus_expr()) + + def convert_elements_to_set_or_tuple(self, elements): + """Helper function to convert elements to either a FiniteSet or Tuple based on content""" + if len(elements) == 1: + if len(elements[0]) == 1: + return elements[0][0] + return FiniteSet(*elements[0]) + elif all(len(elem) == 1 for elem in elements): + return FiniteSet(*[elem[0] for elem in elements]) + else: + return FiniteSet(*[ + sympy.Tuple(*l) for l in elements + ]) + + def convert_set_elements_relation(self, expr): + semicolon_elements_no_relation = self.convert_semicolon_elements_no_relation(expr.semicolon_elements_no_relation()) + set_elements = self.convert_elements_to_set_or_tuple(semicolon_elements_no_relation) + + atom_expressions = self.convert_atom_expr_list(expr.atom_expr_list()) + if expr.IN(): + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(atom_expressions): + # set value + self.variances[atom_expressions] = set_elements + self.var[str(atom_expressions)] = set_elements + return set_elements + elif self.config.interpret_contains_as_eq: + return sympy.Eq(atom_expressions, set_elements, evaluate=False) + else: + return sympy.Contains(atom_expressions, set_elements, evaluate=False) + elif expr.ASSIGNMENT(): + if self.config.interpret_simple_eq_as_assignment and is_expr_of_only_symbols(atom_expressions): + # set value + self.variances[atom_expressions] = set_elements + self.var[str(atom_expressions)] = set_elements + return set_elements + else: + return sympy.Eq(atom_expressions, set_elements, evaluate=False) + return set_elements + + def convert_set_elements(self, expr): + semicolon_elements = self.convert_semicolon_elements(expr.semicolon_elements()) + return self.convert_elements_to_set_or_tuple(semicolon_elements) + + + def convert_set_minus(self, expr): + if expr.union_expr(): + return self.convert_set_union(expr.union_expr()) + + left = self.convert_set_minus(expr.minus_expr()[0]) + right = self.convert_set_minus(expr.minus_expr()[1]) + return sympy.Complement(left, right, evaluate=False) + + def convert_set_union(self, expr): + if expr.intersection_expr(): + return self.convert_set_intersection(expr.intersection_expr()) + + left = self.convert_set_union(expr.union_expr()[0]) + right = self.convert_set_union(expr.union_expr()[1]) + + # It's hard to know what the user meant, but clearly we cant do intersection with tuple + if isinstance(left, sympy.Tuple): + left = FiniteSet(*left) + + if isinstance(right, sympy.Tuple): + right = FiniteSet(*right) + + return sympy.Union(left, right, evaluate=False) + + def convert_set_intersection(self, expr): + if expr.set_group(): + return self.convert_set_group(expr.set_group()) + + left = self.convert_set_intersection(expr.intersection_expr()[0]) + right = self.convert_set_intersection(expr.intersection_expr()[1]) + + if isinstance(left, sympy.Tuple): + left = FiniteSet(*left) + + if isinstance(right, sympy.Tuple): + right = FiniteSet(*right) + + return sympy.Intersection(left, right, evaluate=False) + + + def convert_set_group(self, expr): + if expr.set_atom(): + return self.convert_set_atom(expr.set_atom()) + + return self.convert_set_minus(expr.minus_expr()) + + def convert_set_atom(self, expr): + if expr.literal_set(): + return self.convert_literal_set(expr.literal_set()) + if expr.interval(): + return self.convert_interval(expr.interval()) + if expr.ordered_tuple(): + return self.convert_ordered_tuple(expr.ordered_tuple()) + if expr.finite_set(): + return self.convert_finite_set(expr.finite_set()) + raise Exception('Unrecognized set atom') + + def convert_interval(self, expr): + left_open = expr.L_PAREN() is not None or expr.L_GROUP() is not None or expr.L_PAREN_VISUAL() is not None + right_open = expr.R_PAREN() is not None or expr.R_GROUP() is not None or expr.R_PAREN_VISUAL() is not None + + left = self.convert_expr(expr.expr()[0]) + right = self.convert_expr(expr.expr()[1]) + + # It doesn't make sense to have interval which represents an empty set, in this case we treat it as a finite set + try: + if (left_open and right_open and right <= left) or (not left_open and not right_open and right < left): + return sympy.Tuple(left, right) + except Exception: + pass + + return sympy.Interval(left, right, left_open=left_open, right_open=right_open) + + def convert_ordered_tuple(self, expr): + elements = self.convert_semicolon_elements(expr.semicolon_elements()) + # We don't support 1 element tuples + if len(elements) == 1 and len(elements[0]) == 1: + return elements[0][0] + return sympy.Tuple(*flatten_list(elements)) + + def convert_finite_set(self, expr): + content = self.convert_semicolon_elements(expr.semicolon_elements()) + # Sometimes people wrap either \boxed{a,b,c}, which we want to be a set, + # but also \boxed{1} which we want to be a number + if expr.BOXED_CMD(): + return self.convert_elements_to_set_or_tuple(content) + return FiniteSet(*flatten_list(content)) + + def convert_semicolon_elements(self, expr): + result = [self.convert_comma_elements(element) for element in expr.comma_elements()] + return result + + def convert_semicolon_elements_no_relation(self, expr): + result = [self.convert_comma_elements_no_relation(element) for element in expr.comma_elements_no_relation()] + return result + + def convert_comma_elements(self, expr): + result = flatten_list(self.convert_element(element) for element in expr.element()) + return result + + def convert_comma_elements_no_relation(self, expr): + result = flatten_list(self.convert_element(element) for element in expr.element_no_relation()) + return result + + def as_unary_minus(self, expr): + if hasattr(expr, 'is_Rational') and expr.is_Rational: + return sympy.Rational(-expr.p, expr.q) + elif hasattr(expr, 'is_Integer') and expr.is_Integer: + return -expr + return sympy.Mul(-1, expr, evaluate=False) + + + + def convert_element(self, element): + if element.plus_minus_expr(): + pm = element.plus_minus_expr() + if len(pm.expr()) == 1: + expr = self.convert_expr(pm.expr()[0]) + return [self.as_unary_minus(expr), expr] + left = self.convert_expr(pm.expr()[0]) + right = self.convert_expr(pm.expr()[1]) + return [sympy.Add(left, right, evaluate=False), sympy.Add(left, self.as_unary_minus(right), evaluate=False)] + elif element.set_atom(): + return [self.convert_set_atom(element.set_atom())] + + elif hasattr(element, 'relation') and element.relation(): + return [self.convert_relation(element.relation())] + + elif hasattr(element, 'expr') and element.expr(): + return [self.convert_expr(element.expr())] + else: + raise Exception('Unrecognized comma element') + + + # Fallback because for some reason finites set wtih paren parses sometimes first + # instead of interval + return elements + + def convert_literal_set(self, expr): + if expr.SET_NATURALS(): + return sympy.S.Naturals + elif expr.SET_INTEGERS(): + return sympy.S.Integers + elif expr.SET_RATIONALS(): + return sympy.S.Rationals + elif expr.SET_REALS(): + return sympy.S.Reals + elif expr.SET_COMPLEX(): + return sympy.S.Complexes + elif expr.SET_EMPTY() or expr.L_BRACE() and expr.R_BRACE(): + return sympy.S.EmptySet + raise Exception('Unrecognized literal set') + + + def convert_expr(self, expr): + if expr.additive(): + return self.convert_add(expr.additive()) + + + def convert_elementary_transform(self, matrix, transform): + if transform.transform_scale(): + transform_scale = transform.transform_scale() + transform_atom = transform_scale.transform_atom() + k = None + num = int(transform_atom.NUMBER().getText()) - 1 + if transform_scale.expr(): + k = self.convert_expr(transform_scale.expr()) + elif transform_scale.group(): + k = self.convert_expr(transform_scale.group().expr()) + elif transform_scale.SUB(): + k = -1 + else: + k = 1 + if transform_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n->kn', row=num, k=k) + elif transform_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n->kn', col=num, k=k) + else: + raise Exception('Row and col don\'s match') + + elif transform.transform_swap(): + first_atom = transform.transform_swap().transform_atom()[0] + second_atom = transform.transform_swap().transform_atom()[1] + first_num = int(first_atom.NUMBER().getText()) - 1 + second_num = int(second_atom.NUMBER().getText()) - 1 + if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): + raise Exception('Row and col don\'s match') + elif first_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n<->m', row1=first_num, row2=second_num) + elif first_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n<->m', col1=first_num, col2=second_num) + else: + raise Exception('Row and col don\'s match') + + elif transform.transform_assignment(): + first_atom = transform.transform_assignment().transform_atom() + second_atom = transform.transform_assignment().transform_scale().transform_atom() + transform_scale = transform.transform_assignment().transform_scale() + k = None + if transform_scale.expr(): + k = self.convert_expr(transform_scale.expr()) + elif transform_scale.group(): + k = self.convert_expr(transform_scale.group().expr()) + elif transform_scale.SUB(): + k = -1 + else: + k = 1 + first_num = int(first_atom.NUMBER().getText()) - 1 + second_num = int(second_atom.NUMBER().getText()) - 1 + if first_atom.LETTER_NO_E().getText() != second_atom.LETTER_NO_E().getText(): + raise Exception('Row and col don\'s match') + elif first_atom.LETTER_NO_E().getText() == 'r': + matrix = matrix.elementary_row_op(op='n->n+km', k=k, row1=first_num, row2=second_num) + elif first_atom.LETTER_NO_E().getText() == 'c': + matrix = matrix.elementary_col_op(op='n->n+km', k=k, col1=first_num, col2=second_num) + else: + raise Exception('Row and col don\'s match') + + return matrix + + + def convert_matrix(self, matrix): + # build matrix + row = matrix.matrix_row() + tmp = [] + rows = 0 + mat = None + + for r in row: + tmp.append([]) + for expr in r.expr(): + tmp[rows].append(self.convert_expr(expr)) + rows = rows + 1 + + mat = sympy.Matrix(tmp) + + if hasattr(matrix, 'MATRIX_XRIGHTARROW') and matrix.MATRIX_XRIGHTARROW(): + transforms_list = matrix.elementary_transforms() + if len(transforms_list) == 1: + for transform in transforms_list[0].elementary_transform(): + mat = self.convert_elementary_transform(mat, transform) + elif len(transforms_list) == 2: + # firstly transform top of xrightarrow + for transform in transforms_list[1].elementary_transform(): + mat = self.convert_elementary_transform(mat, transform) + # firstly transform bottom of xrightarrow + for transform in transforms_list[0].elementary_transform(): + mat = self.convert_elementary_transform(mat, transform) + + return mat + + + def add_flat(self, lh, rh): + if hasattr(lh, 'is_Add') and lh.is_Add or hasattr(rh, 'is_Add') and rh.is_Add: + args = [] + if hasattr(lh, 'is_Add') and lh.is_Add: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_Add') and rh.is_Add: + args = args + list(rh.args) + else: + args += [rh] + return sympy.Add(*args, evaluate=False) + else: + return sympy.Add(lh, rh, evaluate=False) + + + def mat_add_flat(self, lh, rh): + if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd or hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: + args = [] + if hasattr(lh, 'is_MatAdd') and lh.is_MatAdd: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_MatAdd') and rh.is_MatAdd: + args = args + list(rh.args) + else: + args += [rh] + # Previously there doit, but I don't think it's needed + return sympy.MatAdd(*[arg for arg in args], evaluate=False) + else: + return sympy.MatAdd(lh, rh, evaluate=False) + + + def mul_flat(self, lh, rh): + if hasattr(lh, 'is_Mul') and lh.is_Mul or hasattr(rh, 'is_Mul') and rh.is_Mul: + args = [] + if hasattr(lh, 'is_Mul') and lh.is_Mul: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_Mul') and rh.is_Mul: + args = args + list(rh.args) + else: + args += [rh] + return sympy.Mul(*args, evaluate=False) + else: + return sympy.Mul(lh, rh, evaluate=False) + + + def mat_mul_flat(self, lh, rh): + if hasattr(lh, 'is_MatMul') and lh.is_MatMul or hasattr(rh, 'is_MatMul') and rh.is_MatMul: + args = [] + if hasattr(lh, 'is_MatMul') and lh.is_MatMul: + args += list(lh.args) + else: + args += [lh] + if hasattr(rh, 'is_MatMul') and rh.is_MatMul: + args = args + list(rh.args) + else: + args += [rh] + return sympy.MatMul(*[arg for arg in args], evaluate=False) + else: + # We don't have to doit there + # if hasattr(lh, 'is_Matrix'): + # lh = lh.doit() + # if hasattr(rh, 'is_Matrix'): + # rh = rh.doit() + return sympy.MatMul(lh, rh, evaluate=False) + + + def convert_add(self, add): + if add.ADD(): + lh = self.convert_add(add.additive(0)) + rh = self.convert_add(add.additive(1)) + + if (hasattr(lh, 'is_Matrix') and lh.is_Matrix) or (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + return self.mat_add_flat(lh, rh) + else: + return self.add_flat(lh, rh) + elif add.SUB(): + lh = self.convert_add(add.additive(0)) + rh = self.convert_add(add.additive(1)) + + if (hasattr(lh, 'is_Matrix') and lh.is_Matrix) or (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + return self.mat_add_flat(lh, self.mat_mul_flat(-1, rh)) + else: + # If we want to force ordering for variables this should be: + # return Sub(lh, rh, evaluate=False) + if not (hasattr(rh, 'is_Matrix') and rh.is_Matrix) and (hasattr(rh, 'func') and rh.func.is_Number): + rh = -rh + else: + rh = self.mul_flat(-1, rh) + return self.add_flat(lh, rh) + else: + return self.convert_mp(add.mp()) + + + def convert_mp(self, mp): + if hasattr(mp, 'mp'): + mp_left = mp.mp(0) + mp_right = mp.mp(1) + else: + mp_left = mp.mp_nofunc(0) + mp_right = mp.mp_nofunc(1) + + if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): + lh = self.convert_mp(mp_left) + rh = self.convert_mp(mp_right) + + if (hasattr(lh, 'is_Matrix') and lh.is_Matrix) or (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + return self.mat_mul_flat(lh, rh) + else: + return self.mul_flat(lh, rh) + elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): + lh = self.convert_mp(mp_left) + rh = self.convert_mp(mp_right) + if (hasattr(lh, 'is_Matrix') and lh.is_Matrix) or (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + return sympy.MatMul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + + # If both are numbers, we convert to sympy.Rational + elif hasattr(lh, 'is_Integer') and lh.is_Integer and hasattr(rh, 'is_Integer') and rh.is_Integer: + return sympy.Rational(lh, rh) + else: + return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + elif mp.CMD_MOD(): + lh = self.convert_mp(mp_left) + rh = self.convert_mp(mp_right) + if (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + raise Exception("Cannot perform modulo operation with a matrix as an operand") + else: + return sympy.Mod(lh, rh, evaluate=False) + else: + if hasattr(mp, 'unary'): + return self.convert_unary(mp.unary()) + else: + return self.convert_unary(mp.unary_nofunc()) + + + def convert_unary(self, unary): + if hasattr(unary, 'unary'): + nested_unary = unary.unary() + else: + nested_unary = unary.unary_nofunc() + if hasattr(unary, 'postfix_nofunc'): + first = unary.postfix() + tail = unary.postfix_nofunc() + postfix = [first] + tail + else: + postfix = unary.postfix() + + if unary.ADD(): + return self.convert_unary(nested_unary) + elif unary.SUB(): + tmp_convert_nested_unary = self.convert_unary(nested_unary) + if (hasattr(tmp_convert_nested_unary, 'is_Matrix') and tmp_convert_nested_unary.is_Matrix): + return self.mat_mul_flat(-1, tmp_convert_nested_unary) + else: + if (hasattr(tmp_convert_nested_unary, 'func') and tmp_convert_nested_unary.func.is_Number): + return -tmp_convert_nested_unary + + elif hasattr(tmp_convert_nested_unary, 'is_Number') and tmp_convert_nested_unary.is_Number: + return -tmp_convert_nested_unary + else: + return self.mul_flat(-1, tmp_convert_nested_unary) + elif postfix: + return self.convert_postfix_list(postfix) + + + def convert_postfix_list(self, arr, i=0): + if i >= len(arr): + raise Exception("Index out of bounds") + + res = self.convert_postfix(arr[i]) + + if isinstance(res, sympy.Expr) or isinstance(res, sympy.Matrix): + if i == len(arr) - 1: + return res # nothing to multiply by + else: + # multiply by next + rh = self.convert_postfix_list(arr, i + 1) + + if (hasattr(res, 'is_Matrix') and res.is_Matrix) or (hasattr(rh, 'is_Matrix') and rh.is_Matrix): + return self.mat_mul_flat(res, rh) + # Support for mixed fractions, 2 \frac{1}{2} + elif hasattr(res, 'is_Integer') and res.is_Integer and hasattr(rh, 'is_Rational') and rh.is_Rational and rh.p > 0 and rh.q > 0: + if res < 0: + return sympy.Rational(res*rh.q - rh.p, rh.q) + else: + return sympy.Rational(res*rh.q + rh.p, rh.q) + else: + return self.mul_flat(res, rh) + elif isinstance(res, list) and len(res) == 1: # must be derivative + wrt = res[0] + if i == len(arr) - 1: + raise Exception("Expected expression for derivative") + else: + expr = self.convert_postfix_list(arr, i + 1) + return sympy.Derivative(expr, wrt) + + return res + + + def do_subs(self, expr, at): + if at.expr(): + at_expr = self.convert_expr(at.expr()) + syms = at_expr.atoms(sympy.Symbol) + if len(syms) == 0: + return expr + elif len(syms) > 0: + sym = next(iter(syms)) + return expr.subs(sym, at_expr) + elif at.equality(): + lh = self.convert_expr(at.equality().expr(0)) + rh = self.convert_expr(at.equality().expr(1)) + return expr.subs(lh, rh) + + + def convert_postfix(self, postfix): + if hasattr(postfix, 'exp'): + exp_nested = postfix.exp() + else: + exp_nested = postfix.exp_nofunc() + + exp = self.convert_exp(exp_nested) + for op in postfix.postfix_op(): + if op.BANG(): + if isinstance(exp, list): + raise Exception("Cannot apply postfix to derivative") + exp = sympy.factorial(exp, evaluate=False) + elif op.eval_at(): + ev = op.eval_at() + at_b = None + at_a = None + if ev.eval_at_sup(): + at_b = self.do_subs(exp, ev.eval_at_sup()) + if ev.eval_at_sub(): + at_a = self.do_subs(exp, ev.eval_at_sub()) + if at_b is not None and at_a is not None: + exp = self.add_flat(at_b, self.mul_flat(at_a, -1)) + elif at_b is not None: + exp = at_b + elif at_a is not None: + exp = at_a + elif op.transpose(): + try: + exp = exp.T + except Exception: + try: + exp = sympy.transpose(exp) + except Exception: + pass + pass + elif op.degree() and self.convert_degrees: + try: + exp = sympy.Mul(exp, sympy.pi/180) + except Exception: + pass + + return exp + + + def convert_exp(self, exp): + if hasattr(exp, 'exp'): + exp_nested = exp.exp() + else: + exp_nested = exp.exp_nofunc() + + if exp_nested: + base = self.convert_exp(exp_nested) + if isinstance(base, list): + raise Exception("Cannot raise derivative to power") + if exp.atom(): + exponent = self.convert_atom(exp.atom()) + else: + # It's expr + exponent = self.convert_expr(exp.expr()) + + return sympy.Pow(base, exponent, evaluate=False) + else: + if hasattr(exp, 'comp'): + return self.convert_comp(exp.comp()) + else: + return self.convert_comp(exp.comp_nofunc()) + + + def convert_comp(self, comp): + if comp.group(): + return self.convert_expr(comp.group().expr()) + elif comp.formatting_group(): + return self.convert_expr(comp.formatting_group().expr()) + elif comp.norm_group(): + return self.convert_expr(comp.norm_group().expr()).norm() + elif comp.abs_group(): + return sympy.Abs(self.convert_expr(comp.abs_group().expr()), evaluate=False) + elif comp.floor_group(): + return self.handle_floor(self.convert_expr(comp.floor_group().expr())) + elif comp.ceil_group(): + return self.handle_ceil(self.convert_expr(comp.ceil_group().expr())) + elif comp.atom(): + return self.convert_atom(comp.atom()) + elif comp.frac(): + return self.convert_frac(comp.frac()) + elif comp.binom(): + return self.convert_binom(comp.binom()) + elif comp.matrix(): + return self.convert_matrix(comp.matrix()) + elif comp.det(): + # !Use Global variances + return self.convert_matrix(comp.det()).subs(self.variances).det() + elif comp.func(): + return self.convert_func(comp.func()) + + + def convert_atom_expr(self, atom_expr): + # find the atom's text + atom_text = '' + if atom_expr.LETTER_NO_E(): + atom_text = atom_expr.LETTER_NO_E().getText() + if atom_text == "I": + return sympy.I + elif atom_expr.GREEK_CMD(): + atom_text = atom_expr.GREEK_CMD().getText() + elif atom_expr.OTHER_SYMBOL_CMD(): + atom_text = atom_expr.OTHER_SYMBOL_CMD().getText() + elif atom_expr.ACCENT(): + atom_text = atom_expr.ACCENT().getText() + # Remove the command by striping first { and last } + text_start = atom_text.index('{') + accent_name = atom_text[1:text_start] + accent_text = atom_text[text_start + 1:-1].replace(" ", "") + # exception: check if bar or overline which are treated both as bar + if accent_name in ["bar", "overline"]: + accent_name = "bar" + elif accent_name in ["vec", "overrightarrow"]: + accent_name = "vec" + elif accent_name in ["tilde", "widetilde"]: + accent_name = "tilde" + elif "text" in accent_name or "mbox" in accent_name: + # We ignore text accents so that $C$ == $\\text{C}$ + accent_name = "" + # Remove the parentheses + accent_text = accent_text.replace("(", "").replace(")", "") + elif "math" in accent_name: + accent_name = "math" + + if accent_name: + atom_text = f"{accent_name}{{{accent_text}}}" + else: + atom_text = accent_text + + # find atom's subscript, if any + subscript_text = '' + if atom_expr.subexpr(): + subexpr = atom_expr.subexpr() + subscript = None + if subexpr.expr(): # subscript is expr + subscript = subexpr.expr().getText().strip() + elif subexpr.atom(): # subscript is atom + subscript = subexpr.atom().getText().strip() + elif subexpr.args(): # subscript is args + subscript = subexpr.args().getText().strip() + subscript_inner_text = StrPrinter().doprint(subscript) + if len(subscript_inner_text) > 1: + subscript_text = '_{' + subscript_inner_text + '}' + else: + subscript_text = '_' + subscript_inner_text + + # construct the symbol using the text and optional subscript + atom_symbol = get_symbol(atom_text.strip() + subscript_text, self.is_real, self.config.lowercase_symbols) + # for matrix symbol + matrix_symbol = None + if atom_text + subscript_text in self.var: + try: + rh = self.var[atom_text + subscript_text] + shape = sympy.shape(rh) + matrix_symbol = sympy.MatrixSymbol(atom_text + subscript_text, shape[0], shape[1]) + self.variances[matrix_symbol] = self.variances[atom_symbol] + except Exception: + pass + + # find the atom's superscript, and return as a Pow if found + if atom_expr.supexpr(): + supexpr = atom_expr.supexpr() + func_pow = None + if supexpr.expr(): + func_pow = self.convert_expr(supexpr.expr()) + else: + func_pow = self.convert_atom(supexpr.atom()) + return sympy.Pow(atom_symbol, func_pow, evaluate=False) + + return atom_symbol if not matrix_symbol else matrix_symbol + + def convert_atom_expr_list(self, atom_expr_list): + converted_atoms = [self.convert_atom_expr(atom_expr) for atom_expr in atom_expr_list.atom_expr()] + if len(converted_atoms) == 1: + return converted_atoms[0] + return sympy.Tuple(*converted_atoms) + + def create_symbol(self, text, enforce_case=False): + if self.config.lowercase_symbols and not enforce_case: + return sympy.Symbol(text.lower(), real=self.is_real) + else: + return sympy.Symbol(text, real=self.is_real) + + def convert_atom(self, atom): + if atom.atom_expr(): + return self.convert_atom_expr(atom.atom_expr()) + elif atom.SYMBOL(): + s = atom.SYMBOL().getText().replace("\\$", "").replace("\\%", "") + if s == "\\infty": + return sympy.oo + else: + raise Exception("Unrecognized symbol") + elif atom.number_subexpr(): + # We just ignore the subexpr right now + s = atom.number_subexpr().NUMBER().getText() + number = self.parse_number(s) + return number + elif atom.E_NOTATION(): + s = atom.E_NOTATION().getText() + return self.parse_number(s) + elif atom.E_NOTATION_E(): + return self.create_symbol('E') + elif atom.DIFFERENTIAL(): + diff_var = self.get_differential_var(atom.DIFFERENTIAL()) + return sympy.Symbol('d' + diff_var.name, real=self.is_real) + elif atom.VARIABLE(): + text = atom.VARIABLE().getText() + is_percent = text.endswith("\\%") + trim_amount = 3 if is_percent else 1 + atom_text = text[10:] + atom_text = atom_text[0:len(atom_text) - trim_amount] + + # Hynek: I don't think we want this to happen + # replace the variable for already known variable values + # if atom_text in self.var: + # symbol = self.var[atom_text] + # else: + symbol = self.create_symbol(atom_text) + + if is_percent: + return convert_to_pct(symbol) + + # return the symbol + return symbol + + elif atom.PERCENT_NUMBER(): + text = atom.PERCENT_NUMBER().getText().replace("\\%", "").replace("%", "").replace(",", "") + number = self.parse_number(text) + percent = sympy.Mul(number, Rational(1, 100), evaluate=False) + return percent + def parse_number(self, text): + text = text.replace(",", "") + # If it's made only of digits, remove the starting 0 + if text.isdigit(): + while len(text) > 1 and text[0] == '0': + text = text[1:] + return sympy.Number(text) + + + def rule2text(self, ctx): + stream = ctx.start.getInputStream() + # starting index of starting token + startIdx = ctx.start.start + # stopping index of stopping token + stopIdx = ctx.stop.stop + + return stream.getText(startIdx, stopIdx) + + + def convert_frac(self, frac): + diff_op = False + partial_op = False + lower_itv = frac.lower.getSourceInterval() + lower_itv_len = lower_itv[1] - lower_itv[0] + 1 + wrt = None + if (frac.lower.start == frac.lower.stop and + frac.lower.start.type == PSLexer.DIFFERENTIAL): + wrt = self.get_differential_var_str(frac.lower.start.text) + diff_op = True + elif (lower_itv_len == 2 and + frac.lower.start.type == PSLexer.SYMBOL and + frac.lower.start.text == '\\partial' and + (frac.lower.stop.type == PSLexer.LETTER_NO_E or frac.lower.stop.type == PSLexer.SYMBOL)): + partial_op = True + wrt = frac.lower.stop.text + if frac.lower.stop.type == PSLexer.SYMBOL: + wrt = wrt[1:] + + if diff_op or partial_op: + wrt = self.create_symbol(wrt, enforce_case=True) + if (diff_op and frac.upper.start == frac.upper.stop and + frac.upper.start.type == PSLexer.LETTER_NO_E and + frac.upper.start.text == 'd'): + return [wrt] + elif (partial_op and frac.upper.start == frac.upper.stop and + frac.upper.start.type == PSLexer.SYMBOL and + frac.upper.start.text == '\\partial'): + return [wrt] + upper_text = self.rule2text(frac.upper) + + expr_top = None + if diff_op and upper_text.startswith('d'): + expr_top = self.parse(upper_text[1:]) + elif partial_op and frac.upper.start.text == '\\partial': + expr_top = self.parse(upper_text[len('\\partial'):]) + if expr_top: + return sympy.Derivative(expr_top, wrt) + + expr_top = self.convert_expr(frac.upper) + expr_bot = self.convert_expr(frac.lower) + if hasattr(expr_top, 'is_Matrix') and expr_top.is_Matrix or hasattr(expr_bot, 'is_Matrix') and expr_bot.is_Matrix: + return sympy.MatMul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) + + elif hasattr(expr_top, 'is_Integer') and expr_top.is_Integer and hasattr(expr_bot, 'is_Integer') and expr_bot.is_Integer: + return sympy.Rational(expr_top, expr_bot) + else: + return sympy.Mul(expr_top, sympy.Pow(expr_bot, -1, evaluate=False), evaluate=False) + + + def convert_binom(self, binom): + expr_top = self.convert_expr(binom.upper) + expr_bot = self.convert_expr(binom.lower) + return sympy.binomial(expr_top, expr_bot) + + + def convert_func(self, func): + if func.func_normal_single_arg(): + if func.func_single_arg(): # function called with parenthesis + arg = self.convert_func_arg(func.func_single_arg()) + else: + arg = self.convert_func_arg(func.func_single_arg_noparens()) + + name = func.func_normal_single_arg().start.text[1:] + + + # get pow + func_pow = None + if func.supexpr(): + if func.supexpr().expr(): + func_pow = self.convert_expr(func.supexpr().expr()) + else: + func_pow = self.convert_atom(func.supexpr().atom()) + + # change arc -> a + if name in ["arcsin", "arccos", "arctan", "arccsc", "arcsec", + "arccot"]: + name = "a" + name[3:] + expr = getattr(sympy_trig, name)(arg, evaluate=False) + elif name in ["arsinh", "arcosh", "artanh"]: + name = "a" + name[2:] + expr = getattr(sympy_hyperbolic, name)(arg, evaluate=False) + elif name in ["arcsinh", "arccosh", "arctanh"]: + name = "a" + name[3:] + expr = getattr(sympy_hyperbolic, name)(arg, evaluate=False) + elif name == "operatorname": + operatorname = func.func_normal_single_arg().func_operator_name.getText() + + if operatorname in ["arsinh", "arcosh", "artanh"]: + operatorname = "a" + operatorname[2:] + expr = getattr(sympy_hyperbolic, operatorname)(arg, evaluate=False) + elif operatorname in ["arcsinh", "arccosh", "arctanh"]: + operatorname = "a" + operatorname[3:] + expr = getattr(sympy_hyperbolic, operatorname)(arg, evaluate=False) + elif operatorname == "floor": + expr = self.handle_floor(arg) + elif operatorname == "ceil": + expr = self.handle_ceil(arg) + elif operatorname == 'eye': + expr = sympy.eye(arg) + elif operatorname == 'rank': + expr = sympy.Integer(arg.rank()) + elif operatorname in ['trace', 'tr']: + expr = arg.trace() + elif operatorname == 'rref': + expr = arg.rref()[0] + elif operatorname == 'nullspace': + expr = arg.nullspace() + elif operatorname == 'norm': + expr = arg.norm() + elif operatorname == 'cols': + expr = [arg.col(i) for i in range(arg.cols)] + elif operatorname == 'rows': + expr = [arg.row(i) for i in range(arg.rows)] + elif operatorname in ['eig', 'eigen', 'diagonalize']: + expr = arg.diagonalize() + elif operatorname in ['eigenvals', 'eigenvalues']: + expr = arg.eigenvals() + elif operatorname in ['eigenvects', 'eigenvectors']: + expr = arg.eigenvects() + elif operatorname in ['svd', 'SVD']: + expr = arg.singular_value_decomposition() + else: + expr = sympy.Function(operatorname)(arg, evaluate=False) + elif name in ["log", "ln"]: + base = 10 + if func.subexpr(): + if func.subexpr().atom(): + base = self.convert_atom(func.subexpr().atom()) + else: + base = self.convert_expr(func.subexpr().expr()) + elif name == "log": + base = 10 + else: + # it's ln + base = sympy.E + expr = sympy.log(arg, base, evaluate=False) + elif name in ["exp", "exponentialE"]: + expr = sympy.exp(arg, evaluate=False) + elif name == "floor": + expr = self.handle_floor(arg) + elif name == "ceil": + expr = self.handle_ceil(arg) + elif name == 'det': + expr = arg.det() + + elif name in ["sin", "cos", "tan", "csc", "sec", "cot"]: + if func_pow == -1: + name = "a" + name + func_pow = None + expr = getattr(sympy_trig, name)(arg, evaluate=False) + + elif name in ["sinh", "cosh", "tanh"]: + if func_pow == -1: + name = "a" + name + func_pow = None + expr = getattr(sympy_hyperbolic, name)(arg, evaluate=False) + + else: + expr = sympy.Function(name)(arg, evaluate=False) + + if func_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + + elif func.func_normal_multi_arg(): + if func.func_multi_arg(): # function called with parenthesis + args = func.func_multi_arg().getText().split(",") + else: + args = func.func_multi_arg_noparens().split(",") + + args = list(map(lambda arg: self.parse(arg), args)) + name = func.func_normal_multi_arg().start.text[1:] + + if name == "operatorname": + operatorname = func.func_normal_multi_arg().func_operator_name.getText() + if operatorname in ["gcd", "lcm"]: + expr = self.handle_gcd_lcm(operatorname, args) + elif operatorname == 'zeros': + expr = sympy.zeros(*args) + elif operatorname == 'ones': + expr = sympy.ones(*args) + elif operatorname == 'diag': + expr = sympy.diag(*args) + elif operatorname == 'hstack': + expr = sympy.Matrix.hstack(*args) + elif operatorname == 'vstack': + expr = sympy.Matrix.vstack(*args) + elif operatorname in ['orth', 'ortho', 'orthogonal', 'orthogonalize']: + if len(args) == 1: + arg = args[0] + expr = GramSchmidt([arg.col(i) for i in range(arg.cols)], True) + else: + expr = GramSchmidt(args, True) + else: + expr = sympy.Function(operatorname)(*args, evaluate=False) + elif name in ["gcd", "lcm"]: + expr = self.handle_gcd_lcm(name, args) + elif name in ["max", "min"]: + name = name[0].upper() + name[1:] + expr = getattr(sympy_misc, name)(*args, evaluate=False) + else: + expr = sympy.Function(name)(*args, evaluate=False) + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = self.convert_expr(func.supexpr().expr()) + else: + func_pow = self.convert_atom(func.supexpr().atom()) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + + elif func.atom_expr_no_supexpr(): + # define a function + f = sympy.Function(func.atom_expr_no_supexpr().getText()) + # args + args = func.func_common_args().getText().split(",") + if args[-1] == '': + args = args[:-1] + args = [self.parse(arg) for arg in args] + # supexpr + if func.supexpr(): + if func.supexpr().expr(): + expr = self.convert_expr(func.supexpr().expr()) + else: + expr = self.convert_atom(func.supexpr().atom()) + return sympy.Pow(f(*args), expr, evaluate=False) + else: + return f(*args) + elif func.FUNC_INT(): + return self.handle_integral(func) + elif func.FUNC_SQRT(): + expr = self.convert_expr(func.base) + if func.root: + r = self.convert_expr(func.root) + return sympy.Pow(expr, 1 / r, evaluate=False) + else: + return sympy.Pow(expr, sympy.S.Half, evaluate=False) + elif func.FUNC_SUM(): + return self.handle_sum_or_prod(func, "summation") + elif func.FUNC_PROD(): + return self.handle_sum_or_prod(func, "product") + elif func.FUNC_LIM(): + return self.handle_limit(func) + elif func.EXP_E(): + return self.handle_exp(func) + + + def convert_func_arg(self, arg): + if hasattr(arg, 'expr'): + return self.convert_expr(arg.expr()) + else: + return self.convert_mp(arg.mp_nofunc()) + + + def handle_integral(self, func): + if func.additive(): + integrand = self.convert_add(func.additive()) + elif func.frac(): + integrand = self.convert_frac(func.frac()) + else: + integrand = 1 + + int_var = None + if func.DIFFERENTIAL(): + int_var = self.get_differential_var(func.DIFFERENTIAL()) + else: + for sym in integrand.atoms(sympy.Symbol): + s = str(sym) + if len(s) > 1 and s[0] == 'd': + if s[1] == '\\': + int_var = self.create_symbol(s[2:], enforce_case=True) + else: + int_var = self.create_symbol(s[1:], enforce_case=True) + int_sym = sym + if int_var: + integrand = integrand.subs(int_sym, 1) + else: + # Assume dx by default + int_var = self.create_symbol('x', enforce_case=True) + + if func.subexpr(): + if func.subexpr().atom(): + lower = self.convert_atom(func.subexpr().atom()) + else: + lower = self.convert_expr(func.subexpr().expr()) + if func.supexpr().atom(): + upper = self.convert_atom(func.supexpr().atom()) + else: + upper = self.convert_expr(func.supexpr().expr()) + return sympy.Integral(integrand, (int_var, lower, upper)) + else: + return sympy.Integral(integrand, int_var) + + + def handle_sum_or_prod(self, func, name): + val = self.convert_mp(func.mp()) + iter_var = self.convert_expr(func.subeq().equality().expr(0)) + start = self.convert_expr(func.subeq().equality().expr(1)) + if func.supexpr().expr(): # ^{expr} + end = self.convert_expr(func.supexpr().expr()) + else: # ^atom + end = self.convert_atom(func.supexpr().atom()) + + if name == "summation": + return sympy.Sum(val, (iter_var, start, end)) + elif name == "product": + return sympy.Product(val, (iter_var, start, end)) + + + def handle_limit(self, func): + sub = func.limit_sub() + if sub.LETTER_NO_E(): + sub_var = self.create_symbol(sub.LETTER_NO_E().getText(), enforce_case=True) + elif sub.GREEK_CMD(): + sub_var = get_symbol(sub.GREEK_CMD().getText().strip(), self.is_real) + elif sub.OTHER_SYMBOL_CMD(): + sub_var = get_symbol(sub.OTHER_SYMBOL_CMD().getText().strip(), self.is_real) + else: + sub_var = self.create_symbol('x', enforce_case=True) + if sub.SUB(): + direction = "-" + else: + direction = "+" + approaching = self.convert_expr(sub.expr()) + content = self.convert_mp(func.mp()) + + return sympy.Limit(content, sub_var, approaching, direction) + + + def handle_exp(self, func): + if func.supexpr(): + if func.supexpr().expr(): # ^{expr} + exp_arg = self.convert_expr(func.supexpr().expr()) + else: # ^atom + exp_arg = self.convert_atom(func.supexpr().atom()) + else: + exp_arg = 1 + return sympy.exp(exp_arg) + + + def handle_gcd_lcm(self, f, args): + """ + Return the result of gcd() or lcm(), as UnevaluatedExpr + + f: str - name of function ("gcd" or "lcm") + args: List[Expr] - list of function arguments + """ + + args = tuple(map(sympy.nsimplify, args)) + + # gcd() and lcm() don't support evaluate=False + return sympy.UnevaluatedExpr(getattr(sympy, f)(args)) + + + def handle_floor(self, expr): + """ + Apply floor() then return the floored expression. + + expr: Expr - sympy expression as an argument to floor() + """ + return sympy_integers.floor(expr, evaluate=False) + + + def handle_ceil(self, expr): + """ + Apply ceil() then return the ceil-ed expression. + + expr: Expr - sympy expression as an argument to ceil() + """ + return sympy_integers.ceiling(expr, evaluate=False) + + + + def get_differential_var(self, d): + text = self.get_differential_var_str(d.getText()) + return self.create_symbol(text, enforce_case=True) + + + def get_differential_var_str(self, text): + for i in range(1, len(text)): + c = text[i] + if not (c == " " or c == "\r" or c == "\n" or c == "\t"): + idx = i + break + text = text[idx:] + if text[0] == "\\": + text = text[1:] + return text + +# # Set image value +# latex2latex('i=I') +# latex2latex('j=I') +# # set Identity(i) +# for i in range(1, 10): +# lh = sympy.Symbol(r'\bm{I}_' + str(i), real=False) +# lh_m = sympy.MatrixSymbol(r'\bm{I}_' + str(i), i, i) +# rh = sympy.Identity(i).as_mutable() +# variances[lh] = rh +# variances[lh_m] = rh +# var[str(lh)] = rh + +# Common regex + +def convert_to_pct(number: Number): + return sympy.Mul(number, sympy.Rational(1, 100), evaluate=False) + +def latex2sympy(latex_str: str, variable_values: dict | None = None, is_real=None, convert_degrees: bool = False, normalization_config: NormalizationConfig | None = NormalizationConfig(), conversion_config: ConversionConfig = ConversionConfig()): + converter = _Latex2Sympy(variable_values, is_real, convert_degrees, config=conversion_config) + if normalization_config is not None: + latex_str = normalize_latex(latex_str, normalization_config) + return converter.parse(latex_str) + + +if __name__ == "__main__": + # print(normalize_latex("20 \\%", NormalizationConfig(basic_latex=True, units=True, malformed_operators=False, nits=True, boxed=False, equations=True))) + print(latex2sympy(r"\boxed{\text{C, E}}")) + print(latex2sympy(r"0.111")) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/logic.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/logic.py new file mode 100644 index 000000000..bb2ced8bb --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/logic.py @@ -0,0 +1,14 @@ +from sympy import And as SympyAnd +from sympy.core.sympify import sympify + +class And(SympyAnd): + """ + Patched version of And that keeps the _unsorted_args attribute + """ + def __new__(cls, *args, **kwargs): + args = [sympify(arg) for arg in args] + obj = super().__new__(cls, *args, **kwargs) + obj._unsorted_args = args + return obj + + diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/math_normalization.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/math_normalization.py new file mode 100644 index 000000000..551f7ffad --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/math_normalization.py @@ -0,0 +1,512 @@ +import re +from dataclasses import dataclass +from typing import Literal +import logging + +logger = logging.getLogger(__name__) + +@dataclass(frozen=True) +class NormalizationConfig: + """Configuration for latex normalization. + + Each field controls a group of related normalizations: + - basic_latex: Basic latex command replacements (mathrm, displaystyle, etc.) + - units: Remove units and their variations + - malformed_operators: Fix malformed operators (sqrt, frac, etc.) + - nits: Small formatting fixes (spaces, dots, etc.) + - boxed: Extract content from boxed environments + - equations: Handle equation splitting and approximations (deprecated) + """ + basic_latex: bool = True + units: bool = False + malformed_operators: bool = False + nits: bool = False + boxed: Literal["all", "last", "none"] = "all" + equations: bool = False + +# Compile all regex patterns once at module level +r_left = re.compile(r"\\m?left(\\\{|\{|\\\||\||\[|\(|\\rbracl|\\lgroup|\\lbrace|\\lbrack|\\vert|\\lvert|\\lceil|\\lfloor|\\vert|\\lvert|\\langle|\\llcorner|\\ulcorner)") +r_right = re.compile(r"\\m?right(\\\}|\}|\\\||\||\]|\)|\\rbrack|\\rgroup|\\rbrace|\\rbrack|\\vert|\\rvert|\\rceil|\\rfloor|\\vert|\\rvert|\\rangle|\\lrcorner|\\urcorner)") + +# Units regex +units = [ + "integer" "point", + "feet", + "sue", + "digit", + "pound", + "meal", + "edge", + "student", + "children ticket", + "multiple", + "east", + "degree", + "mph", + "kmph", + "ft", + "m square", + " m east", + "sq m", + "deg", + "mile", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "gm", + "p . m", + "lb", + "tile", + "per", + "dm", + "lt", + "gain", + "ab", + "way", + "west", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "cent", + "by", + "gal", + "kmh", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "l", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "s", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "month", + "km", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", + "kilogram", + "second", + "ampere", + "A", + "K", + "mol", + "cd", + "N", + "J", + "W", + "Pa", + "Hz", + "C", + "V", + "Ω", + "F", + "T", + "H", + "eV", + "kW·h", + "atm", + "bar", + "°C" +] + +# We sort here to that when matching from right the longest units are matched first +# E.g "percent" is matched before "cent" + +units_regex_pattern = f"(?:{'|'.join(units)})(?:s|es)?" +units_regex = re.compile(f"(\\d|\\}}|\\s)\\s*(?:{units_regex_pattern})\\s*$") + +# Basic latex regex +to_remove_regex = re.compile( + r"\\mathrm\{th\}|" # "th" + r"\\!\s*|" # comma with inverse space + r"\\text\s*\{\s*\}|" # text with empty braces + r"\\text\s*\{\s*\}|" # text with empty braces + r"\\\$|\$|" # dollar signs + r"(?{pattern})" for name, pattern, _ in to_replace_patterns) +to_replace_regex = re.compile(pattern) + +# Create lookup dictionary for replacements +replacements = {name: replacement for name, _, replacement in to_replace_patterns} + +command_slash_fix_regex = re.compile(r"\\\\(?=[a-zA-Z])") +permutation_regex = re.compile(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}") +equation_split_regex = re.compile(r"(?)=") +unit_superscript_regex = re.compile(r"(\\(?:text|mbox){.*?})(\^\d|\{\^\d\})?$") +approx_split_regex = re.compile(r"\\approx") + +# Malformed operators regex +malformed_operators_patterns = [ + (re.compile(r"\^\s?\((.*?)\)"), r"^{\1}"), + (re.compile(r"sqrt\s?\((.*?)\)"), r"\\sqrt{\1}"), + (re.compile(r"\\frac\s?(\d)\s?(\d+)"), r"\\frac{\1}{\2}"), + (re.compile(r"\\log_\s?(\d)\s?(\d+)"), r"\\log_{\1}{\2}"), + (re.compile(r"\\frac\s?{(.*?)}\s?(\d)"), r"\\frac{\1}{\2}"), + (re.compile(r"\\frac\s?(\d)\s?{(.*?)}"), r"\\frac{\1}{\2}"), + (re.compile(r"\\sqrt\s?(\d)"), r"\\sqrt{\1}") +] + +def _fix_malformed_operators(text: str) -> str: + """Fix malformed operators in the given text.""" + expr_str = text + for pattern, replacement in malformed_operators_patterns: + expr_str = pattern.sub(replacement, expr_str) + expr_str = expr_str.replace(" sqrt", "\\sqrt") + return expr_str + +def replace(match): + # Find which group matched + # Get corresponding replacement from dict + return replacements[match.lastgroup] + +def replace_in_latex(text: str) -> str: + return to_replace_regex.sub(replace, text) + +VALID_SEPARATOR_PATTERN = re.compile(r'and|or|,|;') +def extract_boxed_content(text: str, mode: Literal["last", "all"] = "last") -> str: + """ + Find and extract all \\boxed{...} or \\fbox{...} elements from a string, searching from right to left. + If mode is "last", return content up to the last valid separator. + If mode is "all", return all boxed contents joined by commas. + """ + + def find_content_boundaries(text: str, opening_brace_pos: int, max_pos: int) -> tuple[int, int] | None: + # Start searching for closing brace from the opening brace position + i = opening_brace_pos + num_left_braces_open = 1 # We start after the opening brace + + while i + 1 < max_pos: # Check if next position is within bounds and max_pos + i += 1 + if text[i] == "{": + num_left_braces_open += 1 + elif text[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + return opening_brace_pos, i + return None + + def has_valid_separator(text: str, content_end: int, next_boxed_start: int) -> bool: + between_text = text[content_end + 1:next_boxed_start] + # Making regex for it not worth it so this works + return len(between_text) < 70 and bool(VALID_SEPARATOR_PATTERN.search(between_text)) + + results = [] + current_pos = len(text) + last_boxed_start = None + + max_pos = len(text) + while True: + boxed_idx = text.rfind("\\boxed", 0, current_pos) + fbox_idx = text.rfind("\\fbox", 0, current_pos) + + if boxed_idx < 0 and fbox_idx < 0: + break + + start_idx = max(boxed_idx, fbox_idx) + command_end = start_idx + (6 if boxed_idx > fbox_idx else 5) + + # Find opening brace + next_char_pos = command_end + while next_char_pos < max_pos and text[next_char_pos].isspace(): + next_char_pos += 1 + + if next_char_pos >= max_pos: + break + + if text[next_char_pos] == "{": + boundaries = find_content_boundaries(text, next_char_pos, max_pos) + if not boundaries: + # This is our last box + if len(results) == 0: + results.append(text[next_char_pos:]) + break + content_start, content_end = boundaries + content = text[content_start + 1:content_end].strip() + + if mode == "last" and last_boxed_start is not None: + if not has_valid_separator(text, content_end, last_boxed_start): + break + + results.append(content) + last_boxed_start = start_idx + max_pos = start_idx + else: + # This is our last box + if len(results) == 0: + results.append(text[next_char_pos:]) + # Otherwise we just ignore it + break + + + current_pos = start_idx + + if not results: + return text + + return ",".join(reversed(results)) + +def _fix_fracs(text: str) -> str: + """ + Fix the formatting of fractions in the given text. + Copied from: https://github.com/hendrycks/math/blob/357963a7f5501a6c1708cf3f3fb0cdf525642761/modeling/math_equivalence.py#L1 + + Args: + text (str): The input text. + + Returns: + str: The text with properly formatted fractions. + + Examples: + >>> _fix_fracs("\\frac12") + "\\frac{1}{2}" + >>> _fix_fracs("\\frac{3}{4}") + "\\frac{3}{4}" + >>> _fix_fracs("\\frac1{2}") + "\\frac{1}{2}" + """ + substrs = text.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + for substr in substrs[1:]: + # This allows use to have \\frac{1}{2} and \\ frac1{2} + substr = substr.lstrip() + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + + elif len(substr) < 2: + return text + else: + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + text = new_str + return text + +def _fix_a_slash_b(text: str) -> str: + """Source: https://github.com/hendrycks/math + Reformat fractions formatted as a/b to \\frac{a}{b}. + Example: + >>> _fix_a_slash_b("2/3") + \frac{2}{3} + """ + if len(text.split("/")) != 2: + return text + a_str = text.split("/")[0] + b_str = text.split("/")[1] + try: + a = int(a_str) + b = int(b_str) + assert text == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except Exception: + return text + + +def _fix_sqrt(text: str) -> str: + """Source: https://github.com/hendrycks/math + Reformat square roots. + Example: + >>> _fix_sqrt("\\sqrt3") + \\sqrt{3} + """ + if "\\sqrt" not in text: + return text + splits = text.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + split = split.lstrip() + if len(split) > 0 and split[0] not in ["{", "["]: + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + +def normalize_latex(text: str, config: NormalizationConfig) -> str: + """Normalize latex string according to the provided configuration. + + Args: + text: The latex string to normalize + config: Configuration controlling which normalizations to apply + + Returns: + The normalized latex string + """ + if config.boxed == "all" or config.boxed == "last": + text = extract_boxed_content(text, mode=config.boxed) + + if config.basic_latex: + # Basic latex command replacements + text = text.replace(r'\mathrm{T}', 'T') + text = text.replace(r'\mathrm{d}', 'd').replace(r'{\rm d}', 'd') + text = text.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}').replace(r'\end{matrix}\right]', r'\end{bmatrix}') + text = r_left.sub(r'\1', text) + text = r_right.sub(r'\1', text) + text = permutation_regex.sub(r"\\frac{(\1)!}{((\1)-(\2))!}", text) + + # Remove useless latex commands + text = to_remove_regex.sub("", text) + text = replace_in_latex(text) + + # Remove new lines and simplify tabs + text = text.replace("\n", " ").replace("\t", " ") + + # Fix doubled backslashes in commands + if "matrix" not in text: + text = command_slash_fix_regex.sub(r"\\", text) + + if config.equations: + logger.warning("equations is deprecated, as it handled by the parser now") + # This is to ensure that a=1,b=2 is not splitted + if not "," in text and not ";" in text: + eq_parts = equation_split_regex.split(text) + # We only shorten if there are more than 2 parts, otherwise we keep equation as is + if len(eq_parts) > 2: + text = eq_parts[-1] + + if config.units: + # Remove the units and possibly the superscript + _text = unit_superscript_regex.sub("", text).strip() + if _text != "" and _text != text: + text = _text + + # Remove unit texts + for _ in range(2): + _text = units_regex.sub(r"\1", text) + if _text != "" and _text != text: + text = _text + + # This can trigger empty \text{...} + # Make sure not to remove space this created + + if config.nits: + # Fix leading decimal + if len(text) > 0 and text[0] == ".": + text = "0" + text + + # Fix 0.5 to fraction + if text == "0.5": + text = "\\frac{1}{2}" + + if config.malformed_operators: + # Fix malformed operators + text = _fix_malformed_operators(text) + text = _fix_sqrt(text) + text = _fix_fracs(text) + text = _fix_a_slash_b(text) + + return text.strip() \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/sets.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/sets.py new file mode 100644 index 000000000..60f757b6e --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/sets.py @@ -0,0 +1,36 @@ +from sympy import S, Basic, Set, ordered, sympify +from sympy.sets import FiniteSet as SympyFiniteSet +from sympy.core.parameters import global_parameters + +class FiniteSet(SympyFiniteSet): + """ + FiniteSet which keeps the _unsorted_args attribute, only available till the first evaluation + """ + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + if evaluate: + args = list(map(sympify, args)) + + if len(args) == 0: + return S.EmptySet + else: + args = list(map(sympify, args)) + unsorted_args = args + + # keep the form of the first canonical arg + dargs = {} + for i in reversed(list(ordered(args))): + if i.is_Symbol: + dargs[i] = i + else: + try: + dargs[i.as_dummy()] = i + except TypeError: + # e.g. i = class without args like `Interval` + dargs[i] = i + _args_set = set(dargs.values()) + args = list(ordered(_args_set, Set._infimum_key)) + obj = Basic.__new__(cls, *args) + obj._args_set = _args_set + obj._unsorted_args = unsorted_args + return obj diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/symbols.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/symbols.py new file mode 100644 index 000000000..ec3d18e22 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/src/latex2sympy2_extended/symbols.py @@ -0,0 +1,168 @@ +import sympy + + +GREEK_LETTER_MAP = { + # Alpha + 'α': 'alpha', + '\\alpha': 'alpha', + '\\char000391': 'alpha', + + # Beta + 'β': 'beta', + '\\beta': 'beta', + '\\char000392': 'beta', + + # Gamma + 'γ': 'gamma', + '\\gamma': 'gamma', + 'Γ': 'Gamma', + '\\Gamma': 'Gamma', + + # Delta + 'δ': 'delta', + '\\delta': 'delta', + 'Δ': 'Delta', + '\\Delta': 'Delta', + + # Epsilon + 'ε': 'epsilon', + '\\epsilon': 'epsilon', + '\\char000190': 'epsilon', + 'ϵ': 'varepsilon', + '\\varepsilon': 'varepsilon', + + # Zeta + 'ζ': 'zeta', + '\\zeta': 'zeta', + '\\char000396': 'zeta', + + # Eta + 'η': 'eta', + '\\eta': 'eta', + '\\char000397': 'eta', + + # Theta + 'θ': 'theta', + '\\theta': 'theta', + 'Θ': 'Theta', + '\\Theta': 'Theta', + 'ϑ': 'vartheta', + '\\vartheta': 'vartheta', + + # Iota + 'ι': 'iota', + '\\iota': 'iota', + '\\char000399': 'iota', + + # Kappa + 'κ': 'kappa', + '\\kappa': 'kappa', + '\\char00039A': 'kappa', + + # Lambda + 'λ': 'lambda', + '\\lambda': 'lambda', + 'Λ': 'Lambda', + '\\Lambda': 'Lambda', + + # Mu + 'μ': 'mu', + '\\mu': 'mu', + '\\char00039C': 'mu', + + # Nu + 'ν': 'nu', + '\\nu': 'nu', + '\\char00039D': 'nu', + + # Xi + 'ξ': 'xi', + '\\xi': 'xi', + 'Ξ': 'Xi', + '\\Xi': 'Xi', + + # Omicron + 'ο': 'omicron', + '\\omicron': 'omicron', + '\\char00039F': 'omicron', + + # Pi + 'π': 'pi', + '\\pi': 'pi', + 'Π': 'Pi', + '\\Pi': 'Pi', + 'ϖ': 'varpi', + '\\varpi': 'varpi', + + # Rho + 'ρ': 'rho', + '\\rho': 'rho', + '\\char0003A1': 'rho', + 'ϱ': 'varrho', + '\\varrho': 'varrho', + + # Sigma + 'σ': 'sigma', + '\\sigma': 'sigma', + 'Σ': 'Sigma', + '\\Sigma': 'Sigma', + 'ς': 'varsigma', + '\\varsigma': 'varsigma', + + # Tau + 'τ': 'tau', + '\\tau': 'tau', + '\\char0003A4': 'tau', + + # Upsilon + 'υ': 'upsilon', + '\\upsilon': 'upsilon', + 'Υ': 'Upsilon', + '\\Upsilon': 'Upsilon', + + # Phi + 'φ': 'phi', + '\\phi': 'phi', + 'Φ': 'Phi', + '\\Phi': 'Phi', + 'ϕ': 'varphi', + '\\varphi': 'varphi', + + # Chi + 'χ': 'chi', + '\\chi': 'chi', + '\\char0003A7': 'chi', + + # Psi + 'ψ': 'psi', + '\\psi': 'psi', + 'Ψ': 'Psi', + '\\Psi': 'Psi', + + # Omega + 'ω': 'omega', + '\\omega': 'omega', + 'Ω': 'Omega', + '\\Omega': 'Omega' +} + +sympy_singleton_map = { + 'pi': sympy.S.Pi, +} + +def get_symbol(latex_str: str, is_real: bool | None = True, lowercase_symbols: bool = False): + latex_str = latex_str.strip() + letter = GREEK_LETTER_MAP.get(latex_str.replace('"', '')) + if letter is None: + letter = latex_str + + if letter.startswith('\\'): + letter = letter[1:] + + if lowercase_symbols: + letter = letter.lower() + + if letter in sympy_singleton_map: + return sympy_singleton_map[letter] + else: + return sympy.Symbol(letter, real=is_real) \ No newline at end of file diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/__init__.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/abs_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/abs_test.py new file mode 100644 index 000000000..559e31f5a --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/abs_test.py @@ -0,0 +1,20 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import Abs + +from tests.context import assert_equal, get_simple_examples + +examples = get_simple_examples(Abs) + +delimiter_pairs = { + '|': '|', + '\\vert': '\\vert', + '\\lvert': '\\rvert' +} + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_abs(input, output, symbolically): + for left, right in delimiter_pairs.items(): + assert_equal("{left}{input}{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) + assert_equal("\\left{left}{input}\\right{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) + assert_equal("\\mleft{left}{input}\\mright{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_bad_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_bad_test.py new file mode 100644 index 000000000..19d4d6fc5 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_bad_test.py @@ -0,0 +1,69 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from tests.context import assert_equal + + +def pytest_generate_tests(metafunc): + metafunc.parametrize('s', metafunc.cls.BAD_STRINGS) + + +class TestAllBad(object): + # These bad latex strings should raise an exception when parsed + BAD_STRINGS = [ + "(", + ")", + # "a / b /", + "\\frac{d}{dx}", + "(\\frac{d}{dx})" + "\\sqrt{}", + "\\sqrt", + "{", + "}", + # "1.1.1", + "\\mathit{TEST}" + "\\frac{2}{}", + "\\frac{}{2}", + "\\int", + # "1 +", + # "a +", + "!", + "!0", + "_", + "^", + # "a // b", + # "a \\cdot \\cdot b", + # "a \\div \\div b", + "a\\mod \\begin{matrix}b\\end{matrix}" + "|", + "||x|", + "\\lfloor x", + "\\lfloor a \\rceil", + "\\operatorname{floor}(12.3, 123.4)", + "()", + "((((((((((((((((()))))))))))))))))", + "-", + "\\frac{d}{dx} + \\frac{d}{dt}", + # "f()", + # "f(,", + # "f(x,,y)", + # "f(x,y,", + "\\sin^x", + "\\cos^2", + # "\\cos 1 \\cos", + # "\\gcd(3)", + # "\\lcm(2)", + "@", "#", "$", "%", "&", "*", + "\\", + "~", + "\\frac{(2 + x}{1 - x)}", + # percentages without numbers before-hand + "a\\%", + "\\%100", + # dollar signs without numbers after + "\\$" + ] + + def test_bad_string(self, s): + with pytest.raises(Exception): + latex2sympy(s) + pass diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_good_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_good_test.py new file mode 100644 index 000000000..0696df71e --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/all_good_test.py @@ -0,0 +1,284 @@ +import pytest +import sympy +from latex2sympy2_extended import latex2sympy +from sympy import ( + E, I, FiniteSet, Function, oo, pi, sqrt, root, Symbol, Add, Mul, Pow, Abs, factorial, log, Eq, Ne, S, + Rational, Integer, UnevaluatedExpr, sin, cos, tan, sinh, cosh, tanh, asin, acos, + atan, asinh, acosh, atanh, csc, sec, Sum, Product, Limit, Integral, Derivative, + LessThan, StrictLessThan, GreaterThan, StrictGreaterThan, exp, binomial, Matrix, + MatMul, MatAdd, Mod, gcd, lcm, floor, ceiling, Max, Min +) +from tests.context import assert_equal, _Add, _Mul, _Pow +import hashlib + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +a = Symbol('a') +b = Symbol('b') +c = Symbol('c') +f = Symbol('f') +t = Symbol('t') +k = Symbol('k') +n = Symbol('n') +theta = Symbol('theta') + +# shorthand definitions + + +def _Abs(a): + return Abs(a, evaluate=False) + + +def _factorial(a): + return factorial(a, evaluate=False) + + +def _log(a, b): + return log(a, b, evaluate=False) + + +def pytest_generate_tests(metafunc): + metafunc.parametrize('s, eq', metafunc.cls.GOOD_PAIRS) + + +class TestAllGood(object): + # These latex strings should parse to the corresponding SymPy expression + GOOD_PAIRS = [ + ("0", 0), + ("1", 1), + ("-3.14", -3.14), + ("5-3", _Add(5, -3)), + ("(-7.13)(1.5)", _Mul(Rational('-7.13'), Rational('1.5'))), + ("\\left(-7.13\\right)\\left(1.5\\right)", _Mul(Rational('-7.13'), Rational('1.5'))), + ("x", x), + ("2x", 2 * x), + ("x^2", x**2), + ("x^{3 + 1}", x**_Add(3, 1)), + ("x^{\\left{3 + 1\\right}}", x**_Add(3, 1)), + ("-3y + 2x", _Add(_Mul(2, x), Mul(-1, 3, y, evaluate=False))), + ("-c", -c), + ("a \\cdot b", a * b), + ("a / b", a / b), + ("a \\div b", a / b), + ("a + b", a + b), + ("a + b - a", Add(a, b, _Mul(-1, a), evaluate=False)), + ("a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)), + ("a^2 + b^2 != 2c^2", Ne(a**2 + b**2, 2 * c**2)), + ("a\\mod b", Mod(a, b)), + ("\\sin \\theta", sin(theta)), + ("\\sin(\\theta)", sin(theta)), + ("\\sin\\left(\\theta\\right)", sin(theta)), + ("\\sin^{-1} a", asin(a)), + ("\\sin a \\cos b", _Mul(sin(a), cos(b))), + ("\\sin \\cos \\theta", sin(cos(theta))), + ("\\sin(\\cos \\theta)", sin(cos(theta))), + ("\\arcsin(a)", asin(a)), + ("\\arccos(a)", acos(a)), + ("\\arctan(a)", atan(a)), + ("\\sinh(a)", sinh(a)), + ("\\cosh(a)", cosh(a)), + ("\\tanh(a)", tanh(a)), + ("\\sinh^{-1}(a)", asinh(a)), + ("\\cosh^{-1}(a)", acosh(a)), + ("\\tanh^{-1}(a)", atanh(a)), + ("\\arcsinh(a)", asinh(a)), + ("\\arccosh(a)", acosh(a)), + ("\\arctanh(a)", atanh(a)), + ("\\arsinh(a)", asinh(a)), + ("\\arcosh(a)", acosh(a)), + ("\\artanh(a)", atanh(a)), + ("\\operatorname{arcsinh}(a)", asinh(a)), + ("\\operatorname{arccosh}(a)", acosh(a)), + ("\\operatorname{arctanh}(a)", atanh(a)), + ("\\operatorname{arsinh}(a)", asinh(a)), + ("\\operatorname{arcosh}(a)", acosh(a)), + ("\\operatorname{artanh}(a)", atanh(a)), + ("\\operatorname{gcd}(a, b)", UnevaluatedExpr(gcd(a, b))), + ("\\operatorname{lcm}(a, b)", UnevaluatedExpr(lcm(a, b))), + ("\\operatorname{gcd}(a,b)", UnevaluatedExpr(gcd(a, b))), + ("\\operatorname{lcm}(a,b)", UnevaluatedExpr(lcm(a, b))), + ("\\operatorname{floor}(a)", floor(a)), + ("\\operatorname{ceil}(b)", ceiling(b)), + ("\\cos^2(x)", cos(x)**2), + ("\\cos(x)^2", cos(x)**2), + ("\\gcd(a, b)", UnevaluatedExpr(gcd(a, b))), + ("\\lcm(a, b)", UnevaluatedExpr(lcm(a, b))), + ("\\gcd(a,b)", UnevaluatedExpr(gcd(a, b))), + ("\\lcm(a,b)", UnevaluatedExpr(lcm(a, b))), + ("\\floor(a)", floor(a)), + ("\\ceil(b)", ceiling(b)), + ("\\max(a, b)", Max(a, b)), + ("\\min(a, b)", Min(a, b)), + ("\\frac{a}{b}", a / b), + ("\\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + ("\\frac{7}{3}", _Mul(7, _Pow(3, -1))), + ("(\\csc x)(\\sec y)", csc(x) * sec(y)), + ("\\lim_{x \\to 3} a", Limit(a, x, 3)), + ("\\lim_{x \\rightarrow 3} a", Limit(a, x, 3)), + ("\\lim_{x \\Rightarrow 3} a", Limit(a, x, 3)), + ("\\lim_{x \\longrightarrow 3} a", Limit(a, x, 3)), + ("\\lim_{x \\Longrightarrow 3} a", Limit(a, x, 3)), + ("\\lim_{x \\to 3^{+}} a", Limit(a, x, 3, dir='+')), + ("\\lim_{x \\to 3^{-}} a", Limit(a, x, 3, dir='-')), + ("\\infty", oo), + ("\\infty\\%", oo), + ("\\$\\infty", oo), + ("-\\infty", -oo), + ("-\\infty\\%", -oo), + ("-\\$\\infty", -oo), + ("\\lim_{x \\to \\infty} \\frac{1}{x}", Limit(_Mul(1, _Pow(x, -1)), x, oo)), + ("\\frac{d}{dx} x", Derivative(x, x)), + ("\\frac{d}{dt} x", Derivative(x, t)), + # ("f(x)", f(x)), + # ("f(x, y)", f(x, y)), + # ("f(x, y, z)", f(x, y, z)), + # ("\\frac{d f(x)}{dx}", Derivative(f(x), x)), + # ("\\frac{d\\theta(x)}{dx}", Derivative(theta(x), x)), + ("|x|", _Abs(x)), + ("\\left|x\\right|", _Abs(x)), + ("||x||", _Abs(Abs(x))), + ("|x||y|", _Abs(x) * _Abs(y)), + ("||x||y||", _Abs(_Abs(x) * _Abs(y))), + ("\\lfloor x\\rfloor", floor(x)), + ("\\lceil y\\rceil", ceiling(y)), + ("\\pi^{|xy|}", pi**_Abs(x * y)), + ("\\frac{\\pi}{3}", _Mul(pi, _Pow(3, -1))), + ("\\sin{\\frac{\\pi}{2}}", sin(_Mul(pi, _Pow(2, -1)), evaluate=False)), + ("a+bI", a + I * b), + ("e^{I\\pi}", -1), + ("\\int x dx", Integral(x, x)), + ("\\int x d\\theta", Integral(x, theta)), + ("\\int (x^2 - y)dx", Integral(x**2 - y, x)), + ("\\int x + a dx", Integral(_Add(x, a), x)), + ("\\int da", Integral(1, a)), + ("\\int_0^7 dx", Integral(1, (x, 0, 7))), + ("\\int_a^b x dx", Integral(x, (x, a, b))), + ("\\int^b_a x dx", Integral(x, (x, a, b))), + ("\\int_{a}^b x dx", Integral(x, (x, a, b))), + ("\\int^{b}_a x dx", Integral(x, (x, a, b))), + ("\\int_{a}^{b} x dx", Integral(x, (x, a, b))), + ("\\int_{ }^{}x dx", Integral(x, x)), + ("\\int^{ }_{ }x dx", Integral(x, x)), + ("\\int^{b}_{a} x dx", Integral(x, (x, a, b))), + # ("\\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + ("\\int (x+a)", Integral(_Add(x, a), x)), + ("\\int a + b + c dx", Integral(Add(a, b, c, evaluate=False), x)), + ("\\int \\frac{dz}{z}", Integral(Pow(z, -1), z)), + ("\\int \\frac{3 dz}{z}", Integral(3 * Pow(z, -1), z)), + ("\\int \\frac{1}{x} dx", Integral(Mul(Integer(1), Pow(x, -1), evaluate=False), x)), + ("\\int \\frac{1}{a} + \\frac{1}{b} dx", Integral(_Add(Mul(1, Pow(a, -1), evaluate=False), Mul(1, Pow(b, -1), evaluate=False)), x)), + ("\\int \\frac{3 \\cdot d\\theta}{\\theta}", Integral(Mul(3, _Pow(theta, -1), evaluate=False), theta)), + ("\\int \\frac{1}{x} + 1 dx", Integral(_Add(Mul(1, Pow(x, -1), evaluate=False), 1), x)), + ("x_0", Symbol('x_0')), + ("x_{1}", Symbol('x_1')), + ("x_a", Symbol('x_a')), + ("x_{b}", Symbol('x_b')), + ("h_\\theta", Symbol('h_{\\theta}')), + ("h_\\theta ", Symbol('h_{\\theta}')), + ("h_{\\theta}", Symbol('h_{\\theta}')), + # ("h_{\\theta}(x_0, x_1)", Symbol('h_{theta}', real=True)(Symbol('x_{0}', real=True), Symbol('x_{1}', real=True))), + ("x!", _factorial(x)), + ("100!", _factorial(100)), + ("\\theta!", _factorial(theta)), + ("(x + 1)!", _factorial(_Add(x, 1))), + ("\\left(x + 1\\right)!", _factorial(_Add(x, 1))), + ("(x!)!", _factorial(_factorial(x))), + ("x!!!", _factorial(_factorial(_factorial(x)))), + ("5!7!", _Mul(_factorial(5), _factorial(7))), + ("\\sqrt{x}", sqrt(x)), + ("\\sqrt{x + b}", sqrt(_Add(x, b))), + ("\\sqrt[3]{\\sin x}", root(sin(x), 3)), + ("\\sqrt[y]{\\sin x}", root(sin(x), y)), + ("\\sqrt[\\theta]{\\sin x}", root(sin(x), theta)), + ("x < y", StrictLessThan(x, y)), + ("x \\leq y", LessThan(x, y)), + ("x > y", StrictGreaterThan(x, y)), + ("x \\geq y", GreaterThan(x, y)), + ("\\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + ("\\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + ("\\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + ("\\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + ("\\sum_{k = 1}^{10} k^2", Sum(k**2, (k, 1, 10))), + ("\\sum_{n = 0}^{\\infty} \\frac{1}{n!}", Sum(_Pow(_factorial(n), -1), (n, 0, oo))), + ("\\prod_{a = b}^{c} x", Product(x, (a, b, c))), + ("\\prod_{a = b}^c x", Product(x, (a, b, c))), + ("\\prod^{c}_{a = b} x", Product(x, (a, b, c))), + ("\\prod^c_{a = b} x", Product(x, (a, b, c))), + ("\\ln x", _log(x, E)), + ("\\ln xy", _log(x * y, E)), + ("\\log x", _log(x, 10)), + ("\\log xy", _log(x * y, 10)), + # ("\\log_2 x", _log(x, 2)), + ("\\log_{2} x", _log(x, 2)), + # ("\\log_a x", _log(x, a)), + ("\\log_{a} x", _log(x, a)), + ("\\log_{11} x", _log(x, 11)), + ("\\log_{a^2} x", _log(x, _Pow(a, 2))), + ("[x]", x), + ("[a + b]", _Add(a, b)), + ("\\frac{d}{dx} [ \\tan x ]", Derivative(tan(x), x)), + ("2\\overline{x}", 2 * Symbol('bar{x}')), + ("2\\overline{x}_n", 2 * Symbol('bar{x}_n')), + ("\\frac{x}{\\overline{x}_n}", x / Symbol('bar{x}_n')), + ("\\frac{\\sin(x)}{\\overline{x}_n}", sin(Symbol('x')) / Symbol('bar{x}_n')), + ("2\\bar{x}", 2 * Symbol('bar{x}')), + ("2\\bar{x}_n", 2 * Symbol('bar{x}_n')), + ("\\sin\\left(\\theta\\right) \\cdot4", sin(theta) * 4), + ("\\ln\\left(\\theta\\right)", _log(theta, E)), + ("\\ln\\left(x-\\theta\\right)", _log(x - theta, E)), + ("\\ln\\left(\\left(x-\\theta\\right)\\right)", _log(x - theta, E)), + # ("\\ln\\left(\\left\\{x-\\theta\\right\\}\\right)", _log(x - theta, E)), + # ("\\ln\\left(\\left|x-\\theta\\right|\\right)", _log(_Abs(x - theta), E)), + ("\\frac{1}{2}xy(x+y)", Mul(Rational(1, 2), x, y, (x + y), evaluate=False)), + ("\\frac{1}{2}\\theta(x+y)", Mul(Rational(1, 2), theta, (x + y), evaluate=False)), + ("1-f(x)", 1 - Function('f')(x)), + + ("\\begin{matrix}1&2\\\\3&4\\end{matrix}", Matrix([[1, 2], [3, 4]])), + ("\\begin{matrix}x&x^2\\\\\\sqrt{x}&x\\end{matrix}", Matrix([[x, x**2], [_Pow(x, S.Half), x]])), + ("\\begin{matrix}\\sqrt{x}\\\\\\sin(\\theta)\\end{matrix}", Matrix([_Pow(x, S.Half), sin(theta)])), + ("\\begin{pmatrix}1&2\\\\3&4\\end{pmatrix}", Matrix([[1, 2], [3, 4]])), + ("\\begin{bmatrix}1&2\\\\3&4\\end{bmatrix}", Matrix([[1, 2], [3, 4]])), + + # scientific notation + ("2.5\\times 10^2", 250), + ("1500\\times 10^{-1}", 150), + + # e notation + ("2.5E2", 250), + ("1500E-1", 150), + + # multiplication without cmd + ("2x2y", Mul(2, x, 2, y, evaluate=False)), + ("2x2", Mul(2, x, 2, evaluate=False)), + ("x2", x * 2), + + # lin alg processing + ("\\theta\\begin{matrix}1&2\\\\3&4\\end{matrix}", MatMul(theta, Matrix([[1, 2], [3, 4]]), evaluate=False)), + ("\\theta\\begin{matrix}1\\\\3\\end{matrix} - \\begin{matrix}-1\\\\2\\end{matrix}", MatAdd(MatMul(theta, Matrix([[1], [3]]), evaluate=False), MatMul(-1, Matrix([[-1], [2]]), evaluate=False), evaluate=False)), + ("\\theta\\begin{matrix}1&0\\\\0&1\\end{matrix}*\\begin{matrix}3\\\\-2\\end{matrix}", MatMul(theta, Matrix([[1, 0], [0, 1]]), Matrix([3, -2]), evaluate=False)), + # ("\\frac{1}{9}\\theta\\begin{matrix}1&2\\\\3&4\\end{matrix}", MatMul(Mul(Integer(1), _Pow(9, -1)), theta, Matrix([[1, 2], [3, 4]]), evaluate=False)), + ("\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix};\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}", FiniteSet(Matrix([1, 2, 3]), Matrix([4, 3, 1]))), + ("\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix};\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}", FiniteSet(Matrix([1, 2, 3]), Matrix([4, 3, 1]))), + ("\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix};\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}", FiniteSet(Matrix([1, 2, 3]), Matrix([4, 3, 1]))), + ("\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix};\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix};\\begin{pmatrix}1\\\\1\\\\1\\end{pmatrix}", FiniteSet(Matrix([1, 2, 3]), Matrix([4, 3, 1]), Matrix([1, 1, 1]))), + ("\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}", Matrix([1, 2, 3])), + ("\\left{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}\\right}", FiniteSet(Matrix([1, 2, 3]))), + ("{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}}", FiniteSet(Matrix([1, 2, 3]))), + + # us dollars + ("\\$1,000.00", 1000), + ("\\$543.21", 543.21), + ("\\$0.009", 0.009), + + # percentages + ("100\\%", 1), + ("1.5\\%", 0.015), + ("0.05\\%", 0.0005), + + # empty set + ("\\emptyset", S.EmptySet) + ] + + def test_good_pair(self, s, eq): + assert_equal(s, eq) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/atom_expr_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/atom_expr_test.py new file mode 100644 index 000000000..9b2601bd9 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/atom_expr_test.py @@ -0,0 +1,59 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import Symbol, Integer, Pow +from tests.context import assert_equal + +# label, text, symbol_text +symbols = [ + ('letter', 'x', 'x'), + ('greek letter', '\\lambda', 'lambda'), + ('greek letter w/ space', '\\alpha ', 'alpha'), + ('accented letter', '\\overline{x}', 'bar{x}') +] + +subscripts = [ + ('2'), + ('{23}'), + ('i'), + ('{ij}'), + ('{i,j}'), + ('{good}'), + ('{x^2}') +] + +examples = [] +for symbol in symbols: + for subscript in subscripts: + examples.append(tuple(list(symbol) + [subscript])) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_supexpr(label, text, symbol_text, subscript): + assert_equal(text + '^2', Pow(Symbol(symbol_text), Integer(2))) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_subexpr(label, text, symbol_text, subscript): + assert_equal(text + '_' + subscript, Symbol(symbol_text + '_' + subscript)) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_subexpr_before_supexpr(label, text, symbol_text, subscript): + assert_equal(text + '_' + subscript + '^2', Pow(Symbol(symbol_text + '_' + subscript), Integer(2))) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_subexpr_before_supexpr_with_braces(label, text, symbol_text, subscript): + wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}' + assert_equal(text + '_' + wrapped_subscript + '^{2}', Pow(Symbol(symbol_text + '_' + subscript), Integer(2))) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_supexpr_before_subexpr(label, text, symbol_text, subscript): + assert_equal(text + '^2_' + subscript, Pow(Symbol(symbol_text + '_' + subscript), Integer(2))) + + +@pytest.mark.parametrize('label, text, symbol_text, subscript', examples) +def test_with_supexpr_before_subexpr_with_braces(label, text, symbol_text, subscript): + wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}' + assert_equal(text + '^{2}_' + wrapped_subscript, Pow(Symbol(symbol_text + '_' + subscript), Integer(2))) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/binomial_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/binomial_test.py new file mode 100644 index 000000000..38b88f0fa --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/binomial_test.py @@ -0,0 +1,37 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import binomial, Symbol +from tests.context import assert_equal, _Add, _Mul, _Pow + +x = Symbol('x') +y = Symbol('y') +theta = Symbol('theta') +gamma = Symbol('gamma') + + +def test_binomial_numeric(): + assert_equal("\\binom{16}{2}", binomial(16, 2)) + + +def test_binomial_symbols(): + assert_equal("\\binom{x}{y}", binomial(x, y)) + + +def test_binomial_greek_symbols(): + assert_equal("\\binom{\\theta}{\\gamma}", binomial(theta, gamma)) + + +def test_binomial_expr(): + assert_equal("\\binom{16+2}{\\frac{4}{2}}", binomial(_Add(16, 2), _Mul(4, _Pow(2, -1)), evaluate=False)) + + +def test_choose_numeric(): + assert_equal("{16 \\choose 2}", binomial(16, 2)) + + +def test_choose_symbols(): + assert_equal("{x \\choose y}", binomial(x, y)) + + +def test_choose_greek_symbols(): + assert_equal("{\\theta \\choose \\gamma}", binomial(theta, gamma)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/boxed_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/boxed_test.py new file mode 100644 index 000000000..e40519c72 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/boxed_test.py @@ -0,0 +1,21 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import FiniteSet, Tuple +from tests.context import assert_equal, get_simple_examples + +examples = get_simple_examples(lambda x: x) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_boxed_func(input, output, symbolically): + assert_equal("\\boxed{{{input}}}".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', [ + ("\\boxed{1,2,3}", FiniteSet(1, 2, 3), False), + ("\\boxed{(1,2,3)}", Tuple(1, 2, 3), False), + ("\\boxed{\\{1,2,3\\}}", FiniteSet(1, 2, 3), False), +]) + +def test_boxed_func_with_braces(input, output, symbolically): + assert_equal(input, output, symbolically=symbolically) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/ceil_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/ceil_test.py new file mode 100644 index 000000000..a93e494be --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/ceil_test.py @@ -0,0 +1,30 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import ceiling +from tests.context import assert_equal, get_simple_examples + +examples = get_simple_examples(ceiling) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_ceil_func(input, output, symbolically): + assert_equal("\\ceil({input})".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_ceil_operatorname(input, output, symbolically): + assert_equal("\\operatorname{{ceil}}({input})".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_ceil_cmd(input, output, symbolically): + assert_equal("\\lceil {input}\\rceil".format(input=input), output, symbolically=symbolically) + assert_equal("\\left\\lceil {input}\\right\\rceil".format(input=input), output, symbolically=symbolically) + assert_equal("\\mleft\\lceil {input}\\mright\\rceil".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_ceil_corners(input, output, symbolically): + assert_equal("\\ulcorner {input}\\urcorner".format(input=input), output, symbolically=symbolically) + assert_equal("\\left\\ulcorner {input}\\right\\urcorner".format(input=input), output, symbolically=symbolically) + assert_equal("\\mleft\\ulcorner {input}\\mright\\urcorner".format(input=input), output, symbolically=symbolically) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/complex_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/complex_test.py new file mode 100644 index 000000000..1566a1918 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/complex_test.py @@ -0,0 +1,22 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import Sum, I, Symbol, Integer +from tests.context import assert_equal + +a = Symbol('a') +b = Symbol('b') +i = Symbol('i') +n = Symbol('n') +x = Symbol('x') + + +def test_complex(): + assert_equal("a+Ib", a + I * b) + + +def test_complex_e(): + assert_equal("e^{I\\pi}", Integer(-1)) + + +def test_complex_sum(): + assert_equal("\\sum_{i=0}^{n} i \\cdot x", Sum(i * x, (i, 0, n))) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/context.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/context.py new file mode 100644 index 000000000..b2ccd7e05 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/context.py @@ -0,0 +1,50 @@ +from latex2sympy2_extended import latex2sympy +from sympy import ( + Add, Mul, Pow, Symbol, Rational, pi, sqrt, + simplify, srepr +) + +def _Add(a, b): + return Add(a, b, evaluate=False) + +def _Mul(a, b): + return Mul(a, b, evaluate=False) + +def _Pow(a, b): + return Pow(a, b, evaluate=False) + +def assert_equal(latex, expr, symbolically=False, variables: dict | None = None): + parsed = latex2sympy(latex, variable_values=variables) + if symbolically: + assert simplify(parsed - expr) == 0 + else: + actual_exp_tree = srepr(parsed) + expected_exp_tree = srepr(expr) + try: + assert actual_exp_tree == expected_exp_tree + except Exception: + if (isinstance(parsed, (int, float)) or parsed.is_number) and \ + (isinstance(expr, (int, float)) or expr.is_number): + assert parsed == expr or parsed - expr == 0 or simplify(parsed - expr) == 0 + else: + print('expected_exp_tree = ', expected_exp_tree) + print('actual exp tree = ', actual_exp_tree) + raise + +def get_simple_examples(func): + ''' + Returns an array of tuples, containing the string `input`, sympy `output` using the provided sympy `func`, and `symbolically` boolean + for calling `compare`. + ''' + x = Symbol('x', real=None) + y = Symbol('y', real=None) + return [ + ("1.1", func(1.1), False), + ("6.9", func(6.9), False), + ("3.5", func(3.5), False), + ("8", func(8), False), + ("0", func(0), False), + ("x", func(x), True), + ("x + y", func(x + y), True), + ("2y-y-y", func(2 * y - y - y), True) + ] diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/exp_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/exp_test.py new file mode 100644 index 000000000..b79ac773f --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/exp_test.py @@ -0,0 +1,58 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import exp, sin, Symbol, E +from tests.context import assert_equal + +x = Symbol('x') +y = Symbol('y') + + +def test_exp_letter(): + assert_equal("e", E) + assert_equal("e", exp(1)) + + +def test_exp_func(): + assert_equal("\\exp(3)", exp(3)) + + +def test_exp_func_no_delim(): + assert_equal("\\exp3", exp(3)) + + +def test_exp_command_symbol(): + assert_equal("\\exponentialE", E) + assert_equal("\\exponentialE", exp(1)) + + +def test_exp_command_symbol_expression(): + assert_equal("\\exponentialE^{3}", exp(3)) + + +def test_exp_command_symbol_multiplied(): + ''' + \\exponentialE is NOT a function, so using the following notation equates to multiplication + ''' + assert_equal("\\exponentialE (3)", E * 3) + assert_equal("\\exponentialE \\left( 3\\right)", E * 3) + assert_equal("\\exponentialE \\times 3", E * 3) + + +def test_exp_numeric(): + assert_equal("e^3", exp(3)) + + +def test_exp_symbol(): + assert_equal("e^x", exp(x)) + + +def test_exp_symbol_expr(): + assert_equal("e^{x+y}", exp(x + y)) + + +def test_exp_symbol_expr_group(): + assert_equal("e^{(x+y)}", exp(x + y)) + + +def test_exp_expr(): + assert_equal("\\sin(x)*e^x", sin(x) * exp(x)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/floor_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/floor_test.py new file mode 100644 index 000000000..b709ceb6d --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/floor_test.py @@ -0,0 +1,30 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import floor +from tests.context import assert_equal, get_simple_examples + +examples = get_simple_examples(floor) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_floor_func(input, output, symbolically): + assert_equal("\\floor({input})".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_floor_operatorname(input, output, symbolically): + assert_equal("\\operatorname{{floor}}({input})".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_floor_cmd(input, output, symbolically): + assert_equal("\\lfloor {input}\\rfloor".format(input=input), output, symbolically=symbolically) + assert_equal("\\left\\lfloor {input}\\right\\rfloor".format(input=input), output, symbolically=symbolically) + assert_equal("\\mleft\\lfloor {input}\\mright\\rfloor".format(input=input), output, symbolically=symbolically) + + +@pytest.mark.parametrize('input, output, symbolically', examples) +def test_floor_corners(input, output, symbolically): + assert_equal("\\llcorner {input}\\lrcorner".format(input=input), output, symbolically=symbolically) + assert_equal("\\left\\llcorner {input}\\right\\lrcorner".format(input=input), output, symbolically=symbolically) + assert_equal("\\mleft\\llcorner {input}\\mright\\lrcorner".format(input=input), output, symbolically=symbolically) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/gcd_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/gcd_test.py new file mode 100644 index 000000000..943ceebc8 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/gcd_test.py @@ -0,0 +1,162 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import Symbol, Rational, UnevaluatedExpr, gcd, igcd, sqrt, pi +from tests.context import assert_equal, _Add, _Mul, _Pow + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + + +def test_gcd_usual(): + assert_equal("\\gcd(18, 3)", gcd(18, 3)) + assert_equal("\\gcd(3, 18)", gcd(3, 18)) + assert_equal("\\gcd(2, 2)", gcd(2, 2)) + assert_equal("\\gcd(0, 21)", UnevaluatedExpr(gcd(0, 21))) + assert_equal("\\gcd(21, 0)", UnevaluatedExpr(gcd(21, 0))) + assert_equal("\\gcd(0, 0)", UnevaluatedExpr(gcd(0, 0))) + assert_equal("\\gcd(6128, 24)", gcd(6128, 24)) + assert_equal("\\gcd(24, 6128)", gcd(24, 6128)) + assert_equal("\\gcd(1E20, 1000000)", gcd(Rational('1E20'), 1000000)) + assert_equal("\\gcd(128*10^32, 1)", gcd(Rational('128E32'), 1)) + + assert_equal("\\operatorname{gcd}(18, 3)", gcd(18, 3)) + assert_equal("\\operatorname{gcd}(3, 18)", gcd(3, 18)) + assert_equal("\\operatorname{gcd}(2, 2)", gcd(2, 2)) + assert_equal("\\operatorname{gcd}(0, 21)", UnevaluatedExpr(gcd(0, 21))) + assert_equal("\\operatorname{gcd}(21, 0)", UnevaluatedExpr(gcd(21, 0))) + assert_equal("\\operatorname{gcd}(0, 0)", UnevaluatedExpr(gcd(0, 0))) + assert_equal("\\operatorname{gcd}(6128, 24)", gcd(6128, 24)) + assert_equal("\\operatorname{gcd}(24, 6128)", gcd(24, 6128)) + assert_equal("\\operatorname{gcd}(1E20, 1000000)", gcd(Rational('1E20'), 1000000)) + assert_equal("\\operatorname{gcd}(128*10^32, 1)", gcd(Rational('128E32'), 1)) + + +def test_gcd_negative(): + assert_equal("\\gcd(-12, 4)", gcd(-12, 4)) + assert_equal("\\gcd(219, -9)", gcd(219, -9)) + assert_equal("\\gcd(-8, -64)", gcd(-8, -64)) + assert_equal("\\gcd(-5, -5)", gcd(-5, -5)) + assert_equal("\\gcd(-1, 182033)", gcd(-1, 182033)) + assert_equal("\\gcd(25, -6125)", gcd(25, -6125)) + assert_equal("\\gcd(243, -2.9543127E21)", gcd(243, Rational('-2.9543127E21'))) + + assert_equal("\\operatorname{gcd}(-12, 4)", gcd(-12, 4)) + assert_equal("\\operatorname{gcd}(219, -9)", gcd(219, -9)) + assert_equal("\\operatorname{gcd}(-8, -64)", gcd(-8, -64)) + assert_equal("\\operatorname{gcd}(-5, -5)", gcd(-5, -5)) + assert_equal("\\operatorname{gcd}(-1, 182033)", gcd(-1, 182033)) + assert_equal("\\operatorname{gcd}(25, -6125)", gcd(25, -6125)) + assert_equal("\\operatorname{gcd}(243, -2.9543127E21)", gcd(243, Rational('-2.9543127E21'))) + + +def test_gcd_float(): + assert_equal("\\gcd(2.4, 3.6)", gcd(Rational('2.4'), Rational('3.6'))) + assert_equal("\\gcd(3.6, 2.4)", gcd(Rational('3.6'), Rational('2.4'))) + assert_equal("\\gcd(\\pi, 3)", gcd(pi, 3)) + assert_equal("\\gcd(618, 1.5)", gcd(618, Rational('1.5'))) + assert_equal("\\gcd(-1.5, 618)", gcd(Rational('-1.5'), 618)) + assert_equal("\\gcd(0.42, 2)", gcd(Rational('0.42'), 2)) + assert_equal("\\gcd(1.43E-13, 21)", gcd(Rational('1.43E-13'), 21)) + assert_equal("\\gcd(21, -143E-13)", gcd(21, Rational('-143E-13'))) + assert_equal("\\gcd(9.80655, 9.80655)", gcd(Rational('9.80655'), Rational('9.80655'))) + assert_equal("\\gcd(0.0000923423, -8341.234802909)", gcd(Rational('0.0000923423'), Rational('-8341.234802909'))) + assert_equal("\\gcd(\\sqrt{5}, \\sqrt{2})", gcd(sqrt(5), sqrt(2))) + + assert_equal("\\operatorname{gcd}(2.4, 3.6)", gcd(Rational('2.4'), Rational('3.6'))) + assert_equal("\\operatorname{gcd}(3.6, 2.4)", gcd(Rational('3.6'), Rational('2.4'))) + assert_equal("\\operatorname{gcd}(\\pi, 3)", gcd(pi, 3)) + assert_equal("\\operatorname{gcd}(618, 1.5)", gcd(618, Rational('1.5'))) + assert_equal("\\operatorname{gcd}(-1.5, 618)", gcd(Rational('-1.5'), 618)) + assert_equal("\\operatorname{gcd}(0.42, 2)", gcd(Rational('0.42'), 2)) + assert_equal("\\operatorname{gcd}(1.43E-13, 21)", gcd(Rational('1.43E-13'), 21)) + assert_equal("\\operatorname{gcd}(21, -143E-13)", gcd(21, Rational('-143E-13'))) + assert_equal("\\operatorname{gcd}(9.80655, 9.80655)", gcd(Rational('9.80655'), Rational('9.80655'))) + assert_equal("\\operatorname{gcd}(0.0000923423, -8341.234802909)", gcd(Rational('0.0000923423'), Rational('-8341.234802909'))) + assert_equal("\\operatorname{gcd}(\\sqrt{5}, \\sqrt{2})", gcd(sqrt(5), sqrt(2))) + + +def test_gcd_fraction(): + assert_equal("\\gcd(1/2, 3)", gcd(Rational('1/2'), 3)) + assert_equal("\\gcd(3, 1/2)", gcd(3, Rational('1/2'))) + assert_equal("\\gcd(6/2, 3)", gcd(Rational('6/2'), 3)) + assert_equal("\\gcd(1/10, 1/10)", gcd(Rational('1/10'), Rational('1/10'))) + assert_equal("\\gcd(42, 42/6)", gcd(42, Rational('42/6'))) + assert_equal("\\gcd(10000000/10, 10000)", gcd(Rational('10000000/10'), 10000)) + + assert_equal("\\operatorname{gcd}(1/2, 3)", gcd(Rational('1/2'), 3)) + assert_equal("\\operatorname{gcd}(3, 1/2)", gcd(3, Rational('1/2'))) + assert_equal("\\operatorname{gcd}(6/2, 3)", gcd(Rational('6/2'), 3)) + assert_equal("\\operatorname{gcd}(1/10, 1/10)", gcd(Rational('1/10'), Rational('1/10'))) + assert_equal("\\operatorname{gcd}(42, 42/6)", gcd(42, Rational('42/6'))) + assert_equal("\\operatorname{gcd}(10000000/10, 10000)", gcd(Rational('10000000/10'), 10000)) + + +def test_gcd_expr(): + assert_equal("\\gcd(1+1, 8)", gcd(1 + 1, 8)) + assert_equal("920*\\gcd(9, 12*4/2)", 920 * gcd(9, 12 * Rational('4/2'))) + assert_equal("\\gcd(32-128, 10)*22", gcd(32 - 128, 10) * 22) + assert_equal("\\sqrt{\\gcd(1.25E24, 1E12)}", sqrt(gcd(Rational('1.25E24'), Rational('1E12')))) + assert_equal("\\gcd(92.0, 000+2)", gcd(Rational('92.0'), 000 + 2)) + + assert_equal("\\operatorname{gcd}(1+1, 8)", gcd(1 + 1, 8)) + assert_equal("920*\\operatorname{gcd}(9, 12*4/2)", 920 * gcd(9, 12 * Rational('4/2'))) + assert_equal("\\operatorname{gcd}(32-128, 10)*22", gcd(32 - 128, 10) * 22) + assert_equal("\\sqrt{\\operatorname{gcd}(1.25E24, 1E12)}", sqrt(gcd(Rational('1.25E24'), Rational('1E12')))) + assert_equal("\\operatorname{gcd}(92.0, 000+2)", gcd(Rational('92.0'), 000 + 2)) + + +def test_gcd_symbol(): + assert_equal("\\gcd(x, y)", gcd(x, y), symbolically=True) + assert_equal("\\gcd(y, -x)", gcd(y, -x), symbolically=True) + assert_equal("\\gcd(2y, x)", gcd(2 * y, x), symbolically=True) + assert_equal("\\gcd(125, 50x)", gcd(125, 50 * x), symbolically=True) + assert_equal("\\gcd(x + 76, \\sqrt{x} * 4)", gcd(x + 76, sqrt(x) * 4), symbolically=True) + assert_equal("\\gcd(y, y)", gcd(y, y), symbolically=True) + assert_equal("y + \\gcd(0.4x, 8/3) / 2", y + gcd(Rational('0.4') * x, Rational('8/3')) / 2, symbolically=True) + assert_equal("6.673E-11 * (\\gcd(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (gcd(Rational('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True) + + assert_equal("\\operatorname{gcd}(x, y)", gcd(x, y), symbolically=True) + assert_equal("\\operatorname{gcd}(y, -x)", gcd(y, -x), symbolically=True) + assert_equal("\\operatorname{gcd}(2y, x)", gcd(2 * y, x), symbolically=True) + assert_equal("\\operatorname{gcd}(125, 50x)", gcd(125, 50 * x), symbolically=True) + assert_equal("\\operatorname{gcd}(x + 76, \\sqrt{x} * 4)", gcd(x + 76, sqrt(x) * 4), symbolically=True) + assert_equal("\\operatorname{gcd}(y, y)", gcd(y, y), symbolically=True) + assert_equal("y + \\operatorname{gcd}(0.4x, 8/3) / 2", y + gcd(Rational('0.4') * x, Rational('8/3')) / 2, symbolically=True) + assert_equal("6.673E-11 * (\\operatorname{gcd}(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (gcd(Rational('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True) + + +def test_multiple_parameters(): + assert_equal("\\gcd(830,450)", gcd(830, 450)) + assert_equal("\\gcd(6,321,429)", igcd(6, 321, 429)) + assert_equal("\\gcd(14,2324)", gcd(14, 2324)) + assert_equal("\\gcd(3, 6, 2)", igcd(3, 6, 2)) + assert_equal("\\gcd(144, 2988, 37116)", igcd(144, 2988, 37116)) + assert_equal("\\gcd(144,2988, 37116,18, 72)", igcd(144, 2988, 37116, 18, 72)) + assert_equal("\\gcd(144, 2988, 37116, 18, 72, 12, 6)", igcd(144, 2988, 37116, 18, 72, 12, 6)) + assert_equal("\\gcd(32)", gcd(32, 32)) + assert_equal("\\gcd(-8, 4,-2)", gcd(-8, gcd(4, -2))) + assert_equal("\\gcd(x, y,z)", gcd(x, gcd(y, z)), symbolically=True) + assert_equal("\\gcd(6*4,48, 3)", igcd(6 * 4, 48, 3)) + assert_equal("\\gcd(6*4,48,3)", igcd(6 * 4, 48, 3)) + assert_equal("\\gcd(2.4,3.6, 0.6)", gcd(Rational('2.4'), gcd(Rational('3.6'), Rational('0.6')))) + assert_equal("\\gcd(2.4,3.6,0.6)", gcd(Rational('2.4'), gcd(Rational('3.6'), Rational('0.6')))) + assert_equal("\\gcd(\\sqrt{3},\\sqrt{2}, \\sqrt{100})", gcd(sqrt(3), gcd(sqrt(2), sqrt(100)))) + assert_equal("\\gcd(1E12, 1E6,1E3, 10)", igcd(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10)) + + assert_equal("\\operatorname{gcd}(830,450)", gcd(830, 450)) + assert_equal("\\operatorname{gcd}(6,321,429)", igcd(6, 321, 429)) + assert_equal("\\operatorname{gcd}(14,2324)", gcd(14, 2324)) + assert_equal("\\operatorname{gcd}(3, 6, 2)", igcd(3, 6, 2)) + assert_equal("\\operatorname{gcd}(144, 2988, 37116)", igcd(144, 2988, 37116)) + assert_equal("\\operatorname{gcd}(144,2988, 37116,18, 72)", igcd(144, 2988, 37116, 18, 72)) + assert_equal("\\operatorname{gcd}(144, 2988, 37116, 18, 72, 12, 6)", igcd(144, 2988, 37116, 18, 72, 12, 6)) + assert_equal("\\operatorname{gcd}(32)", gcd(32, 32)) + assert_equal("\\operatorname{gcd}(-8, 4,-2)", gcd(-8, gcd(4, -2))) + assert_equal("\\operatorname{gcd}(x, y,z)", gcd(x, gcd(y, z)), symbolically=True) + assert_equal("\\operatorname{gcd}(6*4,48, 3)", igcd(6 * 4, 48, 3)) + assert_equal("\\operatorname{gcd}(6*4,48,3)", igcd(6 * 4, 48, 3)) + assert_equal("\\operatorname{gcd}(2.4,3.6, 0.6)", gcd(Rational('2.4'), gcd(Rational('3.6'), Rational('0.6')))) + assert_equal("\\operatorname{gcd}(2.4,3.6,0.6)", gcd(Rational('2.4'), gcd(Rational('3.6'), Rational('0.6')))) + assert_equal("\\operatorname{gcd}(\\sqrt{3},\\sqrt{2}, \\sqrt{100})", gcd(sqrt(3), gcd(sqrt(2), sqrt(100)))) + assert_equal("\\operatorname{gcd}(1E12, 1E6,1E3, 10)", igcd(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/greek_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/greek_test.py new file mode 100644 index 000000000..a12f11816 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/greek_test.py @@ -0,0 +1,19 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import Symbol +from tests.context import assert_equal + +epsilon_lower = Symbol('epsilon') +varepsilon = Symbol('varepsilon') + + +def test_greek_epsilon(): + assert_equal("\\epsilon", epsilon_lower) + + +def test_greek_epsilon_upper(): + assert_equal('\\char"000190', epsilon_lower) + + +def test_greek_varepsilon(): + assert_equal('\\varepsilon', varepsilon) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/grouping_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/grouping_test.py new file mode 100644 index 000000000..a5d0977ff --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/grouping_test.py @@ -0,0 +1,52 @@ +from latex2sympy2_extended import latex2sympy +import pytest +from sympy import Integral, sin, Symbol, Mul, Integer, Pow +from tests.context import assert_equal, _Add, _Mul, _Pow + +a = Symbol('a') +b = Symbol('b') +x = Symbol('x') +theta = Symbol('theta') + + +func_arg_examples = [ + ('\\int ', 'x dx', Integral(x, x)), + ('\\sin', '\\theta ', sin(theta)) +] + +example_groups = [ + ('1+2', '3-4', _Mul(_Add(1, 2), _Add(3, _Mul(-1, 4)))) +] + +modifiable_delimiter_pairs = { + '(': ')', + '\\lgroup': '\\rgroup', + '\\{': '\\}', + '\\lbrace': '\\rbrace', + '[': ']', + '\\lbrack': '\\rbrack', +} + + +@pytest.mark.parametrize('func, args, output', func_arg_examples) +def test_func_arg_groupings(func, args, output): + # none + assert_equal("{func} {args}".format(func=func, args=args), output) + # normal brace (not modifiable) + assert_equal("{func}{{{args}}}".format(func=func, args=args), output) + # rest of delimiters, with modifications + for left, right in modifiable_delimiter_pairs.items(): + assert_equal("{func}{left}{args}{right}".format(left=left, right=right, func=func, args=args), output) + assert_equal("{func}\\left{left}{args}\\right{right}".format(left=left, right=right, func=func, args=args), output) + assert_equal("{func}\\mleft{left}{args}\\mright{right}".format(left=left, right=right, func=func, args=args), output) + + +@pytest.mark.parametrize('group1, group2, output', example_groups) +def test_delimiter_groupings(group1, group2, output): + # normal brace (not modifiable) + assert_equal("{{{group1}}}{{{group2}}}".format(group1=group1, group2=group2), output) + # rest of delimiters, with modifications + for left, right in modifiable_delimiter_pairs.items(): + assert_equal("{left}{group1}{right}{left}{group2}{right}".format(left=left, right=right, group1=group1, group2=group2), output) + assert_equal("\\left{left}{group1}\\right{right}\\left{left}{group2}\\right{right}".format(left=left, right=right, group1=group1, group2=group2), output) + assert_equal("\\mleft{left}{group1}\\mright{right}\\mleft{left}{group2}\\mright{right}".format(left=left, right=right, group1=group1, group2=group2), output) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/lcm_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/lcm_test.py new file mode 100644 index 000000000..7fef3d003 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/lcm_test.py @@ -0,0 +1,161 @@ +from tests.context import assert_equal +import pytest +from sympy import Number, Symbol, Rational, UnevaluatedExpr, lcm, ilcm, sqrt, pi + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + + +def test_lcm_usual(): + assert_equal("\\lcm(6, 4)", lcm(6, 4)) + assert_equal("\\lcm(4, 6)", lcm(4, 6)) + assert_equal("\\lcm(2, 2)", lcm(2, 2)) + assert_equal("\\lcm(0, 21)", UnevaluatedExpr(lcm(0, 21))) + assert_equal("\\lcm(21, 0)", UnevaluatedExpr(lcm(21, 0))) + assert_equal("\\lcm(0, 0)", UnevaluatedExpr(lcm(0, 0))) + assert_equal("\\lcm(9, 21)", lcm(9, 21)) + assert_equal("\\lcm(6128, 24)", lcm(6128, 24)) + assert_equal("\\lcm(24, 6128)", lcm(24, 6128)) + assert_equal("\\lcm(1E20, 1000000)", lcm(Rational('1E20'), 1000000)) + assert_equal("\\lcm(128*10^32, 1)", lcm(Rational('128E32'), 1)) + + assert_equal("\\operatorname{lcm}(6, 4)", lcm(6, 4)) + assert_equal("\\operatorname{lcm}(4, 6)", lcm(4, 6)) + assert_equal("\\operatorname{lcm}(2, 2)", lcm(2, 2)) + assert_equal("\\operatorname{lcm}(0, 21)", UnevaluatedExpr(lcm(0, 21))) + assert_equal("\\operatorname{lcm}(21, 0)", UnevaluatedExpr(lcm(21, 0))) + assert_equal("\\operatorname{lcm}(0, 0)", UnevaluatedExpr(lcm(0, 0))) + assert_equal("\\operatorname{lcm}(9, 21)", lcm(9, 21)) + assert_equal("\\operatorname{lcm}(6128, 24)", lcm(6128, 24)) + assert_equal("\\operatorname{lcm}(24, 6128)", lcm(24, 6128)) + assert_equal("\\operatorname{lcm}(1E20, 1000000)", lcm(Rational('1E20'), 1000000)) + assert_equal("\\operatorname{lcm}(128*10^32, 1)", lcm(Rational('128E32'), 1)) + + +def test_lcm_negative(): + assert_equal("\\lcm(-12, 4)", lcm(-12, 4)) + assert_equal("\\lcm(219, -9)", lcm(219, -9)) + assert_equal("\\lcm(-8, -12)", lcm(-8, -12)) + assert_equal("\\lcm(-5, -5)", lcm(-5, -5)) + assert_equal("\\lcm(-1, 182033)", lcm(-1, 182033)) + assert_equal("\\lcm(25, -30)", lcm(25, -30)) + assert_equal("\\lcm(243, -2.9543127E21)", lcm(243, Rational('-2.9543127E21'))) + + assert_equal("\\operatorname{lcm}(-12, 4)", lcm(-12, 4)) + assert_equal("\\operatorname{lcm}(219, -9)", lcm(219, -9)) + assert_equal("\\operatorname{lcm}(-8, -12)", lcm(-8, -12)) + assert_equal("\\operatorname{lcm}(-5, -5)", lcm(-5, -5)) + assert_equal("\\operatorname{lcm}(-1, 182033)", lcm(-1, 182033)) + assert_equal("\\operatorname{lcm}(25, -30)", lcm(25, -30)) + assert_equal("\\operatorname{lcm}(243, -2.9543127E21)", lcm(243, Rational('-2.9543127E21'))) + + +def test_lcm_float(): + assert_equal("\\lcm(2.4, 3.6)", lcm(Rational('2.4'), Rational('3.6'))) + assert_equal("\\lcm(3.6, 2.4)", lcm(Rational('3.6'), Rational('2.4'))) + assert_equal("\\lcm(\\pi, 3)", lcm(pi, 3)) + assert_equal("\\lcm(618, 1.5)", lcm(618, Rational('1.5'))) + assert_equal("\\lcm(-1.5, 618)", lcm(Rational('-1.5'), 618)) + assert_equal("\\lcm(0.42, 2)", lcm(Rational('0.42'), 2)) + assert_equal("\\lcm(1.43E-13, 21)", lcm(Rational('1.43E-13'), 21)) + assert_equal("\\lcm(21, -143E-13)", lcm(21, Rational('-143E-13'))) + assert_equal("\\lcm(9.80655, 9.80655)", lcm(Rational('9.80655'), Rational('9.80655'))) + assert_equal("\\lcm(0.0000923423, -8341.234802909)", lcm(Rational('0.0000923423'), Rational('-8341.234802909'))) + assert_equal("\\lcm(\\sqrt{5}, \\sqrt{2})", lcm(sqrt(5), sqrt(2))) + + assert_equal("\\operatorname{lcm}(2.4, 3.6)", lcm(Rational('2.4'), Rational('3.6'))) + assert_equal("\\operatorname{lcm}(3.6, 2.4)", lcm(Rational('3.6'), Rational('2.4'))) + assert_equal("\\operatorname{lcm}(\\pi, 3)", lcm(pi, 3)) + assert_equal("\\operatorname{lcm}(618, 1.5)", lcm(618, Rational('1.5'))) + assert_equal("\\operatorname{lcm}(-1.5, 618)", lcm(Rational('-1.5'), 618)) + assert_equal("\\operatorname{lcm}(0.42, 2)", lcm(Rational('0.42'), 2)) + assert_equal("\\operatorname{lcm}(1.43E-13, 21)", lcm(Rational('1.43E-13'), 21)) + assert_equal("\\operatorname{lcm}(21, -143E-13)", lcm(21, Rational('-143E-13'))) + assert_equal("\\operatorname{lcm}(9.80655, 9.80655)", lcm(Rational('9.80655'), Rational('9.80655'))) + assert_equal("\\operatorname{lcm}(0.0000923423, -8341.234802909)", lcm(Rational('0.0000923423'), Rational('-8341.234802909'))) + assert_equal("\\operatorname{lcm}(\\sqrt{5}, \\sqrt{2})", lcm(sqrt(5), sqrt(2))) + + +def test_lcm_fraction(): + assert_equal("\\lcm(1/2, 3)", lcm(Rational('1/2'), 3)) + assert_equal("\\lcm(3, 1/2)", lcm(3, Rational('1/2'))) + assert_equal("\\lcm(6/2, 3)", lcm(Rational('6/2'), 3)) + assert_equal("\\lcm(1/10, 1/10)", lcm(Rational('1/10'), Rational('1/10'))) + assert_equal("\\lcm(42, 42/6)", lcm(42, Rational('42/6'))) + assert_equal("\\lcm(10000000/10, 10000)", lcm(Rational('10000000/10'), 10000)) + + assert_equal("\\operatorname{lcm}(1/2, 3)", lcm(Rational('1/2'), 3)) + assert_equal("\\operatorname{lcm}(3, 1/2)", lcm(3, Rational('1/2'))) + assert_equal("\\operatorname{lcm}(6/2, 3)", lcm(Rational('6/2'), 3)) + assert_equal("\\operatorname{lcm}(1/10, 1/10)", lcm(Rational('1/10'), Rational('1/10'))) + assert_equal("\\operatorname{lcm}(42, 42/6)", lcm(42, Rational('42/6'))) + assert_equal("\\operatorname{lcm}(10000000/10, 10000)", lcm(Rational('10000000/10'), 10000)) + + +def test_lcm_expr(): + assert_equal("\\lcm(1+1, 8)", lcm(1 + 1, 8)) + assert_equal("920*\\lcm(9, 12*4/2)", 920 * lcm(9, 12 * Rational('4/2'))) + assert_equal("\\lcm(32-128, 10)*22", lcm(32 - 128, 10) * 22) + assert_equal("\\sqrt{\\lcm(1.25E24, 1E12)}", sqrt(lcm(Rational('1.25E24'), Rational('1E12')))) + assert_equal("\\lcm(92.0, 000+2)", lcm(Rational('92.0'), 000 + 2)) + + assert_equal("\\operatorname{lcm}(1+1, 8)", lcm(1 + 1, 8)) + assert_equal("920*\\operatorname{lcm}(9, 12*4/2)", 920 * lcm(9, 12 * Rational('4/2'))) + assert_equal("\\operatorname{lcm}(32-128, 10)*22", lcm(32 - 128, 10) * 22) + assert_equal("\\sqrt{\\operatorname{lcm}(1.25E24, 1E12)}", sqrt(lcm(Rational('1.25E24'), Rational('1E12')))) + assert_equal("\\operatorname{lcm}(92.0, 000+2)", lcm(Rational('92.0'), 000 + 2)) + + +def test_lcm_symbol(): + assert_equal("\\lcm(x, y)", lcm(x, y), symbolically=True) + assert_equal("\\lcm(y, -x)", lcm(y, -x), symbolically=True) + assert_equal("\\lcm(2y, x)", lcm(2 * y, x), symbolically=True) + assert_equal("\\lcm(125, 50x)", lcm(125, 50 * x), symbolically=True) + assert_equal("\\lcm(x + 76, \\sqrt{x} * 4)", lcm(x + 76, sqrt(x) * 4), symbolically=True) + assert_equal("\\lcm(y, y)", lcm(y, y), symbolically=True) + assert_equal("y + \\lcm(0.4x, 8/3) / 2", y + lcm(Number('0.4') * x, Rational('8/3')) / 2, symbolically=True) + assert_equal("6.673E-11 * (\\lcm(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (lcm(Number('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True) + + assert_equal("\\operatorname{lcm}(x, y)", lcm(x, y), symbolically=True) + assert_equal("\\operatorname{lcm}(y, -x)", lcm(y, -x), symbolically=True) + assert_equal("\\operatorname{lcm}(2y, x)", lcm(2 * y, x), symbolically=True) + assert_equal("\\operatorname{lcm}(125, 50x)", lcm(125, 50 * x), symbolically=True) + assert_equal("\\operatorname{lcm}(x + 76, \\sqrt{x} * 4)", lcm(x + 76, sqrt(x) * 4), symbolically=True) + assert_equal("\\operatorname{lcm}(y, y)", lcm(y, y), symbolically=True) + assert_equal("y + \\operatorname{lcm}(0.4x, 8/3) / 2", y + lcm(Rational('0.4') * x, Rational('8/3')) / 2, symbolically=True) + assert_equal("6.673E-11 * (\\operatorname{lcm}(8.85418782E-12, 9x) + 4) / 8y", Rational('6.673E-11') * (lcm(Rational('8.85418782E-12'), 9 * x) + 4) / (8 * y), symbolically=True) + + +def test_multiple_parameters(): + assert_equal("\\lcm(830,450)", lcm(830, 450)) + assert_equal("\\lcm(6,321,429)", ilcm(6, 321, 429)) + assert_equal("\\lcm(14,2324)", lcm(14, 2324)) + assert_equal("\\lcm(3, 6, 2)", ilcm(3, 6, 2)) + assert_equal("\\lcm(8, 9, 21)", ilcm(8, 9, 21)) + assert_equal("\\lcm(144, 2988, 37116)", ilcm(144, 2988, 37116)) + assert_equal("\\lcm(144,2988,37116,18,72)", ilcm(144, 2988, 37116, 18, 72)) + assert_equal("\\lcm(144, 2988, 37116, 18, 72, 12, 6)", ilcm(144, 2988, 37116, 18, 72, 12, 6)) + assert_equal("\\lcm(32)", lcm(32, 32)) + assert_equal("\\lcm(-8, 4, -2)", lcm(-8, lcm(4, -2))) + assert_equal("\\lcm(x, y, z)", lcm(x, lcm(y, z)), symbolically=True) + assert_equal("\\lcm(6*4, 48, 3)", ilcm(6 * 4, 48, 3)) + assert_equal("\\lcm(2.4, 3.6, 0.6)", lcm(Rational('2.4'), lcm(Rational('3.6'), Rational('0.6')))) + assert_equal("\\lcm(\\sqrt{3}, \\sqrt{2},\\sqrt{100})", lcm(sqrt(3), lcm(sqrt(2), sqrt(100)))) + assert_equal("\\lcm(1E12, 1E6, 1E3, 10)", ilcm(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10)) + + assert_equal("\\operatorname{lcm}(830,450)", lcm(830, 450)) + assert_equal("\\operatorname{lcm}(6,321,429)", ilcm(6, 321, 429)) + assert_equal("\\operatorname{lcm}(14,2324)", lcm(14, 2324)) + assert_equal("\\operatorname{lcm}(3, 6, 2)", ilcm(3, 6, 2)) + assert_equal("\\operatorname{lcm}(8, 9, 21)", ilcm(8, 9, 21)) + assert_equal("\\operatorname{lcm}(144, 2988, 37116)", ilcm(144, 2988, 37116)) + assert_equal("\\operatorname{lcm}(144,2988,37116,18,72)", ilcm(144, 2988, 37116, 18, 72)) + assert_equal("\\operatorname{lcm}(144, 2988, 37116, 18, 72, 12, 6)", ilcm(144, 2988, 37116, 18, 72, 12, 6)) + assert_equal("\\operatorname{lcm}(32)", lcm(32, 32)) + assert_equal("\\operatorname{lcm}(-8, 4, -2)", lcm(-8, lcm(4, -2))) + assert_equal("\\operatorname{lcm}(x, y, z)", lcm(x, lcm(y, z)), symbolically=True) + assert_equal("\\operatorname{lcm}(6*4,48, 3)", ilcm(6 * 4, 48, 3)) + assert_equal("\\operatorname{lcm}(2.4, 3.6,0.6)", lcm(Rational('2.4'), lcm(Rational('3.6'), Rational('0.6')))) + assert_equal("\\operatorname{lcm}(\\sqrt{3}, \\sqrt{2},\\sqrt{100})", lcm(sqrt(3), lcm(sqrt(2), sqrt(100)))) + assert_equal("\\operatorname{lcm}(1E12,1E6, 1E3, 10)", ilcm(Rational('1E12'), Rational('1E6'), Rational('1E3'), 10)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/left_right_cdot_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/left_right_cdot_test.py new file mode 100644 index 000000000..935b19bed --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/left_right_cdot_test.py @@ -0,0 +1,10 @@ +import pytest +from latex2sympy2_extended import latex2sympy +from sympy import sin, Symbol +from tests.context import assert_equal + +x = Symbol('x') + + +def test_left_right_cdot(): + assert_equal("\\sin\\left(x\\right)\\cdot x", sin(x) * x) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/linalg_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/linalg_test.py new file mode 100644 index 000000000..7d6a5a431 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/linalg_test.py @@ -0,0 +1,15 @@ +from tests.context import assert_equal +import pytest +from sympy import MatMul, Matrix + + +# def test_linalg_placeholder(): +# assert_equal("\\begin{pmatrix}1&2\\\\3&4\\end{pmatrix}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), variables={'v': Matrix([1, 2])}) + + +# def test_linalg_placeholder_multiple(): +# assert_equal("\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), variables={'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])}) + + +# def test_linalg_placeholder_multiple_mul(): +# assert_equal("\\begin{pmatrix}3&-1\\end{pmatrix}\\cdot\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[3, -1]]), Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), variables={'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])}) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/max_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/max_test.py new file mode 100644 index 000000000..d02e5beba --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/max_test.py @@ -0,0 +1,79 @@ +from tests.context import assert_equal +import pytest +from sympy import Symbol, Rational, Float, Max, sqrt, exp, pi, nsimplify + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + + +def test_max_usual(): + assert_equal("\\max(1, 5)", Max(1, 5)) + assert_equal("\\max(12, 4)", Max(12, 4)) + assert_equal("\\max(109, 120)", Max(109, 120)) + assert_equal("\\max(3, 3)", Max(3, 3)) + assert_equal("\\max(0, 0)", Max(0, 0)) + assert_equal("\\max(1)", Max(1)) + assert_equal("\\max(1092198374, 290348E32)", Max(1092198374, Rational('290348E32'))) + assert_equal("\\max(5, 2, 17, 4)", Max(5, 2, 17, 4)) + + +def test_max_negative(): + assert_equal("\\max(-9, 4)", Max(-9, 4)) + assert_equal("\\max(4, -9)", Max(4, -9)) + assert_equal("\\max(-7)", Max(-7)) + assert_equal("\\max(-2, -2)", Max(-2, -2)) + assert_equal("\\max(-324E-3, -58)", Max(Rational('-324E-3'), -58)) + assert_equal("\\max(-1, 0, 1, -37, 42)", Max(-1, 0, 1, -37, 42)) + + +def test_max_float(): + # assert_equal("\\max(\\pi, 3)", Max(pi, 3, evaluate=False)) + assert_equal("\\max(1234.56789, 1234.5678901)", Max(Float('1234.56789'), Float('1234.5678901'), evaluate=False)) + assert_equal("\\max(12.4, 9.5)", Max(12.4, 9.5, evaluate=False)) + assert_equal("\\max(6, 6.2)", Max(6, 6.2, evaluate=False)) + assert_equal("\\max(-98.7)", Max(-98.7, evaluate=False)) + assert_equal("\\max(7.1, 9)", Max(7.1, 9, evaluate=False)) + assert_equal("\\max(-21E-12, 0.00005)", Max(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True) + assert_equal("\\max(\\sqrt{3}, 0, 1)", Max(sqrt(3), 0, 1)) + + +def test_max_fraction(): + assert_equal("\\max(1/2, 1/4)", Max(Rational('1/2'), Rational('1/4'))) + assert_equal("\\max(6/2, 3)", Max(Rational('6/2'), 3)) + assert_equal("\\max(2/4, 1/2)", Max(Rational('2/4'), Rational('1/2'))) + assert_equal("\\max(-12/5, 6.4)", Max(Rational('-12/5'), Rational('6.4'))) + assert_equal("\\max(1/10)", Max(Rational('1/10'))) + assert_equal("\\max(1.5, \\pi/2)", Max(Rational('1.5'), pi / 2, evaluate=False)) + assert_equal("\\max(-4/3, -2/1, 0/9, -3)", Max(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3)) + + +def test_max_expr(): + assert_equal("\\max((1+6)/3, 7)", Max(Rational(1 + 6, 3), 7, evaluate=False)) + assert_equal("\\max(58*9)", Max(58 * 9, evaluate=False)) + assert_equal("\\max(1+6/3, -5)", Max(1 + Rational('6/3'), -5, evaluate=False)) + assert_equal("\\max(7*4/5, 092) * 2", Max(7 * 4 / 5, 92, evaluate=False) * 2) + assert_equal("38+\\max(13, 15-2.3)", 38 + Max(13, 15 - Float('2.3'))) + assert_equal("\\sqrt{\\max(99.9999999999999, 100)}", sqrt(Max(Float('99.9999999999999'), 100))) + assert_equal("\\max(274/(5+2), \\exp(12.4), 1.4E2)", Max(Rational(274, 5 + 2), exp(Float('12.4')), Float('1.4E2'))) + + +def test_max_symbol(): + assert_equal("\\max(x)", Max(x), symbolically=True) + assert_equal("\\max(x, y)", Max(x, y), symbolically=True) + assert_equal("\\max(y, x)", Max(y, x), symbolically=True) + assert_equal("\\max(x+y, y+x)", Max(x + y, y + x), symbolically=True) + assert_equal("\\max(9x/4, z)", Max(9 * x / 4, z), symbolically=True) + assert_equal("\\max(y\\pi, 9)", Max(y * pi, 9), symbolically=True) + assert_equal("\\max(2y-y, y + 1)", Max(2 * y - y, y + 1), symbolically=True) + assert_equal("\\max(z, y, x)", Max(z, y, x), symbolically=True) + + +def test_max_multiarg(): + assert_equal("\\max(1,2)", Max(1, 2)) + assert_equal("\\max(9,876,543)", Max(9, 876, 543)) + assert_equal("\\max(x, y,z)", Max(x, y, z), symbolically=True) + assert_equal("\\max(5.8,7.4, 2.2,-10)", Max(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10)) + assert_equal("\\max(\\pi,12E2,84,\\sqrt{5},12/5)", Max(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5'))) + assert_equal("\\max(823,51)", Max(823, 51)) + assert_equal("\\max(72*4,23, 9)", Max(72 * 4, 23, 9)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/min_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/min_test.py new file mode 100644 index 000000000..fc6f8e5a2 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/min_test.py @@ -0,0 +1,79 @@ +from tests.context import assert_equal +import pytest +from sympy import Number, Symbol, Rational, Float, Min, sqrt, exp, pi, nsimplify + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') + + +def test_min_usual(): + assert_equal("\\min(1, 5)", Min(1, 5)) + assert_equal("\\min(12, 4)", Min(12, 4)) + assert_equal("\\min(109, 120)", Min(109, 120)) + assert_equal("\\min(3, 3)", Min(3, 3)) + assert_equal("\\min(0, 0)", Min(0, 0)) + assert_equal("\\min(1)", Min(1)) + assert_equal("\\min(1092198374, 290348E32)", Min(1092198374, Rational('290348E32'))) + assert_equal("\\min(5, 2, 17, 4)", Min(5, 2, 17, 4)) + + +def test_min_negative(): + assert_equal("\\min(-9, 4)", Min(-9, 4)) + assert_equal("\\min(4, -9)", Min(4, -9)) + assert_equal("\\min(-7)", Min(-7)) + assert_equal("\\min(-2, -2)", Min(-2, -2)) + assert_equal("\\min(-324E-3, -58)", Min(Rational('-324E-3'), -58)) + assert_equal("\\min(-1, 0, 1, -37, 42)", Min(-1, 0, 1, -37, 42)) + + +def test_min_float(): + assert_equal("\\min(\\pi, 3)", Min(pi, 3)) + assert_equal("\\min(1234.56789, 1234.5678901)", Min(Rational('1234.56789'), Rational('1234.5678901'))) + assert_equal("\\min(12.4, 9.5)", Min(12.4, 9.5)) + assert_equal("\\min(6, 6.2)", Min(6, 6.2)) + assert_equal("\\min(-98.7)", Min(-98.7)) + assert_equal("\\min(7.1, 9)", Min(7.1, 9)) + assert_equal("\\min(-21E-12, 0.00005)", Min(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True) + assert_equal("\\min(\\sqrt{3}, 0, 1)", Min(sqrt(3), 0, 1)) + + +def test_min_fraction(): + assert_equal("\\min(1/2, 1/4)", Min(Rational('1/2'), Rational('1/4'))) + assert_equal("\\min(6/2, 3)", Min(Rational('6/2'), 3)) + assert_equal("\\min(2/4, 1/2)", Min(Rational('2/4'), Rational('1/2'))) + assert_equal("\\min(-12/5, 6.4)", Min(Rational('-12/5'), Rational('6.4'))) + assert_equal("\\min(1/10)", Min(Rational('1/10'))) + assert_equal("\\min(1.5, \\pi/2)", Min(Rational('1.5'), pi / 2, evaluate=False)) + assert_equal("\\min(-4/3, -2/1, 0/9, -3)", Min(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3)) + + +def test_min_expr(): + assert_equal("\\min((1+6)/3, 7)", Min(Rational(1 + 6, 3), 7)) + assert_equal("\\min(58*9)", Min(58 * 9)) + assert_equal("\\min(1+6/3, -5)", Min(1 + Rational('6/3'), -5)) + assert_equal("\\min(7*4/5, 092) * 2", Min(7 * 4 / 5, 92) * 2) + assert_equal("38+\\min(13, 15-2.3)", 38 + Min(13, 15 - Float('2.3'), evaluate=False)) + assert_equal("\\sqrt{\\min(99.9999999999999, 100)}", sqrt(Min(Float('99.9999999999999'), 100))) + assert_equal("\\min(274/(5+2), \\exp(12.4), 1.4E2)", Min(Rational(274, 5 + 2), exp(Float('12.4')), Number('1.4E2'))) + + +def test_min_symbol(): + assert_equal("\\min(x)", Min(x), symbolically=True) + assert_equal("\\min(x, y)", Min(x, y), symbolically=True) + assert_equal("\\min(y, x)", Min(y, x), symbolically=True) + assert_equal("\\min(x+y, y+x)", Min(x + y, y + x), symbolically=True) + assert_equal("\\min(9x/4, z)", Min(9 * x / 4, z), symbolically=True) + assert_equal("\\min(y\\pi, 9)", Min(y * pi, 9), symbolically=True) + assert_equal("\\min(2y-y, y + 1)", Min(2 * y - y, y + 1), symbolically=True) + assert_equal("\\min(z, y, x)", Min(z, y, x), symbolically=True) + + +def test_min_multiarg(): + assert_equal("\\min(1,2)", Min(1, 2)) + assert_equal("\\min(9,876,543)", Min(9, 876, 543)) + assert_equal("\\min(x, y,z)", Min(x, y, z), symbolically=True) + assert_equal("\\min(5.8,7.4, 2.2,-10)", Min(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10)) + assert_equal("\\min(\\pi,12E2,84,\\sqrt{5},12/5)", Min(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5'))) + assert_equal("\\min(823,51)", Min(823, 51)) + assert_equal("\\min(72*4,23, 9)", Min(72 * 4, 23, 9)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mixed_fraction_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mixed_fraction_test.py new file mode 100644 index 000000000..3008bebd5 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mixed_fraction_test.py @@ -0,0 +1,19 @@ +from sympy import Add, Float, Mul, Rational, Symbol +from tests.context import assert_equal +import pytest + +@pytest.mark.parametrize('latex, latex2sympy', [ + ('1 \\frac{1}{2}', Rational(3, 2)), + ('1 \\frac{1}{2} + 3', Add(Rational(3, 2), 3)), + ('3 1\\frac{1}{2}', Rational(9, 2)), + # For float's we interpret it as multiplication + ('3.1 \\frac{1}{2}', Mul(Float(3.1), Rational(1, 2))), + # Negative numbers + ('-3 \\frac{1}{2}', Rational(-7, 2)), + + # We didn't break other postfix + ('ab', Mul(Symbol('a'), Symbol('b'))), + +]) +def test_mixed_fraction(latex, latex2sympy): + assert_equal(latex, latex2sympy) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mod_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mod_test.py new file mode 100644 index 000000000..4b0348aa3 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/mod_test.py @@ -0,0 +1,70 @@ +from tests.context import assert_equal +import pytest +from sympy import Float, Number, Symbol, Rational, Mod, sqrt, nsimplify, pi, GoldenRatio + +x = Symbol('x') +y = Symbol('y') + + +def test_mod_usual(): + assert_equal("128\\mod 3", Mod(128, 3)) + assert_equal("7\\mod 128", Mod(7, 128)) + assert_equal("5\\mod 10", Mod(5, 10)) + assert_equal("5\\mod 5", Mod(5, 5)) + assert_equal("3\\mod 2", Mod(3, 2)) + assert_equal("0 \\mod 6", Mod(0, 6)) + assert_equal("6109\\mod 28", Mod(6109, 28)) + assert_equal("4000000000\\mod 28791", Mod(4000000000, 28791)) + assert_equal("128*10^300\\mod 876123", Mod(Rational('128E300'), 876123)) + assert_equal("876123\\mod 128E300", Mod(876123, Rational('128E300'))) + + +def test_mod_negative(): + assert_equal("-1\\mod 2", Mod(-1, 2)) + assert_equal("-3\\mod 3", Mod(-3, 3)) + assert_equal("-12\\mod -12", Mod(-12, -12)) + assert_equal("-128\\mod 4", Mod(-128, 4)) + assert_equal("9\\mod -213", Mod(9, -213)) + assert_equal("123123\\mod -541", Mod(123123, -541)) + assert_equal("-123123\\mod 541", Mod(-123123, 541)) + assert_equal("-97E34\\mod 7", Mod(Rational('-97E34'), 7)) + + +def test_mod_fraction(): + assert_equal("1/2\\mod 3", Mod(Rational(1, 2), 3)) + assert_equal("6/2\\mod 3", Mod(Rational(6, 2), 3)) + assert_equal("-14/2\\mod 5", Mod(Rational(-14, 2), 5)) + assert_equal("123\\mod (42/6)", Mod(123, Rational(42, 6))) + assert_equal("431\\mod (2/123)", Mod(431, Rational(2, 123))) + assert_equal("5/5\\mod (5/5)", Mod(Rational(5, 5), Rational(5, 5))) + assert_equal("849/-21\\mod (092/2)", Mod(Rational(849, -21), Rational(92, 2))) + assert_equal("13*10^9\\mod (21/-2)", Mod(13E9, Rational(21, -2))) + + +def test_mod_float(): + assert_equal("0.41\\mod 2", Mod(Float("0.41"), 2)) + assert_equal("143E-13\\mod 21", Mod(Number('143E-13'), 21)) + assert_equal("-9.80665\\mod 9.80665", Mod(-9.80665, 9.80665)) + assert_equal("0.0000923423\\mod -8341.234802909", Mod(Float("0.0000923423"), Float("-8341.234802909"))) + assert_equal("\\sqrt{5}\\mod \\sqrt{2}", Mod(sqrt(5), sqrt(2))) + assert_equal("987\\mod \\pi", Mod(987, pi)) + assert_equal("\\pi\\mod ((1+\\sqrt{5})/2)", Mod(pi, nsimplify(GoldenRatio)), symbolically=True) + assert_equal("1234\\mod 1E-29", Mod(1234, Number('1E-29'), evaluate=False)) + + +def test_mod_expr(): + assert_equal("1+1\\mod 2", 1 + Mod(1, 2)) + assert_equal("876123\\mod 128\\times 10^300", Mod(876123, 128) * 1E300) + assert_equal("141\\mod 9/3", Rational(Mod(141, 9) / 3)) + assert_equal("872 / (12\\mod 9 * 4) * 2", Rational(2 * 872, (Mod(12, 9) * 4))) + assert_equal("1E-32 * (1E29\\mod 74)", Float('1E-32') * Mod(Float('1E29'), 74)) + assert_equal("299792458\\mod 9.81", Mod(299792458, Float('9.81'))) + + +def test_mod_symbol(): + assert_equal("x\\mod y", Mod(x, y)) + assert_equal("2x\\mod y", Mod(2 * x, y)) + assert_equal("y + 3\\mod 2 / 4", y + Rational(Mod(3, 2), 4), symbolically=True) + assert_equal("0.5x * 2 + \\sqrt{x}\\mod 8y", 0.5 * x * 2 + Mod(sqrt(x), 8 * y), symbolically=True) + assert_equal("6.673E-11 * ((8.85418782E-12\\mod 9x) + 4) / 2y", Float('6.673E-11') * (Mod(Float('8.85418782E-12'), 9 * x) + 4) / (2 * y), symbolically=True) + diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/normalization_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/normalization_test.py new file mode 100644 index 000000000..9a8a0e455 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/normalization_test.py @@ -0,0 +1,33 @@ +from latex2sympy2_extended.math_normalization import NormalizationConfig, normalize_latex + + +def test_units_normalization(): + config = NormalizationConfig( + basic_latex=False, + units=True, + malformed_operators=False, + nits=False, + boxed="none", + equations=False + ) + + # Test basic unit removal + assert normalize_latex("865 \\mbox{ inches}^2", config) == "865" + assert normalize_latex("\\mbox{hello}", config) == "\\mbox{hello}" + assert normalize_latex("1cm", config) == "1" + assert normalize_latex("\\frac{1}{2} cm", config) == "\\frac{1}{2}" + assert normalize_latex("\\frac{1}{2}cm", config) == "\\frac{1}{2}" + assert normalize_latex("1 < ms < 2", config) == "1 < ms < 2" + assert normalize_latex("3s", config) == "3" + +def test_boxed_normalization(): + config = NormalizationConfig( + basic_latex=False, + units=False, + malformed_operators=False, + nits=False, + equations=False, + boxed="all" + ) + + assert normalize_latex("\\boxed{\\left( 3, \\frac{\\pi}{2} \\right)}.", config) == "\\left( 3, \\frac{\\pi}{2} \\right)" diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/overline_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/overline_test.py new file mode 100644 index 000000000..73eb6befe --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/overline_test.py @@ -0,0 +1,9 @@ +from tests.context import assert_equal +import pytest +from sympy import sin, Symbol + +x = Symbol('x') + + +def test_overline(): + assert_equal("\\frac{\\sin(x)}{\\overline{x}_n}", sin(x) / Symbol('bar{x}_n')) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/pi_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/pi_test.py new file mode 100644 index 000000000..d87c5b3ec --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/pi_test.py @@ -0,0 +1,15 @@ +from tests.context import assert_equal, _Mul, _Pow +import pytest +from sympy import pi, Symbol, acos, cos + + +def test_pi_frac(): + assert_equal("\\frac{\\pi}{3}", _Mul(pi, _Pow(3, -1))) + + +def test_pi_nested(): + assert_equal("\\arccos{\\cos{\\frac{\\pi}{3}}}", acos(cos(_Mul(pi, _Pow(3, -1)), evaluate=False), evaluate=False)) + + +def test_pi_arccos(): + assert_equal("\\arccos{-1}", pi, symbolically=True) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/relation_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/relation_test.py new file mode 100644 index 000000000..76d4644cc --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/relation_test.py @@ -0,0 +1,83 @@ +# import pytest +# from sympy import ( +# Symbol, StrictLessThan, LessThan, StrictGreaterThan, GreaterThan, +# Eq, Ne, Contains, And, S +# ) +# from tests.context import assert_equal + +# x = Symbol('x') +# y = Symbol('y') +# z = Symbol('z') +# a = Symbol('a') +# b = Symbol('b') + +# @pytest.mark.parametrize("latex,expected", [ +# # Test less than +# ("x < y", StrictLessThan(x, y)), +# ("x \\lt y", StrictLessThan(x, y)), + +# # Test less than or equal +# ("x \\leq y", LessThan(x, y)), +# ("x \\le y", LessThan(x, y)), + +# # Test greater than +# ("x > y", StrictGreaterThan(x, y)), +# ("x \\gt y", StrictGreaterThan(x, y)), + +# # Test greater than or equal +# ("x \\geq y", GreaterThan(x, y)), +# ("x \\ge y", GreaterThan(x, y)), + +# # Test equality +# ("x = y", Eq(x, y)), +# ("x == y", Eq(x, y)), + +# # Test inequality +# ("x \\neq y", Ne(x, y)), +# ("x \\ne y", Ne(x, y)) +# ]) +# def test_basic_relations(latex, expected): +# assert_equal(latex, expected) + +# @pytest.mark.parametrize("latex,expected", [ +# # Test chained inequalities +# ("x < y < z", And(StrictLessThan(x, y), StrictLessThan(y, z))), +# ("a \\leq x \\leq b", And(LessThan(a, x), LessThan(x, b))), + +# # Test mixed chains +# ("x < y = z", And(StrictLessThan(x, y), Eq(y, z))), +# ("a \\geq x > b", And(GreaterThan(a, x), StrictGreaterThan(x, b))) +# ]) +# def test_chained_relations(latex, expected): +# assert_equal(latex, expected) + +# @pytest.mark.parametrize("latex,expected", [ +# # Test element membership +# ("x \\in {1,2,3}", Contains(x, S.Reals)), +# ("y \\notin {a,b}", Not(Contains(y, S.Reals))), + +# # Test assignment with sets +# ("x = {1,2,3}", Eq(x, S.Reals)) +# ]) +# def test_set_relations(latex, expected): +# assert_equal(latex, expected) + +# @pytest.mark.parametrize("latex,expected", [ +# # Test relations with arithmetic expressions +# ("2x + 1 < 3y", StrictLessThan(2*x + 1, 3*y)), +# ("x^2 \\geq y^2 + 1", GreaterThan(x**2, y**2 + 1)), + +# # Test with fractions and complex expressions +# ("\\frac{x}{y} = z", Eq(x/y, z)), +# ("\\sqrt{x} < y", StrictLessThan(x**0.5, y)) +# ]) +# def test_relations_with_expressions(latex, expected): +# assert_equal(latex, expected) + +# @pytest.mark.parametrize("invalid_latex", [ +# "x < y > z", +# "x = y < z > a" +# ]) +# def test_invalid_relations(invalid_latex): +# with pytest.raises(Exception): +# assert_equal(invalid_latex, None) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/set_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/set_test.py new file mode 100644 index 000000000..c21ad3db2 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/set_test.py @@ -0,0 +1,136 @@ +import pytest +from sympy import ( + Eq, Symbol, FiniteSet, Interval, S, Union, Intersection, + Complement, Contains, Not, Add, Mul, Pow, UnevaluatedExpr, Rational +) +import sympy +from tests.context import assert_equal, _Add, _Mul, _Pow + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +a = Symbol('a') +b = Symbol('b') + +def test_literal_sets(): + # Test built-in sets + assert_equal("\\mathbb{N}", S.Naturals) + assert_equal("\\mathbb{Z}", S.Integers) + assert_equal("\\mathbb{Q}", S.Rationals) + assert_equal("\\mathbb{R}", S.Reals) + assert_equal("\\mathbb{C}", S.Complexes) + # Test empty set + assert_equal("\\emptyset", S.EmptySet) + assert_equal("\\{\\}", S.EmptySet) + assert_equal("∅", S.EmptySet) + +def test_finite_sets(): + # Test basic finite sets + assert_equal("{1,2,3}", FiniteSet(1, 2, 3)) + assert_equal("{x,y,z}", FiniteSet(x, y, z)) + # Test sets with expressions + assert_equal("{x+y, 2x, y^2}", FiniteSet(_Add(x, y), _Mul(2, x), _Pow(y, 2))) + # Test sets with plus-minus notation + assert_equal("{x \\pm y}", FiniteSet(_Add(x, y), _Add(x, _Mul(-1, y)))) + assert_equal("{2 \\pm 1}", FiniteSet(Add(2, -1, evaluate=False), Add(2, 1, evaluate=False))) + assert_equal("1,2,3,4", FiniteSet(1, 2, 3, 4)) + +def test_intervals(): + # Test closed intervals + assert_equal("[a,b]", Interval(a, b)) + assert_equal("[1,2]", Interval(1, 2)) + # # Test open intervals + assert_equal("(a,b)", Interval(a, b, left_open=True, right_open=True)) + assert_equal("(1,2)", Interval(1, 2, left_open=True, right_open=True)) + # # Test half-open intervals + assert_equal("[a,b)", Interval(a, b, right_open=True)) + assert_equal("(a,b]", Interval(a, b, left_open=True)) + +def test_set_operations(): + # Test union + assert_equal("{1,2} \\cup {2,3}", Union(FiniteSet(1, 2), FiniteSet(2, 3), evaluate=False)) + assert_equal("{1,2} ∪ {2,3}", Union(FiniteSet(1, 2), FiniteSet(2, 3), evaluate=False)) + # Test intersection + assert_equal("{1,2} \\cap {2,3}", Intersection(FiniteSet(1, 2), FiniteSet(2, 3), evaluate=False)) + assert_equal("{1,2} ∩ {2,3}", Intersection(FiniteSet(1, 2), FiniteSet(2, 3), evaluate=False)) + # Test set difference + assert_equal("{1,2} \\setminus {2}", Complement(FiniteSet(1, 2), FiniteSet(2), evaluate=False)) + assert_equal("{1,2} ∖ {2}", Complement(FiniteSet(1, 2), FiniteSet(2), evaluate=False)) + +def test_set_relations(): + # Test element membership + assert_equal("x \\in {1,2}", Eq(Symbol("x"), FiniteSet(1, 2), evaluate=False)) + assert_equal("x \\in {1,2,3}", Eq(Symbol("x"), FiniteSet(1, 2, 3), evaluate=False)) + assert_equal("x \\notin {1,2,3}", Not(Eq(Symbol("x"), FiniteSet(1, 2, 3), evaluate=False), evaluate=False)) + # # Test subset relations + assert_equal("{1} \\subseteq {1,2}", FiniteSet(1).is_subset(FiniteSet(1, 2))) + assert_equal("{1,2} \\supseteq {1}", FiniteSet(1).is_subset(FiniteSet(1, 2))) + +def test_complex_set_operations(): + # Test nested set operations + assert_equal("({1,2} \\cup {3,4}) \\cap {2,3}", + Intersection(Union(FiniteSet(1, 2), FiniteSet(3, 4), evaluate=False), + FiniteSet(2, 3), evaluate=False)) + # Test multiple operations + assert_equal("{1,2} \\cup {3,4} \\cup {5,6}", + Union(Union(FiniteSet(1, 2), FiniteSet(3, 4), evaluate=False), FiniteSet(5, 6), evaluate=False)) + assert_equal("{1,2} \\cap {2,3} \\cap {2,4}", + Intersection(Intersection(FiniteSet(1, 2), FiniteSet(2, 3), evaluate=False), FiniteSet(2, 4), evaluate=False)) + # Test mixed operations + assert_equal("({1,2} \\cup {3,4}) \\setminus {2,3}", + Complement(Union(FiniteSet(1, 2), FiniteSet(3, 4), evaluate=False), + FiniteSet(2, 3), evaluate=False)) + +def test_interval_operations(): + # Test interval operations + assert_equal("[0,1] \\cup [2,3]", Union(Interval(0, 1), Interval(2, 3), evaluate=False)) + assert_equal("[0,2] \\cap (1,3)", Intersection(Interval(0, 2), + Interval(1, 3, left_open=True, right_open=True), + evaluate=False)) + assert_equal("(0,1)", Interval(0, 1, left_open=True, right_open=True)) + assert_equal("[0,2] \\setminus (0,1)", Complement(Interval(0, 2), + Interval(0, 1, left_open=True, right_open=True), + evaluate=False)) + +def test_mixed_set_types(): + # Test operations between different set types + assert_equal("{1,2} \\cup [0,3]", Union(FiniteSet(1, 2), Interval(0, 3), evaluate=False)) + assert_equal("(0,1) \\cap {0.5}", Intersection(Interval(0, 1, left_open=True, right_open=True), + FiniteSet(0.5), evaluate=False)) + assert_equal("[0,1] \\setminus {0.5}", Complement(Interval(0, 1), FiniteSet(0.5), evaluate=False)) + +def test_empty_set(): + assert_equal("\\emptyset", S.EmptySet) + assert_equal("\\{\\}", S.EmptySet) + +@pytest.mark.parametrize('input, output', [ + ('$S_{MBCN}:S=7:32$', Eq(Mul(Symbol('s_{mbcn}'), Pow(Symbol('s'), -1)), Rational(7, 32), evaluate=False)), + (r"$(37,3,3,13),(17,3,3,7),(3,37,3,13),(3,17,3,7),(3,3,2,3)$", FiniteSet(sympy.Tuple(37, 3, 3, 13), sympy.Tuple(17, 3, 3, 7), sympy.Tuple(3, 37, 3, 13), sympy.Tuple(3, 17, 3, 7), sympy.Tuple(3, 3, 2, 3))), + ('$(0;0;0),(0;-2;0),(0;0;6),(0;-2;6),(4;0;0),(4;-2;0),(4;0;6),(4;-2;6)$', FiniteSet(sympy.Tuple(0, 0, 0), sympy.Tuple(0, -2, 0), sympy.Tuple(0, 0, 6), sympy.Tuple(0, -2, 6), sympy.Tuple(4, 0, 0), sympy.Tuple(4, -2, 0), sympy.Tuple(4, 0, 6), sympy.Tuple(4, -2, 6))), + ('$1,2;3,4;5,6$', FiniteSet(sympy.Tuple(1, 2), sympy.Tuple(3, 4), sympy.Tuple(5, 6))), + ('$(1,1;2,2)$', sympy.Tuple(1,1,2,2)), + ('${1,2,3}$', FiniteSet(1, 2, 3)), + ('${{1},{2},{3}}$', FiniteSet(FiniteSet(1), FiniteSet(2), FiniteSet(3))), + ('$k = 1,2,3$', Eq(Symbol('k'), FiniteSet(1, 2, 3), evaluate=False)), + ('$k \\in {1,2,3}$', Eq(Symbol('k'), FiniteSet(1, 2, 3), evaluate=False)), + ('$k \\in 1,2,3$', Eq(Symbol('k'), FiniteSet(1, 2, 3), evaluate=False)), + ('$k=1, b=2$', FiniteSet(Eq(Symbol('k'), 1), Eq(Symbol('b'), 2))), + ('$k=1; b=2$', FiniteSet(Eq(Symbol('k'), 1), Eq(Symbol('b'), 2))), +]) +def test_set_of_sets(input, output): + assert_equal(input, output) + + +@pytest.mark.parametrize('input, output', [ + ('a=1 and b=2', FiniteSet(Eq(a, 1), Eq(b, 2))), + ('a=1 or b=2', FiniteSet(Eq(a, 1), Eq(b, 2))), + ('a=1 \\text{ and } b=2', FiniteSet(Eq(a, 1), Eq(b, 2))), + ('a=1 \\text{ or } b=2', FiniteSet(Eq(a, 1), Eq(b, 2))), + (r"$(11,7)or(7,11)$", FiniteSet(sympy.Tuple(11, 7), sympy.Interval(7, 11, left_open=True, right_open=True))), + (r"$z \\in \left[-\frac{3}{2}, -1\right] \cup \left[1, \\frac{3}{2}\right]$", Eq(z, Union(Interval(-Rational(3, 2), -1), Interval(1, Rational(3, 2)), evaluate=False), evaluate=False)), + (r"\boxed{ p = 5, q = 2 \quad ; \quad p = 7, q = 2}", FiniteSet(sympy.Tuple(Eq(Symbol('p'), 5), Eq(Symbol('q'), 2)), sympy.Tuple(Eq(Symbol('p'), 7), Eq(Symbol('q'), 2)))), + (r"\boxed{1}", 1), +]) +def test_and_or_text(input, output): + assert_equal(input, output) + diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/symbols_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/symbols_test.py new file mode 100644 index 000000000..fbe09db27 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/symbols_test.py @@ -0,0 +1,22 @@ +import pytest +import sympy +from latex2sympy2_extended.latex2sympy2 import latex2sympy +from tests.context import assert_equal, get_simple_examples + + +@pytest.mark.parametrize('input, output, symbolically', [ + ('\\text{a}', 'a', True), + ('\\text{(b)}', 'b', True), + ('\\textit{c}', '\\text{c}', True), + ('\\textbf{i}', 'i', True), + ('\\mbox{hello}', '\\text{hello}', True), + ('E', 'E', False), + ('e', 'e', False), + # Would be awesome to get this working one day + # ('\\text{Even}', 'even', True), +]) + +def test_symbol(input, output, symbolically): + input_parsed = latex2sympy(input) + output_parsed = latex2sympy(output) + assert input_parsed == output_parsed diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/trig_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/trig_test.py new file mode 100644 index 000000000..e22d3f117 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/trig_test.py @@ -0,0 +1,21 @@ +from tests.context import assert_equal +import pytest +from sympy import asinh, Symbol + +# x = Symbol('x', real=True); + +# latex = "\\sinh(x)" +# math = process_sympy(latex) +# print("latex: %s to math: %s" %(latex,math)) +# +# latex = "\\arcsinh(x)" +# math = process_sympy(latex) +# print("latex: %s to math: %s" %(latex,math)) +# +# latex = "\\arsinh(x)" +# math = process_sympy(latex) +# print("latex: %s to math: %s" %(latex,math)) + + +def test_arcsinh(): + assert_equal("\\operatorname{arcsinh}\\left(1\\right)", asinh(1, evaluate=False)) diff --git a/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/variable_test.py b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/variable_test.py new file mode 100644 index 000000000..65535e551 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/latex2sympy2_extended/tests/variable_test.py @@ -0,0 +1,91 @@ +from tests.context import assert_equal +import pytest +from sympy import Rational, UnevaluatedExpr, Symbol, Mul, Pow, Max, Min, gcd, lcm, floor, ceiling + +x = Symbol('x') +y = Symbol('y') + + +def test_variable_letter(): + assert_equal("\\variable{x}", Symbol('x')) + + +def test_variable_digit(): + assert_equal("\\variable{1}", Symbol('1')) + + +def test_variable_letter_subscript(): + assert_equal("\\variable{x_y}", Symbol('x_y')) + + +def test_variable_letter_comma_subscript(): + assert_equal("\\variable{x_{i,j}}", Symbol('x_{i,j}')) + + +def test_variable_digit_subscript(): + assert_equal("\\variable{x_1}", Symbol('x_1')) + + +def test_variable_after_subscript_required(): + with pytest.raises(Exception): + assert_equal("\\variable{x_}", Symbol('x_')) + + +def test_variable_before_subscript_required(): + with pytest.raises(Exception): + assert_equal("\\variable{_x}", Symbol('_x')) + + +def test_variable_bad_name(): + with pytest.raises(Exception): + assert_equal("\\variable{\\sin xy}", None) + + +def test_variable_in_expr(): + assert_equal("4\\cdot\\variable{x}", 4 * Symbol('x')) + + +def test_variable_greek_letter(): + assert_equal("\\variable{\\alpha }\\alpha", Symbol('\\alpha ') * Symbol('alpha')) + + +def test_variable_greek_letter_subscript(): + assert_equal("\\variable{\\alpha _{\\beta }}\\alpha ", Symbol('\\alpha _{\\beta }') * Symbol('alpha')) + + +def test_variable_bad_unbraced_long_subscript(): + with pytest.raises(Exception): + assert_equal("\\variable{x_yz}", None) + + +def test_variable_bad_unbraced_long_complex_subscript(): + with pytest.raises(Exception): + assert_equal("\\variable{x\\beta 10_y\\alpha 20}", None) + + +def test_variable_braced_subscript(): + assert_equal("\\variable{x\\beta 10_{y\\alpha 20}}", Symbol('x\\beta 10_{y\\alpha 20}')) + + +def test_variable_complex_expr(): + assert_equal("4\\cdot\\variable{value1}\\frac{\\variable{value_2}}{\\variable{a}}\\cdot x^2", 4 * Symbol('value1') * Symbol('value_2') / Symbol('a') * x**2) + + +def test_variable_dollars(): + assert_equal("\\$\\variable{x}", Symbol('x')) + + +def test_variable_percentage(): + assert_equal("\\variable{x}\\%", Symbol('x') * Rational(1, 100)) + + +def test_variable_single_arg_func(): + assert_equal("\\floor(\\variable{x})", floor(Symbol('x'))) + assert_equal("\\ceil(\\variable{x})", ceiling(Symbol('x'))) + + +def test_variable_multi_arg_func(): + assert_equal("\\gcd(\\variable{x}, \\variable{y})", UnevaluatedExpr(gcd(Symbol('x'), Symbol('y')))) + assert_equal("\\lcm(\\variable{x}, \\variable{y})", UnevaluatedExpr(lcm(Symbol('x'), Symbol('y')))) + assert_equal("\\max(\\variable{x}, \\variable{y})", Max(Symbol('x'), Symbol('y'), evaluate=False)) + assert_equal("\\min(\\variable{x}, \\variable{y})", Min(Symbol('x'), Symbol('y'), evaluate=False)) diff --git a/vlmeval/dataset/PHYSICS/reward_manager.py b/vlmeval/dataset/PHYSICS/reward_manager.py new file mode 100644 index 000000000..f4d661ed8 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/reward_manager.py @@ -0,0 +1,117 @@ +import asyncio +from concurrent.futures import ProcessPoolExecutor +from functools import partial +import psutil +from collections import defaultdict +from reward_score import compute_score + + +async def single_compute_score(pred, gold, problem, executor, timeout): + loop = asyncio.get_running_loop() + try: + future = loop.run_in_executor( + executor, + partial(compute_score, pred, gold, problem) + ) + return await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + print(f"[Timeout] Task timeout after {timeout}s: {pred[:80]}") + return None + except Exception as e: + print(f"[Error] Task failed: {e}, pred: {pred[:80]}") + return None + +async def parallel_compute_score_async(preds, golds, problems, num_processes, timeout): + results = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + try: + tasks_async = [ + single_compute_score(pred, gold, problem, executor, timeout) + for pred, gold, problem in zip(preds, golds, problems) + ] + results = await asyncio.gather(*tasks_async, return_exceptions=False) + print("[Success] All tasks gathered.") + except Exception as e: + print(f"[Exception] async gather failed: {e}") + raise + finally: + print("[Shutdown] Cleaning up remaining subprocesses...") + terminated_count = 0 + for pid, proc in executor._processes.items(): + try: + p = psutil.Process(pid) + p.terminate() + try: + p.wait(timeout=5) + except psutil.TimeoutExpired: + p.kill() + terminated_count += 1 + except Exception: + pass + print(f"[Shutdown] {terminated_count} subprocess(es) terminated.") + + # Format results + formatted = [] + for result in results: + if isinstance(result, Exception) or result is None: + formatted.append({ + "score": 0., + "acc": False, + "extracted_gt": None, + "extracted_pred": None, + }) + elif isinstance(result, dict): + formatted.append(result) + else: + formatted.append(result[0]) + return formatted + +def run_reward_scoring(preds, golds, problems, num_processes=64, timeout=300.): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(parallel_compute_score_async( + preds, golds, problems, num_processes, timeout + )) + finally: + loop.close() + + + +def verifier_manager(data, return_dict: bool = True): + assert isinstance(data, list), f"data should be a list, but got {type(data)}" + + problems = [item['problem'] for item in data] + preds = [item['response'] for item in data] + golds = [item['answer'] for item in data] + + + try: + results = run_reward_scoring( + preds=preds, + golds=golds, + problems=problems, + num_processes=64, + timeout=300., + ) + except asyncio.TimeoutError as e: + print('Global timeout in reward computing! Setting all as 0.') + results = [{ + "score": 0., + "acc": False, + "extracted_gt": None, + "extracted_pred": None, + } for _ in range(len(data))] + except Exception as e: + print(f"Unexpected error in batched reward computing. Setting all as 0.: {e}") + results = [{ + "score": 0., + "acc": False, + "extracted_gt": None, + "extracted_pred": None, + } for _ in range(len(data))] + + if return_dict: + return results + else: + return [result['score'] for result in results] diff --git a/vlmeval/dataset/PHYSICS/reward_score.py b/vlmeval/dataset/PHYSICS/reward_score.py new file mode 100644 index 000000000..802ebeb52 --- /dev/null +++ b/vlmeval/dataset/PHYSICS/reward_score.py @@ -0,0 +1,1153 @@ +# Copyright 2025 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides a math answer grading function with high recall. +Based on HF math_verify, verl, open reasoner zero, etc. +""" + +import os +import re +import signal +import math +import time +import traceback +from openai import OpenAI +from functools import wraps, partial +from itertools import islice, zip_longest +from typing import Optional, Union +from pylatexenc import latex2text +from decimal import Decimal, localcontext +import sympy +from sympy import N, Pow, Mul +from sympy.parsing import sympy_parser +from math_verify import (ExprExtractionConfig, LatexExtractionConfig, parse, verify) + + +def timeout(timeout_seconds: int = 10): + if os.name == "posix": + import signal + def decorator(func): + def handler(signum, frame): + raise TimeoutError("verify timed out!") + def wrapper(*args, **kwargs): + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout_seconds) + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + return wrapper + return decorator + + +# units mainly from MathQA +unit_texts = [ + "east", + "degree", + "mph", + "kmph", + "ft", + "m sqaure", + " m east", + "sq m", + "deg", + "mile", + "q .", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "p . m", + "lb", + "tile", + "per", + "lt", + "gain", + "ab", + "way", + "west", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "percent", + "by", + "gal", + "kmh", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "feet", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "month", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", + "kilogram", + "second", + "ampere", + "A", + "K", + "mol", + "cd", + "N", + "J", + "W", + "Pa", + "Hz", + "C", + "V", + "Ω", + "F", + "T", + "H", + "eV", + "kW·h", + "atm", + "bar", + "°C" +] +unit_texts.extend([t + "s" for t in unit_texts]) + +def _strip_string(string): + def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + # linebreaks + string = string.replace("\n", "") + # print(string) + + # remove inverse spaces + string = string.replace("\\!", "") + # print(string) + + # replace \\ with \ + string = string.replace("\\\\", "\\") + # print(string) + + # matrix + string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) + string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) + string = string.replace("bmatrix", "pmatrix") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = ( + string.replace("\\neq", "\\ne") + .replace("\\leq", "\\le") + .replace("\\geq", "\\ge") + ) + # print(string) + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + # print(string) + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + # Remove unit: texts + for _ in range(2): + for unit_text in unit_texts: + # use regex, the prefix should be either the start of the string or a non-alphanumeric character + # the suffix should be either the end of the string or a non-alphanumeric character + _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) + if _string != "": + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def _strip_properly_formatted_commas(expr: str): + # We want to be careful because we don't want to strip tuple commas + p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + while True: + next_expr = p1.sub("\\1\\3\\4", expr) + if next_expr == expr: + break + expr = next_expr + return next_expr + +def _is_float(num: str) -> bool: + try: + float(num) + return True + except ValueError: + return False + +def _is_int(x: float) -> bool: + try: + return abs(x - int(round(x))) <= 1e-7 + except: + return False + +def _is_frac(expr: str) -> bool: + return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) + +def _str_is_int(x: str) -> bool: + try: + x = _strip_properly_formatted_commas(x) + x = float(x) + return abs(x - int(round(x))) <= 1e-7 + except: + return False + +def _str_to_int(x: str) -> bool: + x = x.replace(",", "") + x = float(x) + return int(x) + +def _inject_implicit_mixed_number(step: str): + """ + Automatically make a mixed number evalable + e.g. 7 3/4 => 7+3/4 + """ + p1 = re.compile("([0-9]) +([0-9])") + step = p1.sub("\\1+\\2", step) ## implicit mults + return step + +def _parse_latex(expr: str) -> str: + """Attempts to parse latex to an expression sympy can read.""" + expr = expr.replace("\\tfrac", "\\frac") + expr = expr.replace("\\dfrac", "\\frac") + expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. + expr = latex2text.LatexNodes2Text().latex_to_text(expr) + + # Replace the specific characters that this parser uses. + expr = expr.replace("√", "sqrt") + expr = expr.replace("π", "pi") + expr = expr.replace("∞", "inf") + expr = expr.replace("∪", "U") + expr = expr.replace("·", "*") + expr = expr.replace("×", "*") + + return expr.strip() + + +# Dan Hendrycks' code +def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]: + if answer is None: + return None + + try: + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", answer) + if m is not None: + answer = m.group("text").strip() + return _strip_string(answer) + except: + return answer + +def sympy_normalize_answer(expr: str) -> str: + """Normalize answer expressions.""" + if expr is None: + return None + + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", expr) + if m is not None: + expr = m.group("text") + + expr = expr.replace("\\%", "%") + expr = expr.replace("\\$", "$") + expr = expr.replace("$", "") + expr = expr.replace("%", "") + expr = expr.replace(" or ", " , ") + expr = expr.replace(" and ", " , ") + + expr = expr.replace("million", "*10^6") + expr = expr.replace("billion", "*10^9") + expr = expr.replace("trillion", "*10^12") + + for _ in range(2): + for unit_text in unit_texts: + # use regex, the prefix should be either the start of the string or a non-alphanumeric character + # the suffix should be either the end of the string or a non-alphanumeric character + _expr = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", expr) + if _expr != "": + expr = _expr + + expr = re.sub(f"\^ *\\\\circ", "", expr) + + if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": + expr = expr[1:-1] + + expr = re.sub(",\\\\! *", "", expr) + if _is_float(expr) and _is_int(float(expr)): + expr = str(int(round(float(expr)))) + if "\\" in expr: + try: + expr = _parse_latex(expr) + except: + pass + + # edge case with mixed numbers and negative signs + expr = re.sub("- *", "-", expr) + + expr = _inject_implicit_mixed_number(expr) + expr = expr.replace(" ", "") + + # if we somehow still have latex braces here, just drop them + expr = expr.replace("{", "") + expr = expr.replace("}", "") + + # don't be case sensitive for text answers + expr = expr.lower() + + if _str_is_int(expr): + expr = str(_str_to_int(expr)) + + return expr + + +def judge_MC(pred, gold): + common_answer = [chr(i) for i in range(65, 91)] # 'A'~'Z' + if pred == gold: + return True + else: + if pred.startswith("[") and pred.endswith("]"): + pred = pred.strip("[]") + if not pred: + return False + if pred[0] in common_answer and (len(pred) > 1 and (pred[1] == "." or pred[1] == ":")): + return pred[0] == gold + if f"'{gold}'" in pred: + return True + else: + return False + +def judge_TF(pred, gold): + def contains_chinese(d): + def is_chinese_char(ch): + return '\u4e00' <= ch <= '\u9fff' + + def check(value): + if isinstance(value, str): + return any(is_chinese_char(ch) for ch in value) + elif isinstance(value, dict): + return any(check(v) for v in value.values()) + elif isinstance(value, list): + return any(check(item) for item in value) + return False + + return check(d) + + if contains_chinese(pred): + if pred in ["是", "对", "正确", "能"]: + pred = "TRUE" + elif pred in ["否", "错", "错误", "不能"]: + pred = "FALSE" + else: + pred = pred.upper() + answers = ["TRUE", "FALSE", "T", "F", "YES", "NO", "Y", "N"] + gold = gold.upper() + if gold not in answers or pred not in answers: + return False + if gold in ["TRUE", "YES", "T", "Y"]: + gold = "TRUE" + if gold in ["FALSE", "NO", "F", "N"]: + gold = "FALSE" + if pred in ["TRUE", "YES", "T", "Y"]: + pred = "TRUE" + if pred in ["FALSE", "NO", "F", "N"]: + pred = "FALSE" + return pred == gold + + +def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool: + ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) + given_answer_normalized_mathd = mathd_normalize_answer(given_answer) + # be at least as lenient as mathd + if ground_truth_normalized_mathd == given_answer_normalized_mathd: + return True, given_answer_normalized_mathd, ground_truth_normalized_mathd + return False, given_answer_normalized_mathd, ground_truth_normalized_mathd + + +# sympy might hang -- we don't care about trying to be lenient in these cases +BAD_SUBSTRINGS = ["^{", "^("] +BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +TUPLE_CHARS = "()[]" +def split_tuple(expr: str): + """ + Split the elements in a tuple/interval, while handling well-formatted commas in large numbers + """ + expr = _strip_properly_formatted_commas(expr) + if len(expr) == 0: + return [] + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): + elems = [elem.strip() for elem in expr[1:-1].split(",")] + else: + elems = [expr] + return elems + + +def _sympy_parse(expr: str): + """Parses an expression with sympy.""" + py_expr = expr.replace("^", "**") + return sympy_parser.parse_expr( + py_expr, + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), + evaluate=False, + ) + +def should_allow_eval(expr: str): + def count_unknown_letters_in_expr(expr: str): + expr = expr.replace("sqrt", "") + expr = expr.replace("frac", "") + letters_in_expr = set([x for x in expr if x.isalpha()]) + return len(letters_in_expr) + # we don't want to try parsing unknown text or functions of more than two variables + if count_unknown_letters_in_expr(expr) > 2: + return False + + for bad_string in BAD_SUBSTRINGS: + if bad_string in expr: + return False + + for bad_regex in BAD_REGEXES: + if re.search(bad_regex, expr) is not None: + return False + + return True + +def handle_pi(string, pi): + if isinstance(string, str) and "pi" in string: + # Find the first occurrence of "\pi" + idx = string.find("pi") + + # Iterate over the string and find all occurrences of "\pi" with a valid previous character + while idx != -1: + + if idx > 0 and string[idx - 1].isdigit(): + # Replace "\pi" with "*math.pi" if the previous character is a digit + string = string[:idx] + f"*{pi}" + string[idx + 2:] + else: + # Replace "\pi" with "1*math.pi" if the previous character is not a digit + string = string[:idx] + f"1*{pi}" + string[idx + 2:] + + # Find the next occurrence of "\pi" + idx = string.find("pi", idx + 1) + + # Evaluate the expression using eval() function + try: + string = eval(string) + except: + pass + + return string + +@timeout(timeout_seconds=30) +def are_equal_under_sympy(gold: str, pred: str, precision: float = 5e-3): + def is_scientific_notation(expr): + return ( + isinstance(expr, Mul) + and isinstance(expr.args[1], Pow) + and expr.args[1].args[0] == 10 + ) + + def to_scientific_notation_sympy(num): + num_sci = f"{num:.2e}" # e.g., "1.23e-5" + base, exponent = num_sci.split("e") + return f"{base}*10**{int(exponent)}" + + def count_decimal_places(x, tol=1e-6): + """ + 返回浮点数 x 的有效小数位数,只保留重要前几位,忽略接近 0 的浮点尾巴。 + """ + with localcontext() as ctx: + ctx.prec = 20 # 高精度防止误差 + d = Decimal(str(x)).normalize() + s = format(d, "f") # 固定点格式 + if "." not in s: + return 0 + integer_part, decimal_part = s.split(".") + + # 去掉右侧全是0或接近0的部分(人为容差) + clean_decimal = "" + for i, ch in enumerate(decimal_part): + test_val = float("0." + decimal_part[:i+1]) + if abs(x - round(x, i+1)) <= tol: + break + clean_decimal += ch + return len(clean_decimal) + + try: + if pred == gold: + return True + + # 尝试转为 float 后做相对误差比较 + pred_value = float(pred) + gold_value = float(gold) + min_decimal_places = min(count_decimal_places(gold_value), count_decimal_places(pred_value)) + + pred_value = round(pred_value, min_decimal_places) + gold_value = round(gold_value, min_decimal_places) + if abs((pred_value - gold_value) / gold_value) <= precision * 1.01: + return True + + # 转为科学记数法后转 sympy 表达式 + spred = _sympy_parse(to_scientific_notation_sympy(float(pred))) + sgold = _sympy_parse(to_scientific_notation_sympy(float(gold))) + if is_scientific_notation(spred) and is_scientific_notation(sgold): + base_pred, exponent_pred = N(spred.args[0]), N(spred.args[1].args[1]) + base_gold, exponent_gold = N(sgold.args[0]), N(sgold.args[1].args[1]) + min_decimal_places = min(count_decimal_places(base_gold), count_decimal_places(base_pred)) + base_pred = round(base_pred, min_decimal_places) + base_gold = round(base_gold, min_decimal_places) + if exponent_pred == exponent_gold and abs(base_pred - base_gold) <= precision * 1.01: + print(1.0 / (10. ** min_decimal_places)) + return True + except Exception: + pass + + # 如果上面都失败,退回原始符号化处理(但注意保留结构) + try: + if should_allow_eval(gold) and should_allow_eval(pred): + exp_gold = _sympy_parse(gold) + exp_pred = _sympy_parse(pred) + + expr = (exp_gold - exp_pred) / (exp_gold if exp_gold != 0 else 1) + simplified = sympy.simplify(expr) + if abs(N(simplified)) <= precision * 1.01: + return True + if is_scientific_notation(exp_pred) != is_scientific_notation(exp_gold): + if is_scientific_notation(exp_pred): + gold = to_scientific_notation_sympy(float(gold)) + exp_gold = _sympy_parse(gold) + else: + pred = to_scientific_notation_sympy(float(pred)) + exp_pred = _sympy_parse(pred) + + if is_scientific_notation(exp_pred) and is_scientific_notation(exp_gold): + base_pred, exponent_pred = N(exp_pred.args[0]), N(exp_pred.args[1].args[1]) + base_gold, exponent_gold = N(exp_gold.args[0]), N(exp_gold.args[1].args[1]) + min_decimal_places = min(count_decimal_places(base_gold), count_decimal_places(base_pred)) + base_pred = round(base_pred, min_decimal_places) + base_gold = round(base_gold, min_decimal_places) + + if exponent_pred == exponent_gold and abs(base_pred - base_gold) <= precision * 1.01: + return True + else: + if N(exp_pred) == N(exp_gold): + return True + except Exception: + pass + + return False + +def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: + ground_truth_normalized = sympy_normalize_answer(ground_truth) + given_normalized = sympy_normalize_answer(given_answer) + if ground_truth_normalized is None: + return False, given_normalized, ground_truth_normalized + + if ground_truth_normalized == given_normalized: + return True, given_normalized, ground_truth_normalized + + if len(given_normalized) == 0: + return False, given_normalized, ground_truth_normalized + + ground_truth_elems = split_tuple(ground_truth_normalized) + given_elems = split_tuple(given_normalized) + + if len(ground_truth_elems) > 1 and ( + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] + ): + is_correct = False + elif len(ground_truth_elems) != len(given_elems): + is_correct = False + else: + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): + if _is_frac(ground_truth_elem) and _is_frac(given_elem): + # if fractions aren't reduced, then shouldn't be marked as correct + # so, we don't want to allow sympy.simplify in this case + is_correct = ground_truth_elem == given_elem + # elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): + # # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) + # is_correct = False + else: + is_correct = judge_MC(given_elem, ground_truth_elem) or judge_TF(given_elem, ground_truth_elem) + if not is_correct: + if "pi" in given_elem or "pi" in ground_truth_elem: + equivs = [] + for pi in [math.pi, 3.14, 180]: + given_elem_pi = handle_pi(given_elem, pi) + ground_truth_elem_pi = handle_pi(ground_truth_elem, pi) + try: + equivs.append(are_equal_under_sympy(ground_truth_elem_pi, given_elem_pi)) + except TimeoutError: + equivs.append(False) + is_correct = any(equivs) + else: + try: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + except TimeoutError: + is_correct = False + if not is_correct: + break + + return is_correct, given_normalized, ground_truth_normalized + + +def repeatness(s: str): + def ranks(l): + index = {v: i for i, v in enumerate(sorted(set(l)))} + return [index[v] for v in l] + + def suffixArray(s): + line = ranks(s) + n, k, ans, sa = len(s), 1, line, [0] * len(s) + while k < n - 1: + line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) + ans, k = line, k << 1 + for i, k in enumerate(ans): + sa[k] = i + return ans, sa + + def lcp(arr, suffixArr, inv_suff): + n, ans, k = len(arr), [0] * len(arr), 0 + + for i in range(n): + if inv_suff[i] == n - 1: + k = 0 + continue + + j = suffixArr[inv_suff[i] + 1] + while i + k < n and j + k < n and arr[i + k] == arr[j + k]: + k += 1 + + ans[inv_suff[i]] = k + if k > 0: + k -= 1 + + return ans + + arr = [ord(i) for i in s] + n = len(arr) + if n <= 1: + return 0 + c, sa = suffixArray(arr) + cnt = sum(lcp(arr, sa, c)) + + return (cnt * 2 / (n * (n + 1))) > 0.2 + +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + +def grade_answer_math_verify(given_answer: str, ground_truth: str) -> bool: + try: + with timeout(1): + try: + if (len(given_answer) > 128 and repeatness(given_answer)) or ( + len(ground_truth) > 128 and repeatness(ground_truth) + ): + return False, given_answer, ground_truth + + # Next call math verify. + given_answer.replace("\n", "") + ground_truth.replace("\n", "") + if not "$" in given_answer: + given_answer = f"${given_answer}$" + if not "$" in ground_truth: + ground_truth = f"${ground_truth}$" + given_answer = parse( + given_answer, + extraction_config=( + LatexExtractionConfig(boxed_match_priority=0), + ExprExtractionConfig(), + ), + fallback_mode="no_fallback", + extraction_mode=["first_match"], + parsing_timeout=1, + ) + ground_truth = parse( + ground_truth, + extraction_config=( + LatexExtractionConfig(boxed_match_priority=0), + ExprExtractionConfig(), + ), + fallback_mode="no_fallback", + extraction_mode=["first_match"], + parsing_timeout=1, + ) + return verify( + ground_truth, + given_answer, + numeric_precision=3, + timeout_seconds=1, + ), given_answer, ground_truth + # or symbolic_equal(ground_truth, given_answer) + except Exception: + return False, given_answer, ground_truth + except TimeoutError: + return False, given_answer, ground_truth + + +def attach_wrapper(obj, func=None): + if func is None: + return partial(attach_wrapper, obj) + setattr(obj, func.__name__, func) + return func + +def retry(max_attempts:int=3, delay:int=1, print_trace_back=False, return_error_info=False): + assert isinstance(max_attempts, int) and isinstance(delay, int), '参数必须是整数' + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + attempts = 0 + while attempts < max_attempts: + try: + return func(*args, **kwargs) + except Exception: + if print_trace_back: + e = traceback.format_exc() + error_info = f">>>函数{func.__name__}第{attempts + 1}次尝试失败,报错信息为: {e}" + print(error_info) + time.sleep(delay) + attempts += 1 + if return_error_info: + return error_info + else: + return None + + @attach_wrapper(wrapper) + def set_max_attempts(new_max_attempts): + nonlocal max_attempts + max_attempts = new_max_attempts + + @attach_wrapper(wrapper) + def set_delay(new_delay): + nonlocal delay + delay = new_delay + + wrapper.get_attempts = lambda: max_attempts + wrapper.get_delay = lambda: delay + return wrapper + return decorator + +class Model_args: + use_model: bool = True + model_name = '/your_model_path/Physics-xVerify' + api_key = 'your-api-key' + base_url = 'http://localhost:8000/v1' + temperature = 0.1 +def grade_answer_xverify(given_answer: str, ground_truth: str, problem: str, **judge_kwargs) -> bool: + @retry(max_attempts=2, delay=1, print_trace_back=True, return_error_info=True) + def call_api(prompt:str, + system_prompt:Optional[str]=None, + client=None, + base_url:Optional[str]=None, + model:str="gpt-4o", # default + api_key:Union[None,str]=None, + max_tokens:int=None, + temperature:float=0.7, + logprobs:bool=False, + top_logprobs:int=1, + **kwargs) -> str: + if not client: + assert api_key is not None,'Please input your api key' + client = OpenAI( + api_key=api_key, + base_url=base_url + ) + if not logprobs: + top_logprobs = None + + messages = [{"role": "system", "content": system_prompt}] if system_prompt is not None else [] + if prompt: + messages.append({"role": "user", "content": prompt}) + model = judge_kwargs.get('model', "gpt-4o") + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + logprobs=logprobs, + top_logprobs=top_logprobs, + **kwargs + ) + return response.choices[0].message.content + + client = OpenAI(api_key=model_args.api_key, base_url=model_args.base_url) + prompt = f''' +You are a diligent and precise assistant tasked with evaluating the correctness of responses. You will receive a question, an output sentence, and the correct answer. Your task is to determine if the output sentence accurately answers the question based on the provided correct answer. Respond with either [Correct] or [Incorrect]. + +Special considerations: +1. **Multiple Answers**: If the output contains multiple answers, evaluate whether later answers modify or correct earlier ones. In such cases, compare the final answer with the correct answer. If the final answer is unclear or incorrect, respond with [Incorrect]. +2. **Mathematical Problems**: If the formats differ but the answers are mathematically equivalent, respond with [Correct]. +3. **Explicit Options**: If the question provides explicit candidate answers, the output will be considered correct if it clearly indicates the correct option’s code or the correct option’s content. +4. **No Explicit Options**: If the question does not provide explicit options, the output must align with the correct answer in content and meaning to be considered [Correct]. +Question: {problem}, Output sentence: {given_answer}, Correct answer: {ground_truth}, Judgement: +''' + correct = call_api(prompt=prompt, + client=client, + model=model_args.model_name, + temperature=model_args.temperature) + return correct.strip() == "Correct" + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + +def remove_boxed(s): + left = "\\boxed{" + try: + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + except: + return None + +def extract_boxed_answer(solution: str) -> str: + """Extract the answer from inside a LaTeX \\boxed{} command""" + solution = last_boxed_only_string(solution) + solution = remove_boxed(solution) + return solution + +def grade(model_answer: str, gt_answer: str): + if "\\boxed" in gt_answer: + gt_answer = extract_boxed_answer(gt_answer) + if model_answer is None or gt_answer is None: + return False, model_answer, gt_answer + + correct, pred, gold = grade_answer_mathd(model_answer, gt_answer) + extracted_pred = pred + extracted_gt = gold + + if not correct: + correct, pred, gold = grade_answer_sympy(model_answer, gt_answer) + extracted_pred = pred + extracted_gt = gold + + if not correct: + correct, pred, gold = grade_answer_math_verify(model_answer, gt_answer) + extracted_pred = pred + extracted_gt = gold + + return correct, extracted_pred, extracted_gt + + +def answer_tag_reward_fn_for_r1(model_output: str, ground_truth: str, problem=None, **judge_kwargs): + extracted_pred = model_output + is_matched = False + + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = re.findall(pattern, extracted_pred) + if len(matches) > 0: + is_matched = True + extracted_pred = matches[-1] + + if "" in extracted_pred: + is_matched = True + extracted_pred = extracted_pred.split("")[-1] + + for ans_marker in ["**answer:**", 'answer:', "answer is", "answers are"]: + ans_idx = extracted_pred.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + extracted_pred = extracted_pred[ans_idx + len(ans_marker):].strip() + if extracted_pred.endswith("\n"): + extracted_pred = extracted_pred[:-2] + + for ans_marker in ["is answer", "is the answer", "are answers", "are the answers"]: + ans_idx = extracted_pred.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + extracted_pred = extracted_pred[:ans_idx].strip() + if extracted_pred.endswith("\n"): + extracted_pred = extracted_pred[:-2] + + if "\\boxed" in extracted_pred: + is_matched = True + extracted_pred = extract_boxed_answer(extracted_pred) + + if extracted_pred is not None: + bold_pattern = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) + bold_matches = re.findall(bold_pattern, extracted_pred) + if bold_matches: + is_matched = True + extracted_pred = bold_matches[-1] + + if "\\boxed" in ground_truth: + gt = extract_boxed_answer(ground_truth) + if gt is not None: + ground_truth = gt + + try: + is_correct, extracted_pred, extracted_gt = grade(extracted_pred, ground_truth) + except Exception as e: + is_correct = False + extracted_gt = ground_truth + + rule_based_acc = is_correct + + if not is_correct: + is_correct = grade_answer_xverify(model_output, ground_truth, problem, **judge_kwargs) + return rule_based_acc, is_correct, extracted_pred, extracted_gt + + +def compute_score(model_output: str, ground_truth: str, question: str, **judge_kwargs): + try: + if isinstance(ground_truth, list): + gold = '' + for a in ground_truth: + for b in a: + if '\\boxed' in b: + gold += extract_boxed_answer(b) + ', ' + else: + gold += b + ', ' + ground_truth = gold[:-2] if gold.endswith(", ") else gold + except Exception as e: + ground_truth = ground_truth[0] + + rule_based_acc, acc, extracted_pred, extracted_gt = answer_tag_reward_fn_for_r1(model_output, ground_truth, question, **judge_kwargs) + return { + "rule_based_acc": rule_based_acc, + "acc": acc, + "extracted_gt": str(extracted_gt), + "extracted_pred": str(extracted_pred) + } \ No newline at end of file diff --git a/vlmeval/dataset/__init__.py b/vlmeval/dataset/__init__.py index d52fe0d89..77c4ed7f9 100644 --- a/vlmeval/dataset/__init__.py +++ b/vlmeval/dataset/__init__.py @@ -21,6 +21,7 @@ from .image_shortqa import ImageShortQADataset, PathVQA_VAL, PathVQA_TEST from .text_mcq import CustomTextMCQDataset, TextMCQDataset +from .PHYSICS import PHYSICS from .vcr import VCRDataset from .mmlongbench import MMLongBench from .dude import DUDE @@ -228,7 +229,7 @@ def evaluate(self, eval_file, **judge_kwargs): ] TEXT_DATASET = [ - TextMCQDataset + TextMCQDataset, PHYSICS ] CUSTOM_DATASET = [