Skip to content

Commit 152df44

Browse files
authored
Merge pull request #11 from esa/batching_improvement
batching
2 parents 69ba9e3 + d1bcea1 commit 152df44

File tree

9 files changed

+613
-409
lines changed

9 files changed

+613
-409
lines changed

doc/notebooks/covariance_propagation.ipynb

Lines changed: 12 additions & 11 deletions
Large diffs are not rendered by default.

doc/notebooks/sgp4_partial_derivatives.ipynb

Lines changed: 273 additions & 272 deletions
Large diffs are not rendered by default.

doc/notebooks/tle_propagation.ipynb

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
},
4444
{
4545
"cell_type": "code",
46-
"execution_count": 10,
46+
"execution_count": 3,
4747
"metadata": {},
4848
"outputs": [
4949
{
@@ -53,7 +53,7 @@
5353
" 5.5809e-01, 6.2651e-02, 4.8993e+00])"
5454
]
5555
},
56-
"execution_count": 10,
56+
"execution_count": 3,
5757
"metadata": {},
5858
"output_type": "execute_result"
5959
}
@@ -126,34 +126,33 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 6,
129+
"execution_count": 14,
130130
"metadata": {},
131131
"outputs": [],
132132
"source": [
133133
"#we first need to prepare the data, the API requires that there are as many TLEs as times. Let us assume we want to\n",
134134
"#propagate each of the \n",
135+
"tles_=[]\n",
136+
"for tle in tles:\n",
137+
" tles_+=[tle]*10000\n",
135138
"tsinces = torch.cat([torch.linspace(0,24*60,10000)]*len(tles))\n",
136139
"#first let's initialize them:\n",
137-
"dsgp4.initialize_tle(tles)\n",
138-
"#then let's construct the TLEs batch by making sure there are as many TLEs as times:\n",
139-
"tles_batch=[]\n",
140-
"for tle in tles:\n",
141-
" tles_batch+=[tle]*10000"
140+
"_,tle_batch=dsgp4.initialize_tle(tles_)"
142141
]
143142
},
144143
{
145144
"cell_type": "code",
146-
"execution_count": 7,
145+
"execution_count": 15,
147146
"metadata": {},
148147
"outputs": [],
149148
"source": [
150149
"#we propagate the batch of 3,000 TLEs for 1 day:\n",
151-
"states_teme=dsgp4.propagate_batch(tles_batch,tsinces)"
150+
"states_teme=dsgp4.propagate_batch(tle_batch,tsinces)"
152151
]
153152
},
154153
{
155154
"cell_type": "code",
156-
"execution_count": 8,
155+
"execution_count": 16,
157156
"metadata": {},
158157
"outputs": [
159158
{

dsgp4/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
__version__ = '0.1.2'
1+
__version__ = '1.0.0'
22

33
import torch
44
torch.set_default_dtype(torch.float64)
55
from .sgp4 import sgp4
66
from .initl import initl
77
from .sgp4init import sgp4init
8+
from .sgp4init_batch import sgp4init_batch
89
from .newton_method import newton_method, update_TLE
910
from .sgp4_batched import sgp4_batched
1011
from .util import propagate, initialize_tle, propagate_batch

dsgp4/sgp4_batched.py

Lines changed: 10 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy
33
from .tle import TLE
44

5-
def sgp4_batched(satellite, tsince):
5+
def sgp4_batched(satellite_batch, tsince):
66
"""
77
This function represents the batch SGP4 propagator.
88
It resembles `sgp4`, but accepts batches of TLEs.
@@ -12,99 +12,25 @@ def sgp4_batched(satellite, tsince):
1212
in km and km/s, respectively, after `tsince` minutes.
1313
1414
Args:
15-
- satellite (``dsgp4.tle.TLE``): TLE object
16-
- tsince (``torch.tensor``): time to propagate, since the TLE epoch, in minutes
15+
- satellite (``dsgp4.tle.TLE``): TLE batch object (with attributes that are N-dimensional tensors)
16+
- tsince (``torch.tensor``): time to propagate, since the TLE epoch, in minutes (also an N-dimensional tensor)
1717
1818
Returns:
1919
- batch_state (``torch.tensor``): a batch of 2x3 tensors, where the first row represents the spacecraft
2020
position (in km) and the second the spacecraft velocity (in km/s)
2121
"""
22-
if not isinstance(satellite, list):
23-
raise ValueError("satellite should be a list of TLE objects.")
24-
if not isinstance(satellite[0],TLE):
25-
raise ValueError("satellite should be a list of TLE objects.")
22+
if not isinstance(satellite_batch, TLE):
23+
raise ValueError("satellite_batch should be a TLE object.")
2624
if not torch.is_tensor(tsince):
2725
raise ValueError("tsince must be a tensor.")
2826
if tsince.ndim!=1:
2927
raise ValueError("tsince should be a one dimensional tensor.")
30-
if len(tsince)!=len(satellite):
31-
raise ValueError("in batch mode, tsince and satellite shall be of same length.")
32-
if not hasattr(satellite[0], '_radiusearthkm'):
33-
raise AttributeError('It looks like the satellite has not been initialized. Please use the `initialize_tle` method or directly `sgp4init` to initialize the satellite. Otherwise, if you are propagating, another option is to use `dsgp4.propagate` and pass `initialized=True` in the arguments.')
28+
if len(tsince)!=len(satellite_batch._argpo):
29+
raise ValueError(f"in batch mode, tsince and satellite_batch shall have attributes of same length. Instead {len(tsince)} for time, and {len(satellite_batch._argpo)} for satellites' attributes found")
30+
if not hasattr(satellite_batch, '_radiusearthkm'):
31+
raise AttributeError('It looks like the satellite_batch has not been initialized. Please use the `initialize_tle` method or directly `sgp4init` to initialize the satellite_batch. Otherwise, if you are propagating, another option is to use `dsgp4.propagate` and pass `initialized=True` in the arguments.')
3432

35-
batch_size = len(satellite)
36-
37-
satellite_batch=satellite[0].copy()
38-
satellite_batch._bstar=torch.stack([s._bstar for s in satellite])
39-
satellite_batch._ndot=torch.stack([s._ndot for s in satellite])
40-
satellite_batch._nddot=torch.stack([s._nddot for s in satellite])
41-
satellite_batch._ecco=torch.stack([s._ecco for s in satellite])
42-
satellite_batch._argpo=torch.stack([s._argpo for s in satellite])
43-
satellite_batch._inclo=torch.stack([s._inclo for s in satellite])
44-
satellite_batch._mo=torch.stack([s._mo for s in satellite])
45-
46-
satellite_batch._no_kozai=torch.stack([s._no_kozai for s in satellite])
47-
satellite_batch._nodeo=torch.stack([s._nodeo for s in satellite])
48-
satellite_batch.satellite_catalog_number=torch.tensor([s.satellite_catalog_number for s in satellite])
49-
satellite_batch._jdsatepoch=torch.stack([s._jdsatepoch for s in satellite])
50-
satellite_batch._jdsatepochF=torch.stack([s._jdsatepochF for s in satellite])
51-
satellite_batch._isimp=torch.tensor([s._isimp for s in satellite])
52-
satellite_batch._method=[s._method for s in satellite]
53-
54-
satellite_batch._mdot=torch.stack([s._mdot for s in satellite])
55-
satellite_batch._argpdot=torch.stack([s._argpdot for s in satellite])
56-
satellite_batch._nodedot=torch.stack([s._nodedot for s in satellite])
57-
satellite_batch._nodecf=torch.stack([s._nodecf for s in satellite])
58-
satellite_batch._cc1=torch.stack([s._cc1 for s in satellite])
59-
satellite_batch._cc4=torch.stack([s._cc4 for s in satellite])
60-
satellite_batch._cc5=torch.stack([s._cc5 for s in satellite])
61-
satellite_batch._t2cof=torch.stack([s._t2cof for s in satellite])
62-
63-
satellite_batch._omgcof=torch.stack([s._omgcof for s in satellite])
64-
satellite_batch._eta=torch.stack([s._eta for s in satellite])
65-
satellite_batch._xmcof=torch.stack([s._xmcof for s in satellite])
66-
satellite_batch._delmo=torch.stack([s._delmo for s in satellite])
67-
satellite_batch._d2=torch.stack([s._d2 for s in satellite])
68-
satellite_batch._d3=torch.stack([s._d3 for s in satellite])
69-
satellite_batch._d4=torch.stack([s._d4 for s in satellite])
70-
satellite_batch._cc5=torch.stack([s._cc5 for s in satellite])
71-
satellite_batch._sinmao=torch.stack([s._sinmao for s in satellite])
72-
satellite_batch._t3cof=torch.stack([s._t3cof for s in satellite])
73-
satellite_batch._t4cof=torch.stack([s._t4cof for s in satellite])
74-
satellite_batch._t5cof=torch.stack([s._t5cof for s in satellite])
75-
76-
satellite_batch._xke=torch.stack([s._xke for s in satellite])
77-
satellite_batch._radiusearthkm=torch.stack([s._radiusearthkm for s in satellite])
78-
satellite_batch._t=torch.stack([s._t for s in satellite])
79-
satellite_batch._aycof=torch.stack([s._aycof for s in satellite])
80-
satellite_batch._x1mth2=torch.stack([s._x1mth2 for s in satellite])
81-
satellite_batch._con41=torch.stack([s._con41 for s in satellite])
82-
satellite_batch._x7thm1=torch.stack([s._x7thm1 for s in satellite])
83-
satellite_batch._xlcof=torch.stack([s._xlcof for s in satellite])
84-
satellite_batch._tumin=torch.stack([s._tumin for s in satellite])
85-
satellite_batch._mu=torch.stack([s._mu for s in satellite])
86-
satellite_batch._j2=torch.stack([s._j2 for s in satellite])
87-
satellite_batch._j3=torch.stack([s._j3 for s in satellite])
88-
satellite_batch._j4=torch.stack([s._j4 for s in satellite])
89-
satellite_batch._j3oj2=torch.stack([s._j3oj2 for s in satellite])
90-
satellite_batch._error=torch.stack([s._error for s in satellite])
91-
satellite_batch._operationmode=[s._operationmode for s in satellite]
92-
satellite_batch._satnum=torch.tensor([s._satnum for s in satellite])
93-
satellite_batch._am=torch.stack([s._am for s in satellite])
94-
satellite_batch._em=torch.stack([s._em for s in satellite])
95-
satellite_batch._im=torch.stack([s._im for s in satellite])
96-
satellite_batch._Om=torch.stack([s._Om for s in satellite])
97-
satellite_batch._mm=torch.stack([s._mm for s in satellite])
98-
satellite_batch._nm=torch.stack([s._nm for s in satellite])
99-
satellite_batch._init=[s._init for s in satellite]
100-
101-
satellite_batch._no_unkozai=torch.stack([s._no_unkozai for s in satellite])
102-
satellite_batch._a=torch.stack([s._a for s in satellite])
103-
satellite_batch._alta=torch.stack([s._altp for s in satellite])
104-
105-
106-
107-
33+
batch_size = len(tsince)
10834
mrt = torch.zeros(batch_size)
10935
x2o3 = torch.tensor(2.0 / 3.0)
11036

@@ -125,7 +51,6 @@ def sgp4_batched(satellite, tsince):
12551
tempe1 = satellite_batch._bstar * satellite_batch._cc4 * satellite_batch._t
12652
templ1 = satellite_batch._t2cof * t2
12753

128-
12954

13055
delomg = satellite_batch._omgcof * satellite_batch._t
13156

dsgp4/sgp4init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def sgp4init(
4242
temp4 = torch.tensor(1.5e-12)
4343

4444
# ----------- set all near earth variables to zero ------------
45-
satellite._isimp = torch.tensor(0); satellite._method = 'n'; satellite._aycof = torch.tensor(0.0);
45+
satellite._isimp = torch.tensor(0); satellite._method = 'n'; satellite._aycof = torch.tensor(0.0);
4646
satellite._con41 = torch.tensor(0.0); satellite._cc1 = torch.tensor(0.0); satellite._cc4 = torch.tensor(0.0);
4747
satellite._cc5 = torch.tensor(0.0); satellite._d2 = torch.tensor(0.0); satellite._d3 = torch.tensor(0.0);
4848
satellite._d4 = torch.tensor(0.0); satellite._delmo = torch.tensor(0.0); satellite._eta = torch.tensor(0.0);
@@ -198,6 +198,6 @@ def sgp4init(
198198
12.0 * satellite._cc1 * satellite._d3 +
199199
6.0 * satellite._d2 * satellite._d2 +
200200
15.0 * cc1sq * (2.0 * satellite._d2 + cc1sq))
201-
sgp4(satellite, torch.zeros(1,1))
201+
sgp4(satellite, torch.zeros(1,1));
202202

203-
satellite._init = 'n'
203+
satellite._init = 'n'

0 commit comments

Comments
 (0)