Skip to content

Commit dec6098

Browse files
authored
Merge pull request #3408 from Matiiss/matiiss-fix-vector-init-from-numpy-arrays
Fix `pygame.Vector{2,3}` initialization from numpy arrays
2 parents 6c5a50a + 92dd2d6 commit dec6098

File tree

2 files changed

+129
-17
lines changed

2 files changed

+129
-17
lines changed

src_c/math.c

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2280,19 +2280,23 @@ static int
22802280
_vector2_set(pgVector *self, PyObject *xOrSequence, PyObject *y)
22812281
{
22822282
if (xOrSequence) {
2283-
if (RealNumber_Check(xOrSequence)) {
2284-
self->coords[0] = PyFloat_AsDouble(xOrSequence);
2285-
/* scalar constructor. */
2286-
if (y == NULL) {
2287-
self->coords[1] = self->coords[0];
2283+
if (pgVectorCompatible_Check(xOrSequence, self->dim)) {
2284+
if (!PySequence_AsVectorCoords(xOrSequence, self->coords, 2)) {
2285+
return -1;
2286+
}
2287+
else {
22882288
return 0;
22892289
}
22902290
}
2291-
else if (pgVectorCompatible_Check(xOrSequence, self->dim)) {
2292-
if (!PySequence_AsVectorCoords(xOrSequence, self->coords, 2)) {
2291+
else if (RealNumber_Check(xOrSequence)) {
2292+
self->coords[0] = PyFloat_AsDouble(xOrSequence);
2293+
if (self->coords[0] == -1.0 && PyErr_Occurred()) {
22932294
return -1;
22942295
}
2295-
else {
2296+
2297+
/* scalar constructor. */
2298+
if (y == NULL) {
2299+
self->coords[1] = self->coords[0];
22962300
return 0;
22972301
}
22982302
}
@@ -2323,6 +2327,9 @@ _vector2_set(pgVector *self, PyObject *xOrSequence, PyObject *y)
23232327

23242328
if (RealNumber_Check(y)) {
23252329
self->coords[1] = PyFloat_AsDouble(y);
2330+
if (self->coords[1] == -1.0 && PyErr_Occurred()) {
2331+
return -1;
2332+
}
23262333
}
23272334
else {
23282335
goto error;
@@ -2718,20 +2725,24 @@ static int
27182725
_vector3_set(pgVector *self, PyObject *xOrSequence, PyObject *y, PyObject *z)
27192726
{
27202727
if (xOrSequence) {
2721-
if (RealNumber_Check(xOrSequence)) {
2722-
self->coords[0] = PyFloat_AsDouble(xOrSequence);
2723-
/* scalar constructor. */
2724-
if (y == NULL && z == NULL) {
2725-
self->coords[1] = self->coords[0];
2726-
self->coords[2] = self->coords[0];
2728+
if (pgVectorCompatible_Check(xOrSequence, self->dim)) {
2729+
if (!PySequence_AsVectorCoords(xOrSequence, self->coords, 3)) {
2730+
return -1;
2731+
}
2732+
else {
27272733
return 0;
27282734
}
27292735
}
2730-
else if (pgVectorCompatible_Check(xOrSequence, self->dim)) {
2731-
if (!PySequence_AsVectorCoords(xOrSequence, self->coords, 3)) {
2736+
else if (RealNumber_Check(xOrSequence)) {
2737+
self->coords[0] = PyFloat_AsDouble(xOrSequence);
2738+
if (self->coords[0] == -1.0 && PyErr_Occurred()) {
27322739
return -1;
27332740
}
2734-
else {
2741+
2742+
/* scalar constructor. */
2743+
if (y == NULL && z == NULL) {
2744+
self->coords[1] = self->coords[0];
2745+
self->coords[2] = self->coords[0];
27352746
return 0;
27362747
}
27372748
}
@@ -2764,7 +2775,14 @@ _vector3_set(pgVector *self, PyObject *xOrSequence, PyObject *y, PyObject *z)
27642775
else if (y && z) {
27652776
if (RealNumber_Check(y) && RealNumber_Check(z)) {
27662777
self->coords[1] = PyFloat_AsDouble(y);
2778+
if (self->coords[1] == -1.0 && PyErr_Occurred()) {
2779+
return -1;
2780+
}
2781+
27672782
self->coords[2] = PyFloat_AsDouble(z);
2783+
if (self->coords[2] == -1.0 && PyErr_Occurred()) {
2784+
return -1;
2785+
}
27682786
}
27692787
else {
27702788
goto error;

test/math_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import pygame.math
77
from pygame.math import Vector2, Vector3
88

9+
try:
10+
import numpy
11+
except ModuleNotFoundError:
12+
numpy = None
13+
914
IS_PYPY = "PyPy" == platform.python_implementation()
1015

1116

@@ -255,6 +260,48 @@ def testConstructionVector2(self):
255260
self.assertEqual(v.x, 1.2)
256261
self.assertEqual(v.y, 3.4)
257262

263+
def testConstructionNumericSequence(self):
264+
class NumericSequence:
265+
# PyFloat_AsDouble will use this to convert to a float
266+
# so this is testing the implementation a bit
267+
def __float__(self):
268+
raise TypeError("Cannot convert to float")
269+
270+
def __getitem__(self, index):
271+
return [1, 0][index]
272+
273+
def __len__(self):
274+
return 2
275+
276+
v = Vector2(NumericSequence())
277+
self.assertEqual(v.x, 1.0)
278+
self.assertEqual(v.y, 0.0)
279+
280+
def testConstructionNumericNonFloat(self):
281+
class NumericNonFloat:
282+
# PyFloat_AsDouble will use this to convert to a float
283+
# so this is testing the implementation a bit
284+
def __float__(self):
285+
raise TypeError("Cannot convert to float")
286+
287+
with self.assertRaises(TypeError):
288+
Vector2(NumericNonFloat())
289+
290+
with self.assertRaises(TypeError):
291+
Vector2(NumericNonFloat(), NumericNonFloat())
292+
293+
with self.assertRaises(TypeError):
294+
Vector2(1.0, NumericNonFloat())
295+
296+
@unittest.skipIf(numpy is None, "numpy not available")
297+
def testConstructionNumpyArray(self):
298+
assert numpy is not None
299+
300+
arr = numpy.array([1.2, 3.4])
301+
v = Vector2(arr)
302+
self.assertEqual(v.x, 1.2)
303+
self.assertEqual(v.y, 3.4)
304+
258305
def testAttributeAccess(self):
259306
tmp = self.v1.x
260307
self.assertEqual(tmp, self.v1.x)
@@ -1431,6 +1478,53 @@ def testConstructionMissing(self):
14311478
self.assertRaises(ValueError, Vector3, 1, 2)
14321479
self.assertRaises(ValueError, Vector3, x=1, y=2)
14331480

1481+
def testConstructionNumericSequence(self):
1482+
class NumericSequence:
1483+
# PyFloat_AsDouble will use this to convert to a float
1484+
# so this is testing the implementation a bit
1485+
def __float__(self):
1486+
raise TypeError("Cannot convert to float")
1487+
1488+
def __getitem__(self, index):
1489+
return [1, 0, 5][index]
1490+
1491+
def __len__(self):
1492+
return 3
1493+
1494+
v = Vector3(NumericSequence())
1495+
self.assertEqual(v.x, 1.0)
1496+
self.assertEqual(v.y, 0.0)
1497+
self.assertEqual(v.z, 5.0)
1498+
1499+
def testConstructionNumericNonFloat(self):
1500+
class NumericNonFloat:
1501+
# PyFloat_AsDouble will use this to convert to a float
1502+
# so this is testing the implementation a bit
1503+
def __float__(self):
1504+
raise TypeError("Cannot convert to float")
1505+
1506+
with self.assertRaises(TypeError):
1507+
Vector3(NumericNonFloat())
1508+
1509+
with self.assertRaises(TypeError):
1510+
Vector3(NumericNonFloat(), NumericNonFloat(), NumericNonFloat())
1511+
1512+
with self.assertRaises(TypeError):
1513+
Vector3(1.0, NumericNonFloat(), 5.0)
1514+
1515+
with self.assertRaises(TypeError):
1516+
Vector3(1.0, 0.0, NumericNonFloat())
1517+
1518+
@unittest.skipIf(numpy is None, "numpy not available")
1519+
def testConstructionNumpyArray(self):
1520+
assert numpy is not None
1521+
1522+
arr = numpy.array([1.2, 3.4, 5.6], dtype=float)
1523+
v = Vector3(arr)
1524+
self.assertEqual(v.x, 1.2)
1525+
self.assertEqual(v.y, 3.4)
1526+
self.assertEqual(v.z, 5.6)
1527+
14341528
def testAttributeAccess(self):
14351529
tmp = self.v1.x
14361530
self.assertEqual(tmp, self.v1.x)

0 commit comments

Comments
 (0)